diff --git a/notebooks/ai-ml/SparkXGBoostCustomerChurn.ipynb b/notebooks/ai-ml/SparkXGBoostCustomerChurn.ipynb index b3bb823..0cac8ec 100644 --- a/notebooks/ai-ml/SparkXGBoostCustomerChurn.ipynb +++ b/notebooks/ai-ml/SparkXGBoostCustomerChurn.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "c674ba30-373e-4320-b7b4-d6435988e3f1", "metadata": { "tags": [] @@ -131,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "4cdbeb05-806e-438a-9010-3e9bea1b5cca", "metadata": { "tags": [] @@ -171,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "f579cf36-351b-41ae-8f6f-97b75bde26c4", "metadata": { "tags": [] @@ -198,31 +198,12 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "ff8fcf6e-6f07-4e8b-ae87-7fef5f132409", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+--------------+------------+\n", - "|MonthlyCharges|TotalCharges|\n", - "+--------------+------------+\n", - "| 0| 0|\n", - "+--------------+------------+\n", - "\n", - "+--------------+------------+\n", - "|MonthlyCharges|TotalCharges|\n", - "+--------------+------------+\n", - "| 0| 11|\n", - "+--------------+------------+\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "df.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in [\"MonthlyCharges\", \"TotalCharges\"]]).show()" ] @@ -237,25 +218,12 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "4ccd3420-e805-4ee4-8edd-e040e20228ef", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+--------------+------------+\n", - "|MonthlyCharges|TotalCharges|\n", - "+--------------+------------+\n", - "| 0| 11|\n", - "+--------------+------------+\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "df.select([count(when(col(c).cast(\"float\").isNull(), c)).alias(c) for c in [\"MonthlyCharges\", \"TotalCharges\"]]).show()" ] @@ -270,35 +238,12 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "55346696-18fb-4399-b7f7-f5fa3f3f2003", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+\n", - "|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService| MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies| Contract|PaperlessBilling| PaymentMethod|MonthlyCharges|TotalCharges|Churn|\n", - "+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+\n", - "|7590-VHVEG|Female| 0| Yes| No| 1| No|No phone service| DSL| No| Yes| No| No| No| No|Month-to-month| Yes| Electronic check| 29.85| 29.85| No|\n", - "|5575-GNVDE| Male| 0| No| No| 34| Yes| No| DSL| Yes| No| Yes| No| No| No| One year| No| Mailed check| 56.95| 1889.5| No|\n", - "|3668-QPYBK| Male| 0| No| No| 2| Yes| No| DSL| Yes| Yes| No| No| No| No|Month-to-month| Yes| Mailed check| 53.85| 108.15| Yes|\n", - "|7795-CFOCW| Male| 0| No| No| 45| No|No phone service| DSL| Yes| No| Yes| Yes| No| No| One year| No|Bank transfer (au...| 42.3| 1840.75| No|\n", - "|9237-HQITU|Female| 0| No| No| 2| Yes| No| Fiber optic| No| No| No| No| No| No|Month-to-month| Yes| Electronic check| 70.7| 151.65| Yes|\n", - "|9305-CDSKC|Female| 0| No| No| 8| Yes| Yes| Fiber optic| No| No| Yes| No| Yes| Yes|Month-to-month| Yes| Electronic check| 99.65| 820.5| Yes|\n", - "|1452-KIOVK| Male| 0| No| Yes| 22| Yes| Yes| Fiber optic| No| Yes| No| No| Yes| No|Month-to-month| Yes|Credit card (auto...| 89.1| 1949.4| No|\n", - "|6713-OKOMC|Female| 0| No| No| 10| No|No phone service| DSL| Yes| No| No| No| No| No|Month-to-month| No| Mailed check| 29.75| 301.9| No|\n", - "|7892-POOKP|Female| 0| Yes| No| 28| Yes| Yes| Fiber optic| No| No| Yes| Yes| Yes| Yes|Month-to-month| Yes| Electronic check| 104.8| 3046.05| Yes|\n", - "|6388-TABGU| Male| 0| No| Yes| 62| Yes| No| DSL| Yes| Yes| No| No| No| No| One year| No|Bank transfer (au...| 56.15| 3487.95| No|\n", - "+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+\n", - "only showing top 10 rows\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "# Drop records where TotalCharges has null or string values that can't be cast to float\n", "df = df.filter(col(\"TotalCharges\").cast(\"float\").isNotNull())\n", @@ -323,40 +268,12 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "b05a2d91-77ae-4bea-9b4d-a8f709219705", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "# Aggregate data\n", "df_plot_data = df.groupBy(\"Contract\", \"Churn\").count()\n", @@ -383,33 +300,12 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "id": "628616a0-992f-4b6c-8668-6a3e6b74e179", "metadata": { "tags": [] }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "# Aggregate data\n", "histogram_data = df.groupBy(\"Tenure\", \"Churn\").count().orderBy(\"Tenure\")\n", @@ -484,115 +380,12 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "8abf9370-938b-4368-a300-e2a9064a48b6", "metadata": { "tags": [] }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3cd67f4d13d94f6bb2d427f7a02e00a9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "config.json: 0%| | 0.00/629 [00:00 (0 + 1) / 1]\r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+----------+--------------------+---------+\n", - "|customerID| text|sentiment|\n", - "+----------+--------------------+---------+\n", - "|7590-VHVEG|The customer was ...| NEGATIVE|\n", - "|5575-GNVDE|The customer was ...| NEGATIVE|\n", - "|3668-QPYBK|The customer was ...| POSITIVE|\n", - "|7795-CFOCW|The customer was ...| NEGATIVE|\n", - "|9237-HQITU|A quick and frien...| POSITIVE|\n", - "|9305-CDSKC|A pleasant conver...| POSITIVE|\n", - "|1452-KIOVK|The customer was ...| NEGATIVE|\n", - "|6713-OKOMC|The customer was ...| NEGATIVE|\n", - "|7892-POOKP|A friendly exchan...| POSITIVE|\n", - "|6388-TABGU|The customer was ...| NEGATIVE|\n", - "|9763-GRSKD|The customer was ...| NEGATIVE|\n", - "|7469-LKBCI|The customer was ...| NEGATIVE|\n", - "|8091-TTVAX|The customer was ...| NEGATIVE|\n", - "|0280-XJGEX|A friendly chat a...| POSITIVE|\n", - "|5129-JLPIS|The customer was ...| NEGATIVE|\n", - "|3655-SNQYZ|The customer was ...| NEGATIVE|\n", - "|8191-XWSZG|The customer was ...| NEGATIVE|\n", - "|9959-WOFKT|The customer was ...| NEGATIVE|\n", - "|4190-MFLUW|A friendly exchan...| POSITIVE|\n", - "|4183-MYFRB|The customer was ...| NEGATIVE|\n", - "+----------+--------------------+---------+\n", - "only showing top 20 rows\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - } - ], + "outputs": [], "source": [ "# Specify the model name and revision - downloads from Huggingface\n", "model_name = \"distilbert-base-uncased-finetuned-sst-2-english\"\n", @@ -640,7 +433,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "95e1b691-d216-43bd-9269-55c6a146c774", "metadata": { "tags": [] @@ -673,7 +466,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "235b509b-0a8c-40d8-8305-15a0930c4015", "metadata": {}, "outputs": [], @@ -692,18 +485,10 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "7a361ece-9bc9-4e66-8c09-f6c6c3373bb4", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - } - ], + "outputs": [], "source": [ "df_train = spark.read.parquet(training_data_path)" ] @@ -725,7 +510,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "id": "4d7c39eb-076d-44d6-874d-e1f727953330", "metadata": {}, "outputs": [], @@ -761,24 +546,12 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "8bedfb96-f79f-4eb1-a163-b1fc719089de", "metadata": { "tags": [] }, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'SparkXGBClassifier' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m xgb \u001b[38;5;241m=\u001b[39m \u001b[43mSparkXGBClassifier\u001b[49m(label_col\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mchurn\u001b[39m\u001b[38;5;124m\"\u001b[39m, features_col\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfeatures\u001b[39m\u001b[38;5;124m\"\u001b[39m, num_workers\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, missing\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.0\u001b[39m)\n\u001b[1;32m 3\u001b[0m pipeline \u001b[38;5;241m=\u001b[39m Pipeline(stages\u001b[38;5;241m=\u001b[39m[xgb])\n", - "\u001b[0;31mNameError\u001b[0m: name 'SparkXGBClassifier' is not defined" - ] - } - ], + "outputs": [], "source": [ "xgb = SparkXGBClassifier(label_col=\"churn\", features_col=\"features\", num_workers=2, missing=0.0)\n", "\n", @@ -797,24 +570,12 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "8d648b41-a6a8-4b67-9fce-ca5c4475c790", "metadata": { "tags": [] }, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'transformed_data' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m train_data, test_data \u001b[38;5;241m=\u001b[39m \u001b[43mtransformed_data\u001b[49m\u001b[38;5;241m.\u001b[39mrandomSplit([\u001b[38;5;241m0.8\u001b[39m, \u001b[38;5;241m0.2\u001b[39m], seed\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m42\u001b[39m) \n", - "\u001b[0;31mNameError\u001b[0m: name 'transformed_data' is not defined" - ] - } - ], + "outputs": [], "source": [ "train_data, test_data = transformed_data.randomSplit([0.8, 0.2], seed=42) " ] @@ -1004,8 +765,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "\n", "# Construct a BigQuery client object.\n", "client = bigquery.Client()\n", "\n", @@ -1024,7 +783,7 @@ }, { "cell_type": "markdown", - "id": "2ece491f", + "id": "b5ff34de-baed-4ee5-ae24-57b850235028", "metadata": {}, "source": [ "Save the predictions directly from Spark to BigQuery." @@ -1040,7 +799,7 @@ "# Define table name to save predictions\n", "table_name = 'predictions'\n", "\n", - "# Write predictions DataFrame to a Delta table\n", + "# Write predictions DataFrame to table\n", "predictions.write \\\n", " .format(\"bigquery\") \\\n", " .mode(\"overwrite\") \\\n",