zl程序教程

您现在的位置是:首页 >  其它

当前栏目

tensorflow 查看ckpt模型中参数值

查看 模型 Tensorflow 参数值
2023-09-11 14:17:14 时间

有时我们有查看tensor内部变量的值的变化情况,需要挖一下,现给出解析代码:

"""
@Date   :2021/5/18
@Author :xxx
"""
import os
from tensorflow.python import pywrap_tensorflow

base_model_dir = r'model-ckp-20210222'
target_model_dir = r'new_ckpt'
checkpoint_path_01 = os.path.join(base_model_dir, "model.ckpt-11383419")
checkpoint_path_02 = os.path.join(target_model_dir, "model.ckpt-94")
reader_01 = pywrap_tensorflow.NewCheckpointReader(checkpoint_path_01)
reader_02 = pywrap_tensorflow.NewCheckpointReader(checkpoint_path_02)
base_var_shape_map = reader_01.get_variable_to_shape_map()
target_shape_map = reader_02.get_variable_to_shape_map()
cnt = 0
tensor_name = 'output_weights'
print(reader_01.get_tensor(tensor_name))
print(reader_02.get_tensor(tensor_name))
for key1 in base_var_shape_map:
    if tensor_name in key1:
        print(reader_01.get_tensor(key1))
        print('tensor name:{}'.format(key1))

print('###############################')
for key1 in target_shape_map:
    if tensor_name in key1:
        print(reader_02.get_tensor(key1))
        print('tensor name:{}'.format(key1))

 

梯度没有更新

梯度更新:

 

 由上我们知道,可以通过此方式进行训练参数冻结(如bert 12 layer),fintuing 指定参数完成特定任务