1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
| from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Dense, Reshape,Flatten, Dropout from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU, UpSampling2D, Conv2D from tensorflow.keras.models import Sequential, Model import numpy as np import matplotlib.pyplot as plt
generator = Sequential()
generator.add(Dense(128*7*7 ,input_dim = 100 , activation = LeakyReLU(0.2))) generator.add(BatchNormalization())
generator.add(Reshape((7,7,128))) generator.add(UpSampling2D())
generator.add(Conv2D(64, kernel_size =5, padding='same')) generator.add(BatchNormalization()) generator.add(Activation(LeakyReLU(0.2))) generator.add(UpSampling2D())
generator.add(Conv2D(1, kernel_size = 5, padding='same', activation='tanh'))
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, input_shape=(28,28,1), padding="same")) discriminator.add(Activation(LeakyReLU(0.2))) discriminator.add(Dropout(0.3))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same")) discriminator.add(Activation(LeakyReLU(0.2))) discriminator.add(Dropout(0.3)) discriminator.add(Flatten()) discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer='adam') discriminator.trainable = False
g_input = Input(shape=(100,)) dis_output = discriminator(generator(g_input)) gan = Model(g_input, dis_output) gan.compile(loss='binary_crossentropy', optimizer = 'adam')
def gan_train(epoch, batch_size, saving_interval):
(X_train, _), (_, _) = mnist.load_data() X_train = X_train.reshape(X_train.shape[0],28,28,1).astype('float32') X_train = (X_train - 127.5) / 127.5
real = np.ones((batch_size,1)) fake = np.zeros((batch_size,1)) for i in range(epoch): idx = np.random.randint(0,X_train.shape[0], batch_size) imgs = X_train[idx] d_loss_real = discriminator.train_on_batch(imgs, real)
noise = np.random.normal(0,1, (batch_size, 100)) gen_imgs = generator.predict(noise) d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
g_loss = gan.train_on_batch(noise, real) print(f'epoch: {i}, d_loss:{d_loss:.4f}, g_loss: {g_loss:.4f}')
if i % saving_interval == 0: noise = np.random.normal(0, 1, (25, 100)) gen_imgs = generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(5, 5) count = 0 for j in range(5): for k in range(5): axs[j, k].imshow(gen_imgs[count, :, :, 0], cmap='gray') axs[j, k].axis('off') count += 1 fig.savefig(f'./gan_mnist_{i}.png')
gan_train(2001,32,200)
|