使用workflow一次完成多个模型的评价和比较
使用 模型 比较 多个 完成 一次 评价 workflow
2023-06-13 09:15:10 时间
前面给大家介绍了使用tidymodels
搞定二分类资料的模型评价和比较。
简介的语法、统一的格式、优雅的操作,让人欲罢不能!
但是太费事儿了,同样的流程来了4遍,那要是选择10个模型,就得来10遍!无聊,非常的无聊。
所以个大家介绍简便方法,不用重复写代码,一次搞定多个模型!
本期目录:
- 加载数据和R包
- 数据预处理
- 选择模型
- 选择重抽样方法
- 构建workflow
- 运行模型
- 查看结果
- 可视化结果
- 选择最好的模型用于测试集
加载数据和R包
首先还是加载数据和R包,和前面的一模一样的操作,数据也没变。
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(tidymodels))
library(kknn)
tidymodels_prefer()
all_plays <- read_rds("../000files/all_plays.rds")
set.seed(20220520)
split_pbp <- initial_split(all_plays, 0.75, strata = play_type)
train_data <- training(split_pbp)
test_data <- testing(split_pbp)
数据预处理
pbp_rec <- recipe(play_type ~ ., data = train_data) %>%
step_rm(half_seconds_remaining,yards_gained, game_id) %>%
step_string2factor(posteam, defteam) %>%
step_corr(all_numeric(), threshold = 0.7) %>%
step_center(all_numeric()) %>%
step_zv(all_predictors())
选择模型
直接选择4个模型,你想选几个都是可以的。
lm_mod <- logistic_reg(mode = "classification",engine = "glm")
knn_mod <- nearest_neighbor(mode = "classification", engine = "kknn")
rf_mod <- rand_forest(mode = "classification", engine = "ranger")
tree_mod <- decision_tree(mode = "classification",engine = "rpart")
选择重抽样方法
set.seed(20220520)
folds <- vfold_cv(train_data, v = 10)
folds
## # 10-fold cross-validation
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [62082/6899]> Fold01
## 2 <split [62083/6898]> Fold02
## 3 <split [62083/6898]> Fold03
## 4 <split [62083/6898]> Fold04
## 5 <split [62083/6898]> Fold05
## 6 <split [62083/6898]> Fold06
## 7 <split [62083/6898]> Fold07
## 8 <split [62083/6898]> Fold08
## 9 <split [62083/6898]> Fold09
## 10 <split [62083/6898]> Fold10
构建workflow
这一步就是不用重复写代码的关键,把所有模型和数据预处理步骤自动连接起来。
library(workflowsets)
four_mods <- workflow_set(list(rec = pbp_rec),
list(lm = lm_mod,
knn = knn_mod,
rf = rf_mod,
tree = tree_mod
),
cross = T
)
four_mods
## # A workflow set/tibble: 4 × 4
## wflow_id info option result
## <chr> <list> <list> <list>
## 1 rec_lm <tibble [1 × 4]> <opts[0]> <list [0]>
## 2 rec_knn <tibble [1 × 4]> <opts[0]> <list [0]>
## 3 rec_rf <tibble [1 × 4]> <opts[0]> <list [0]>
## 4 rec_tree <tibble [1 × 4]> <opts[0]> <list [0]>
运行模型
首先是一些运行过程中的参数设置:
keep_pred <- control_resamples(save_pred = T, verbose = T)
然后就是运行4个模型(目前一直是在训练集中),我们给它加速一下:
library(doParallel)
## Loading required package: foreach
##
## Attaching package: 'foreach'
## The following objects are masked from 'package:purrr':
##
## accumulate, when
## Loading required package: iterators
## Loading required package: parallel
cl <- makePSOCKcluster(12) # 加速,用12个线程
registerDoParallel(cl)
four_fits <- four_mods %>%
workflow_map("fit_resamples",
seed = 0520,
verbose = T,
resamples = folds,
control = keep_pred
)
## i 1 of 4 resampling: rec_lm
## ✔ 1 of 4 resampling: rec_lm (18.4s)
## i 2 of 4 resampling: rec_knn
## ✔ 2 of 4 resampling: rec_knn (3m 51.9s)
## i 3 of 4 resampling: rec_rf
## ✔ 3 of 4 resampling: rec_rf (1m 15.6s)
## i 4 of 4 resampling: rec_tree
## ✔ 4 of 4 resampling: rec_tree (6.1s)
four_fits
## # A workflow set/tibble: 4 × 4
## wflow_id info option result
## <chr> <list> <list> <list>
## 1 rec_lm <tibble [1 × 4]> <opts[2]> <rsmp[+]>
## 2 rec_knn <tibble [1 × 4]> <opts[2]> <rsmp[+]>
## 3 rec_rf <tibble [1 × 4]> <opts[2]> <rsmp[+]>
## 4 rec_tree <tibble [1 × 4]> <opts[2]> <rsmp[+]>
stopCluster(cl)
需要很长时间!大家笔记本如果内存不够可能会失败哦~
查看结果
查看模型在训练集中的表现:
collect_metrics(four_fits)
## # A tibble: 8 × 9
## wflow_id .config preproc model .metric .estimator mean n std_err
## <chr> <chr> <chr> <chr> <chr> <chr> <dbl> <int> <dbl>
## 1 rec_lm Preprocessor1_M… recipe logi… accura… binary 0.724 10 1.91e-3
## 2 rec_lm Preprocessor1_M… recipe logi… roc_auc binary 0.781 10 1.88e-3
## 3 rec_knn Preprocessor1_M… recipe near… accura… binary 0.671 10 7.31e-4
## 4 rec_knn Preprocessor1_M… recipe near… roc_auc binary 0.716 10 1.28e-3
## 5 rec_rf Preprocessor1_M… recipe rand… accura… binary 0.732 10 1.48e-3
## 6 rec_rf Preprocessor1_M… recipe rand… roc_auc binary 0.799 10 1.90e-3
## 7 rec_tree Preprocessor1_M… recipe deci… accura… binary 0.720 10 1.97e-3
## 8 rec_tree Preprocessor1_M… recipe deci… roc_auc binary 0.704 10 2.01e-3
查看每一个预测结果,这个就不运行了,毕竟好几万行,太多了。。。
collect_predictions(four_fits)
可视化结果
直接可视化4个模型的结果,感觉比ROC曲线更好看,还给出了可信区间。
这个图可以自己用ggplot2
语法修改。
four_fits %>% autoplot(metric = "roc_auc")+theme_bw()
image-20220704145235120
选择最好的模型用于测试集
选择表现最好的应用于测试集:
rand_res <- last_fit(rf_mod,pbp_rec,split_pbp)
查看在测试集的模型表现:
collect_metrics(rand_res) # test 中的模型表现
image-20220704144956748
使用其他指标查看模型表现:
metricsets <- metric_set(accuracy, mcc, f_meas, j_index)
collect_predictions(rand_res) %>%
metricsets(truth = play_type, estimate = .pred_class)
image-20220704145017664
可视化结果,喜闻乐见的混淆矩阵:
collect_predictions(rand_res) %>%
conf_mat(play_type,.pred_class) %>%
autoplot()
image-20220704145028522
喜闻乐见的ROC曲线:
collect_predictions(rand_res) %>%
roc_curve(play_type,.pred_pass) %>%
autoplot()
image-20220704145041578
还有非常多曲线和评价指标可选,大家可以看我之前的介绍推文~
是不是很神奇呢,完美符合一次挑选多个模型的要求,且步骤清稀,代码美观,非常适合进行多个模型的比较。
相关文章
- Fabric.js 使用纯色遮挡画布(前景色)
- 使用深度学习模型近似简单的大气环流模式
- Linux中PLSQL视频,PLSQL使用视频教程:PLSQL的使用方法「建议收藏」
- laravel 使用资源路由创建控制器关联模型获取不到实例??(坑)
- Flask 学习-48.Flask-RESTX 使用api.model() 模型工厂
- Pytorch中现有网络模型的使用及修改
- 使用ResNet101作为预训练模型训练Faster-RCNN-TensorFlow-Python3-master[通俗易懂]
- 使用阈值调优改进分类模型性能
- [Redis]laravel中使用Redis分布式锁解决并发问题
- 使用 TVMC 编译和优化模型
- 【Laravel系列4.3】模型Eloquent ORM的使用(一)
- 【Laravel系列4.4】模型Eloquent ORM的使用(二)
- 使用 Docker 来快速上手中文 Stable Diffusion 模型:太乙
- 如何在 Termius 中添加带有端口转发的 ssh 命令(使用 -L、-R、-D)
- [Chemical Science | 论文简读] 使用基于Transformer的模型和超图探索策略预测逆合成路径
- TP6.0 模型JSON字段的使用 【系统配置表 key-value】
- 【教程】使用 Captum 解释 GNN 模型预测
- 使用scikit-learn为PyTorch 模型进行超参数网格搜索
- 小孩子如何学会语言?科学家使用计算机模型解释儿童语言学习过程
- Python使用GARCH,EGARCH,GJR-GARCH模型和蒙特卡洛模拟进行股价预测|附代码数据
- Oracle 视图 ALL_GG_AUTO_CDR_COLUMNS 官方解释,作用,如何使用详细说明
- 使用FIO测试Linux性能(fiolinux)
- MySQL:如何使用条件筛选正确查询数据(mysql条件筛选)
- Linux下消息队列MQ的使用(linuxmq)
- 使用 MySQL 长度函数轻松获取数据字段长度(mysql长度函数)
- Oracle视图与多表联接的使用方法详解(oracle视图多表)
- SQL Server表:从字段结构到使用实例(sqlserver表的列)
- 发布基于Redis的订阅发布模型实现信息实时消息交换(使用redis的订阅)
- 直接在JS里创建JSON数据然后遍历使用