TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例
案例 实现 利用 参数 重复 告诉 TF variables
2023-09-14 09:04:51 时间
TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例
目录
输出结果
后期更新……
代码设计
import tensorflow as tf
# 22 scope (name_scope/variable_scope)
from __future__ import print_function
class TrainConfig:
batch_size = 20
time_steps = 20
input_size = 10
output_size = 2
cell_size = 11
learning_rate = 0.01
class TestConfig(TrainConfig):
time_steps = 1
class RNN(object):
def __init__(self, config):
self._batch_size = config.batch_size
self._time_steps = config.time_steps
self._input_size = config.input_size
self._output_size = config.output_size
self._cell_size = config.cell_size
self._lr = config.learning_rate
self._built_RNN()
def _built_RNN(self):
with tf.variable_scope('inputs'):
self._xs = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._input_size], name='xs')
self._ys = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._output_size], name='ys')
with tf.name_scope('RNN'):
with tf.variable_scope('input_layer'):
l_in_x = tf.reshape(self._xs, [-1, self._input_size], name='2_2D') # (batch*n_step, in_size)
# Ws (in_size, cell_size)
Wi = self._weight_variable([self._input_size, self._cell_size])
print(Wi.name)
# bs (cell_size, )
bi = self._bias_variable([self._cell_size, ])
# l_in_y = (batch * n_steps, cell_size)
with tf.name_scope('Wx_plus_b'):
l_in_y = tf.matmul(l_in_x, Wi) + bi
l_in_y = tf.reshape(l_in_y, [-1, self._time_steps, self._cell_size], name='2_3D')
with tf.variable_scope('cell'):
cell = tf.contrib.rnn.BasicLSTMCell(self._cell_size)
with tf.name_scope('initial_state'):
self._cell_initial_state = cell.zero_state(self._batch_size, dtype=tf.float32)
self.cell_outputs = []
cell_state = self._cell_initial_state
for t in range(self._time_steps):
if t > 0: tf.get_variable_scope().reuse_variables()
cell_output, cell_state = cell(l_in_y[:, t, :], cell_state)
self.cell_outputs.append(cell_output)
self._cell_final_state = cell_state
with tf.variable_scope('output_layer'):
# cell_outputs_reshaped (BATCH*TIME_STEP, CELL_SIZE)
cell_outputs_reshaped = tf.reshape(tf.concat(self.cell_outputs, 1), [-1, self._cell_size])
Wo = self._weight_variable((self._cell_size, self._output_size))
bo = self._bias_variable((self._output_size,))
product = tf.matmul(cell_outputs_reshaped, Wo) + bo
# _pred shape (batch*time_step, output_size)
self._pred = tf.nn.relu(product) # for displacement
with tf.name_scope('cost'):
_pred = tf.reshape(self._pred, [self._batch_size, self._time_steps, self._output_size])
mse = self.ms_error(_pred, self._ys)
mse_ave_across_batch = tf.reduce_mean(mse, 0)
mse_sum_across_time = tf.reduce_sum(mse_ave_across_batch, 0)
self._cost = mse_sum_across_time
self._cost_ave_time = self._cost / self._time_steps
with tf.variable_scope('trian'):
self._lr = tf.convert_to_tensor(self._lr)
self.train_op = tf.train.AdamOptimizer(self._lr).minimize(self._cost)
@staticmethod
def ms_error(y_target, y_pre):
return tf.square(tf.subtract(y_target, y_pre))
@staticmethod
def _weight_variable(shape, name='weights'):
initializer = tf.random_normal_initializer(mean=0., stddev=0.5, )
return tf.get_variable(shape=shape, initializer=initializer, name=name)
@staticmethod
def _bias_variable(shape, name='biases'):
initializer = tf.constant_initializer(0.1)
return tf.get_variable(name=name, shape=shape, initializer=initializer)
if __name__ == '__main__':
train_config = TrainConfig() #定义train_config
test_config = TestConfig()
# # the wrong method to reuse parameters in train rnn
# with tf.variable_scope('train_rnn'):
# train_rnn1 = RNN(train_config)
# with tf.variable_scope('test_rnn'):
# test_rnn1 = RNN(test_config)
# the right method to reuse parameters in train rnn
#目的使train的RNN调用参数,然后利用variable_scope方法共享RNN,让test的RNN再次调用一样的参数,
with tf.variable_scope('rnn') as scope:
sess = tf.Session()
train_rnn2 = RNN(train_config)
scope.reuse_variables() #告诉TF想重复利用RNN的参数
test_rnn2 = RNN(test_config)
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
init = tf.initialize_all_variables()
else:
init = tf.global_variables_initializer()
sess.run(init)
相关文章
- python实现K近邻算法案例
- java局域网发送文件_Java如何实现局域网文件传输代码案例分享
- BootStrap案例
- Hadoop入门(八)——本地运行模式+完全分布模式案例详解,实现WordCount和集群分发脚本xsync快速配置环境变量 (图文详解步骤2021)[通俗易懂]
- php案例:用GD库生成单色图案
- php案例:解压一个压缩包中多个文件
- JavaScript案例:筋斗云
- 案例:实现简易的模板引擎
- 业务用例元模型-软件方法(下)第9章分析类图案例篇Part08
- jpa实现增删改查_hibernate入门案例
- VB6.0 支持鼠标滚轮教程的案例分享
- VB实现的倒计时类代码详解案例分享
- 【愚公系列】2022年12月 .NET CORE工具案例-滑块验证码和拼图验证功能实现
- PyQt5可视化 7 饼图和柱状图实操案例 ④层叠柱状图和百分比柱状图及饼图的实现【超详解图文教程】
- postgresql 实现得到时间对应周的周一案例
- MongoDB聚合分组取第一条记录的案例与实现方法
- Redis解决优惠券秒杀应用案例
- 基于Redis结合SpringBoot的秒杀案例详解
- Mysql中DATEDIFF函数的基础语法及练习案例
- LAMP实战案例:实现PowerDNS 应用部署
- SAMBA实战案例:利用SAMBA实现指定目录共享
- 实战案例:CentOS 7 实现基于cobbler实现自动化安装
- 探索Oracle的XML技术:应用与案例解析(oraclexml)
- Oracle 维度管理的实现方式和应用案例分析(oracle维度)
- 使用Oracle存储过程实现数据查询(oracle存储过程案例)
- 优化Oracle内存管理优化一个典型案例(oracle内存 典型)
- TP5实现Redis缓存技术案例研究(tp5redis案例)
- Redis实现超卖预防成功案例(redis防止超卖实例)
- Ajax案例集下载:新增分页查询案例(包括《Ajax开发精要》中的两个综合案例)下载
- Android列表实现(2)_游标列表案例讲解
- jqueryajax局部无刷新更新数据的实现案例
- PHP+memcache实现消息队列案例分享