Spark使用Pipeline构造机器学习任务【Java】
package ml.dataframe.clustering;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.ml.evaluation.ClusteringEvaluator;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.feature.*;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import java.util.Arrays;
/**
* SPKMeans based on data frame
* Created by d00454735 on 2018/8/29.
*/
public class SPKMeans {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.master("local")
.appName("Java Spark SQL basic example")
.config("spark.some.config.option", "some-value")
.getOrCreate();
// Loads data.
Dataset<Row> dataset = spark.read()
.format("jdbc")
.option("driver", "com.mysql.cj.jdbc.Driver")
.option("url", "jdbc:mysql://127.0.0.1/serverTimezone=UTC")
.option("dbtable", "test.test")
.option("user", "root")
.option("password", "root")
.load();
StringIndexer cvModel = new StringIndexer()
.setInputCol("start_time")
.setOutputCol("start_time_countize");
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{"record_id", "user_id", "voice_time", "start_time_countize"})
.setOutputCol("features");
// Trains a k-means model.
KMeans kmeans = new KMeans().setK(2).setSeed(1L).setFeaturesCol("features");
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {cvModel, assembler, kmeans});
PipelineModel model = pipeline.fit(dataset);
// Make predictions
Dataset<Row> predictions = model.transform(dataset);
// Evaluate clustering by computing Silhouette score
ClusteringEvaluator evaluator = new ClusteringEvaluator();
double silhouette = evaluator.evaluate(predictions);
System.out.println("Silhouette with squared euclidean distance = " + silhouette);
// Shows the result.
Vector[] centers = ((KMeansModel) model.stages()[2]).clusterCenters();
System.out.println("Cluster Centers: ");
for (Vector center: centers) {
System.out.println(center);
}
}
}
欢迎大家关注DataLearner官方微信,接受最新的AI技术推送
