zl程序教程

您现在的位置是:首页 >  其他

当前栏目

记一次模型调试问题:使用TextLSTM/RNN学习不动,损失和acc均无变化

调试学习 模型 一次 变化 损失 RNN 问题
2023-09-14 09:08:40 时间

问题

在清华新闻分类数据集上,使用TextCNN效果不错,使用TextLSTM/RNN学习不动,损失和acc均无变化

定位问题

  1. CNN效果有提升,说明train代码和数据没问题;
  2. 更改RNN/LSTM结构,加损失函数还是没效果;
  3. 修改lr、embed_dim,num_laber均无效果;
  4. 本地一步步debug,发现一个问题,input里面有很多padding为0了,怀疑是短文本pad太多没有学到特征;

优化

  1. 修改seq_length,从50改为32,短了补0,长了截取。重新训练,问题解决,模型开始学习。
# 修改seq_length,模型效果缓慢提升
epoch:0 item:50  loss:2.3013758659362793 train_acc:0.109375  dev_acc:0.1474609375
epoch:0 item:100 loss:2.2781167030334473 train_acc:0.1953125 dev_acc:0.17578125
epoch:0 item:150 loss:2.205510377883911  train_acc:0.21875   dev_acc:0.1767578125
epoch:0 item:200 loss:2.183446168899536  train_acc:0.2421875 dev_acc:0.294921875
epoch:0 item:250 loss:2.144759178161621  train_acc:0.2265625 dev_acc:0.2470703125
epoch:0 item:300 loss:2.143526792526245  train_acc:0.265625  dev_acc:0.28515625
epoch:0 item:350 loss:2.078019857406616  train_acc:0.34375   dev_acc:0.279296875
epoch:0 item:400 loss:2.096219301223755  train_acc:0.296875  dev_acc:0.3203125
epoch:0 item:450 loss:2.0016613006591797 train_acc:0.3984375 dev_acc:0.3974609375
epoch:0 item:500 loss:1.9698306322097778 train_acc:0.390625  dev_acc:0.4150390625
epoch:0 item:550 loss:1.9621188640594482 train_acc:0.4453125 dev_acc:0.47265625
  1. 改为取各个token的mean、max,同样可以学到特征,比取最后一层效果更好
# 取mean,效果优于last hidden state
epoch:0 item:50  loss:1.9982243776321411 train_acc:0.4375    dev_acc:0.474609375
epoch:0 item:100 loss:1.7818747758865356 train_acc:0.6015625 dev_acc:0.6044921875
epoch:0 item:150 loss:1.762993574142456  train_acc:0.5625    dev_acc:0.625
epoch:0 item:200 loss:1.71768057346344   train_acc:0.6484375 dev_acc:0.673828125
epoch:0 item:250 loss:1.6551015377044678 train_acc:0.65625   dev_acc:0.64453125
epoch:0 item:300 loss:1.661691427230835  train_acc:0.6640625 dev_acc:0.640625
epoch:0 item:350 loss:1.6576321125030518 train_acc:0.65625   dev_acc:0.6865234375
# 取max,效果比较好,2个batch后提升很明显
epoch:0 item:50  loss:1.9646885395050049 train_acc:0.578125  dev_acc:0.5859375
epoch:0 item:100 loss:1.7058160305023193 train_acc:0.7890625 dev_acc:0.7041015625
epoch:0 item:150 loss:1.743579626083374  train_acc:0.6328125 dev_acc:0.7392578125

代码及数据集

https://github.com/haibincoder/NlpSummary/tree/master/torchcode/classification