zl程序教程

您现在的位置是:首页 >  后端

当前栏目

tensorflow 动态数组 TensorArray

数组 动态 Tensorflow
2023-09-14 09:09:28 时间
tensorflow 动态数组随时可以读取

import tensorflow as tf
ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False)
ta = ta.write(0, 10)
ta = ta.write(1, 20)
ta = ta.write(2, 30)


print(ta.read(0))

print(ta.read(1))

print(ta.read(2))

print(ta.stack())
tf.Tensor(10.0, shape=(), dtype=float32)
tf.Tensor(20.0, shape=(), dtype=float32)
tf.Tensor(30.0, shape=(), dtype=float32)
tf.Tensor([10. 20. 30.], shape=(3,), dtype=float32)
@tf.function
def fibonacci(n):
  n=5
  ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
  ta = ta.unstack([0., 1.])

  for i in range(2, n):
    ta = ta.write(i, ta.read(i - 1) + ta.read(i - 2))

  return ta.stack()

fibonacci(7)
<tf.Tensor: shape=(5,), dtype=float32, numpy=array([0., 1., 1., 2., 3.], dtype=float32)>
v = tf.Variable(1)
@tf.function
def f(x):
  ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
  for i in tf.range(x):
    v.assign_add(i)
    ta = ta.write(i, v)
  return ta.stack()
f(5)

<tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1,  2,  4,  7, 11], dtype=int32)>