zl程序教程

您现在的位置是:首页 >  其它

当前栏目

tf.tensor_scatter_nd_update

update TF Tensor scatter
2023-09-14 09:09:28 时间
对应位置的索引赋值
这里是一维坐标赋值
import tensorflow as tf

tensor = [0, 0, 0, 0, 0, 0, 0, 0]    # tf.rank(tensor) == 1
indices =[  [1],  [3],[4],    [7]]       # num_updates == 4, index_depth == 1
updates = [  9,   10, 11,     12]            # num_updates == 4
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor([ 0  9  0 10 11  0  0 12], shape=(8,), dtype=int32)
这里是二维坐标赋值
tensor = [[1, 1], [1, 1], [1, 1]]    # tf.rank(tensor) == 2
indices = [[0, 1], [2, 0]]           # num_updates == 2, index_depth == 2
updates = [5, 10]                    # num_updates == 2
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor(
[[ 1  5]
 [ 1  1]
 [10  1]], shape=(3, 2), dtype=int32)
行索引赋值整行
tensor = tf.zeros([6, 3], dtype=tf.int32)
indices = tf.constant([[2], [4]])     # num_updates == 2, index_depth == 1
# num_updates == 2, inner_shape==3
updates = tf.constant([[1, 2, 3],
                       [4, 5, 6]])
print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy())

[[0 0 0]
 [0 0 0]
 [1 2 3]
 [0 0 0]
 [4 5 6]
 [0 0 0]]