From 1455a71509fef24dfc71591730adffd71518b1a7 Mon Sep 17 00:00:00 2001 From: yennanliue Date: Mon, 6 Aug 2018 06:43:40 +0100 Subject: [PATCH] add DT model --- spark_/train_spark_DT.java | 104 +++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 spark_/train_spark_DT.java diff --git a/spark_/train_spark_DT.java b/spark_/train_spark_DT.java new file mode 100644 index 0000000..1ee9e0f --- /dev/null +++ b/spark_/train_spark_DT.java @@ -0,0 +1,104 @@ +// package nl.craftsmen.spark.iris; + +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.DecisionTreeRegressionModel; +import org.apache.spark.ml.regression.DecisionTreeRegressor; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; + +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.when; + +/**** + +# credit + +https://craftsmen.nl/an-introduction-to-machine-learning-with-apache-spark/ +https://github.com/Silfen66/SparkIris/blob/master/src/main/java/nl/craftsmen/spark/iris/SparkIris.java + +# ML spark feature vectore set up + +https://spark.apache.org/docs/latest/ml-classification-regression.html + +**** / + +/** + + * Apache Spark MLLib Java algorithm for classifying the Iris Species + * into three categories using a Random Forest Classification algorithm. + * + + **/ + + +public class train_spark_DT { + + private static final String PATH = "train_data_java.csv"; + + public static void main(String[] args) { + + // initialise Spark session + SparkSession sparkSession = SparkSession.builder().appName("train_spark_DT").config("spark.master", "local").getOrCreate(); + + // load dataset, which has a header at the first row + Dataset rawData = sparkSession.read().option("header", "true").csv(PATH); + + // cast the values of the features to doubles for usage in the feature column vector + Dataset transformedDataSet = rawData.withColumn("vendor_id", rawData.col("vendor_id").cast("double")) + .withColumn("passenger_count", rawData.col("passenger_count").cast("double")) + .withColumn("pickup_longitude", rawData.col("pickup_longitude").cast("double")) + .withColumn("pickup_latitude", rawData.col("pickup_latitude").cast("double")) + .withColumn("dropoff_longitude", rawData.col("dropoff_longitude").cast("double")) + .withColumn("dropoff_latitude", rawData.col("dropoff_latitude").cast("double")) + .withColumn("trip_duration", rawData.col("trip_duration").cast("double")); + + // add a numerical label column for the Random Forest Classifier + //transformedDataSet = transformedDataSet + // .withColumn("trip_duration", rawData.col("trip_duration").cast("double")); + + + // identify the feature colunms + String[] inputColumns = {"vendor_id","passenger_count", "pickup_longitude","pickup_latitude", "dropoff_longitude", "dropoff_latitude"}; + VectorAssembler assembler = new VectorAssembler().setInputCols(inputColumns).setOutputCol("features"); + Dataset featureSet = assembler.transform(transformedDataSet); + + // split data random in trainingset (70%) and testset (30%) using a seed so results can be reproduced + long seed = 5043; + Dataset[] trainingAndTestSet = featureSet.randomSplit(new double[]{0.7, 0.3}, seed); + Dataset trainingSet = trainingAndTestSet[0]; + Dataset testSet = trainingAndTestSet[1]; + + trainingSet.show(); + + // train the algorithm based on a Random Forest Classification Algorithm with default values + DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setLabelCol("trip_duration") + .setFeaturesCol("features"); + DecisionTreeRegressionModel dtModel = dt.fit(trainingSet); + Dataset predictions = dtModel.transform(testSet); + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("trip_duration") + .setPredictionCol("prediction") + .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions); + + + // test the model against the testset and show results + System.out.println("----------------- prediction ----------------- "); + predictions.select("id", "trip_duration", "prediction").show(20); + System.out.println("----------------- prediction ----------------- "); + + // evaluate the model + //RandomForestRegressionModel rfModel = (RandomForestRegressionModel)(model.stages()[1]); + + System.out.println("----------------- accuracy ----------------- "); + System.out.println("Trained DT model:\n" + dtModel.toDebugString()); + System.out.println("accuracy: " + rmse ); + System.out.println("----------------- accuracy ----------------- "); + } +} \ No newline at end of file