使用pytorch动手实现LSTM模块
2023-09-14 09:15:50 时间
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
import math
class NaiveLSTM(nn.Module):
"""Naive LSTM like nn.LSTM"""
def __init__(self, input_size: int, hidden_size: int):
super(NaiveLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# input gate
self.w_ii = Parameter(Tensor(hidden_size, input_size))
self.w_hi = Parameter(Tensor(hidden_size, hidden_size))
self.b_ii = Parameter(Tensor(hidden_size, 1))
self.b_hi = Parameter(Tensor(hidden_size, 1))
# forget gate
self.w_if = Parameter(Tensor(hidden_size, input_size))
self.w_hf = Parameter(Tensor(hidden_size, hidden_size))
self.b_if = Parameter(Tensor(hidden_size, 1))
self.b_hf = Parameter(Tensor(hidden_size, 1))
# output gate
self.w_io = Parameter(Tensor(hidden_size, input_size))
self.w_ho = Parameter(Tensor(hidden_size, hidden_size))
self.b_io = Parameter(Tensor(hidden_size, 1))
self.b_ho = Parameter(Tensor(hidden_size, 1))
# cell
self.w_ig = Parameter(Tensor(hidden_size, input_size))
self.w_hg = Parameter(Tensor(hidden_size, hidden_size))
self.b_ig = Parameter(Tensor(hidden_size, 1))
self.b_hg = Parameter(Tensor(hidden_size, 1))
self.reset_weigths()
def reset_weigths(self):
"""reset weights
"""
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
init.uniform_(weight, -stdv, stdv)
def forward(self, inputs: Tensor, state: Tuple[Tensor]) \
-> Tuple[Tensor, Tuple[Tensor, Tensor]]:
"""Forward
Args:
inputs: [1, 1, input_size]
state: ([1, 1, hidden_size], [1, 1, hidden_size])
"""
# seq_size, batch_size, _ = inputs.size()
if state is None:
h_t = torch.zeros(1, self.hidden_size).t()
c_t = torch.zeros(1, self.hidden_size).t()
else:
(h, c) = state
h_t = h.squeeze(0).t()
c_t = c.squeeze(0).t()
hidden_seq = []
seq_size = 1
for t in range(seq_size):
x = inputs[:, t, :].t()
# input gate
i = torch.sigmoid(self.w_ii @ x + self.b_ii + self.w_hi @ h_t +
self.b_hi)
# forget gate
f = torch.sigmoid(self.w_if @ x + self.b_if + self.w_hf @ h_t +
self.b_hf)
# cell
g = torch.tanh(self.w_ig @ x + self.b_ig + self.w_hg @ h_t
+ self.b_hg)
# output gate
o = torch.sigmoid(self.w_io @ x + self.b_io + self.w_ho @ h_t +
self.b_ho)
c_next = f * c_t + i * g
h_next = o * torch.tanh(c_next)
c_next_t = c_next.t().unsqueeze(0)
h_next_t = h_next.t().unsqueeze(0)
hidden_seq.append(h_next_t)
hidden_seq = torch.cat(hidden_seq, dim=0)
return hidden_seq, (h_next_t, c_next_t)
def reset_weigths(model):
"""reset weights
"""
for weight in model.parameters():
init.constant_(weight, 0.5)
### test
inputs = torch.ones(1, 1, 10)
h0 = torch.ones(1, 1, 20)
c0 = torch.ones(1, 1, 20)
print(h0.shape, h0)
print(c0.shape, c0)
print(inputs.shape, inputs)
# test naive_lstm with input_size=10, hidden_size=20
naive_lstm = NaiveLSTM(10, 20)
reset_weigths(naive_lstm)
output1, (hn1, cn1) = naive_lstm(inputs, (h0, c0))
print(hn1.shape, cn1.shape, output1.shape)
print(hn1)
print(cn1)
print(output1)
相关文章
- 如何在 React Native 中写一个自定义模块
- ansible使用script模块在受控机上执行脚本(ansible2.9.5)
- 原已经安装好的nginx,现在需要添加一个未被编译安装的模块--echo-nginx-module-0.56
- OpenCV每日函数 计算摄影模块(6) 非真实感渲染算法
- Atitit es6新特性 Es7 es8 新特性 目录 1.1. ECMAScript 的历史1 2. 新特性2 2.1. 全面的class模型2 2.2. .模块 import、expor
- 基于simulink的PN码伪码匹配的同步仿真,包括解调,伪码匹配,fft等模块
- spring boot & maven 多模块 ---心得
- 【CSS】浮动 ④ ( 浮动布局案例 - 电商布局模块 | 案例分析 | 布局测量摆放 | 浮动布局代码示例 )
- python——random模块
- VC++给软件添加异常捕获模块生成dump文件(附源码)
- PyTorch学习笔记(三):PyTorch主要组成模块
- 【PyTorch】 多进程队列中传入pytorch处理后的tensor张量
- 【PyTorch】numpy数组与pytorch的tensor相互转化
- Python 正则表达模块详解
- Linux启动流程与模块管理