深度学习 对抗生成网络 使用生成对抗网络生成图片
2023-09-11 14:19:20 时间
这是最新找到的 对抗生成网络的代码,亲测可以跑通。前几天也上传了一个网上找到的代码,但是这回这个代码中判别网络的假数据中加入了 detach() 函数, 网上查找说这个函数可以切断神经网络的反向传导,虽然不是很理解,但总是感觉这个更对一些。对于 detach 这个函数在这里面的作用网上怎么说的都有,不过个人感觉最有说服力的说法是 减少没有必要的运算,毕竟在判别网络中我们是不需要修改生成网络的参数的,也就是说这个时候求解生成网络的梯度,对其进行反向求导是没有必要的,而这个说法和代码中的注释部分相合。
#encoding:UTF-8 #读入CIFAR-10数据 from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 import torchvision.transforms as transforms from torchvision.utils import save_image dataset = CIFAR10(root='./data', download=True, transform=transforms.ToTensor()) dataloader = DataLoader(dataset, batch_size=64, shuffle=True) for batch_idx, data in enumerate(dataloader): real_images, _ = data batch_size = real_images.size(0) print ('#{} has {} images.'.format(batch_idx, batch_size)) if batch_idx % 100 == 0: path = './data/CIFAR10_shuffled_batch{:03d}.png'.format(batch_idx) save_image(real_images, path, normalize=True) #搭建生成网络和鉴别网络 import torch.nn as nn # 搭建生成网络 latent_size = 64 # 潜在大小 n_channel = 3 # 输出通道数 n_g_feature = 64 # 生成网络隐藏层大小 gnet = nn.Sequential( # 输入大小 = (64, 1, 1) nn.ConvTranspose2d(latent_size, 4 * n_g_feature, kernel_size=4, bias=False), nn.BatchNorm2d(4 * n_g_feature), nn.ReLU(), # 大小 = (256, 4, 4) nn.ConvTranspose2d(4 * n_g_feature, 2 * n_g_feature, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(2 * n_g_feature), nn.ReLU(), # 大小 = (128, 8, 8) nn.ConvTranspose2d(2 * n_g_feature, n_g_feature, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(n_g_feature), nn.ReLU(), # 大小 = (64, 16, 16) nn.ConvTranspose2d(n_g_feature, n_channel, kernel_size=4, stride=2, padding=1), nn.Sigmoid(), # 图片大小 = (3, 32, 32) ) print (gnet) # 搭建鉴别网络 n_d_feature = 64 # 鉴别网络隐藏层大小 dnet = nn.Sequential( # 图片大小 = (3, 32, 32) nn.Conv2d(n_channel, n_d_feature, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2), # 大小 = (64, 16, 16) nn.Conv2d(n_d_feature, 2 * n_d_feature, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(2 * n_d_feature), nn.LeakyReLU(0.2), # 大小 = (128, 8, 8) nn.Conv2d(2 * n_d_feature, 4 * n_d_feature, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(4 * n_d_feature), nn.LeakyReLU(0.2), # 大小 = (256, 4, 4) nn.Conv2d(4 * n_d_feature, 1, kernel_size=4), # 对数赔率张量大小 = (1, 1, 1) ) print(dnet) import torch.nn.init as init #初始化权重值 def weights_init(m): # 用于初始化权重值的函数 if type(m) in [nn.ConvTranspose2d, nn.Conv2d]: init.xavier_normal_(m.weight) elif type(m) == nn.BatchNorm2d: init.normal_(m.weight, 1.0, 0.02) init.constant_(m.bias, 0) gnet.apply(weights_init) dnet.apply(weights_init) #主程序 import torch import torch.optim # 损失 criterion = nn.BCEWithLogitsLoss() # 优化器 goptimizer = torch.optim.Adam(gnet.parameters(), lr=0.0002, betas=(0.5, 0.999)) doptimizer = torch.optim.Adam(dnet.parameters(), lr=0.0002, betas=(0.5, 0.999)) # 用于测试的固定噪声,用来查看相同的潜在张量在训练过程中生成图片的变换 batch_size = 64 fixed_noises = torch.randn(batch_size, latent_size, 1, 1) # 训练过程 epoch_num = 10 for epoch in range(epoch_num): for batch_idx, data in enumerate(dataloader): # 载入本批次数据 real_images, _ = data batch_size = real_images.size(0) # 训练鉴别网络 labels = torch.ones(batch_size) # 真实数据对应标签为1 preds = dnet(real_images) # 对真实数据进行判别 outputs = preds.reshape(-1) dloss_real = criterion(outputs, labels) # 真实数据的鉴别器损失 dmean_real = outputs.sigmoid().mean() # 计算鉴别器将多少比例的真数据判定为真,仅用于输出显示 noises = torch.randn(batch_size, latent_size, 1, 1) # 潜在噪声 fake_images = gnet(noises) # 生成假数据 labels = torch.zeros(batch_size) # 假数据对应标签为0 fake = fake_images.detach() # 使得梯度的计算不回溯到生成网络,可用于加快训练速度.删去此步结果不变 preds = dnet(fake) # 对假数据进行鉴别 outputs = preds.view(-1) dloss_fake = criterion(outputs, labels) # 假数据的鉴别器损失 dmean_fake = outputs.sigmoid().mean() # 计算鉴别器将多少比例的假数据判定为真,仅用于输出显示 dloss = dloss_real + dloss_fake # 总的鉴别器损失 dnet.zero_grad() dloss.backward() doptimizer.step() # 训练生成网络 labels = torch.ones(batch_size) # 生成网络希望所有生成的数据都被认为是真数据 preds = dnet(fake_images) # 把假数据通过鉴别网络 outputs = preds.view(-1) gloss = criterion(outputs, labels) # 真数据看到的损失 gmean_fake = outputs.sigmoid().mean() # 计算鉴别器将多少比例的假数据判定为真,仅用于输出显示 gnet.zero_grad() gloss.backward() goptimizer.step() # 输出本步训练结果 print('[{}/{}]'.format(epoch, epoch_num) + '[{}/{}]'.format(batch_idx, len(dataloader)) + '鉴别网络损失:{:g} 生成网络损失:{:g}'.format(dloss, gloss) + '真数据判真比例:{:g} 假数据判真比例:{:g}/{:g}'.format( dmean_real, dmean_fake, gmean_fake)) if batch_idx % 100 == 0: fake = gnet(fixed_noises) # 由固定潜在张量生成假数据 save_image(fake, # 保存假数据 './data/images_epoch{:02d}_batch{:03d}.png'.format( epoch, batch_idx))
相关文章
- 常用网络名词mark & 网络学习笔记
- 学习Docker容器网络模型 - 搭建分布式Zookeeper集群
- 【跟唐老师学习云网络】 - 第7篇 Tcpdump大杀器抓包
- 《跟唐老师学习云网络》 - ip命令
- 《跟唐老师学习云网络》 - 什么是VLAN和VXLAN
- IOS开发之网络编程--文件压缩和解压缩
- 网络推送通知:及时,相关和准确 (navigator.serviceWorker.register(), window.PushManager, new Notification)
- 深度学习基础(常见的网络2)
- 【学习总结】网络-应用层-DNS解析概述
- 南京邮电大学网络攻防平台WEB题
- 深度学习Bible学习笔记:第六章 深度前馈网络
- 网络01:双无线路由器无缝对接设置
- 高性能网络服务器编程:为什么linux下epoll是最好,Netty要比NIO.2好?
- socket网络编程【python】
- kubernetes-v1.20.4 二进制部署-Calico网络组件、Dashboard和CoreDNS
- Centos7 k8s v1.5.2二进制部署安装-网络插件Flannel的安装
- Keras之TCN:基于keras框架利用时间卷积网络TCN算法对上海最高气温实现回归预测(把时间序列数据集转化为有监督学习数据集)案例
- DL之VGG16:基于VGG16(Keras)利用Knifey-Spoony数据集对网络架构进行迁移学习
- 鲲鹏性能优化十板斧(三)——网络子系统性能调优
- 网络基础——知识生活化会变得如此简单
- html css 网络 页面布局 颜色 参考 拾取器网址
- NASNet学习笔记—— 核心一:延续NAS论文的核心机制使得能够自动产生网络结构; 核心二:采用resnet和Inception重复使用block结构思想; 核心三:利用迁移学习将生成的网络迁移到大数据集上提出一个new search space。
- 【算力网络】算力网络的技术创新——算网一体关键技术
- Android的即时通讯(db文件无网络),建议肯定要学习
- WSL无法运行大型的深度学习?运行大型网络出错
- 【深度学习】语义分割实验:Unet网络/MSRC2数据集
- 物理层 一看就懂的网络传输介质介绍