package rdd.ml.clustering
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.{SparkConf, SparkContext}
/**
* Created by d00454735 on 2018/7/25.
*/
object KMeansTest {
def main(args: Array[String]): Unit = {
// 以本地模式运行,并初始化SparkContext
val masterURL = "local[*]"
val conf = new SparkConf().setAppName("KMeans Test").setMaster(masterURL)
val sc = new SparkContext(conf)
// 载入数据,并将每一行的数据解析,注意,官网给出的解析使用的是map方法,这里改写成了mapPartitions,因为后者是以数据块为单位的方式处理,其效率要远高于map方式(以数据行位单位处理)
val data = sc.textFile("file:/d:/data/kmeans_data.txt")
val parsedData = data.mapPartitions(partition => parseData(partition)).cache()
// 设定KMeans聚类的参数并对模型进行训练,这里定义聚成2类,迭代20次
val numClusters = 2
val numIterations = 20
val clusters = KMeans.train(parsedData, numClusters, numIterations)
// 计算Within Set Sum of Squared Errors以评估聚类效果
val WSSSE = clusters.computeCost(parsedData)
println(s"Within Set Sum of Squared Errors = $WSSSE")
// 保存模型以及后续加载模型的写法
clusters.save(sc, "target/org/apache/spark/KMeansExample/KMeansModel")
val sameModel = KMeansModel.load(sc, "target/org/apache/spark/KMeansExample/KMeansModel")
}
//解析数据,每一行先以空额分割成数组,然后将数组内的元素转化成Double形式,最后变成向量
def parseData(lines : Iterator[String]): Iterator[Vector] = {
lines.map(s => Vectors.dense(s.split(' ').map(_.toDouble)))
}
}
private[spark] def run(
data: RDD[Vector],
instr: Option[Instrumentation[NewKMeans]]): KMeansModel = {
// 如果数据没有做缓存的话会影响性能,这里做了一个警告检查
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
// 计算平方和并缓存,计算方法就是把每一行的元素平方后加和,然后再开方
val norms = data.map(Vectors.norm(_, 2.0))
norms.persist()
//zip是将元素连接起来,这里的含义就是将数据中的每一行和刚才的norm连接
val zippedData = data.zip(norms).map { case (v, norm) =>
new VectorWithNorm(v, norm)
}
//这里开始真正执行程序了
val model = runAlgorithm(zippedData, instr)
//释放缓存数据
norms.unpersist()
// Warn at the end of the run as well, for increased visibility.
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
//返回模型
model
}