用pytorch进行文本分类,数据集为keras内置的imdb影评数据(二分类),代码包含六个部分(详见代码)

使用环境:

pytorch:1.1.0

cuda:10.0

gpu:RTX2070

(1)导入相应的库、定义常量以及加载imdb数据

Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存-LMLPHP

(2)使用DataLoader加载数据

Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存-LMLPHP

(3)定义LSTM模型用于文本二分类

Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存-LMLPHP

(4)定义训练函数和测试函数

Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存-LMLPHP

(5)开始模型的训练(并保存最优模型权重),训练较快,2min左右

Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存-LMLPHP

(6)加载模型权重并测试

Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存-LMLPHP

05-11 13:56