-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtrain_spark_DT.java
104 lines (76 loc) · 4.7 KB
/
train_spark_DT.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// package nl.craftsmen.spark.iris;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.when;
/****
# credit
https://craftsmen.nl/an-introduction-to-machine-learning-with-apache-spark/
https://github.com/Silfen66/SparkIris/blob/master/src/main/java/nl/craftsmen/spark/iris/SparkIris.java
# ML spark feature vectore set up
https://spark.apache.org/docs/latest/ml-classification-regression.html
**** /
/**
* Apache Spark MLLib Java algorithm for classifying the Iris Species
* into three categories using a Random Forest Classification algorithm.
*
**/
public class train_spark_DT {
private static final String PATH = "train_data_java.csv";
public static void main(String[] args) {
// initialise Spark session
SparkSession sparkSession = SparkSession.builder().appName("train_spark_DT").config("spark.master", "local").getOrCreate();
// load dataset, which has a header at the first row
Dataset<Row> rawData = sparkSession.read().option("header", "true").csv(PATH);
// cast the values of the features to doubles for usage in the feature column vector
Dataset<Row> transformedDataSet = rawData.withColumn("vendor_id", rawData.col("vendor_id").cast("double"))
.withColumn("passenger_count", rawData.col("passenger_count").cast("double"))
.withColumn("pickup_longitude", rawData.col("pickup_longitude").cast("double"))
.withColumn("pickup_latitude", rawData.col("pickup_latitude").cast("double"))
.withColumn("dropoff_longitude", rawData.col("dropoff_longitude").cast("double"))
.withColumn("dropoff_latitude", rawData.col("dropoff_latitude").cast("double"))
.withColumn("trip_duration", rawData.col("trip_duration").cast("double"));
// add a numerical label column for the Random Forest Classifier
//transformedDataSet = transformedDataSet
// .withColumn("trip_duration", rawData.col("trip_duration").cast("double"));
// identify the feature colunms
String[] inputColumns = {"vendor_id","passenger_count", "pickup_longitude","pickup_latitude", "dropoff_longitude", "dropoff_latitude"};
VectorAssembler assembler = new VectorAssembler().setInputCols(inputColumns).setOutputCol("features");
Dataset<Row> featureSet = assembler.transform(transformedDataSet);
// split data random in trainingset (70%) and testset (30%) using a seed so results can be reproduced
long seed = 5043;
Dataset<Row>[] trainingAndTestSet = featureSet.randomSplit(new double[]{0.7, 0.3}, seed);
Dataset<Row> trainingSet = trainingAndTestSet[0];
Dataset<Row> testSet = trainingAndTestSet[1];
trainingSet.show();
// train the algorithm based on a Random Forest Classification Algorithm with default values
DecisionTreeRegressor dt = new DecisionTreeRegressor()
.setLabelCol("trip_duration")
.setFeaturesCol("features");
DecisionTreeRegressionModel dtModel = dt.fit(trainingSet);
Dataset<Row> predictions = dtModel.transform(testSet);
RegressionEvaluator evaluator = new RegressionEvaluator()
.setLabelCol("trip_duration")
.setPredictionCol("prediction")
.setMetricName("rmse");
double rmse = evaluator.evaluate(predictions);
// test the model against the testset and show results
System.out.println("----------------- prediction ----------------- ");
predictions.select("id", "trip_duration", "prediction").show(20);
System.out.println("----------------- prediction ----------------- ");
// evaluate the model
//RandomForestRegressionModel rfModel = (RandomForestRegressionModel)(model.stages()[1]);
System.out.println("----------------- accuracy ----------------- ");
System.out.println("Trained DT model:\n" + dtModel.toDebugString());
System.out.println("accuracy: " + rmse );
System.out.println("----------------- accuracy ----------------- ");
}
}