简介
短文本语义匹配(SimilarityNet, SimNet)是一个计算短文本相似度的框架,可以根据用户输入的两个文本,计算出相似度得分。SimNet框架在百度各产品上广泛应用,主要包括BOW、CNN、RNN、MMDNN等核心网络结构形式,提供语义相似度计算训练和预测框架,适用于信息检索、新闻推荐、智能客服等多个应用场景,帮助企业解决语义匹配问题。模型结构如下:
SimNet讲解参考优秀博客https://www.jiqizhixin.com/articles/2017-06-15-5
下载安装命令
## CPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle
## GPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu
注意
本项目代码需要使用GPU环境来运行:
并且检查相关参数设置, 例如use_gpu, fluid.CUDAPlace(0)等处是否设置正确.
数据格式说明
训练模式一共分为pairwise
和pointwise
两种模式。
pairwise模式:
训练集格式如下: query \t pos_query \t neg_query。 query、pos_query和neg_query是以空格分词的中文文本,中间使用制表符'\t'隔开,pos_query表示与query相似的正例,neg_query表示与query不相似的随机负例,文本编码为utf-8。
现在 安卓模拟器 哪个 好 用 电脑 安卓模拟器 哪个 更好 电信 手机 可以 用 腾讯 大王 卡 吗 ?
土豆 一亩地 能 收 多少 斤 一亩 地土豆 产 多少 斤 一亩 地 用 多少 斤 土豆 种子
开发集和测试集格式:query1 \t query2 \t label。
query1和query2表示以空格分词的中文文本,label为0或1,1表示query1与query2相似,0表示query1与query2不相似,query1、query2和label中间以制表符'\t'隔开,文本编码为utf-8。
现在 安卓模拟器 哪个 好 用 电脑 安卓模拟器 哪个 更好 1
为什么 头发 掉 得 很厉害 我 头发 为什么 掉 得 厉害 1
常喝 薏米 水 有 副 作用 吗 女生 可以 长期 喝 薏米 水养生 么 0
长 的 清新 是 什么 意思 小 清新 的 意思 是 什么 0
pointwise模式:
训练集、开发集和测试集数据格式相同:query1和query2表示以空格分词的中文文本,label为0或1,1表示query1与query2相似,0表示query1与query2不相似,query1、query2和label中间以制表符'\t'隔开,文本编码为utf-8。
现在 安卓模拟器 哪个 好 用 电脑 安卓模拟器 哪个 更好 1
为什么 头发 掉 得 很厉害 我 头发 为什么 掉 得 厉害 1
常喝 薏米 水 有 副 作用 吗 女生 可以 长期 喝 薏米 水养生 么 0
长 的 清新 是 什么 意思 小 清新 的 意思 是 什么 0
infer数据集:
pairwise
和pointwise
的infer数据集格式相同:query1 \t query2。
query1和query2为以空格分词的中文文本。
怎么 调理 湿热 体质 ? 湿热 体质 怎样 调理 啊
搞笑 电影 美国 搞笑 的 美国 电影
注:本项目额外提供了分词预处理脚本(在preprocess目录下),可供用户使用,具体使用方法如下:
python tokenizer.py --test_data_dir ./test.txt.utf8 --batch_size 1 > test.txt.utf8.seg
其中test.txt.utf8为待分词的文件,一条文本数据一行,utf8编码,分词结果存放在test.txt.utf8.seg文件中
代码结构说明
.
├── run_classifier.py:该项目的主函数,封装包括训练、预测、评估的部分
├── config.py:定义该项目模型的配置类,读取具体模型类别、以及模型的超参数等
├── reader.py:定义了读入数据的相关函数
├── utils.py:定义了其他常用的功能函数
├── Config: 定义多种模型的配置文件
文件介绍
similarity_net: 存放短文本匹配的主要执行文件
similarity_net/data目录下存放训练集数据示例、集数据示例、测试集数据示例,以及对应词索引字典(term2id.dict)
similarity_net/model_files下存放训练好的模型数据,其中基于大规模数据训练好的pairwise模型(基于bow模型训练),保存在model_files/simnet_bow_pairwise_pretrained_model/下。训练好的pointwise模型保存在bow_pointwise/100下,其中100指的是训练步数,可修改。
models:共享的模型集合,本例的bow模型文件在models/matching/bow.py中
preprocess:共享的数据预处理流程
本例算法运行基于GPU,若采用CPU,请将run.sh文件中的参数use_cuda改为false
#可以通过以下语句查看数据集内容,此处仅展示部分
f = open("similarity_net/data/ecom","r")
i=1
for line in f:
i=i+1
print(line)
if i >10:
break
柏格曼 橱柜 有 吗 除了 橱柜 , 其他 可以 定制 吗 ? 0 皮 表带 断了 能 修 吗 ? 扣 表带 的 地方 断了 , 能 修 吗 ? 1 你好 、 请问 你们 有 实木 橱柜 门板 批发 吗 ? 除了 橱柜 , 其他 可以 定制 吗 ? 0 手表 不动 了 怎么 办 手表 不走 了 怎么 办 ? 1 个体 营业执照 代办 你们 这 可以 办理 营业执照 吗 1 我 的 驾照 2018年 到期 , 是 到 上海 去 换 还是 在 老家 换 可以 办 国内 驾照 吗 0 世界上 最 贵 的 手机 是 什么样 的 你们 公司 手机 贵 吗 1 厂房 出租 ? 你们 在 松江区 有 厂房 出租 吗 1 怎么 检验 二手车 是否 漏油 ? 怎样 买到 好 的 二手车 0 你好 我 想 承包 地食堂 食堂 托管 后 是 什么样 0
我们公开了自建的测试集,包括百度知道、ECOM、QQSIM、UNICOM四个数据集,基于上面的预训练模型,用户可以进入evaluate目录下依次执行下列命令获取测试集评估结果,本例展示的是基于大规模数据训练好的pairwise模型,若想测试pointwise模型,修改.sh文件中的INIT_CHECKPOINT到模型目录即可。
sh evaluate_ecom.sh
sh evaluate_qqsim.sh
sh evaluate_zhidao.sh
sh evaluate_unicom.sh
用户可以基于示例数据构建训练集和开发集,可以运行下面的命令,进行模型训练和开发集验证,关于训练的相关参数可在run.sh文件中修改。
!cd similarity_net && sh run.sh train
----------- Configuration Arguments ----------- batch_size: 128 compute_accuracy: False config_path: ./config/bow_pointwise.json do_infer: False do_test: True do_train: True do_valid: True enable_ce: False epoch: 120 infer_data_dir: ./data/infer_data infer_result_path: infer_result init_checkpoint: examples/cnn_pointwise.json lamda: 0.958 output_dir: ./model_files save_steps: 100 skip_steps: 10 task_mode: pointwise task_name: simnet test_data_dir: ./data/test_pointwise_data test_result_path: test_result train_data_dir: ./data/train_pointwise_data use_cuda: True valid_data_dir: ./data/test_pointwise_data validation_steps: 100 verbose_result: True vocab_path: ./data/term2id.dict ------------------------------------------------ W0828 17:01:55.135362 179 device_context.cc:259] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0 W0828 17:01:55.139626 179 device_context.cc:267] device: 0, cuDNN Version: 7.3. start train process ... epoch: 0, loss: 0.694876, used time: 0.018286 sec epoch: 1, loss: 0.666920, used time: 0.007737 sec epoch: 2, loss: 0.644924, used time: 0.007653 sec epoch: 3, loss: 0.621647, used time: 0.007453 sec epoch: 4, loss: 0.596388, used time: 0.007460 sec epoch: 5, loss: 0.568347, used time: 0.010349 sec epoch: 6, loss: 0.538345, used time: 0.007091 sec epoch: 7, loss: 0.506268, used time: 0.083083 sec epoch: 8, loss: 0.471838, used time: 0.007405 sec epoch: 9, loss: 0.435825, used time: 0.007289 sec epoch: 10, loss: 0.398896, used time: 0.006648 sec epoch: 11, loss: 0.361460, used time: 0.006303 sec epoch: 12, loss: 0.323817, used time: 0.006800 sec epoch: 13, loss: 0.286733, used time: 0.006457 sec epoch: 14, loss: 0.250810, used time: 0.007178 sec epoch: 15, loss: 0.216461, used time: 0.007067 sec epoch: 16, loss: 0.184460, used time: 0.006960 sec epoch: 17, loss: 0.155453, used time: 0.007412 sec epoch: 18, loss: 0.129435, used time: 0.007023 sec epoch: 19, loss: 0.106711, used time: 0.007316 sec epoch: 20, loss: 0.087089, used time: 0.006915 sec epoch: 21, loss: 0.070319, used time: 0.006860 sec epoch: 22, loss: 0.056391, used time: 0.007401 sec epoch: 23, loss: 0.045026, used time: 0.006812 sec epoch: 24, loss: 0.035867, used time: 0.006653 sec epoch: 25, loss: 0.028555, used time: 0.006753 sec epoch: 26, loss: 0.022770, used time: 0.006276 sec epoch: 27, loss: 0.018210, used time: 0.007046 sec epoch: 28, loss: 0.014633, used time: 0.006643 sec epoch: 29, loss: 0.011825, used time: 0.006803 sec epoch: 30, loss: 0.009617, used time: 0.007638 sec epoch: 31, loss: 0.007880, used time: 0.006894 sec epoch: 32, loss: 0.006511, used time: 0.007558 sec epoch: 33, loss: 0.005429, used time: 0.006743 sec epoch: 34, loss: 0.004567, used time: 0.006586 sec epoch: 35, loss: 0.003877, used time: 0.006981 sec epoch: 36, loss: 0.003320, used time: 0.006359 sec epoch: 37, loss: 0.002869, used time: 0.010105 sec epoch: 38, loss: 0.002502, used time: 0.007091 sec epoch: 39, loss: 0.002200, used time: 0.006626 sec epoch: 40, loss: 0.001950, used time: 0.007295 sec epoch: 41, loss: 0.001742, used time: 0.006636 sec epoch: 42, loss: 0.001567, used time: 0.006919 sec epoch: 43, loss: 0.001420, used time: 0.006564 sec epoch: 44, loss: 0.001295, used time: 0.006673 sec epoch: 45, loss: 0.001188, used time: 0.007324 sec epoch: 46, loss: 0.001097, used time: 0.006637 sec epoch: 47, loss: 0.001017, used time: 0.007091 sec epoch: 48, loss: 0.000949, used time: 0.006471 sec epoch: 49, loss: 0.000889, used time: 0.006439 sec epoch: 50, loss: 0.000837, used time: 0.006889 sec epoch: 51, loss: 0.000791, used time: 0.006620 sec epoch: 52, loss: 0.000750, used time: 0.007734 sec epoch: 53, loss: 0.000714, used time: 0.007153 sec epoch: 54, loss: 0.000682, used time: 0.006997 sec epoch: 55, loss: 0.000653, used time: 0.007598 sec epoch: 56, loss: 0.000628, used time: 0.006924 sec epoch: 57, loss: 0.000605, used time: 0.007502 sec epoch: 58, loss: 0.000584, used time: 0.007194 sec epoch: 59, loss: 0.000566, used time: 0.007589 sec epoch: 60, loss: 0.000549, used time: 0.008174 sec epoch: 61, loss: 0.000533, used time: 0.007172 sec epoch: 62, loss: 0.000519, used time: 0.007691 sec epoch: 63, loss: 0.000506, used time: 0.007020 sec epoch: 64, loss: 0.000494, used time: 0.007219 sec epoch: 65, loss: 0.000483, used time: 0.007499 sec epoch: 66, loss: 0.000473, used time: 0.006457 sec epoch: 67, loss: 0.000464, used time: 0.009499 sec epoch: 68, loss: 0.000455, used time: 0.010381 sec epoch: 69, loss: 0.000447, used time: 0.010359 sec epoch: 70, loss: 0.000439, used time: 0.011210 sec epoch: 71, loss: 0.000432, used time: 0.009208 sec epoch: 72, loss: 0.000426, used time: 0.007774 sec epoch: 73, loss: 0.000419, used time: 0.007008 sec epoch: 74, loss: 0.000413, used time: 0.006941 sec epoch: 75, loss: 0.000408, used time: 0.007444 sec epoch: 76, loss: 0.000402, used time: 0.006495 sec epoch: 77, loss: 0.000397, used time: 0.006951 sec epoch: 78, loss: 0.000392, used time: 0.006649 sec epoch: 79, loss: 0.000388, used time: 0.006763 sec epoch: 80, loss: 0.000383, used time: 0.007741 sec epoch: 81, loss: 0.000379, used time: 0.006991 sec epoch: 82, loss: 0.000375, used time: 0.007846 sec epoch: 83, loss: 0.000371, used time: 0.007176 sec epoch: 84, loss: 0.000367, used time: 0.006994 sec epoch: 85, loss: 0.000364, used time: 0.007500 sec epoch: 86, loss: 0.000360, used time: 0.006729 sec epoch: 87, loss: 0.000357, used time: 0.007332 sec epoch: 88, loss: 0.000353, used time: 0.006722 sec epoch: 89, loss: 0.000350, used time: 0.006866 sec epoch: 90, loss: 0.000347, used time: 0.007011 sec epoch: 91, loss: 0.000343, used time: 0.006456 sec epoch: 92, loss: 0.000340, used time: 0.007103 sec epoch: 93, loss: 0.000337, used time: 0.006939 sec epoch: 94, loss: 0.000334, used time: 0.006878 sec epoch: 95, loss: 0.000332, used time: 0.007788 sec epoch: 96, loss: 0.000329, used time: 0.007171 sec epoch: 97, loss: 0.000326, used time: 0.009529 sec epoch: 98, loss: 0.000323, used time: 0.006718 sec global_steps: 100, valid_auc: 1.000000 saving infer model in ./model_files/bow_pointwise/100 epoch: 99, loss: 0.000321, used time: 2.018070 sec epoch: 100, loss: 0.000318, used time: 0.009685 sec epoch: 101, loss: 0.000316, used time: 0.008802 sec epoch: 102, loss: 0.000313, used time: 0.007791 sec epoch: 103, loss: 0.000310, used time: 0.007369 sec epoch: 104, loss: 0.000308, used time: 0.007419 sec epoch: 105, loss: 0.000306, used time: 0.007456 sec epoch: 106, loss: 0.000303, used time: 0.007173 sec epoch: 107, loss: 0.000301, used time: 0.007799 sec epoch: 108, loss: 0.000298, used time: 0.007263 sec epoch: 109, loss: 0.000296, used time: 0.006812 sec epoch: 110, loss: 0.000294, used time: 0.007546 sec epoch: 111, loss: 0.000292, used time: 0.006687 sec epoch: 112, loss: 0.000289, used time: 0.006969 sec epoch: 113, loss: 0.000287, used time: 0.007132 sec epoch: 114, loss: 0.000285, used time: 0.007028 sec epoch: 115, loss: 0.000283, used time: 0.007616 sec epoch: 116, loss: 0.000281, used time: 0.006946 sec epoch: 117, loss: 0.000279, used time: 0.007951 sec epoch: 118, loss: 0.000276, used time: 0.007090 sec epoch: 119, loss: 0.000274, used time: 0.006912 sec AUC of test is 1.000000
对训练好的pointwise模型进行评估
!cd similarity_net && sh run.sh eval
----------- Configuration Arguments ----------- batch_size: 128 compute_accuracy: False config_path: ./config/bow_pointwise.json do_infer: False do_test: True do_train: False do_valid: False enable_ce: False epoch: 10 infer_data_dir: None infer_result_path: infer_result.txt init_checkpoint: ./model_files/bow_pointwise/100 lamda: 0.958 output_dir: None save_steps: 200 skip_steps: 10 task_mode: pointwise task_name: simnet test_data_dir: ./data/test_pointwise_data test_result_path: ./test_result.txt train_data_dir: None use_cuda: True valid_data_dir: None validation_steps: 100 verbose_result: True vocab_path: ./data/term2id.dict ------------------------------------------------ W0828 17:09:10.953100 404 device_context.cc:259] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0 W0828 17:09:10.956790 404 device_context.cc:267] device: 0, cuDNN Version: 7.3. start test process ... AUC of test is 1.000000 test result saved in /home/aistudio/similarity_net/./test_result.txt
基于已有的预训练模型,可以运行下面的命令进行推测,并保存推测结果到本地。
!cd similarity_net && sh run.sh infer
----------- Configuration Arguments ----------- batch_size: 128 compute_accuracy: False config_path: ./config/bow_pointwise.json do_infer: True do_test: False do_train: False do_valid: False enable_ce: False epoch: 10 infer_data_dir: ./data/infer_data infer_result_path: ./infer_result.txt init_checkpoint: ./model_files/bow_pointwise/100 lamda: 0.91 output_dir: None save_steps: 200 skip_steps: 10 task_mode: pointwise task_name: simnet test_data_dir: None test_result_path: test_result train_data_dir: None use_cuda: True valid_data_dir: None validation_steps: 100 verbose_result: True vocab_path: ./data/term2id.dict ------------------------------------------------ W0828 17:07:39.618010 359 device_context.cc:259] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0 W0828 17:07:39.622139 359 device_context.cc:267] device: 0, cuDNN Version: 7.3. start test process ... infer result saved in /home/aistudio/similarity_net/./infer_result.txt
进阶使用
如何组建自己的模型
用户可以根据自己的需求,组建自定义的模型,具体方法如下所示:
i. 定义自己的网络结构
用户可以在../models/matching
下定义自己的模型;
ii. 更改模型配置
用户仿照config
中的文件生成自定义模型的配置文件。
用户需要保留配置文件中的net
、loss
、optimizer
、task_mode
和model_path
字段。net
为用户自定义的模型参数,task_mode
表示训练模式,为pairwise
或pointwise
,要与训练命令中的--task_mode
命令保持一致,model_path
为模型保存路径,loss
和optimizer
依据自定义模型的需要仿照config
下的其他文件填写。
iii.模型训练,运行训练、评估、预测脚本即可(具体方法同上)。
点击链接,使用AI Studio一键上手实践项目吧:https://aistudio.baidu.com/aistudio/projectdetail/124373
下载安装命令
## CPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle
## GPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu
>> 访问 PaddlePaddle 官网,了解更多相关内容。