zl程序教程

您现在的位置是:首页 >  硬件

当前栏目

R-多分类logistic回归(机器学习)

机器学习 分类 回归 Logistic
2023-06-13 09:13:57 时间

多分类logistic回归

在之前文章介绍了,如何在R里面处理多分类的回归模型,得到的是各个因素的系数及相对OR,但是解释性,比二元logistic回归方程要冗杂的多。

那么今天继续前面的基础上,用机器学习的方法来解释多分类问题。 其实最终回归到这类分类问题的本质:有了一系列的影响因素x,那么根据这些影响因素来判断最终y属于哪一类别。

image.png

1.数据案例

这里主要用到DALEX包里面包含的HR数据,里面记录了职工在工作岗位的状态与年龄,性别,工作时长,评价及薪水有关。根据7847条记录来评估,如果一个职工属于男性,68岁,薪水及评价处于3等级,那么该职工可能会处于什么状态。

library(DALEX)
library(iBreakDown)
library(car)
library(questionr)
try(data(package="DALEX"))
data(HR)

# split
set.seed(543)
ind = sample(2,nrow(HR),replace=TRUE,prob=c(0.9,0.1))
trainData = HR[ind==1,]
testData = HR[ind==2,]

# randforest
m_rf = randomForest(status ~ . , data = trainData)

2.随机森林模型

我们根据上述数据,分成训练集与测试集(Train and Test)测试集用来估计随机森林模型的效果。

2.1模型评估

通过对Train数据构建rf模型后,我们对Train数据进行拟合,看一下模型的效果,Accuracy : 0.9357 显示很好,kappa一致性为90%。 那再用该fit去预测test数据, Accuracy : 0.7166 , Kappa : 56% ,显示效果不怎么理想。

# Prediction and Confusion Matrix - Training data 
pred1 <- predict(m_rf, trainData)
head(pred1)
confusionMatrix(pred1, trainData$status)  #

pred2 <- predict(m_rf, testData)
head(pred2)
confusionMatrix(pred2, testData$status)  #

> confusionMatrix(pred1, trainData$status)  #
Confusion Matrix and Statistics

          Reference
Prediction fired   ok promoted
  fired     2478  194       49
  ok          43 1738       80
  promoted    25   64     2375

Overall Statistics
                                          
               Accuracy : 0.9354          
                 95% CI : (0.9294, 0.9411)
    No Information Rate : 0.3613          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.9024          
                                          
 Mcnemar's Test P-Value : < 2.2e-16       

Statistics by Class:

                     Class: fired Class: ok Class: promoted
Sensitivity                0.9733    0.8707          0.9485
Specificity                0.9460    0.9756          0.9804
Pos Pred Value             0.9107    0.9339          0.9639
Neg Pred Value             0.9843    0.9502          0.9718
Prevalence                 0.3613    0.2833          0.3554
Detection Rate             0.3517    0.2467          0.3371
Detection Prevalence       0.3862    0.2641          0.3497
Balanced Accuracy          0.9596    0.9232          0.9644
> 
> pred2 <- predict(m_rf, testData)
> head(pred2)
    1    20    36    42    49    56 
fired fired fired fired fired    ok 
Levels: fired ok promoted
> confusionMatrix(pred2, testData$status)  #
Confusion Matrix and Statistics

          Reference
Prediction fired  ok promoted
  fired      246  62       19
  ok          37 117       37
  promoted    26  46      211

Overall Statistics
                                         
               Accuracy : 0.7166         
                 95% CI : (0.684, 0.7476)
    No Information Rate : 0.3858         
    P-Value [Acc > NIR] : < 2e-16        
                                         
                  Kappa : 0.5692         
                                         
 Mcnemar's Test P-Value : 0.03881        

Statistics by Class:

                     Class: fired Class: ok Class: promoted
Sensitivity                0.7961    0.5200          0.7903
Specificity                0.8354    0.8715          0.8652
Pos Pred Value             0.7523    0.6126          0.7456
Neg Pred Value             0.8671    0.8230          0.8919
Prevalence                 0.3858    0.2809          0.3333
Detection Rate             0.3071    0.1461          0.2634
Detection Prevalence       0.4082    0.2385          0.3533
Balanced Accuracy          0.8157    0.6958          0.8277

2.2变量重要性

我们看到,对影响因素进行重要性排序,等同于P值。在预测时候,哪些因素对y占影响比重较大。这里的variable_importance(),可以有好几种方式对变量进行衡量,这里采用默认的MeanDecreaseGini.

# vip
vip(m_rf)
var=randomForest::importance(m_rf)
var

image.png

2.2边际效应

我们知道了hours,age比较重要,那么是如何重要的,譬如年龄在什么阶段,会导致升职或者开除。 当工作小时在45以内,被开除/离职的概率较大,当工作时常超过60以后,很有可能会被提升。得到升职加薪的机会。 当然了,也可以绘制2D的边际效应,两个因素相互作用的Partial plot

# partial plot
partialPlot(m_rf, HR, age)
head(partial(m_rf, pred.var = "age"))  # returns a data frame

# for all varibles
nm=rownames(var)
# Get partial depedence values for top predictors
pd_df <- partial_dependence(fit = m_rf,
                            vars = nm,
                            data = df_rf,
                            n = c(100, 200))
                        
# Plot partial dependence using edarf
plot_pd(pd_df)

image.png

image.png

2.3个体预测

现在假如有一个员工的信息如下,

      gender      age    hours evaluation salary   status
10000 female 57.96254 54.78624          4      4 promoted

去预测该职工最后的状态: 该预测结果显示,这个职工,有97%的可能性要升职加薪。而他的实际状态也是Promoted。

new_observation=tail(HR,1)
p_fun <- function(object, newdata){predict(object, newdata = newdata, type = "prob")}
bd_rf <- local_attributions(m_rf,
                            data = HR_test,
                            new_observation =  new_observation,
                            predict_function = p_fun)

bd_rf
plot(bd_rf)

image.png

> sessionInfo()
R version 3.6.2 (2019-12-12)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS Mojave 10.14

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  utils     datasets  grDevices methods   base     

other attached packages:
 [1] edarf_1.1.1         ranger_0.12.1       questionr_0.7.0     car_3.0-7          
 [5] carData_3.0-3       nnet_7.3-14         DALEX_1.2.1         vip_0.2.2          
 [9] ggpubr_0.3.0        rstatix_0.5.0       caret_6.0-86        lattice_0.20-41    
[13] pdp_0.7.0           randomForest_4.6-14 iBreakDown_1.2.0    hrbrthemes_0.8.0   
[17] reshape2_1.4.4      RColorBrewer_1.1-2  forcats_0.5.0       stringr_1.4.0      
[21] dplyr_0.8.5         purrr_0.3.4         readr_1.3.1         tidyr_1.0.3        
[25] tibble_3.0.1        ggplot2_3.3.0       tidyverse_1.3.0    

参考

  1. iBreakDown plots for classification models
  2. prediction 预测结果输出为概率
  3. pdp 边际效应
  4. Partial dependence (PD) plots For Random Forests
  5. Explaining Black-Box Machine Learning Models
  6. Interpretable Machine Learning