diff --git a/livy/session.py b/livy/session.py index 85d7198..40adee8 100644 --- a/livy/session.py +++ b/livy/session.py @@ -43,11 +43,38 @@ def _spark_serialise_dataframe_code( ) return template.format(spark_dataframe_name) - -def _deserialise_dataframe(text: str) -> pandas.DataFrame: +def _handle_duplication_in_column(ordered_pairs): + """Rename duplicate keys.""" + d = {} + for k, v in ordered_pairs: + if k in d: + k_1 = '{}*'.format(k) + d[k_1] = v + else: + d[k] = v + return d + + +def _handle_duplication_in_column(ordered_pairs): + """Rename duplicated keys according to prefix.""" + d = {} + for k, v in ordered_pairs: + if k in d: + k_1 = '{0}{1}'.format(duplication_prefix_global,k) + d[k_1] = v + else: + d[k] = v + return d + + +def _deserialise_dataframe(text: str, manage_duplication: bool=False, duplication_prefix: str='*_') -> pandas.DataFrame: rows = [] for line in text.split("\n"): - if line: + if line and manage_duplication == True: + global duplication_prefix_global + duplication_prefix_global= duplication_prefix + rows.append(json.loads(line, object_pairs_hook=_handle_duplication_in_column)) + elif line and manage_duplication == False: rows.append(json.loads(line)) return pandas.DataFrame.from_records(rows) @@ -292,22 +319,26 @@ def run(self, code: str) -> Output: output.raise_for_status() return output - def download(self, dataframe_name: str) -> pandas.DataFrame: + def download(self, dataframe_name: str , manage_duplication: bool = False , duplication_prefix: str='*_') -> pandas.DataFrame: """Evaluate and download a Spark dataframe from the managed session. :param dataframe_name: The name of the Spark dataframe to download. + :param manage_duplication: Preserver column duplication from the spark DF. + :param duplication_prefix: Define the string to be prefixed on duplicated columns. Default = '*_'. """ code = _spark_serialise_dataframe_code(dataframe_name, self.kind) output = self._execute(code) output.raise_for_status() if output.text is None: raise RuntimeError("statement had no text output") - return _deserialise_dataframe(output.text) + return _deserialise_dataframe(output.text, manage_duplication, duplication_prefix) - def read(self, dataframe_name: str) -> pandas.DataFrame: + def read(self, dataframe_name: str, manage_duplication: bool = False, duplication_prefix: str='*_' ) -> pandas.DataFrame: """Evaluate and retrieve a Spark dataframe in the managed session. :param dataframe_name: The name of the Spark dataframe to read. + :param manage_duplication: Preserve column duplication from the spark DF. + :param duplication_prefix: Define the string to be prefixed on duplicated columns. Default = '*_'. .. deprecated:: 0.8.0 Use :meth:`download` instead. @@ -317,7 +348,7 @@ def read(self, dataframe_name: str) -> pandas.DataFrame: "version. Use LivySession.download instead.", DeprecationWarning, ) - return self.download(dataframe_name) + return self.download(dataframe_name, manage_duplication, duplication_prefix) def download_sql(self, query: str) -> pandas.DataFrame: """Evaluate a Spark SQL query and download the result. diff --git a/tests/test_integration.py b/tests/test_integration.py index 56fa78e..ccfd32b 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -23,9 +23,13 @@ class Parameters: error_code: str dataframe_multiply_code: str dataframe_trim_code: str + dataframe_dup_code: str RANGE_DATAFRAME = pandas.DataFrame({"value": range(100)}) +DUP_DATAFRAME = pandas.DataFrame({"value": range(100),"dup_value": range(100), "*_dup_value": range(100)}) +DUP_DATAFRAME_PREFIX = pandas.DataFrame({"value": range(100),"dup_value": range(100), "test_dup_value": range(100)}) + SPECIAL_CHARACTER_EXAMPLES = [ # Single and double quotes can terminate string literals in Scala/Python/R "'", @@ -60,6 +64,9 @@ class Parameters: SPARK_TRIM_DATAFRAME = """ val trimmed = text.select(trim($"text") alias "text") """ +SPARK_DUP_DATAFRAME = """ +val df_dup = df.withColumn("dup_value", col("value")).join(df.withColumn("dup_value", col("value")), "value").sort(col("value")) +""" SPARK_TEST_PARAMETERS = Parameters( print_foo_code='println("foo")', print_foo_output="foo\n\n", @@ -69,6 +76,7 @@ class Parameters: error_code="1 / 0", dataframe_multiply_code=SPARK_MULTIPLY_DATAFRAME, dataframe_trim_code=SPARK_TRIM_DATAFRAME, + dataframe_dup_code=SPARK_DUP_DATAFRAME, ) PYSPARK_CREATE_RANGE_DATAFRAME = """ @@ -82,6 +90,10 @@ class Parameters: from pyspark.sql.functions import trim trimmed = text.select(trim(text.text).alias("text")) """ +PYSPARK_DUP_DATAFRAME = """ +df_dup = df.withColumn('dup_value', df.value).join(df.withColumn('dup_value', df.value), 'value').sort('value') +""" + PYSPARK_TEST_PARAMETERS = Parameters( print_foo_code='print("foo")', print_foo_output="foo\n", @@ -91,6 +103,7 @@ class Parameters: error_code="1 / 0", dataframe_multiply_code=PYSPARK_MULTIPLY_DATAFRAME, dataframe_trim_code=PYSPARK_TRIM_DATAFRAME, + dataframe_dup_code=PYSPARK_DUP_DATAFRAME, ) SPARKR_CREATE_RANGE_DATAFRAME = """ @@ -102,6 +115,9 @@ class Parameters: SPARKR_TRIM_DATAFRAME = """ trimmed <- select(text, alias(trim(text$text), "text")) """ +SPARKR_DUP_DATAFRAME = """ +df_dup <- withColumnRenamed(withColumn(withColumn(df, "dup_value", df$value), "dup_value*", df$value), "dup_value*", "dup_value") +""" SPARKR_TEST_PARAMETERS = Parameters( print_foo_code='print("foo")', print_foo_output='[1] "foo"\n', @@ -111,6 +127,7 @@ class Parameters: error_code="missing_function()", dataframe_multiply_code=SPARKR_MULTIPLY_DATAFRAME, dataframe_trim_code=SPARKR_TRIM_DATAFRAME, + dataframe_dup_code=SPARKR_DUP_DATAFRAME, ) SQL_CREATE_VIEW = """ @@ -148,8 +165,8 @@ def test_session(integration_url, capsys, session_kind, params): session.run(params.error_code) assert session.download("df").equals(RANGE_DATAFRAME) - session.upload("uploaded", RANGE_DATAFRAME) + session.run(params.dataframe_multiply_code) assert session.download("multiplied").equals(RANGE_DATAFRAME * 2) @@ -159,6 +176,10 @@ def test_session(integration_url, capsys, session_kind, params): TEXT_DATAFRAME.applymap(lambda s: s.strip()) ) + session.run(params.dataframe_dup_code) + assert session.download("df_dup", True).equals(DUP_DATAFRAME) + assert session.download("df_dup", True, 'test_').equals(DUP_DATAFRAME_PREFIX) + assert _session_stopped(integration_url, session.session_id)