zl程序教程

您现在的位置是:首页 >  数据库

当前栏目

基于spark源码做ml的自定义功能开发

2023-04-18 14:40:22 时间

spark的ml中已经封装了许多关于特征的处理方式:

极大方便了我们在做数据预处理时的使用。 但是这明显不够,在机器学习的领域中,还有许许多多的处理方式,这些都没有存在于feature包中。 那要如何去实现?

比较简单的方式:spark ml本质上就是对dataframe的操作,可以在代码中处理df以实现该功能。

但是实际应用中发现,这样的方式并不好用,我们所做的处理,纯粹是对df的转换提取等操作,这个过程无法进行落地,也无法加入pipeline做重复训练。

所以,我采用了另一种方式:基于saprk源代码开发 首先介绍一下本次想要实现的功能:WOE

woe的计算逻辑:

计算的逻辑还是比较清楚的,公式如下:

其中 i为数据离散后的组,good i 和 bad i 对应该组好坏的个数, good all 和bad all 对应好坏的总数。

编写代码:

对于woe转换的功能,有如下参数:

  • 输入字段:哪些字段需要做woe转换
  • 输出字段:字段做woe转换之后的新列名是什么
  • 标签列:label列的列名
  • 正类: positiveLabel 确定 1 为 good ,还是 0 为 good

1、自定义一个代码接口

方便transform和transformModel共同使用

trait woeTransformParams extends Params with HasInputCols with HasOutputCols with HasLabelCol{
  val positiveLabel: Param[String] = new Param(this,"positiveLabel","positiveLabel you want to choose",ParamValidators.inArray(Array(woeTransform.one,woeTransform.zero)))
  def getPositiveLabel = ${positiveLabel}
}

2、编写woeTransform

继承Estimator抽象类,实现copy,transformSchema,fit方法。

  • fit方法会生成一个代理df,并通过该代理df生成model。在使用该model进行转换的时候,实际上就是使用代理df里的规则对数据集进行处理
  • transformSchema :生成新的schema信息
  • copy:返回一个相同UID的实例,包含extraMap的信息。

代码实现过程如下:

class woeTransform(override val uid: String) extends Estimator[woeTransformModel] with woeTransformParams with DefaultParamsWritable{
  def this() = this(Identifiable.randomUID("woeTransform"))
  
  def setLabelCol(value:String) = set(labelCol,value)
  def setInputCols(value:Array[String]) = set(inputCols,value)
  def setOutputCols(value:Array[String]) = set(outputCols,value)
 def setPositiveLabel(value:String) = set(positiveLabel,value)

override def copy(extra: ParamMap): Estimator[woeTransformModel] = defaultCopy(extra)

 override def transformSchema(schema: StructType): StructType = {
    val tmpArr = $(inputCols).filter(schema.fieldNames.contains(_))
    require(tmpArr.length == ${inputCols}.length,"输入字段中有schema中不存在的字段")

    val addedFields = $(outputCols).map{outputCol =>
      StructField(outputCol,DoubleType,false)
    }
    StructType(schema ++ addedFields)
  }

  /**
    * dataset中包含训练数据,将该数据计算出surrogateDF并生成model
    */
  override def fit(dataset: Dataset[_]): woeTransformModel = {
    val schema = dataset.schema
    transformSchema(schema,logging = true)

    val lb = new ListBuffer[(String,String,Double)]()
    val cols: Array[Column] = schema.fieldNames.map(col(_))

   
    val newLabel = "new_"+$(labelCol)
    val labelColt = when(col($(labelCol)).equalTo(1.0),woeTransform.one)
      .when(col($(labelCol)).equalTo(1.0),woeTransform.zero)
      .otherwise(col($(labelCol)))
      .as(newLabel)

    val dataFrame = dataset.select((cols.+:(labelColt)):_*)

   //对每一个inputcol的每一个组做woe转换并且加入到listBuffer中
    $(inputCols).foreach{inputCol =>
    
      //crosstab 交叉表计算,具体公式可以问度娘
      val singleInfo = dataFrame.stat.crosstab(inputCol, newLabel)

      val analyseDF = $(positiveLabel) match {
          case woeTransform.zero => singleInfo.withColumnRenamed(woeTransform.zero.toString,"good").withColumnRenamed(woeTransform.one.toString,"bad")
          case woeTransform.one => singleInfo.withColumnRenamed(woeTransform.one.toString,"good").withColumnRenamed(woeTransform.zero.toString,"bad")
        }

      val row = analyseDF.select(sum("bad"),sum("good")).head()

      val (bad,good) = (row.getLong(0).toInt,row.getLong(1).toInt)

      analyseDF.collect().foreach{row =>
        val bi = row.getAs[Long]("bad").toDouble + 0.0000001
        val gi = row.getAs[Long]("good").toDouble + 0.0000001
        val woe = Math.log((gi / good) / ((bi / bad) + 0.0000001))

        lb.+=((inputCol,row.getString(0),woe))
      }

    }

    //将之前记录信息的listbuffer 转成代理 df ,并生成 woeTransformModel
    import dataset.sparkSession.implicits._
    val surrogateDF = lb.toList.toDF()
    copyValues(new woeTransformModel(uid,surrogateDF).setParent(this))
  }
}

object woeTransform extends DefaultParamsReadable[woeTransform]{

   val zero = "0"
   val one = "1"

  override def load(path: String): woeTransform = super.load(path)
}

3、编写woeTransformModel

class woeTransform:

  • 继承Model抽象类,实现copy 、 transformSchema 、 transform方法 。 前两个方法与之前一致。transform方法中主要实现的是,以surrogatedf 为转换逻辑,来处理新的数据集。
  • 实现MLWritable实现模型的写操作。

object woeTransformModel:

  • 实现MLReadable 对模型的 读操作。 读写过程要对应,否则在模型的落地与加载过程中会出错

代码如下:

class woeTransformModel(override val uid: String,val surrogateDF: DataFrame)
  extends Model[woeTransformModel] with woeTransformParams with MLWritable{

  import woeTransformModel._

  def setLabelCol(value:String) = set(labelCol,value)

  def setInputCols(value:Array[String]) = set(inputCols,value)

  def setOutputCols(value:Array[String]) = set(outputCols,value)

  def setPositiveLabel(value:String) = set(positiveLabel,value)

  override def copy(extra: ParamMap): woeTransformModel = {
    val copied = new woeTransformModel(uid,surrogateDF)
    copyValues(copied, extra).setParent(parent)
  }


  /**
    * Transforms the input dataset.
    */
  override def transform(dataset: Dataset[_]): DataFrame = {
    val newSchema = transformSchema (dataset.schema, logging = true)
    val inArray = $(inputCols)
    val outArray = $(outputCols)

    var ruleMap: Map[String,Double] = Map()
    surrogateDF.rdd.collect().foreach{row=>
      val colName: String = row.getString(0)
      val bucket: String = row.getString(1)
      val woe: Double = row.getDouble(2)
      ruleMap += (colName+"-"+bucket -> woe)
    }

   val newRdd =  dataset.toDF.rdd.map{ row=>

      val ab = new ArrayBuffer[Double]()
       for(i <- 0 to inArray.length-1){
         val colName= inArray(i)
         val bucket = row.getAs[Object](colName).toString

         val woe = ruleMap.apply(colName+"-"+bucket)
         ab += woe
       }

      Row.merge(row,Row.fromSeq(ab))
    }

    dataset.sparkSession.createDataFrame(newRdd,newSchema)
  }

  override def transformSchema(schema: StructType): StructType = {
    val tmpArr = $(inputCols).filter(schema.fieldNames.contains(_))
    require(tmpArr.length == $(inputCols).length,"输入字段中有schema中不存在的字段")

    val addedFields = $(outputCols).map{outputCol =>
      StructField(outputCol,DoubleType,false)
    }

    StructType(schema ++ addedFields)
  }

  /**
    * Returns an `MLWriter` instance for this ML instance.
    */
  override def write: MLWriter = new woeTransformModelWriter(this)
}

object woeTransformModel extends MLReadable[woeTransformModel]{

  class woeTransformModelWriter(instance: woeTransformModel) extends MLWriter {

    override protected def saveImpl(path: String): Unit = {
      DefaultParamsWriter.saveMetadata(instance, path, sc)
      val dataPath = new Path(path, "data").toString
      instance.surrogateDF.repartition(1).write.parquet(dataPath)
    }
  }


  class woeTransformReader extends MLReader[woeTransformModel]{

    private val className = classOf[woeTransformModel].getName
    /**
      * Loads the ML component from the input path.
      */
    override def load(path: String): woeTransformModel = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
      val dataPath = new Path(path, "data").toString
      val surrogateDF = sqlContext.read.parquet(dataPath)
      val model = new woeTransformModel(metadata.uid, surrogateDF)
      metadata.getAndSetParams(model)
      model
    }
  }


  /**
    * Returns an `MLReader` instance for this class.
    */
  override def read: MLReader[woeTransformModel] = new woeTransformReader

  override def load(path: String): woeTransformModel = super.load(path)
}

检验功能正确性

我使用了一个简单的数据来做检验,下面是使用我们的计算公式来计算得到的结果.

然后来测试下,我们编写的代码的结果。 将我们刚编写的代码放入org.apache.spark.ml.feature包下,重新编译打包,引入工程.

使用同样的数据集,得到的结果如下:

与之前结果一致。

这里只是为了实现逻辑,并没有对特殊情况做完善。 各位若有想法,可以指出共同探讨