zl程序教程

您现在的位置是:首页 >  其它

当前栏目

TensorFlow2-高阶操作(二):张量分割【split(分割后:rank不变)】【unstack(分割后:rank-1)】

操作 分割 split 高阶 不变 RANK 张量 tensorflow2
2023-09-27 14:20:41 时间

一、split

split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name='split'
)

将张量分割成子张量.

  • 如果 num_or_size_splits 是整数类型,num_split,则 value 沿维度 axis 分割成为 num_split 更小的张量.要求 num_split 均匀分配 value.shape[axis]。
  • 如果 num_or_size_splits 不是整数类型,则它被认为是一个张量 size_splits,然后将 value 分割成 len(size_splits) 块.第 i 部分的形状与 value 的大小相同,除了沿维度 axis 之外的大小 size_splits[i]。
import pandas as pd
import tensorflow as tf

x = tf.Variable(tf.random.uniform([5, 30], -1, 1))
print("x = \n", pd.DataFrame(x.numpy()))
print("-" * 200)

# Split `x` into 3 tensors along dimension 1
s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
print("s0 = \n", pd.DataFrame(s0.numpy()))
print("-" * 50)
print("s1 = \n", pd.DataFrame(s1.numpy()))
print("-" * 50)
print("s2 = \n", pd.DataFrame(s2.numpy()))
print("-" * 200)

# Split `x` into 3 tensors with sizes [4, 15, 11] along dimension 1
t0, t1, t2 = tf.split(x, num_or_size_splits=[4, 15, 11], axis=1)
print("t0 = \n", pd.DataFrame(t0.numpy()))
print("-" * 50)
print("t1 = \n", pd.DataFrame(t1.numpy()))
print("-" * 50)
print("t2 = \n", pd.DataFrame(t2.numpy()))
print("-" * 200)

打印结果:

x = 
         0         1         2     ...           27        28        29
0 -0.888679  0.882839  0.739282    ...    -0.688343 -0.930151 -0.875597
1 -0.153850 -0.319729 -0.098402    ...     0.489693 -0.170844 -0.091632
2  0.003379  0.187339  0.795501    ...     0.379071 -0.256689  0.564788
3 -0.372030  0.340384 -0.875375    ...    -0.214336  0.717279  0.092451
4 -0.495783  0.257741 -0.358638    ...    -0.921029 -0.830439  0.507138

[5 rows x 30 columns]
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
s0 = 
           0         1         2    ...            7         8         9
0 -0.888679  0.882839  0.739282    ...    -0.403924  0.196670 -0.098327
1 -0.153850 -0.319729 -0.098402    ...     0.418904  0.081062  0.173876
2  0.003379  0.187339  0.795501    ...     0.615282 -0.385442 -0.311836
3 -0.372030  0.340384 -0.875375    ...    -0.252203 -0.587342  0.321012
4 -0.495783  0.257741 -0.358638    ...     0.552696  0.620588  0.132702

[5 rows x 10 columns]
--------------------------------------------------
s1 = 
           0         1         2    ...            7         8         9
0  0.509016  0.740289 -0.964265    ...     0.459772 -0.697755 -0.540041
1  0.904286  0.986134 -0.409174    ...     0.187198 -0.445747  0.813097
2 -0.137152  0.934053 -0.751823    ...     0.309953  0.716927  0.848913
3  0.096014  0.069597  0.777320    ...    -0.907295 -0.384888  0.764411
4 -0.706331 -0.901017 -0.529774    ...    -0.301620  0.066731  0.770751

[5 rows x 10 columns]
--------------------------------------------------
s2 = 
           0         1         2    ...            7         8         9
0 -0.356173 -0.040504  0.150185    ...    -0.688343 -0.930151 -0.875597
1 -0.436071 -0.224807  0.383009    ...     0.489693 -0.170844 -0.091632
2  0.169518  0.384529 -0.600068    ...     0.379071 -0.256689  0.564788
3  0.038849  0.754196 -0.049200    ...    -0.214336  0.717279  0.092451
4  0.245371 -0.548065  0.338353    ...    -0.921029 -0.830439  0.507138

[5 rows x 10 columns]
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
t0 = 
           0         1         2         3
0 -0.888679  0.882839  0.739282  0.454827
1 -0.153850 -0.319729 -0.098402 -0.764573
2  0.003379  0.187339  0.795501 -0.467434
3 -0.372030  0.340384 -0.875375  0.350312
4 -0.495783  0.257741 -0.358638  0.301579
--------------------------------------------------
t1 = 
          0         1         2     ...           12        13        14
0 -0.484341 -0.429574  0.999090    ...     0.634394  0.459772 -0.697755
1  0.325134 -0.227807 -0.890493    ...     0.152983  0.187198 -0.445747
2 -0.074674 -0.037023  0.830544    ...    -0.993245  0.309953  0.716927
3  0.044287  0.245083 -0.858829    ...    -0.583070 -0.907295 -0.384888
4 -0.105187  0.293733  0.783647    ...     0.397994 -0.301620  0.066731

[5 rows x 15 columns]
--------------------------------------------------
t2 = 
          0         1         2     ...           8         9         10
0 -0.540041 -0.356173 -0.040504    ...    -0.688343 -0.930151 -0.875597
1  0.813097 -0.436071 -0.224807    ...     0.489693 -0.170844 -0.091632
2  0.848913  0.169518  0.384529    ...     0.379071 -0.256689  0.564788
3  0.764411  0.038849  0.754196    ...    -0.214336  0.717279  0.092451
4  0.770751  0.245371 -0.548065    ...    -0.921029 -0.830439  0.507138

[5 rows x 11 columns]
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Process finished with exit code 0

二、unstack

将秩为 R 的张量的给定维度出栈为秩为 (R-1) 的张量.

通过沿 axis 维度将 num 张量从 value 中分离出来.如果没有指定 num(默认值),则从 value 的形状推断.如果 value.shape[axis] 不知道,则引发 ValueError.

例如,给定一个具有形状 (A, B, C, D) 的张量.

  • 如果 axis == 0,那么 output 中的第 i 个张量就是切片 value[i, :, :, :],并且 output 中的每个张量都具有形状 (B, C, D).(请注意,出栈的维度已经消失,不像split).
  • 如果 axis == 1,那么 output 中的第 i 个张量就是切片 value[:, i, :, :],并且 output 中的每个张量都具有形状 (A, C, D).
tf.unstack(value, num=None, axis=0, name='unstack')
  • value: A rank R > 0 Tensor to be unstacked.
  • num: An int. The length of the dimension axis. Automatically inferred if None (the default).
  • axis: An int. The axis to unstack along. Defaults to the first dimension. Negative - values: wrap around, so the valid range is [-R, R).
  • name: A name for the operation (optional).
import tensorflow as tf

x = tf.reshape(tf.range(12), (3, 4))
print("x = \n", x)
print("-" * 200)

p, q, r = tf.unstack(x)
print("p = ", p)
print("-" * 50)
print("q = ", q)
print("-" * 50)
print("r = ", r)
print("-" * 200)

打印结果:

x = 
 tf.Tensor(
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]], shape=(3, 4), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
p =  tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
--------------------------------------------------
q =  tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
--------------------------------------------------
r =  tf.Tensor([ 8  9 10 11], shape=(4,), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Process finished with exit code 0