本文介绍了星火随机森林模型的交叉验证的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我要评估一个随机森林的一些数据被训练。是否有阿帕奇星火任何实用程序执行相同或做我必须手动进行交叉验证?任何帮助/指导方针将大大AP preciated。
I want to evaluate a random forest being trained on some data. Is there any utility in Apache Spark to do the same or do I have to perform cross validation manually? Any help/guidelines would be greatly appreciated.
推荐答案
ML提供<$c$c>CrossValidator$c$c>类可用于执行交叉验证和参数的搜索。假设你的数据已经preprocessed您可以添加交叉验证如下:
ML provides CrossValidator
class which can be used to perform cross-validation and parameter search. Assuming your data is already preprocessed you can add cross-validation as follows:
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
// [label: double, features: vector]
trainingData org.apache.spark.sql.DataFrame = ???
val nFolds: Int = ???
val NumTrees: Int = ???
val rf = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setNumTrees(NumTrees)
val pipeline = new Pipeline().setStages(Array(rf))
val paramGrid = new ParamGridBuilder().build() // No parameter search
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
// "f1", "precision", "recall", "weightedPrecision", "weightedRecall"
.setMetricName("precision")
val cv = new CrossValidator()
// ml.Pipeline with ml.classification.RandomForestClassifier
.setEstimator(pipeline)
// ml.evaluation.MulticlassClassificationEvaluator
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(nFolds)
val model = cv.fit(trainingData) // trainingData: DataFrame
使用PySpark:
Using PySpark:
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
trainingData = ... # DataFrame[label: double, features: vector]
numFolds = ... # Integer
rf = RandomForestClassifier(labelCol="label", featuresCol="features")
evaluator = MulticlassClassificationEvaluator() # + other params as in Scala
pipeline = Pipeline(stages=[rf])
crossval = CrossValidator(
estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=evaluator,
numFolds=numFolds)
model = crossval.fit(trainingData)
这篇关于星火随机森林模型的交叉验证的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!