Skip to content

Commit

Permalink
Address comments from review and also add text labels for each image …
Browse files Browse the repository at this point in the history
…in the end
  • Loading branch information
Deependra Patel committed Dec 4, 2024
1 parent 82f754e commit 1bbe6d9
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 69 deletions.
205 changes: 137 additions & 68 deletions ai-ml-samples/interactive/ImageClassificationInSpark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,27 @@
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"source": [
"<table align=\"left\">\n",
"</td>\n",
"<td style=\"text-align: center\">\n",
"<a href=\"https://console.cloud.google.com/vertex-ai/workbench/instances/create?download_url=https://raw.githubusercontent.com/GoogleCloudDataproc/cloud-dataproc/ai-ml-samples/interactive/ImageClassificationInSpark.ipynb\">\n",
"<img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
"</a>\n",
"</td>\n",
"<td style=\"text-align: center\">\n",
"<a href=\"https://github.com/GoogleCloudDataproc/cloud-dataproc/ai-ml-samples/interactive/ImageClassificationInSpark.ipynb\">\n",
"<img src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
"</a>\n",
"</td>\n",
"</table>"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"metadata": {
Expand All @@ -31,16 +52,7 @@
"source": [
"## Overview\n",
"\n",
"With this notebook, we learn how to do distributed ML inference (image classification) using Dataproc Spark Serverless interactively.\n",
"\n",
"Following steps are performed:\n",
"1. Create a dataproc-enabled [Vertex workbench](https://cloud.google.com/vertex-ai/docs/workbench/instances/create-dataproc-enabled) instance\n",
"2. This notebook fetches pretrained model from torch-hub but serverless by default has private IPs hence can't access public internet [\"External network access\"](https://cloud.google.com/dataproc-serverless/docs/concepts/network#subnetwork_requirements). Setup [Public NAT](https://cloud.google.com/nat/docs/set-up-manage-network-address-translation#create-nat-gateway) for your \"network\".\n",
"3. Create a serverless runtime template with \"network\" (for which Cloud NAT is setup) and connect to remote kernel using instructions on same [page](https://cloud.google.com/vertex-ai/docs/workbench/instances/create-dataproc-enabled#serverless-spark)\n",
"4. Write code in the above notebook which runs on multiple Spark executors\n",
"5. We then create a Spark DataFrame of the urls of the images we want to classify. We download a pre-trained Resnet50 model from on driver, broadcast it to all the executors. Inference is written as a pandas UDF, that runs on each partition of the URLs.\n",
"\n",
"Note: You should first create a notebook mentioned in steps above, then import this entire notebook there."
"In this tutorial, you perform distributed ML inference via image classification using Apache Spark."
]
},
{
Expand All @@ -49,7 +61,11 @@
"id": "61RBz8LLbxCR"
},
"source": [
"## Get started"
"## Get started\n",
"\n",
"1. Create a dataproc-enabled [Vertex workbench](https://cloud.google.com/vertex-ai/docs/workbench/instances/create-dataproc-enabled) instance or use an existing instance\n",
"2. Setup [Public NAT](https://cloud.google.com/nat/docs/set-up-manage-network-address-translation#create-nat-gateway) for your \"network\". This notebook fetches pretrained model from torch-hub but serverless by default has private IPs hence can't access public internet [\"External network access\"](https://cloud.google.com/dataproc-serverless/docs/concepts/network#subnetwork_requirements).\n",
"3. Create a [serverless runtime template](https://cloud.google.com/dataproc-serverless/docs/quickstarts/jupyterlab-sessions#dataproc_create_serverless_runtime_template-JupyterLab) with \"network\" and connect to remote kernel using instructions on [page](https://cloud.google.com/vertex-ai/docs/workbench/instances/create-dataproc-enabled#serverless-spark)"
]
},
{
Expand All @@ -58,7 +74,8 @@
"id": "EdvJRUWRNGHE"
},
"source": [
"## Notebook tutorial"
"## Notebook tutorial\n",
"We create a Spark DataFrame of the urls of the images (listing a public bucket) we want to classify. We download a pre-trained Resnet50 model from torch-hub on driver, broadcast it to all the executors. Inference is written as a pandas UDF, that runs on each partition of data."
]
},
{
Expand All @@ -71,15 +88,15 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6fc324893334"
},
"outputs": [],
"cell_type": "markdown",
"source": [
"# All libraries needed in this notebook like torch, torchvision, google-cloud-storage are already installed in serverless. If you need something extra, feel free to do `pip install <library>`"
]
"`torch`, `torchvision`, `google-cloud-storage`\n",
"\n",
"All libraries needed in this notebook are already installed in serverless. If you need something extra, feel free to do `pip install <library>`"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -129,35 +146,26 @@
"outputs": [],
"source": [
"from pyspark.sql import SparkSession\n",
"spark = SparkSession.builder.appName(\"classficationDemo\").getOrCreate()\n",
"spark = SparkSession.builder.appName(\"classificationDemo\").getOrCreate()\n",
"sc = spark.sparkContext\n",
"\n",
"# Set to True for GPU enabled serverless sessions/dataproc clusters\n",
"cuda = False\n",
"\n",
"# Enable Arrow support.\n",
"spark.conf.set(\"spark.sql.execution.arrow.enabled\", \"true\")\n",
"spark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"64\")\n",
"\n",
"import pandas as pd\n",
"\n",
"import torch\n",
"from torch.utils.data import Dataset\n",
"from torchvision import datasets, models, transforms\n",
"from torchvision.datasets.folder import default_loader # private API\n",
"\n",
"from pyspark.sql.functions import col, pandas_udf\n",
"from torchvision.models import ResNet50_Weights\n",
"from pyspark.sql.types import ArrayType, FloatType, StringType\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"use_cuda = cuda and torch.cuda.is_available()\n",
"device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
"\n",
"files_df = spark.createDataFrame(map(lambda file : get_blob_uri(file), blobs), StringType()).repartition(10)\n",
"files_df = spark.createDataFrame(map(lambda file : get_blob_uri(file), blobs), StringType()).toDF(*[\"inputFile\"]).repartition(10)\n",
"\n",
"# Downloads and broadcasts the model weights to all the workers\n",
"model_state = models.resnet50(pretrained=True).state_dict()\n",
"bc_model_state = sc.broadcast(model_state)"
"weights = ResNet50_Weights.DEFAULT\n",
"bc_model_weights = sc.broadcast(weights)"
],
"metadata": {
"collapsed": false
Expand All @@ -166,7 +174,7 @@
{
"cell_type": "markdown",
"source": [
"### Wrapper class and define Pandas UDF"
"### Implement Dataset class for image dataset, so it can be fed into PyTorch's DataLoader"
],
"metadata": {
"collapsed": false
Expand All @@ -177,6 +185,9 @@
"execution_count": null,
"outputs": [],
"source": [
"from torch.utils.data import Dataset\n",
"from torchvision.datasets.folder import default_loader\n",
"\n",
"class ImageDataset(Dataset):\n",
" def __init__(self, paths, transform=None):\n",
" self.paths = paths\n",
Expand All @@ -186,44 +197,76 @@
" def __getitem__(self, index):\n",
" client = storage.Client()\n",
" path = self.paths[index]\n",
" # Downloads file locally from GCS before feeding into model\n",
" # Download file from GCS as image loader needs local file\n",
" blob = Blob.from_string(path, client=client)\n",
" local_file = \"/tmp/\" + path.split(\"/\")[-1]\n",
" blob.download_to_file(open(local_file, \"wb\"))\n",
" image = default_loader(local_file)\n",
" if self.transform is not None:\n",
" image = self.transform(image)\n",
" return image\n",
"\n",
" return image"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from torchvision.models import resnet50\n",
"def get_model_for_eval():\n",
" \"\"\"Gets the broadcasted model to each python worker\"\"\"\n",
" torch.hub.set_dir(\"/tmp/models\")\n",
" model = models.resnet50(pretrained=True)\n",
" model.load_state_dict(bc_model_state.value)\n",
" model.eval()\n",
" return model\n",
" \"\"\"Gets the broadcasted model to each python worker\"\"\"\n",
" torch.hub.set_dir(\"/tmp/models\")\n",
" model = resnet50(bc_model_weights.value)\n",
" model.eval()\n",
" return model"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Define Pandas UDF for parallel run on each partition. Learn [more](https://spark.apache.org/docs/3.4.2/api/python/reference/pyspark.sql/api/pyspark.sql.functions.pandas_udf.html)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from pyspark.sql.functions import pandas_udf\n",
"from torchvision import transforms\n",
"\n",
"# Using Pandas UDF for parallel run on each partition\n",
"@pandas_udf(ArrayType(FloatType()))\n",
"def predict_batch_udf(paths: pd.Series) -> pd.Series:\n",
" transform = transforms.Compose([\n",
" transforms.Resize(224),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225])\n",
" ])\n",
" images = ImageDataset(paths, transform=transform)\n",
" loader = torch.utils.data.DataLoader(images, batch_size=500, num_workers=8)\n",
" model = get_model_for_eval()\n",
" model.to(device)\n",
" all_predictions = []\n",
" with torch.no_grad():\n",
" for batch in loader:\n",
" predictions = list(model(batch.to(device)).cpu().numpy())\n",
" for prediction in predictions:\n",
" all_predictions.append(prediction)\n",
" return pd.Series(all_predictions)"
"\n",
" #Transformation needed on input by Resnet model\n",
" transform = transforms.Compose([\n",
" transforms.Resize(224),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225])\n",
" ])\n",
" images = ImageDataset(paths, transform=transform)\n",
" #Tune batch_size/num_workers based on your workload\n",
" loader = torch.utils.data.DataLoader(images, batch_size=500, num_workers=8)\n",
" model = get_model_for_eval()\n",
" model.to(device)\n",
" all_predictions = []\n",
" with torch.no_grad():\n",
" for batch in loader:\n",
" predictions = list(model(batch.to(device)).numpy())\n",
" for prediction in predictions:\n",
" all_predictions.append(prediction)\n",
" return pd.Series(all_predictions)"
],
"metadata": {
"collapsed": false
Expand All @@ -232,7 +275,7 @@
{
"cell_type": "markdown",
"source": [
"### Call UDF on the DataFrame and show"
"### Call UDF on the DataFrame and convert output into Pandas dataframe"
],
"metadata": {
"collapsed": false
Expand All @@ -243,9 +286,35 @@
"execution_count": null,
"outputs": [],
"source": [
"predictions_df = files_df.select(col(\"value\"),\n",
" predict_batch_udf(col(\"value\"))).alias(\"predictions\")\n",
"predictions_df.show()"
"from pyspark.sql.functions import col\n",
"\n",
"predictions_df = files_df.select(col(\"inputFile\"),\n",
" predict_batch_udf(col(\"inputFile\")).alias(\"predictions\"))\n",
"predictions = predictions_df.toPandas()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Check output for few input files\n",
"Eg. \"gs://cloud-samples-data/generative-ai/image/pixel-tablet.jpg\" gets the tag \"analog clock\" with the highest probability"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"predictions[\"label\"] = predictions[\"predictions\"].map(lambda x: weights.meta[\"categories\"][x.argmax()])\n",
"\n",
"pd.set_option('display.max_colwidth', None)\n",
"predictions.head(10)[[\"inputFile\", \"label\"]]"
],
"metadata": {
"collapsed": false
Expand Down
2 changes: 1 addition & 1 deletion ai-ml-samples/interactive/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ There are three ways to run notebooks inside Dataproc:
3. [Jupyter optional component](https://cloud.google.com/dataproc/docs/concepts/components/jupyter)

## Contributing
Please follow this template (https://github.com/GoogleCloudPlatform/generative-ai/blob/main/notebook_template.ipynb) to contribute more scenarios
Please follow this [template](https://github.com/GoogleCloudPlatform/generative-ai/blob/main/notebook_template.ipynb) to contribute more scenarios

0 comments on commit 1bbe6d9

Please sign in to comment.