Skip to content

Commit

Permalink
Removing code duplications (#1093)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidrabinowitz authored Oct 23, 2023
1 parent 46b4c47 commit cbcd5ae
Show file tree
Hide file tree
Showing 12 changed files with 249 additions and 301 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/cpd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Duplicate Code Detection

on:
push:
branches: [ master ]
pull_request:
# The branches below must be a subset of the branches above
branches: [ master ]

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/setup-java@v3
with:
distribution: 'temurin'
java-version: '8'
cache: 'maven'

- name: Running Duplicate Code Detection
run: ./mvnw pmd:cpd-check -Pall -Daggregate=true
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,22 @@
import java.util.Properties;
import java.util.ServiceLoader;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.validation.constraints.NotNull;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.types.Metadata;
import org.jetbrains.annotations.NotNull;
import org.apache.spark.sql.types.StructType;
import scala.collection.mutable.ListBuffer;

/** Spark related utilities */
public class SparkBigQueryUtil {
Expand Down Expand Up @@ -281,4 +286,20 @@ public static ImmutableMap<String, String> extractJobLabels(SparkConf sparkConf)
.ifPresent(tag -> labels.put("dataproc_job_uuid", tag.substring(tag.lastIndexOf('_') + 1)));
return labels.build();
}

public static List<AttributeReference> schemaToAttributeReferences(StructType schema) {
List<AttributeReference> result =
Stream.of(schema.fields())
.map(
field ->
new AttributeReference(
field.name(),
field.dataType(),
field.nullable(),
field.metadata(),
NamedExpression.newExprId(),
new ListBuffer<String>().toStream()))
.collect(Collectors.toList());
return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright 2022 Google Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.spark.bigquery.direct;

import com.google.cloud.bigquery.Schema;
import com.google.cloud.bigquery.connector.common.BigQueryClientFactory;
import com.google.cloud.bigquery.connector.common.BigQueryStorageReadRowsTracer;
import com.google.cloud.bigquery.connector.common.BigQueryTracerFactory;
import com.google.cloud.bigquery.connector.common.BigQueryUtil;
import com.google.cloud.bigquery.connector.common.ReadRowsHelper;
import com.google.cloud.bigquery.storage.v1.DataFormat;
import com.google.cloud.bigquery.storage.v1.ReadRowsRequest;
import com.google.cloud.bigquery.storage.v1.ReadRowsResponse;
import com.google.cloud.bigquery.storage.v1.ReadSession;
import com.google.cloud.spark.bigquery.InternalRowIterator;
import com.google.cloud.spark.bigquery.ReadRowsResponseToInternalRowIteratorConverter;
import com.google.cloud.spark.bigquery.SchemaConverters;
import com.google.cloud.spark.bigquery.SchemaConvertersConfiguration;
import com.google.cloud.spark.bigquery.SparkBigQueryConfig;
import com.google.cloud.spark.bigquery.metrics.SparkMetricsSource;
import com.google.common.base.Joiner;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import org.apache.spark.InterruptibleIterator;
import org.apache.spark.Partition;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.StructType;

class BigQueryRDDContext implements Serializable {

private static long serialVersionUID = -2219993393692435055L;

private final Partition[] partitions;
private final ReadSession readSession;
private final String[] columnsInOrder;
private final Schema bqSchema;
private final SparkBigQueryConfig options;
private final BigQueryClientFactory bigQueryClientFactory;
private final BigQueryTracerFactory bigQueryTracerFactory;

private List<String> streamNames;

public BigQueryRDDContext(
Partition[] parts,
ReadSession readSession,
Schema bqSchema,
String[] columnsInOrder,
SparkBigQueryConfig options,
BigQueryClientFactory bigQueryClientFactory,
BigQueryTracerFactory bigQueryTracerFactory) {

this.partitions = parts;
this.readSession = readSession;
this.columnsInOrder = columnsInOrder;
this.bigQueryClientFactory = bigQueryClientFactory;
this.bigQueryTracerFactory = bigQueryTracerFactory;
this.options = options;
this.bqSchema = bqSchema;
this.streamNames = BigQueryUtil.getStreamNames(readSession);
}

public scala.collection.Iterator<InternalRow> compute(Partition split, TaskContext context) {
BigQueryPartition bigQueryPartition = (BigQueryPartition) split;
SparkMetricsSource sparkMetricsSource = new SparkMetricsSource();
SparkEnv.get().metricsSystem().registerSource(sparkMetricsSource);
BigQueryStorageReadRowsTracer tracer =
bigQueryTracerFactory.newReadRowsTracer(
Joiner.on(",").join(streamNames), sparkMetricsSource);

ReadRowsRequest.Builder request =
ReadRowsRequest.newBuilder().setReadStream(bigQueryPartition.getStream());

ReadRowsHelper readRowsHelper =
new ReadRowsHelper(
bigQueryClientFactory,
request,
options.toReadSessionCreatorConfig().toReadRowsHelperOptions(),
Optional.of(tracer));
Iterator<ReadRowsResponse> readRowsResponseIterator = readRowsHelper.readRows();

StructType schema =
options
.getSchema()
.orElse(
SchemaConverters.from(SchemaConvertersConfiguration.from(options))
.toSpark(bqSchema));

ReadRowsResponseToInternalRowIteratorConverter converter;
if (options.getReadDataFormat().equals(DataFormat.AVRO)) {
converter =
ReadRowsResponseToInternalRowIteratorConverter.avro(
bqSchema,
Arrays.asList(columnsInOrder),
readSession.getAvroSchema().getSchema(),
Optional.of(schema),
Optional.of(tracer),
SchemaConvertersConfiguration.from(options));
} else {
converter =
ReadRowsResponseToInternalRowIteratorConverter.arrow(
Arrays.asList(columnsInOrder),
readSession.getArrowSchema().getSerializedSchema(),
Optional.of(schema),
Optional.of(tracer));
}

return new InterruptibleIterator<InternalRow>(
context,
new ScalaIterator<InternalRow>(
new InternalRowIterator(readRowsResponseIterator, converter, readRowsHelper, tracer)));
}

public Partition[] getPartitions() {
return partitions;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,26 +202,19 @@ RDD<InternalRow> createRDD(
Class<? extends RDD<InternalRow>> clazz =
(Class<? extends RDD<InternalRow>>) Class.forName(bigQueryRDDClassName);
Constructor<? extends RDD<InternalRow>> constructor =
clazz.getConstructor(
SparkContext.class,
Partition[].class,
ReadSession.class,
Schema.class,
String[].class,
SparkBigQueryConfig.class,
BigQueryClientFactory.class,
BigQueryTracerFactory.class);
clazz.getConstructor(SparkContext.class, BigQueryRDDContext.class);

RDD<InternalRow> bigQueryRDD =
constructor.newInstance(
sqlContext.sparkContext(),
partitions,
readSession,
bqSchema,
columnsInOrder,
options,
bigQueryClientFactory,
bigQueryTracerFactory);
new BigQueryRDDContext(
partitions,
readSession,
bqSchema,
columnsInOrder,
options,
bigQueryClientFactory,
bigQueryTracerFactory));

return bigQueryRDD;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,12 @@

package com.google.cloud.spark.bigquery.direct;

import com.google.cloud.bigquery.Schema;
import com.google.cloud.bigquery.connector.common.BigQueryClientFactory;
import com.google.cloud.bigquery.connector.common.BigQueryStorageReadRowsTracer;
import com.google.cloud.bigquery.connector.common.BigQueryTracerFactory;
import com.google.cloud.bigquery.connector.common.BigQueryUtil;
import com.google.cloud.bigquery.connector.common.ReadRowsHelper;
import com.google.cloud.bigquery.storage.v1.DataFormat;
import com.google.cloud.bigquery.storage.v1.ReadRowsRequest;
import com.google.cloud.bigquery.storage.v1.ReadRowsResponse;
import com.google.cloud.bigquery.storage.v1.ReadSession;
import com.google.cloud.spark.bigquery.InternalRowIterator;
import com.google.cloud.spark.bigquery.ReadRowsResponseToInternalRowIteratorConverter;
import com.google.cloud.spark.bigquery.SchemaConverters;
import com.google.cloud.spark.bigquery.SchemaConvertersConfiguration;
import com.google.cloud.spark.bigquery.SparkBigQueryConfig;
import com.google.cloud.spark.bigquery.metrics.SparkMetricsSource;
import com.google.common.base.Joiner;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import org.apache.spark.Dependency;
import org.apache.spark.InterruptibleIterator;
import org.apache.spark.Partition;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.StructType;
import scala.collection.immutable.Seq;
import scala.collection.immutable.Seq$;

Expand All @@ -54,94 +30,25 @@
// scala.collection.immutable.Seq.
class Scala213BigQueryRDD extends RDD<InternalRow> {

private final Partition[] partitions;
private final ReadSession readSession;
private final String[] columnsInOrder;
private final Schema bqSchema;
private final SparkBigQueryConfig options;
private final BigQueryClientFactory bigQueryClientFactory;
private final BigQueryTracerFactory bigQueryTracerFactory;
// Added suffix so that CPD wouldn't mark as duplicate
private final BigQueryRDDContext ctx213;

private List<String> streamNames;

public Scala213BigQueryRDD(
SparkContext sparkContext,
Partition[] parts,
ReadSession readSession,
Schema bqSchema,
String[] columnsInOrder,
SparkBigQueryConfig options,
BigQueryClientFactory bigQueryClientFactory,
BigQueryTracerFactory bigQueryTracerFactory) {
public Scala213BigQueryRDD(SparkContext sparkContext, BigQueryRDDContext ctx) {
super(
sparkContext,
(Seq<Dependency<?>>) Seq$.MODULE$.<Dependency<?>>newBuilder().result(),
scala.reflect.ClassTag$.MODULE$.apply(InternalRow.class));

this.partitions = parts;
this.readSession = readSession;
this.columnsInOrder = columnsInOrder;
this.bigQueryClientFactory = bigQueryClientFactory;
this.bigQueryTracerFactory = bigQueryTracerFactory;
this.options = options;
this.bqSchema = bqSchema;
this.streamNames = BigQueryUtil.getStreamNames(readSession);
this.ctx213 = ctx;
}

@Override
public scala.collection.Iterator<InternalRow> compute(Partition split, TaskContext context) {
BigQueryPartition bigQueryPartition = (BigQueryPartition) split;
SparkMetricsSource sparkMetricsSource = new SparkMetricsSource();
SparkEnv.get().metricsSystem().registerSource(sparkMetricsSource);
BigQueryStorageReadRowsTracer tracer =
bigQueryTracerFactory.newReadRowsTracer(
Joiner.on(",").join(streamNames), sparkMetricsSource);

ReadRowsRequest.Builder request =
ReadRowsRequest.newBuilder().setReadStream(bigQueryPartition.getStream());

ReadRowsHelper readRowsHelper =
new ReadRowsHelper(
bigQueryClientFactory,
request,
options.toReadSessionCreatorConfig().toReadRowsHelperOptions(),
Optional.of(tracer));
Iterator<ReadRowsResponse> readRowsResponseIterator = readRowsHelper.readRows();

StructType schema =
options
.getSchema()
.orElse(
SchemaConverters.from(SchemaConvertersConfiguration.from(options))
.toSpark(bqSchema));

ReadRowsResponseToInternalRowIteratorConverter converter;
if (options.getReadDataFormat().equals(DataFormat.AVRO)) {
converter =
ReadRowsResponseToInternalRowIteratorConverter.avro(
bqSchema,
Arrays.asList(columnsInOrder),
readSession.getAvroSchema().getSchema(),
Optional.of(schema),
Optional.of(tracer),
SchemaConvertersConfiguration.from(options));
} else {
converter =
ReadRowsResponseToInternalRowIteratorConverter.arrow(
Arrays.asList(columnsInOrder),
readSession.getArrowSchema().getSerializedSchema(),
Optional.of(schema),
Optional.of(tracer));
}

return new InterruptibleIterator<InternalRow>(
context,
new ScalaIterator<InternalRow>(
new InternalRowIterator(readRowsResponseIterator, converter, readRowsHelper, tracer)));
return ctx213.compute(split, context);
}

@Override
public Partition[] getPartitions() {
return partitions;
return ctx213.getPartitions();
}
}
Loading

0 comments on commit cbcd5ae

Please sign in to comment.