pyspark 梯度提升树
提升 梯度 Pyspark
2023-09-14 09:09:29 时间
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 7 18:15:30 2018
@author: luogan
"""
from pyspark.ml import Pipeline
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql import SparkSession
spark= SparkSession\
.builder \
.appName("dataFrame") \
.getOrCreate()
# Load and parse the data file, converting it to a DataFrame.
data = spark.read.format("libsvm").load("/home/luogan/lg/softinstall/spark-2.2.0-bin-hadoop2.7/data/mllib/sample_libsvm_data.txt")
# Index labels, adding metadata to the label column.
# Fit on whole dataset to include all labels in index.
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
# Automatically identify categorical features, and index them.
# Set maxCategories so features with > 4 distinct values are treated as continuous.
featureIndexer =\
VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])
# Train a GBT model.
gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10)
# Chain indexers and GBT in a Pipeline
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt])
# Train model. This also runs the indexers.
model = pipeline.fit(trainingData)
# Make predictions.
predictions = model.transform(testData)
# Select example rows to display.
predictions.select("prediction", "indexedLabel", "features").show(5)
# Select (prediction, true label) and compute test error
evaluator = MulticlassClassificationEvaluator(
labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Test Error = %g" % (1.0 - accuracy))
gbtModel = model.stages[2]
print(gbtModel) # summary only
+----------+------------+--------------------+
|prediction|indexedLabel| features|
+----------+------------+--------------------+
| 1.0| 1.0|(692,[95,96,97,12...|
| 1.0| 1.0|(692,[121,122,123...|
| 1.0| 1.0|(692,[122,123,124...|
| 1.0| 1.0|(692,[124,125,126...|
| 1.0| 1.0|(692,[124,125,126...|
+----------+------------+--------------------+
only showing top 5 rows
Test Error = 0.0571429
GBTClassificationModel (uid=GBTClassifier_483a8eddb2c54d041fae) with 10 trees
相关文章
- WWW'22「Meta」MetaBalance:动态调整辅助任务的梯度提升多任务推荐系统性能
- 政务数据质量管理提升的5个最佳实践
- 多位CS教授操刀,这本书带你入门「提升概率推理」,免费预览章节放出
- 学习Linux与C语言,利用子进程提升编程能力(linuxc子进程)
- 监控Linux服务器性能提升效率(linux性能监控)
- Oracle日志提升数据安全性(oracle日志的作用)
- 学习Linux源码:提升编程技能的无穷乐趣(阅读linux源码)
- 优化优化Oracle服务器参数提升数据库性能(oracle服务器参数)
- 使用Redis Web客户端提升开发效率(redisweb客户端)
- 掌握Oracle分区更新的技巧,提升数据库性能(oracle分区更新)
- 优化数据库访问:探讨oracle视图的效率提升(oracle视图的效率)
- MSSQL中提升模糊查询效率的窍门(mssql模糊查询效率)
- 快速学习Oracle Q操作,提升工作效率(oracle q 操作)
- MySQL数据库的三种范式重复性最小化,数据结构优化,数据处理灵活性,提升数据质量与维护效率
- 提升系统性能Redis缓存的清理工作(服务器redis缓存清理)
- 新一代Xe8推出Redis,性能大幅提升(xe8 redis)
- 了解Oracle事物表让性能提升更轻松(oracle 事物表)
- 用 Oracle 轻松替换,提升效率(oracle使用替换)
- 提升使用Redis数据可靠性的关键要素(使用redis数据可靠性)
- Redis面试技巧提升高频考题应对能力(redis 高频面试)