zl程序教程

您现在的位置是:首页 >  云平台

当前栏目

torch.nn中的网络模型介绍

网络 介绍 模型 torch NN
2023-09-11 14:21:06 时间

一、torch.nn.Linear

torch.nn.Linear(in_features,out_features,bias=True)

nn.linear()是用来设置网络中的全连接层的,也可以说是线性映射,这里面没有激活函数。而在全连接层中的输入与输出都是二维张量,输入输出的形状为[batch_size, size]

import torch
from IPython.core.interactiveshell import  InteractiveShell
InteractiveShell.ast_node_interactivity='all'

#输入=【128,20】,128个样本,维度20
x = torch.randn(128, 20) 
#通过全连接做线性变换,从20维转化为30维度
m = torch.nn.Linear(20, 30)  # 20,30是指维度
#输出的size=【128,30】,128个样本,每个样本维度30
output = m(x)


output.shape
m.weight.shape
m.bias.shape