From f0d1cef8e74005e460a5199093445a689711ab6d Mon Sep 17 00:00:00 2001 From: yennanliue Date: Sun, 17 Mar 2019 12:07:05 +0000 Subject: [PATCH] fix spark script code layout --- spark_/data_preprocess.java | 23 ---------------------- spark_/prepare.py | 15 --------------- spark_/spark_SQL.java | 5 ----- spark_/train_spark_KNN.py | 37 ------------------------------------ spark_/train_spark_RF.py | 8 -------- spark_/train_spark_RF_2.java | 4 ---- spark_/train_spark_RF_3.java | 27 ++++++++++++-------------- 7 files changed, 12 insertions(+), 107 deletions(-) diff --git a/spark_/data_preprocess.java b/spark_/data_preprocess.java index b4ef691..b2492ff 100644 --- a/spark_/data_preprocess.java +++ b/spark_/data_preprocess.java @@ -1,33 +1,21 @@ - import java.io.*; import java.sql.*; - import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; - import static org.apache.spark.sql.functions.col; import static org.apache.spark.sql.functions.when; import static org.apache.spark.sql.functions.avg; import static org.apache.spark.sql.functions.col; import static org.apache.spark.sql.functions.max; - import org.apache.log4j.Level; import org.apache.log4j.Logger; - - /* - Credit - 1) https://github.com/jleetutorial/sparkTutorial/blob/master/src/main/java/com/sparkTutorial/sparkSql/HousePriceSolution.java - - */ - - public class data_preprocess { public static final String delimiter = ","; @@ -38,11 +26,7 @@ public static void main(String[] args) { // step 2 : read csv via spark //String args = "some arg"; read_csv("test"); - } - - - public static void read_csv (String csvFile) { @@ -69,11 +53,9 @@ public static void read_csv (String csvFile) .agg(avg("pickup_longitude"), max("pickup_latitude")) .show(); - transformedDataSet.groupBy("passenger_count") .agg(avg("trip_duration"), max("pickup_latitude"),max("pickup_longitude")) .show(); - // PART 3 : filter trip_duration < 500 data System.out.println(" ---------------- PART 3 ----------------"); transformedDataSet.filter(col("trip_duration").$less(500)).show(); @@ -83,11 +65,6 @@ public static void read_csv (String csvFile) Dataset transformedDataSet_ = transformedDataSet.withColumn( "trip_duration_", col("trip_duration").divide(10).cast("double")); transformedDataSet_.select( col("trip_duration_"),col("trip_duration")).show(); - - } } - - - diff --git a/spark_/prepare.py b/spark_/prepare.py index d61ecac..02e6a77 100644 --- a/spark_/prepare.py +++ b/spark_/prepare.py @@ -1,13 +1,9 @@ # python 2.7 - - import os - from pyspark import SparkConf, SparkContext from pyspark.sql import SQLContext from pyspark.sql.functions import count, avg - sc =SparkContext() SparkContext.getOrCreate() conf = SparkConf().setAppName("building a warehouse") @@ -17,9 +13,6 @@ print (sc) print ("==================") - - - def run(): df_train = sqlCtx.read.format('com.databricks.spark.csv')\ .options(header='true', inferschema='true')\ @@ -72,15 +65,7 @@ def filter_column(): order by 2 desc limit 10""").show() - - - - if __name__ == '__main__': #run() #test() filter_column() - - - - diff --git a/spark_/spark_SQL.java b/spark_/spark_SQL.java index 2dbb279..b8816c8 100644 --- a/spark_/spark_SQL.java +++ b/spark_/spark_SQL.java @@ -16,7 +16,6 @@ */ //package org.apache.spark.examples.sql; - /* credit https://github.com/apache/spark/blob/master/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java @@ -169,9 +168,6 @@ private static void runBasicDataFrameExample(SparkSession spark) throws Analysis Dataset sqlDF2 = spark.sql("SELECT id, pickup_datetime, date(pickup_datetime) as date,month(pickup_datetime) as month, date_format(pickup_datetime, 'EEEE') as dow FROM df limit 10 "); sqlDF2.show(); - - - // $example on:global_temp_view$ // Register the DataFrame as a global temporary view df.createGlobalTempView("df"); @@ -199,6 +195,5 @@ private static void runBasicDataFrameExample(SparkSession spark) throws Analysis } - } \ No newline at end of file diff --git a/spark_/train_spark_KNN.py b/spark_/train_spark_KNN.py index b56281c..a5f9d45 100644 --- a/spark_/train_spark_KNN.py +++ b/spark_/train_spark_KNN.py @@ -1,8 +1,6 @@ # python 3 # -*- coding: utf-8 -*- - - """ * modify from @@ -14,11 +12,8 @@ https://weiminwang.blog/2016/06/09/pyspark-tutorial-building-a-random-forest-binary-classifier-on-unbalanced-dataset/ https://github.com/notthatbreezy/nyc-taxi-spark-ml/blob/master/python/generate-model.py - """ - - # load basics library import csv import os @@ -33,8 +28,6 @@ from pyspark.ml.feature import VectorAssembler from pyspark.ml.clustering import KMeans - - # --------------------------------- # config sc =SparkContext() @@ -49,20 +42,6 @@ - - - -### ================================================ ### - -# feature engineering - - -# HELP FUNC - - -### ================================================ ### - - if __name__ == '__main__': # load data with spark way trainNYC = sc.textFile('train_data_java.csv').map(lambda line: line.split(",")) @@ -93,19 +72,3 @@ print (output_data.take(30)) #print (output_data.toDF().head(30)) print (' ------- KNN model output ------- ' ) - - - - - - - - - - - - - - - - diff --git a/spark_/train_spark_RF.py b/spark_/train_spark_RF.py index dd941f5..30d7765 100644 --- a/spark_/train_spark_RF.py +++ b/spark_/train_spark_RF.py @@ -24,8 +24,6 @@ # ###################################################################################################################################################### - - # load basics library import csv import os @@ -53,8 +51,6 @@ print ("==================") # --------------------------------- - - ### ================================================ ### # # feature engineering @@ -63,8 +59,6 @@ # ### ================================================ ### - - if __name__ == '__main__': # load data with spark way trainNYC = sc.textFile('train_data_java.csv').map(lambda line: line.split(",")) @@ -130,5 +124,3 @@ rfModel = model.stages[1] print(' *** : RF MODEL SUMMARY : ', rfModel) # summary only print ('='*100) - - diff --git a/spark_/train_spark_RF_2.java b/spark_/train_spark_RF_2.java index 223b493..4c6e91f 100644 --- a/spark_/train_spark_RF_2.java +++ b/spark_/train_spark_RF_2.java @@ -1,4 +1,3 @@ - import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; @@ -8,10 +7,8 @@ import org.apache.spark.ml.feature.VectorAssembler; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; - import static org.apache.spark.sql.functions.col; import static org.apache.spark.sql.functions.when; - import java.util.ArrayList; import java.util.List; import java.util.Arrays; @@ -30,7 +27,6 @@ import org.apache.spark.sql.AnalysisException; import static org.apache.spark.sql.functions.col; - /* Run the RF model again with pre-process data via diff --git a/spark_/train_spark_RF_3.java b/spark_/train_spark_RF_3.java index 47e4de3..5aeed5b 100644 --- a/spark_/train_spark_RF_3.java +++ b/spark_/train_spark_RF_3.java @@ -1,5 +1,4 @@ // package nl.craftsmen.spark.iris; - import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; @@ -10,10 +9,8 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; - import static org.apache.spark.sql.functions.col; import static org.apache.spark.sql.functions.when; - import java.util.Arrays; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineStage; @@ -23,8 +20,6 @@ import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; - - /**** # credit @@ -46,7 +41,6 @@ **/ - public class train_spark_RF_3 { private static final String PATH = "train_data_java.csv"; @@ -94,14 +88,20 @@ public static void main(String[] args) { // --------------------------- TUNE RF MODEL ------------------------------ // hash make data fit pipeline form - HashingTF hashingTF = new HashingTF() + /* HashingTF hashingTF = new HashingTF() .setNumFeatures(1000) - .setInputCol(trainingSet) + .setInputCol(trainingSet.getOutputCol()) .setOutputCol("features"); - + */ // set up pipepline - Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {trainingSet,hashingTF ,rf}); + //Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {trainingSet,rf}); + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("trip_duration") + .setPredictionCol("prediction") + .setMetricName("rmse"); + + // grid search ParamMap[] paramGrid = new ParamGridBuilder() @@ -110,7 +110,7 @@ public static void main(String[] args) { CrossValidator cv = new CrossValidator() - .setEstimator(pipeline) + .setEstimator(evaluator) .setEvaluator(new RegressionEvaluator()) .setEstimatorParamMaps(paramGrid).setNumFolds(2); // Use 3+ in practice @@ -127,10 +127,7 @@ public static void main(String[] args) { //RandomForestRegressionModel rfModel = rf.fit(trainingSet); //Dataset predictions = rfModel.transform(testSet); - RegressionEvaluator evaluator = new RegressionEvaluator() - .setLabelCol("trip_duration") - .setPredictionCol("prediction") - .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions);