zl程序教程

您现在的位置是:首页 >  数据库

当前栏目

torch 的 dim 和 numpy 的axis 表示方向不同

2023-04-18 12:33:03 时间

1. torch中以index_select为例子

torch.index_select(input, dim, index, out=None) - 功能:在维度dim上,按index索引数据 - 返回值:依index索引数据拼接的张量 - index:要索引的张量 - dim:要索引的维度 - index:要索引数据的序号

x = torch.randn(3, 4)
print(x)
indices = torch.tensor([0, 2])
torch.index_select(x, 1, indices)



#把1改为0
y = torch.randn(3, 4)
print(y)
indices = torch.tensor([0, 2])
torch.index_select(y, 0, indices)

输出如下,可以看出,dim=1时按照列索引;dim=0时,按照行索引

tensor([[ 1.9626,  0.1007, -1.2005,  1.2650],
        [ 0.3603,  0.6343, -0.6197,  0.5740],
        [-0.0798,  0.9674, -0.7761,  0.5552]])
tensor([[ 1.9626, -1.2005],
        [ 0.3603, -0.6197],
        [-0.0798, -0.7761]])


tensor([[ 0.2274, -2.1934, -0.3129,  0.3869],
        [ 0.3831, -0.7156, -1.0765, -2.1098],
        [-0.8007, -0.0095,  0.8703, -0.8797]])
tensor([[ 0.2274, -2.1934, -0.3129,  0.3869],
        [-0.8007, -0.0095,  0.8703, -0.8797]])

2.numpy 中 以mean为例

x = numpy.random.randint(1,10,(3,4))
print(x)
print(x.mean(0))


y = numpy.random.randint(1,10,(3,4))
print(y)
print(y.mean(1))

输出如下,axis = 0时,按照竖直方向从上往下计算均值,输出4个数;axis=1时,按照水平方向从左往右计算均值,输出三个数。

[[6 8 4 9]
 [7 5 9 3]
 [1 7 6 1]]
[4.66666667 6.66666667 6.33333333 4.33333333]


[[3 3 6 5]
 [4 3 1 5]
 [7 2 2 5]]
[4.25 3.25 4.  ]