zl程序教程

您现在的位置是:首页 >  后端

当前栏目

pytorch中各个优化函数的对比

PyTorch 函数 优化 对比 各个
2023-09-14 09:05:37 时间

import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(1) # reproducible

LR = 0.01
BATCH_SIZE = 32
EPOCH = 12

fake dataset

x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))

plot dataset

plt.scatter(x.numpy(), y.numpy())
plt.show()

put dateset into torch dataset

torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)

default network

class Net(torch.nn.Module):
def init(self):