Skip to content

Commit

Permalink
fix spark script code layout
Browse files Browse the repository at this point in the history
  • Loading branch information
yennanliu committed Mar 17, 2019
1 parent c061900 commit f0d1cef
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 107 deletions.
23 changes: 0 additions & 23 deletions spark_/data_preprocess.java
Original file line number Diff line number Diff line change
@@ -1,33 +1,21 @@

import java.io.*;
import java.sql.*;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.when;
import static org.apache.spark.sql.functions.avg;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.max;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;



/*
Credit
1) https://github.com/jleetutorial/sparkTutorial/blob/master/src/main/java/com/sparkTutorial/sparkSql/HousePriceSolution.java
*/



public class data_preprocess {

public static final String delimiter = ",";
Expand All @@ -38,11 +26,7 @@ public static void main(String[] args) {
// step 2 : read csv via spark
//String args = "some arg";
read_csv("test");

}



public static void read_csv (String csvFile)

{
Expand All @@ -69,11 +53,9 @@ public static void read_csv (String csvFile)
.agg(avg("pickup_longitude"), max("pickup_latitude"))
.show();


transformedDataSet.groupBy("passenger_count")
.agg(avg("trip_duration"), max("pickup_latitude"),max("pickup_longitude"))
.show();

// PART 3 : filter trip_duration < 500 data
System.out.println(" ---------------- PART 3 ----------------");
transformedDataSet.filter(col("trip_duration").$less(500)).show();
Expand All @@ -83,11 +65,6 @@ public static void read_csv (String csvFile)
Dataset<Row> transformedDataSet_ = transformedDataSet.withColumn(
"trip_duration_", col("trip_duration").divide(10).cast("double"));
transformedDataSet_.select( col("trip_duration_"),col("trip_duration")).show();


}

}



15 changes: 0 additions & 15 deletions spark_/prepare.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
# python 2.7


import os

from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext
from pyspark.sql.functions import count, avg


sc =SparkContext()
SparkContext.getOrCreate()
conf = SparkConf().setAppName("building a warehouse")
Expand All @@ -17,9 +13,6 @@
print (sc)
print ("==================")




def run():
df_train = sqlCtx.read.format('com.databricks.spark.csv')\
.options(header='true', inferschema='true')\
Expand Down Expand Up @@ -72,15 +65,7 @@ def filter_column():
order by 2 desc
limit 10""").show()





if __name__ == '__main__':
#run()
#test()
filter_column()




5 changes: 0 additions & 5 deletions spark_/spark_SQL.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
//package org.apache.spark.examples.sql;


/*
credit
https://github.com/apache/spark/blob/master/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java
Expand Down Expand Up @@ -169,9 +168,6 @@ private static void runBasicDataFrameExample(SparkSession spark) throws Analysis
Dataset<Row> sqlDF2 = spark.sql("SELECT id, pickup_datetime, date(pickup_datetime) as date,month(pickup_datetime) as month, date_format(pickup_datetime, 'EEEE') as dow FROM df limit 10 ");
sqlDF2.show();




// $example on:global_temp_view$
// Register the DataFrame as a global temporary view
df.createGlobalTempView("df");
Expand Down Expand Up @@ -199,6 +195,5 @@ private static void runBasicDataFrameExample(SparkSession spark) throws Analysis


}


}
37 changes: 0 additions & 37 deletions spark_/train_spark_KNN.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# python 3
# -*- coding: utf-8 -*-



"""
* modify from
Expand All @@ -14,11 +12,8 @@
https://weiminwang.blog/2016/06/09/pyspark-tutorial-building-a-random-forest-binary-classifier-on-unbalanced-dataset/
https://github.com/notthatbreezy/nyc-taxi-spark-ml/blob/master/python/generate-model.py
"""



# load basics library
import csv
import os
Expand All @@ -33,8 +28,6 @@
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.clustering import KMeans



# ---------------------------------
# config
sc =SparkContext()
Expand All @@ -49,20 +42,6 @@






### ================================================ ###

# feature engineering


# HELP FUNC


### ================================================ ###


if __name__ == '__main__':
# load data with spark way
trainNYC = sc.textFile('train_data_java.csv').map(lambda line: line.split(","))
Expand Down Expand Up @@ -93,19 +72,3 @@
print (output_data.take(30))
#print (output_data.toDF().head(30))
print (' ------- KNN model output ------- ' )
















8 changes: 0 additions & 8 deletions spark_/train_spark_RF.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
#
######################################################################################################################################################



# load basics library
import csv
import os
Expand Down Expand Up @@ -53,8 +51,6 @@
print ("==================")
# ---------------------------------



### ================================================ ###
#
# feature engineering
Expand All @@ -63,8 +59,6 @@
#
### ================================================ ###



if __name__ == '__main__':
# load data with spark way
trainNYC = sc.textFile('train_data_java.csv').map(lambda line: line.split(","))
Expand Down Expand Up @@ -130,5 +124,3 @@
rfModel = model.stages[1]
print(' *** : RF MODEL SUMMARY : ', rfModel) # summary only
print ('='*100)


4 changes: 0 additions & 4 deletions spark_/train_spark_RF_2.java
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
Expand All @@ -8,10 +7,8 @@
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.when;

import java.util.ArrayList;
import java.util.List;
import java.util.Arrays;
Expand All @@ -30,7 +27,6 @@
import org.apache.spark.sql.AnalysisException;
import static org.apache.spark.sql.functions.col;


/*
Run the RF model again with pre-process data via
Expand Down
27 changes: 12 additions & 15 deletions spark_/train_spark_RF_3.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// package nl.craftsmen.spark.iris;

import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
Expand All @@ -10,10 +9,8 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.when;

import java.util.Arrays;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
Expand All @@ -23,8 +20,6 @@
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;



/****
# credit
Expand All @@ -46,7 +41,6 @@
**/


public class train_spark_RF_3 {

private static final String PATH = "train_data_java.csv";
Expand Down Expand Up @@ -94,14 +88,20 @@ public static void main(String[] args) {
// --------------------------- TUNE RF MODEL ------------------------------

// hash make data fit pipeline form
HashingTF hashingTF = new HashingTF()
/* HashingTF hashingTF = new HashingTF()
.setNumFeatures(1000)
.setInputCol(trainingSet)
.setInputCol(trainingSet.getOutputCol())
.setOutputCol("features");

*/

// set up pipepline
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {trainingSet,hashingTF ,rf});
//Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {trainingSet,rf});
RegressionEvaluator evaluator = new RegressionEvaluator()
.setLabelCol("trip_duration")
.setPredictionCol("prediction")
.setMetricName("rmse");



// grid search
ParamMap[] paramGrid = new ParamGridBuilder()
Expand All @@ -110,7 +110,7 @@ public static void main(String[] args) {


CrossValidator cv = new CrossValidator()
.setEstimator(pipeline)
.setEstimator(evaluator)
.setEvaluator(new RegressionEvaluator())
.setEstimatorParamMaps(paramGrid).setNumFolds(2); // Use 3+ in practice

Expand All @@ -127,10 +127,7 @@ public static void main(String[] args) {

//RandomForestRegressionModel rfModel = rf.fit(trainingSet);
//Dataset<Row> predictions = rfModel.transform(testSet);
RegressionEvaluator evaluator = new RegressionEvaluator()
.setLabelCol("trip_duration")
.setPredictionCol("prediction")
.setMetricName("rmse");

double rmse = evaluator.evaluate(predictions);


Expand Down

0 comments on commit f0d1cef

Please sign in to comment.