val numFeatures: Int, //特征数量
val numExamples: Long, //数据数量
val numClasses: Int, //类别数量
val maxBins: Int, //最大箱数
val featureArity: Map[Int, Int], //特征元数
val unorderedFeatures: Set[Int], //无序特征集合
val numBins: Array[Int], //箱数
val impurity: Impurity, //不纯度(指使用哪种信息差异公式计算)
val quantileStrategy: QuantileStrategy, //分位数策略
val maxDepth: Int, //最大数深度
val minInstancesPerNode: Int, //每个节点最小的实例数
val minInfoGain: Double, //最小的信息增益
val numTrees: Int, //数的数量
val numFeaturesPerNode: Int //每个节点的特征数
/**
* Convert an input dataset into its TreePoint representation,
* binning feature values in preparation for DecisionTree training.
* @param input Input dataset.
* @param splits Splits for features, of size (numFeatures, numSplits).
* @param metadata Learning and dataset metadata
* @return TreePoint dataset representation
*/
def convertToTreeRDD(
input: RDD[LabeledPoint],
splits: Array[Array[Split]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
// Construct arrays for featureArity for efficiency in the inner loop.
// 构造一个数组,其长度是特征长度,元素的值是该特征对应的特征值的数量,如果是分类属性,直接获取元数据中的featureArity对应的值即可。如果是连续变量则获取该特征下切分点中对应的threshold,threshold就是表明如果值小于它则在树的左边,否则在树的右边
val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
var featureIndex = 0
while (featureIndex < metadata.numFeatures) {
featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
featureIndex += 1
}
// thresholds是一个二维数组,是指连续属性下threshold的数组
val thresholds: Array[Array[Double]] = featureArity.zipWithIndex.map { case (arity, idx) =>
if (arity == 0) {
splits(idx).map(_.asInstanceOf[ContinuousSplit].threshold)
} else {
Array.empty[Double]
}
}
// 根据分类属性的数量和连续属性的threshold数组来转换RDD成树形,也就是找出每个特征下每个元素对应的箱子的ID
input.map { x =>
TreePoint.labeledPointToTreePoint(x, thresholds, featureArity)
}
}
/**
* Convert one LabeledPoint into its TreePoint representation.
* 这个方法就是返回某个特征下某个数据点的箱子的ID
* @param thresholds For each feature, split thresholds for continuous features,
* empty for categorical features.
* @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
* for categorical features.
*/
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
thresholds: Array[Array[Double]],
featureArity: Array[Int]): TreePoint = {
val numFeatures = labeledPoint.features.size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
arr(featureIndex) =
findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
featureIndex += 1
}
new TreePoint(labeledPoint.label, arr)
}
/**
* Find discretized value for one (labeledPoint, feature).
* 寻找变量值对应的箱子ID,如果是连续变量,则在上述的threshold数组中找,否则在特征元中找
* NOTE: We cannot use Bucketizer since it handles split thresholds differently than the old
* (mllib) tree API. We want to maintain the same behavior as the old tree API.
*
* @param featureArity 0 for continuous features; number of categories for categorical features.
*/
private def findBin(
featureIndex: Int,
labeledPoint: LabeledPoint,
featureArity: Int,
thresholds: Array[Double]): Int = {
val featureValue = labeledPoint.features(featureIndex)
// 特征元数等于0表明是连续变量
if (featureArity == 0) {
val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
if (idx >= 0) {
idx
} else {
-idx - 1
}
} else {
// 如果是分类属性,先判断这个值是否在范围内,如果不在报错,否则就把这个值转成整形返回
// Categorical feature bins are indexed by feature values.
if (featureValue < 0 || featureValue >= featureArity) {
throw new IllegalArgumentException(
s"DecisionTree given invalid data:" +
s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
s" but a data point gives it value $featureValue.\n" +
" Bad data point: " + labeledPoint.toString)
}
featureValue.toInt
}
}