f1 score 代码_在pytorch 中计算精度、回归率、F1 score等指标的实例「建议收藏」
pytorch中训练完网络后,需要对学习的结果进行测试。官网上例程用的方法统统都是正确率,使用的是torch.eq()这个函数。
但是为了更精细的评价结果,我们还需要计算其他各个指标。在把官网API翻了一遍之后发现并没有用于计算TP,TN,FP,FN的函数。。。
在动了无数歪脑筋之后,心想pytorch完全支持numpy,那能不能直接进行判断,试了一下果然可以,上代码:
# TP predict 和 label 同时为1
TP += ((pred_choice == 1) & (target.data == 1)).cpu().sum()
# TN predict 和 label 同时为0
TN += ((pred_choice == 0) & (target.data == 0)).cpu().sum()
# FN predict 0 label 1
FN += ((pred_choice == 0) & (target.data == 1)).cpu().sum()
# FP predict 1 label 0
FP += ((pred_choice == 1) & (target.data == 0)).cpu().sum()
p = TP / (TP + FP)
r = TP / (TP + FN)
F1 = 2 * r * p / (r + p)
acc = (TP + TN) / (TP + TN + FP + FN
这样就能看到各个指标了。
因为target是Variable所以需要用target.data取到对应的tensor,又因为是在gpu上算的,需要用 .cpu() 移到cpu上。
因为这是一个batch的统计,所以需要用+=累计出整个epoch的统计。当然,在epoch开始之前需要清零
以上这篇在pytorch 中计算精度、回归率、F1 score等指标的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/182193.html原文链接:https://javaforall.cn
相关文章
- Python装饰器工程实例及关键点总结
- 谷歌云T2A实例发布,TOP公有云都有ARM主机了
- LSTM应用场景以及pytorch实例
- 【Groovy】Groovy 扩展方法 ( 实例扩展方法配置 | 扩展方法示例 | 编译实例扩展类 | 打包实例扩展类字节码到 jar 包中 | 测试使用 Thread 实例扩展方法 )
- MySQL正则表达式regexp_replace函数的用法实例
- Oracle触发器实例代码
- jQuery Ajax 实例 ($.ajax、$.post、$.get)详解编程语言
- 实战Java搭配MySQL:从零开始的数据库操作实践(javamysql实例)
- 使用Redis单实例实现分布式锁
- Oracle如何关闭多余实例(oracle关闭多余实例)
- JQuery入门实例1
- JavaScript正则表达式验证中文实例讲解
- ASP.NET动态生成静态页面的实例代码
- java实现哈弗曼编码与反编码实例分享(哈弗曼算法)
- android自定义进度条渐变色View的实例代码
- php分页函数完整实例代码