zl程序教程

您现在的位置是:首页 >  工具

当前栏目

Spark中累加器的应用及场景

应用Spark 场景
2023-09-27 14:28:03 时间

一、什么是累加器

累加器(Accumulators)与广播变量(Broadcast Variables)共同作为Spark提供的两大共享变量,主要用于跨集群的数据节点之间的数据共享,突破数据在集群各个executor不能共享问题。

而累加器主要定义在driver节点,在executor节点进行操作,最后在driver节点聚合结果做进一步的处理。

二、常见的累加器

Spark提供了三种常见的累加器,分别是LongAccumulator(参数支持Integer、Long)、DoubleAccumulator(参数支持Float、Double)、CollectionAccumulator(参数支持任意类型)

如果以上三种累加其还不满足你的需求,别慌,Spark还支持用户可自定自己的专属累加器-自定义累加器(User-defined Accumulator),用户定义自己的类并继成AccumulatorV2实现自己的专属累加器。当然上边的三个累加器也都继成了AccumulatorV2,也属于自定义累加器,只不过spark帮我们实现了一些常用的

三、累加器使用

1、累加器的使用场景

累加器常用来某些统计类场景,比如统计最近1小时多少用户或 IP访问数量;监控某些灰黑产利用平台可能存在漏洞大肆薅羊毛,并短信或邮件通知相关人员及时采取对策,等相关场景会用到

大多情况下我们的数据是海量的,以task的形式在集群各个executor上执行处理,数据处完成后,我们需要得到最终数据做下一步处理。我们不可能在executor上发送短信或邮件,那样每个节点都会通知本节点处理的结果,也会收到很多邮件,显然不是我们想要的结果

这时候强大的累加器就排上用场了,在driver端定义累加器,在各个executor端做累加操作,最终在driver端完成聚合,这样就达到我们的目的了

2、累加器的案例实现

LongAccumulator 

对1、2、3、4求和

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import scala.collection.JavaConversions._

object TestAccumulator {
     def main(args: Array[String]): Unit = {
      val ss: SparkSession = SparkSession.builder().appName("test-ccumulator").master("local[2]").getOrCreate()
      val sc = ss.sparkContext
      val longAccumulator  = sc.longAccumulator("My longAccumulator")
      sc.makeRDD(Arrays.asList(1,2,3,4)).foreach(v => {
        longAccumulator.add(v)
        println(Thread.currentThread().getName+" longAccumulator="+longAccumulator.value) 
       })
      println("Driver "+Thread.currentThread().getName+" longAccumulator="+longAccumulator.value) 
  }
}

输出如下

Executor task launch worker for task 1 longAccumulator=3
Executor task launch worker for task 0 longAccumulator=1
Executor task launch worker for task 0 longAccumulator=3
Executor task launch worker for task 1 longAccumulator=7

Driver main longAccumulator=10

 可以看到 task0有两次,分别对1、2累加,两次累加结果分别为1、3,而 task1也有两次,分别对3、4累加,两次累加结果分别为3、7。最后在driver端进行聚合,3(task0)+7(task1)=10

CollectionAccumulator

给话费低于5元的用户统一发送短信提醒

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.streaming.dstream.InputDStream

import org.apache.spark.util.CollectionAccumulator

import scala.collection.JavaConversions._

object TestAccumulator {
   def main(args: Array[String]): Unit = {
      val ss: SparkSession = SparkSession.builder().appName("test-ccumulator").master("local[2]").getOrCreate()
      val sc = ss.sparkContext
      val collectionAccumulator: CollectionAccumulator[Student] = sc.collectionAccumulator("My collectionAccumulator")

      sc.makeRDD(Arrays.asList(new Student("18123451234","张三",3),new Student("17123451235","李四",2))).foreach(v => {
        if(v.balance < 5) {
          collectionAccumulator.add(v)
          println(Thread.currentThread().getName+" collectionAccumulator="+v.name) 
        }
        
       })
       for(obj <-  collectionAccumulator.value){
    	     println("Driver "+Thread.currentThread().getName+" collectionAccumulator="+obj.name) 
    	     println("尊敬的"+obj.phone+"用户,您的话费余额较低,为了不影响使用,请尽快充值") 
       }
  }
}
class Student(phonex: String, namex: String,balancex: Int) extends Serializable{
   var phone: String = phonex
   var name: String = namex
   var balance: Int = balancex
}

输出如下

Executor task launch worker for task 1 collectionAccumulator=李四
Executor task launch worker for task 0 collectionAccumulator=张三


Driver main collectionAccumulator=李四
尊敬的17123451235用户,您的话费余额较低,为了不影响使用,请尽快充值
Driver main collectionAccumulator=张三
尊敬的18123451234用户,您的话费余额较低,为了不影响使用,请尽快充值

User-Defined Accumulator

统计出在一次限次领券活动中,监控领取了多张优惠券的用户,并邮件告警及时处理漏洞,防止意想不到的刷券行为

用SparkStreaming批处理读取hive准实时领券表,把存在的被刷券信息,一次邮件告警出来

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.streaming.dstream.InputDStream
import org.apache.spark.streaming.{Minutes, StreamingContext}
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.mutable
import org.apache.spark.util.CollectionAccumulator

import scala.collection.JavaConversions._
import org.apache.spark.util.AccumulatorV2
import java.util.ArrayList
import java.util.Collections


/**
 * Desc:
 * 刷券监控
 */
object TestAccumulator {
  private val logger: Logger = LoggerFactory.getLogger(TestAccumulator.getClass)

  private lazy val ss: SparkSession = SparkSession.builder()
    .appName("test-ccumulator")
    //.master("local[2]")
    .config("spark.sql.adaptive.enabled", true)
    .config("spark.sql.autoBroadcastJoinThreshold", "-1")
    .config("spark.sql.crossJoin.enabled", true)
    .config("spark.task.maxFailures", 5)
    .config("spark.files.ignoreCorruptFiles", true)
    .config("spark.files.ignoreMissingFiles", true)
    .config("spark.sql.storeAssignmentPolicy", "LEGACY")
    .config("dfs.client.block.write.replace-datanode-on-failure.policy", "NEVER")
    .config("mapred.output.compress", true)
    .config("hive.exec.compress.output", true)
    .config("mapreduce.map.output.compress.codec", "org.apache.hadoop.io.compress.SnappyCodec")
    .config("mapreduce.output.fileoutputformat.compress", true)
    .config("mapreduce.output.fileoutputformat.compress.codec", "org.apache.hadoop.io.compress.SnappyCodec")
    .enableHiveSupport()
    .getOrCreate()

  def main(args: Array[String]): Unit = {
    val sc = ss.sparkContext
    val ssc: StreamingContext = new StreamingContext(sc, Minutes(5)) //每5分钟跑一批
    
    //定义累加器
    val myListAccumulator: MyListAccumulator[CpsActivity] = new MyListAccumulator
    //注册累加器
    sc.register(myListAccumulator, "MyListAccumulator")
    
    val queue: mutable.Queue[RDD[String]] = new mutable.Queue[RDD[String]]()
    val queueDS: InputDStream[String] = ssc.queueStream(queue)
    queueDS.foreachRDD(rdd => {
      //统计出在一次限次领券活动中,监控领取了多张优惠券的用户,并邮件告警及时处理漏洞
      val sql = "select cps_id,batch_id,activity_id,activity_name,user_id,count(*) from cps_coupon_record group by cps_id,batch_id,activity_id,activity_name,user_id having count(*) > 1"
      print(sql)
      val dataFrame: DataFrame = ss.sql(sql);
      dataFrame.foreachPartition { line =>
        while (line.hasNext) {
          val row = line.next();
          val cpsId = row.getAs("cps_id");
          val batchId = row.getAs("batch_id");
          val activityId = row.getAs("activity_id");
          val activityName = row.getAs("activity_name");
          val cpsActivity = new CpsActivity(cpsId, batchId, activityId, activityName)
          myListAccumulator.add(cpsActivity);
        }
      }
      //每次执行完成清空累加器
      myListAccumulator.reset();
      ss.sqlContext.clearCache()
    })
    //发送邮件或执行其它操作,这里就省略了发邮件代码
    val msg = "您好,巡检系统监控到以下优惠券活动,存在同一个用户领取了多张优惠券情况,可能存在刷券行为,请及时核对并处理\n"
    var sb = new StringBuilder()
    sb.append(msg)
    sb.append(String.format("%s,%s,%s,%s", "活动名称","活动ID","批次号","券ID")).append("\n")
    for ( obj <- myListAccumulator.value) {
      val text = String.format("%s,%s,%s,%s\n", obj.activityName,obj.activityId,obj.batchId,obj.cpsId);
      sb.append(text)
    }
    println(sb.toString())
    
    ssc.start()
    ssc.awaitTermination()
  }
}
/**
 * 券信息
 */
class CpsActivity(cpsIdx: String, batchIdx: String,activityIdx: String, activityNamex: String) extends Serializable{
   var cpsId: String = cpsIdx
   var batchId: String = batchIdx
   var activityId: String = activityIdx
   var activityName: String = activityNamex
}

/**
 * 自定义累加器
 */
class MyListAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
  private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]())

  override def isZero: Boolean = _list.isEmpty

  override def copyAndReset(): MyListAccumulator[T] = new MyListAccumulator

  override def copy(): MyListAccumulator[T] = {
    val newAcc = new MyListAccumulator[T]
    _list.synchronized {
      newAcc._list.addAll(_list)
    }
    newAcc
  }

  override def reset(): Unit = _list.clear()

  override def add(v: T): Unit = _list.add(v)

  override def merge(other: AccumulatorV2[T, java.util.List[T]]): Unit = other match {
    case o: MyListAccumulator[T] => _list.addAll(o.value)
    case _ => throw new UnsupportedOperationException(
      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
  }

  override def value: java.util.List[T] = _list.synchronized {
    java.util.Collections.unmodifiableList(new ArrayList[T](_list))
  }
}

输出如下

您好,巡检系统监控到以下优惠券活动,存在同一个用户领取了多张优惠券情况,可能存在刷券行为,请及时核对并处理

活动名称,活动ID,批次号,券ID
新人专享券,12345,8888,1234445
新人专享券,12345,8888,1234446

参考:RDD Programming Guide - Spark 2.3.0 DocumentationSpark 2.3.0 programming guide in Java, Scala and Pythonhttps://spark.apache.org/docs/2.3.0/rdd-programming-guide.html#accumulators