zl程序教程

您现在的位置是:首页 >  .Net

当前栏目

torch.spmm矩阵乘法

2023-02-18 16:32:57 时间

Example:

import torch
indices = torch.tensor([[0,1],
                        [0,1]])
values = torch.tensor([2,3])
shape = torch.Size((2,2))
s = torch.sparse.FloatTensor(indices,values,shape)
print(s)

d = torch.tensor([[1,2],
                  [3,4]])

print(d)
print(torch.spmm(s,d))
"""
tensor(indices=tensor([[0, 1],
                       [0, 1]]),
       values=tensor([2, 3]),
       size=(2, 2), nnz=2, layout=torch.sparse_coo)
tensor([[1, 2],
        [3, 4]])
tensor([[ 2,  4],
        [ 9, 12]])
"""