什么是决策树?

决策树及其集成是分类和回归机器学习任务的流行方法。决策树被广泛使用,因为它们易于解释,处理分类特征,扩展到多类分类设置,不需要特征缩放,并且能够捕获非线性和特征相互作用。随机森林和增强算法等树集成算法在分类和回归任务中表现最佳。

常应用于以下类型的场景:

  1. 预测用户贷款是否能够按时还款;
  2. 预测邮件是否是垃圾邮件;
  3. 预测用户是否会购买某件商品等等

官网:分类和回归

决策树的优缺点

优点:

  1. 决策树算法易理解,机理解释起来简单。

  2. 决策树算法可以用于小数据集。

  3. 决策树算法的时间复杂度较小,为用于训练决策树的数据点的对数。

  4. 相比于其他算法智能分析一种类型变量,决策树算法可处理数字和数据的类别。

  5. 能够处理多输出的问题。

  6. 对缺失值不敏感。

  7. 可以处理不相关特征数据。

  8. 效率高,决策树只需要一次构建,反复使用,每一次预测的最大计算次数不超过决策树的深度。

缺点:

  1. 对连续性的字段比较难预测。

  2. 容易出现过拟合。

  3. 当类别太多时,错误可能就会增加的比较快。

  4. 在处理特征关联性比较强的数据时表现得不是太好。

  5. 对于各类别样本数量不一致的数据,在决策树当中,信息增益的结果偏向于那些具有更多数值的特征。

参考博客:决策树算法优缺点

决策树示例——鸢尾花分类

数据集下载:

链接:
https://pan.baidu.com/s/1AshgNxx1wOWhLgKxgjrZww?pwd=lz3l 

提取码:
lz3l

数据集介绍:

iris.data 数据集中共有五个字段,逗号分隔,前四个为特征字段,最后一个为标签字段。

标签字段列一共有三种值,分别是:Iris-setosaIris-versicolorIris-virginica

将数据集中的随机百分之70作为训练集,剩余的作为测试集。

需求实现:

import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object Iris {

    // TODO 鸢尾花种类判断

    def main(args: Array[String]): Unit = {

        val sc: SparkSession = SparkSession
                .builder()
                .appName("Iris")
                .master("local[*]")
                .getOrCreate()

        // 1.加载鸢尾花数据
        val train_data: RDD[String] = sc
                .read
                .textFile("iris.data")
                .rdd

        // 2.将随机百分之70的数据设置为训练集,其余为测试集
        val data: Array[RDD[String]] = train_data.randomSplit(Array(0.7, 0.3))

        // 3.向量转换
        import sc.implicits._

        val trainDF: DataFrame = data(0).map(lines => {
            val arr: Array[String] = lines.split(",")
            LabeledPoint(
                if (arr(4).equals("Iris-setosa")) {
                    1D
                } else if (arr(4).equals("Iris-versicolor")) {
                    2D
                } else {
                    3D
                },
                Vectors.dense(arr.take(4).map(_.toDouble))
            )
        }).toDF("label", "features")

        // 4.创建决策树对象
        val classifier = new DecisionTreeClassifier()

        // 设置最大深度、分支、质量、特征列
        classifier.setMaxDepth(5).setMaxBins(32).setImpurity("gini").setFeaturesCol("features")

        // 5.训练模型
        val model: DecisionTreeClassificationModel = classifier.fit(trainDF)

        // 打印模型
        println(model.toDebugString)

        // 6.将测试集转换成向量
        val testDF: DataFrame = data(1).map(lines => {
            val arr: Array[String] = lines.split(",")
            LabeledPoint(
                if (arr(4).equals("Iris-setosa")) {
                    1D
                } else if (arr(4).equals("Iris-versicolor")) {
                    2D
                } else {
                    3D
                },
                Vectors.dense(arr.take(4).map(_.toDouble))
            )
        }).toDF("label", "features")

        // 7.模型预测
        val result: DataFrame = model.transform(testDF.select("label", "features"))

        // 8.模型预测评估
        result.select("label", "features","prediction").show(100)

        // 9.计算错误率
        val error: Double = result.where("label = prediction").count.toDouble/result.count
        println("错误率为:"+(1-error))

    }

}
11-30 10:37