1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
| from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Activation, Dropout, Flatten, Conv2D, MaxPooling2D from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.callbacks import EarlyStopping from tensorflow.keras import optimizers, Input, models, layers, optimizers, metrics from tensorflow.keras.applications import VGG16, EfficientNetB7 import numpy as np import matplotlib.pyplot as plt
train_datagen = ImageDataGenerator(rescale=1./255, horizontal_flip=True, width_shift_range=0.1, height_shift_range=0.1) train_generator = train_datagen.flow_from_directory('./train',target_size =(150,150),batch_size=5, class_mode='binary')
test_datagen = ImageDataGenerator(rescale=1/255.) test_generator = test_datagen.flow_from_directory('./test',target_size =(150,150),batch_size=5, class_mode='binary')
transfer_model = VGG16(include_top= False, input_shape=(150,150,3), weights='imagenet')
transfer_model.trainable=False
finetune_model = Sequential() finetune_model.add(transfer_model) finetune_model.add(Flatten()) finetune_model.add(Dense(64)) finetune_model.add(Activation('relu')) finetune_model.add(Dropout(0.5)) finetune_model.add(Dense(1)) finetune_model.add(Activation('sigmoid'))
finetune_model.compile(loss='binary_crossentropy', optimizer=optimizers.Adam(learning_rate=0.0002), metrics=['accuracy']) early_stopping_callback = EarlyStopping(monitor='val_loss', patience=5)
history = finetune_model.fit(train_generator, epochs=20, validation_steps=10, validation_data=test_generator,callbacks=[early_stopping_callback])
y_vloss = history.history['val_loss'] y_loss = history.history['loss']
x_len = np.arange(len(y_loss)) plt.plot(x_len, y_vloss,marker='.', c='red', label='testsetloss') plt.plot(x_len, y_loss,marker='.', c='blue', label='trainsetloss')
plt.legend(loc='upper right') plt.grid() plt.xlabel('epoch') plt.ylabel('loss') plt.show()
|