情感分析系列(二)——使用BiLSTM进行情感分析
分析 系列 进行 情感 使用
2023-09-27 14:19:49 时间
一、数据预处理
先前我们已经进行了数据预处理:情感分析系列(一)——IMDb数据集及其预处理,这里不再过多介绍,本文将聚焦于模型的搭建与使用。
二、搭建BiLSTM
首先导入本文需要的所有包
import random
import torch
import torch.nn as nn
import numpy as np
from data_preprocess import load_imdb
from torch.utils.data import DataLoader
from torchtext.vocab import GloVe
我们使用双向LSTM来处理文本序列,将正向和反向LSTM在最后一个时间步上的隐状态连接在一起送进分类层(全连接层):
class BiLSTM(nn.Module):
def __init__(self, vocab, embed_size=100, hidden_size=256, num_layers=2, dropout=0.1, use_glove=False):
super().__init__()
self.embedding = nn.Embedding(len(vocab), embed_size, padding_idx=vocab['<pad>'])
self.rnn = nn.LSTM(embed_size, hidden_size, num_layers=num_layers, bidirectional=True, dropout=dropout)
self.fc = nn.Linear(2 * hidden_size, 2)
# xavier初始化参数
self._reset_parameters()
# 该参数决定是否使用预训练的词向量,如果使用,则将其冻结,训练期间不再更新
if use_glove:
glove = GloVe(name="6B", dim=100)
self.embedding = nn.Embedding.from_pretrained(glove.get_vecs_by_tokens(vocab.get_itos()),
padding_idx=vocab['<pad>'],
freeze=True)
def forward(self, x):
x = self.embedding(x).transpose(0, 1)
_, (h_n, _) = self.rnn(x)
output = self.fc(torch.cat((h_n[-1], h_n[-2]), dim=-1))
return output
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
至于为什么是连接 h_n[-1]
和 h_n[-2]
,可参考这篇文章:深入剖析多层双向LSTM的输入输出。
定义一个函数用来固定随机种子
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
三、训练&测试
因为是在3090-24GB上训练,所以这里batch size开到了512,此外学习率为0.001,训练50个epoch。
set_seed()
BATCH_SIZE = 512
LEARNING_RATE = 0.001
NUM_EPOCHS = 50
train_data, test_data, vocab = load_imdb()
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BiLSTM(vocab).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
for epoch in range(1, NUM_EPOCHS + 1):
print(f'Epoch {epoch}\n' + '-' * 32)
avg_train_loss = 0
for batch_idx, (X, y) in enumerate(train_loader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = criterion(pred, y)
avg_train_loss += loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (batch_idx + 1) % 5 == 0:
print(f"[{(batch_idx + 1) * BATCH_SIZE:>5}/{len(train_loader.dataset):>5}] train loss: {loss:.4f}")
print(f"Avg train loss: {avg_train_loss/(batch_idx + 1):.4f}\n")
acc = 0
for X, y in test_loader:
with torch.no_grad():
X, y = X.to(device), y.to(device)
pred = model(X)
acc += (pred.argmax(1) == y).sum().item()
print(f"Accuracy: {acc / len(test_loader.dataset):.4f}")
不使用预训练词向量的结果:
Accuracy: 0.7911
使用了GloVe词向量的结果:
Accuracy: 0.8761
🧑💻 创作不易,如需源码可前往 SA-IMDb 进行查看,下载时还请您随手给一个follow和star,谢谢!
相关文章
- Java入门系列之集合ArrayList源码分析
- 小甲鱼 OllyDbg 教程系列 (七) :VB 程序逆向分析
- JVM系列之:从汇编角度分析NullCheck
- 使用react全家桶制作博客后台管理系统 网站PWA升级 移动端常见问题处理 循序渐进学.Net Core Web Api开发系列【4】:前端访问WebApi [Abp 源码分析]四、模块配置 [Abp 源码分析]三、依赖注入
- (《机器学习》完整版系列)第10章 降维与度量学习——10.3 主成分分析的优化目标(坐标变换的魔力)
- 中国云计算市场分析:多云能力成抢滩关键
- 性能分析系列1:小命令保证大性能
- 人之将死其言也善?30年来死囚遗言分析
- RHCSA 系列(四): 编辑文本文件及分析文本
- 网络安全系列-三十六:使用Suricata IDS分析pcap文件
- 网络安全系列-三十五:公开的用于网络流量分析的pcap文件
- 22 面向对象编程 对象的创建分析 类与对象的关系 创建与初始化对象
- Adaboost算法原理分析和实例+代码(简明易懂)
- Java小白进阶系列——Java锁框架AQS源码分析目录大纲
- Android应用程序线程消息循环模型分析
- 申论存在问题分析
- Java集合迭代器 Iterator分析
- 【历史上的今天】1 月 10 日:算法分析之父出生;史上最失败的世纪并购;含冤 50 年的计算机先驱
- 相关分析和回归分析
- 从Line的招股书风险分析看看这家公司的现状怎么样
- PHP性能分析——xhprof(window 安装xhporf)
- 【Hadoop】Hadoop生态系列之InputForamt.class与OutputFormat.class分析
- springboot依赖管理和自动配置源码分析