基于CelebA数据集的GAN模型-2
数据 基于 模型 GAN
2023-06-13 09:16:59 时间
前两篇我们介绍了celeB数据集
直接上代码咯
导入依赖:
# example of a gan for generating faces
from numpy import load
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Dropout
from matplotlib import pyplot
然后是定义判别器:
# define the standalone discriminator model
def define_discriminator(in_shape=(80,80,3)):
model = Sequential()
# normal
model.add(Conv2D(128, (5,5), padding='same', input_shape=in_shape))
model.add(LeakyReLU(alpha=0.2))
# downsample to 40x40
model.add(Conv2D(128, (5,5), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# downsample to 20x30
model.add(Conv2D(128, (5,5), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# downsample to 10x10
model.add(Conv2D(128, (5,5), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# downsample to 5x5
model.add(Conv2D(128, (5,5), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# classifier
model.add(Flatten())
model.add(Dropout(0.4))
model.add(Dense(1, activation='sigmoid'))
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
接下来定义生成器:
# define the standalone generator model
def define_generator(latent_dim):
model = Sequential()
# foundation for 5x5 feature maps
n_nodes = 128 * 5 * 5
model.add(Dense(n_nodes, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Reshape((5, 5, 128)))
# upsample to 10x10
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# upsample to 20x20
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# upsample to 40x40
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# upsample to 80x80
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# output layer 80x80x3
model.add(Conv2D(3, (5,5), activation='tanh', padding='same'))
return model
定义GAN:
# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model):
# make weights in the discriminator not trainable
d_model.trainable = False
# connect them
model = Sequential()
# add generator
model.add(g_model)
# add the discriminator
model.add(d_model)
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt)
return model
加载我们上一篇文章写到的numpy压缩数据包点npz结尾的文件
# load and prepare training images
def load_real_samples():
# load the face dataset
data = load('img_align_celeba.npz')
X = data['arr_0']
# convert from unsigned ints to floats
X = X.astype('float32')
# scale from [0,255] to [-1,1]
X = (X - 127.5) / 127.5
return X
从latentSpace中生成样本点:
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
# generate points in the latent space
x_input = randn(latent_dim * n_samples)
# reshape into a batch of inputs for the network
x_input = x_input.reshape(n_samples, latent_dim)
return x_input
生成假的样本:
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(g_model, latent_dim, n_samples):
# generate points in latent space
x_input = generate_latent_points(latent_dim, n_samples)
# predict outputs
X = g_model.predict(x_input)
# create 'fake' class labels (0)
y = zeros((n_samples, 1))
return X, y
定义一个画图的函数
# create and save a plot of generated images
def save_plot(examples, epoch, n=10):
# scale from [-1,1] to [0,1]
examples = (examples + 1) / 2.0
# plot images
for i in range(n * n):
# define subplot
pyplot.subplot(n, n, 1 + i)
# turn off axis
pyplot.axis('off')
# plot raw pixel data
pyplot.imshow(examples[i])
# save plot to file
filename = 'generated_plot_e%03d.png' % (epoch+1)
pyplot.savefig(filename)
pyplot.close()
定义评估函数:
# evaluate the discriminator, plot generated images, save generator model
def summarize_performance(epoch, g_model, d_model, dataset, latent_dim, n_samples=100):
# prepare real samples
X_real, y_real = generate_real_samples(dataset, n_samples)
# evaluate discriminator on real examples
_, acc_real = d_model.evaluate(X_real, y_real, verbose=0)
# prepare fake examples
x_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_samples)
# evaluate discriminator on fake examples
_, acc_fake = d_model.evaluate(x_fake, y_fake, verbose=0)
# summarize discriminator performance
print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
# save plot
save_plot(x_fake, epoch)
# save the generator model tile file
filename = 'generator_model_%03d.h5' % (epoch+1)
g_model.save(filename)
最后是我们的训练函数:
# train the generator and discriminator
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=100, n_batch=128):
bat_per_epo = int(dataset.shape[0] / n_batch)
half_batch = int(n_batch / 2)
# manually enumerate epochs
for i in range(n_epochs):
# enumerate batches over the training set
for j in range(bat_per_epo):
# get randomly selected 'real' samples
X_real, y_real = generate_real_samples(dataset, half_batch)
# update discriminator model weights
d_loss1, _ = d_model.train_on_batch(X_real, y_real)
# generate 'fake' examples
X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
# update discriminator model weights
d_loss2, _ = d_model.train_on_batch(X_fake, y_fake)
# prepare points in latent space as input for the generator
X_gan = generate_latent_points(latent_dim, n_batch)
# create inverted labels for the fake samples
y_gan = ones((n_batch, 1))
# update the generator via the discriminator's error
g_loss = gan_model.train_on_batch(X_gan, y_gan)
# summarize loss on this batch
print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' %
(i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))
# evaluate the model performance, sometimes
if (i+1) % 10 == 0:
summarize_performance(i, g_model, d_model, dataset, latent_dim)
# size of the latent space
latent_dim = 100
# create the discriminator
d_model = define_discriminator()
# create the generator
g_model = define_generator(latent_dim)
# create the gan
gan_model = define_gan(g_model, d_model)
# load image data
dataset = load_real_samples()
# train model
train(g_model, d_model, gan_model, dataset, latent_dim)
最后生成的图像:
相关文章
- Java项目毕业设计:基于springboot+vue的电影视频网站系统「建议收藏」
- 工具推荐|基于天气雷达数据的冰雹反演工具
- 基于Proxy从0到1实现响应式数据
- 基于多源观测数据的降水估计新方法
- 车道线识别之 tusimple 数据集介绍
- CentOS搭建基于ZIPKIN的数据追踪系统
- kafka批量删除topic_kafka清空数据
- R语言中基于混合数据抽样(MIDAS)回归的HAR-RV模型预测GDP增长|附代码数据
- AACL2022 | “讲好中国故事” ! 一种基于数据增强的中文故事生成框架(竟然有源码)
- 探索Wiredtiger引擎基于B-Tree数据写入分析
- 基于『成交数据』的股票联动研究
- 基于用户的协同过滤来构建推荐系统详解大数据
- django做form表单的数据验证详解编程语言
- MongoDB如何实现随机选取数据(mongodb随机)
- Redis:内存中的数据刷新不可或缺(redis 刷盘)
- 优化应用性能:使用Redis实现高效数据存储和缓存(appredis)
- Oracle共享集锁解决数据安全性的重要一步(oracle 共享集锁)
- 大容量数据存储基于Redis类技术(类redis大容量存储)
- 简单快速实现基于Redis的数据存储(简易redis实现)
- 基于App与MySQL交互实现数据读写示例(app读写mysql例子)
- 基于VC与Redis的高效数据序列化方案(vc redis序列化)
- 事件利用Redis集群掌控数据过期从监听到深度交互(redis集群监听过期)
- 使用redis队列lpush管理任务数据(redis队列lpush)