Skip to content

Commit

Permalink
fix feature vector setting and make spark java RF model work successf…
Browse files Browse the repository at this point in the history
…ully
  • Loading branch information
yennanliu committed Aug 3, 2018
1 parent fb08d4c commit 6681684
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions spark_/train_spark_RF.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,27 @@
import static org.apache.spark.sql.functions.when;

/****
credit
# credit
https://craftsmen.nl/an-introduction-to-machine-learning-with-apache-spark/
https://github.com/Silfen66/SparkIris/blob/master/src/main/java/nl/craftsmen/spark/iris/SparkIris.java
# ML spark feature vectore set up
https://spark.apache.org/docs/latest/ml-classification-regression.html
**** /
/**
* Apache Spark MLLib Java algorithm for classifying the Iris Species
* into three categories using a Random Forest Classification algorithm.
*/
*
**/


public class train_spark_RF {

private static final String PATH = "train_data_java.csv";
Expand Down Expand Up @@ -50,7 +63,7 @@ public static void main(String[] args) {


// identify the feature colunms
String[] inputColumns = {"vendor_id","passenger_count", "pickup_longitude","pickup_latitude", "dropoff_longitude", "dropoff_latitude","trip_duration"};
String[] inputColumns = {"vendor_id","passenger_count", "pickup_longitude","pickup_latitude", "dropoff_longitude", "dropoff_latitude"};
VectorAssembler assembler = new VectorAssembler().setInputCols(inputColumns).setOutputCol("features");
Dataset<Row> featureSet = assembler.transform(transformedDataSet);

Expand All @@ -63,7 +76,9 @@ public static void main(String[] args) {
trainingSet.show();

// train the algorithm based on a Random Forest Classification Algorithm with default values
RandomForestRegressor rf = new RandomForestRegressor();
RandomForestRegressor rf = new RandomForestRegressor()
.setLabelCol("trip_duration")
.setFeaturesCol("features");
RandomForestRegressionModel rfModel = rf.fit(trainingSet);
Dataset<Row> predictions = rfModel.transform(testSet);
RegressionEvaluator evaluator = new RegressionEvaluator()
Expand Down

0 comments on commit 6681684

Please sign in to comment.