使用mlr3搞定二分类资料的多个模型评价和比较
前面介绍了使用tidymodels
进行二分类资料的模型评价和比较,不知道大家学会了没?
我之前详细介绍过mlr3
这个包,也是目前R语言机器学习领域比较火的R包了,今天说下这么用mlr3
进行二分类资料的模型评价和比较。
本期目录:
- 加载R包
- 建立任务
- 数据预处理
- 选择多个模型
- 建立benchmark_grid
- 开始计算
- 查看模型表现
- 结果可视化
- 选择最好的模型
加载R包
首先还是加载数据和R包,和之前的数据一样的。
library(mlr3verse)
## Loading required package: mlr3
library(mlr3pipelines)
library(mlr3filters)
建立任务
然后是对数据进行划分训练集和测试集,对数据进行预处理,为了和之前的tidymodels
进行比较,这里使用的数据和预处理步骤都是和之前一样的。
# 读取数据
all_plays <- readRDS("../000files/all_plays.rds")
# 建立任务
pbp_task <- as_task_classif(all_plays, target="play_type")
# 数据划分
split_task <- partition(pbp_task, ratio=0.75)
task_train <- pbp_task$clone()$filter(split_task$train)
task_test <- pbp_task$clone()$filter(split_task$test)
数据预处理
建立任务后就是建立数据预处理步骤,这里采用和上篇推文tidymodels
中一样的预处理步骤:
# 数据预处理
pbp_prep <- po("select", # 去掉3列
selector = selector_invert(
selector_name(c("half_seconds_remaining","yards_gained","game_id")))
) %>>%
po("colapply", # 把这两列变成因子类型
affect_columns = selector_name(c("posteam","defteam")),
applicator = as.factor) %>>%
po("filter", # 去除高度相关的列
filter = mlr3filters::flt("find_correlation"), filter.cutoff=0.3) %>>%
po("scale", scale = F) %>>% # 中心化
po("removeconstants") # 去掉零方差变量
可以看到mlr3
的数据预处理与tidymodels
相比,在语法上确实是有些复杂了,而且由于使用的R6
,很多语法看起来很别扭,文档也说的不清楚,对于新手来说还是tidymodels
更好些。目前来说最大的优势可能就是速度了吧。。。
如果你想把预处理步骤应用于数据,得到预处理之后的数据,可以用以下代码:
task_prep <- pbp_prep$clone()$train(pbp_task)[[1]]
dim(task_train$data())
## 68982 26
task_prep$feature_types
## id type
## 1: defteam factor
## 2: defteam_score numeric
## 3: defteam_timeouts_remaining factor
## 4: down ordered
## 5: goal_to_go factor
## 6: in_fg_range factor
## 7: in_red_zone factor
## 8: no_huddle factor
## 9: posteam factor
## 10: posteam_score numeric
## 11: posteam_timeouts_remaining factor
## 12: previous_play factor
## 13: qtr ordered
## 14: score_differential numeric
## 15: shotgun factor
## 16: total_pass numeric
## 17: two_min_drill factor
## 18: yardline_100 numeric
## 19: ydstogo numeric
这样就得到了处理好的数据,但是对于mlr3pipelines
来说,这一步做不做都可以。
选择多个模型
还是选择和之前一样的4个模型:逻辑回归、随机森林、决策树、k最近邻:
# 随机森林
rf_glr <- as_learner(pbp_prep %>>% lrn("classif.ranger", predict_type="prob"))
rf_glr$id <- "randomForest"
# 逻辑回归
log_glr <-as_learner(pbp_prep %>>% lrn("classif.log_reg", predict_type="prob"))
log_glr$id <- "logistic"
# 决策树
tree_glr <- as_learner(pbp_prep %>>% lrn("classif.rpart", predict_type="prob"))
tree_glr$id <- "decisionTree"
# k近邻
kknn_glr <- as_learner(pbp_prep %>>% lrn("classif.kknn", predict_type="prob"))
kknn_glr$id <- "kknn"
建立benchmark_grid
类似于tidymodels
中的workflow_set
。
接下来就是选择10折交叉验证,建立多个模型,语法也是很简单了。
set.seed(0520)
# 10折交叉验证
cv <- rsmp("cv",folds=10)
set.seed(0520)
# 建立多个模型
design <- benchmark_grid(
tasks = task_train,
learners = list(rf_glr,log_glr,tree_glr,kknn_glr),
resampling = cv
)
在训练集中,使用10折交叉验证,运行4个模型,看这语法是不是也很简单清晰?
开始计算
下面就是开始计算,和tidymodels
相比,这一块语法更加简单一点,就是建立benchmark_grid
,然后使用benchmark()
函数即可。
# 加速
library(future)
plan("multisession",workers=12)
# 减少屏幕输出
lgr::get_logger("mlr3")$set_threshold("warn")
lgr::get_logger("bbotk")$set_threshold("warn")
# 开始运行
bmr <- benchmark(design,store_models = T)
Growing trees.. Progress: 29%. Estimated remaining time: 1 minute, 14 seconds.
Growing trees.. Progress: 61%. Estimated remaining time: 39 seconds.
Growing trees.. Progress: 92%. Estimated remaining time: 8 seconds.
Growing trees.. Progress: 29%. Estimated remaining time: 1 minute, 16 seconds.
Growing trees.. Progress: 60%. Estimated remaining time: 40 seconds.
Growing trees.. Progress: 91%. Estimated remaining time: 8 seconds.
Growing trees.. Progress: 43%. Estimated remaining time: 40 seconds.
Growing trees.. Progress: 83%. Estimated remaining time: 12 seconds.
Growing trees.. Progress: 42%. Estimated remaining time: 42 seconds.
Growing trees.. Progress: 90%. Estimated remaining time: 7 seconds.
Growing trees.. Progress: 30%. Estimated remaining time: 1 minute, 10 seconds.
Growing trees.. Progress: 62%. Estimated remaining time: 38 seconds.
Growing trees.. Progress: 93%. Estimated remaining time: 7 seconds.
Growing trees.. Progress: 30%. Estimated remaining time: 1 minute, 10 seconds.
Growing trees.. Progress: 61%. Estimated remaining time: 38 seconds.
Growing trees.. Progress: 92%. Estimated remaining time: 7 seconds.
Growing trees.. Progress: 29%. Estimated remaining time: 1 minute, 15 seconds.
Growing trees.. Progress: 60%. Estimated remaining time: 41 seconds.
Growing trees.. Progress: 91%. Estimated remaining time: 9 seconds.
Growing trees.. Progress: 32%. Estimated remaining time: 1 minute, 7 seconds.
Growing trees.. Progress: 73%. Estimated remaining time: 22 seconds.
Growing trees.. Progress: 42%. Estimated remaining time: 42 seconds.
Growing trees.. Progress: 84%. Estimated remaining time: 11 seconds.
Growing trees.. Progress: 32%. Estimated remaining time: 1 minute, 7 seconds.
Growing trees.. Progress: 63%. Estimated remaining time: 36 seconds.
Growing trees.. Progress: 94%. Estimated remaining time: 6 seconds.
# 结果
bmr
<BenchmarkResult> of 40 rows with 4 resampling runs
nr task_id learner_id resampling_id iters warnings errors
1 all_plays randomForest cv 10 0 0
2 all_plays logistic cv 10 0 0
3 all_plays decisionTree cv 10 0 0
4 all_plays kknn cv 10 0 0
查看模型表现
查看结果:
# 默认结果
bmr$aggregate()
nr resample_result task_id learner_id resampling_id iters classif.ce
1: 1 <ResampleResult[22]> all_plays randomForest cv 10 0.2695630
2: 2 <ResampleResult[22]> all_plays logistic cv 10 0.2770287
3: 3 <ResampleResult[22]> all_plays decisionTree cv 10 0.2799570
4: 4 <ResampleResult[22]> all_plays kknn cv 10 0.3220549
也是支持同时查看多个结果的:
measures <- msrs(c("classif.auc","classif.acc","classif.bbrier"))
bmr_res <- bmr$aggregate(measures)
bmr_res[,c(4,7:9)]
learner_id classif.auc classif.acc classif.bbrier
1: randomForest 0.7978436 0.7304370 0.1790968
2: logistic 0.7798504 0.7229713 0.1866577
3: decisionTree 0.7034790 0.7200430 0.2003303
4: kknn 0.7322762 0.6779451 0.2210171
结果可视化
支持ggplot2
语法,使用起来和tidymodels
差不多,也是对结果直接autoplot()
即可。
library(ggplot2)
autoplot(bmr)+theme(axis.text.x = element_text(angle = 45))
喜闻乐见的ROC曲线:
autoplot(bmr,type = "roc")
选择最好的模型
通过比较结果可以发现还是随机森林效果最好~,下面选择随机森林,在训练集上训练,在测试集上测试结果。
这一步并没有使用10折交叉验证,如果你想用,也是可以的~
# 训练
rf_glr$train(task_train)
训练好之后就是在测试集上测试并查看结果:
# 测试
prediction <- rf_glr$predict(task_test)
head(as.data.table(prediction))
row_ids truth response prob.pass prob.run
1: 4 run pass 0.7649998 0.23500021
2: 6 run run 0.4168520 0.58314804
3: 11 pass pass 0.7199717 0.28002834
4: 13 run pass 0.9406333 0.05936668
5: 17 run run 0.4073665 0.59263354
6: 24 pass pass 0.6243693 0.37563072
混淆矩阵:
prediction$confusion
truth
response pass run
pass 10629 3175
run 2955 6235
可视化混淆矩阵:
autoplot(prediction)
当然也是支持多个指标的:
prediction$score(msrs(c("classif.auc","classif.acc","classif.bbrier")))
classif.auc classif.acc classif.bbrier
0.8011720 0.7334087 0.1775684
喜闻乐见ROC曲线:
autoplot(prediction,type = "roc")
image-20220704162604466
总体来看mlr3
和tidymodels
相比有优势也有劣势,基本步骤大同小异,除了预处理步骤比较复杂外,其他地方都比较简单~
初学者还是推荐使用tidymodels
,熟悉了可以试一下mlr3
,集成化程度更高,目前也更加稳定,tidymodels
目前还处于快速开发中,经常出现各种小问题,但是说明文档比较详细。
mlr3
相比之下更稳定一些,速度明显更快!尤其是数据量比较大的时候!但是mlr3
的说明文档并不是很详细,只有mlr3 book
,而且很多用法并没有介绍!经常得自己琢磨。
mlr3 book中文翻译版 可以翻看我之前的推文!
相关文章
- 语言模型如何产品落地?《GPT-3:使用大型语言模型构建创新的NLP产品》新书带你实操
- Java面向对象之创建和使用对象——定义学生/教师类并输出相关信息
- 使用预训练模型,在Jetson NANO上预测公交车到站时间
- 如何使用 SAP UI5 V2 ODataModel 模型 API 实现 deepCreate 的场景以及局限性
- 使用阈值调优改进分类模型性能
- 使用阈值调优改进分类模型性能
- 使用workflow一次完成多个模型的评价和比较
- 使用 TVMC 编译和优化模型
- 使用Visual Python自动生成代码
- 使用腾讯云IM搭建应用内类微信社交聊天模块实践
- 在 PyTorch 中使用梯度检查点在GPU 上训练更大的模型
- ThinkPHP6.0 模型搜索器的使用
- Python使用GARCH,EGARCH,GJR-GARCH模型和蒙特卡洛模拟进行股价预测|附代码数据
- Linux中的进程守护supervisor安装配置及使用
- iseMySQL模型设计:使用最新技术高效实现(mysqlmdl)
- Oracle系统中使用fopen函数(oraclefopen)
- 使用Oracle优化表空间大小(oracle表空间太大)
- MySQL循环变量的使用技巧(mysql循环变量)
- 使用 Cockpit 管理你的树莓派
- 保护你的数据安全:使用Linux系统加密狗(linux系统加密狗)
- Linux下使用Yum:最简便的软件安装方式(linux下使用yum)
- 深入MySQL学习avg函数的使用方法(mysql中avg 用法)
- 使用Redis高级工具实现数据持久化(redis高级工具)
- 方法如何优雅地使用Redis分布式锁(redis锁的最佳使用)
- php教程之phpize使用方法
- MongoDB中MapReduce编程模型使用实例
- 使用远程桌面连接Windows2003&2008服务器详细图文教程