我有一个进行多标签文本分类的机器学习模型。我有一个预测对象,可以成功预测用作输入的文本字符串的分类。它将其预测分配给单个预测作为如下列表:

[('unrelated', 0.9684208035469055), ('curated', 0.02895800955593586)]


我觉得这可能很简单,但是本质上我只需要
为策划的比赛创建一个阈值。

因此,如果策划的置信度高于.90或类似的水平,我可以打印一条声明。

但是,我不知道如何指定此条件。

这是一个列表对象,因此我尝试指定索引。但是,每个索引都输出两个['label', confidence]。此外,索引的顺序取决于置信度。它始终始终首先显示最高级别的置信度标签。因此,指定索引号将无济于事,因为它会更改。

single_prediction = predictor.predict(result)
df.at[0,'prediction'] = single_prediction
if single_prediction[0] >= .95:
    print('this is a match')
print(single_prediction)

最佳答案

您可以使用列表推导来实现:

results = [ [('curated', 0.6), ('unrelated', 0.4)],
           [('unrelated', 0.55), ('curated', 0.45)],
          [('unrelated', 0.7), ('curated', 0.3)]]

threshold = 0.4
for result in results:
    if [x[1] for x in result if x[0] == 'curated'][0] > threshold:
        print(result)




输出:

[('curated', 0.6), ('unrelated', 0.4)]
[('unrelated', 0.55), ('curated', 0.45)]

关于python - 如何为输出预测设置条件阈值?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/56840085/

10-14 22:07