Skip to content

Commit

Permalink
Demo app updates to use PyTorch 1.9 org.pytorch:pytorch_android_lite:…
Browse files Browse the repository at this point in the history
…1.9.0 (pytorch#151)

* initial commit

* Revert "initial commit"

This reverts commit 5a65775.

* main readme and helloworld/demo app readme updates

* build.gradle, README and code update for PT1.9 for HelloWorld and Object Detection using pytorch_android_lite:1.9.0

* build.gradle, README and code update for PT1.9 for Question Answering using pytorch_android_lite:1.9.0

* build.gradle, README and code update for PT1.9 for TorchVideo using pytorch_android_lite:1.9.0

* HelloWorld script fix

* README update for TorchVideo
  • Loading branch information
jeffxtang authored Jun 17, 2021
1 parent 367d2d9 commit 3d31a03
Show file tree
Hide file tree
Showing 18 changed files with 134 additions and 80 deletions.
12 changes: 6 additions & 6 deletions HelloWorldApp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ This application runs TorchScript serialized TorchVision pretrained [MobileNet v
Let’s start with model preparation. If you are familiar with PyTorch, you probably should already know how to train and save your model. In case you don’t, we are going to use a pre-trained image classification model(MobileNet v3), which is packaged in [TorchVision](https://pytorch.org/docs/stable/torchvision/index.html).
To install it, run the command below:
```
pip install torchvision
pip install torch torchvision
```

To serialize and optimize the model for Android, you can use the Python [script](https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/trace_model.py) in the root folder of HelloWorld app:
Expand All @@ -22,12 +22,12 @@ model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model.save("app/src/main/assets/model.pt")
optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.ptl")
```
If everything works well, we should have our scripted and optimized model - `model.pt` generated in the assets folder of android application.
That will be packaged inside android application as `asset` and can be used on the device.

By using the new MobileNet v3 model instead of the old Resnet18 model, and by calling the `optimize_for_mobile` method on the traced model, the model inference time on a Pixel 3 gets decreased from over 230ms to about 40ms.
By using the new MobileNet v3 model instead of the old Resnet18 model, and by calling the `optimize_for_mobile` method on the traced model, the model inference time on a Pixel 3 gets decreased from over 230ms to about 40ms.

More details about TorchScript you can find in [tutorials on pytorch.org](https://pytorch.org/docs/stable/jit.html)

Expand All @@ -54,8 +54,8 @@ repositories {
}
dependencies {
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
}
```
Where `org.pytorch:pytorch_android` is the main dependency with PyTorch Android API, including libtorch native library for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64).
Expand All @@ -73,7 +73,7 @@ Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));

#### 5. Loading TorchScript Module
```
Module module = Module.load(assetFilePath(this, "model.pt"));
Module module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));
```
`org.pytorch.Module` represents `torch::jit::script::Module` that can be loaded with `load` method specifying file path to the serialized to file model.

Expand Down
11 changes: 2 additions & 9 deletions HelloWorldApp/app/build.gradle
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
apply plugin: 'com.android.application'

repositories {
jcenter()
maven {
url "https://oss.sonatype.org/content/repositories/snapshots"
}
}

android {
compileSdkVersion 28
buildToolsVersion "29.0.2"
Expand All @@ -26,6 +19,6 @@ android {

dependencies {
implementation 'androidx.appcompat:appcompat:1.1.0'
implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
}
Binary file modified HelloWorldApp/app/src/main/assets/model.pt
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import android.widget.TextView;

import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
Expand Down Expand Up @@ -37,7 +38,7 @@ protected void onCreate(Bundle savedInstanceState) {
bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
// loading serialized torchscript module from packaged into app android asset model.pt,
// app/src/model/assets/model.pt
module = Module.load(assetFilePath(this, "model.pt"));
module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));
} catch (IOException e) {
Log.e("PytorchHelloWorld", "Error reading assets", e);
finish();
Expand Down
2 changes: 1 addition & 1 deletion HelloWorldApp/trace_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model.save("app/src/main/assets/model.pt")
optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.pt")
20 changes: 12 additions & 8 deletions ObjectDetection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

## Prerequisites

* PyTorch 1.7 or later (Optional)
* PyTorch 1.9.0 or later (Optional)
* Python 3.8 (Optional)
* Android Pytorch library 1.7.0
* Android Pytorch library pytorch_android_lite:1.9.0 and pytorch_android_torchvision:1.9.0
* Android Studio 4.0.1 or later

## Quick Start
Expand All @@ -17,9 +17,9 @@ To Test Run the Object Detection Android App, follow the steps below:

### 1. Prepare the Model

If you don't have the PyTorch environment set up to run the script, you can download the model file [here](https://drive.google.com/file/d/1QOxNfpy_j_1KbuhN8INw2AgAC82nEw0F/view?usp=sharing) to the `android-demo-app/ObjectDetection/app/src/main/assets` folder, then skip the rest of this step and go to step 2 directly.
If you don't have the PyTorch environment set up to run the script, you can download the model file `yolov5s.torchscript.ptl` [here](https://drive.google.com/u/1/uc?id=1_MF7NVi9Csm1lizoSCp1wCtUUUpuhwet&export=download) to the `android-demo-app/ObjectDetection/app/src/main/assets` folder, then skip the rest of this step and go to step 2 directly.

Be aware that the downloadable model file was created with PyTorch 1.7.0, matching the PyTorch Android library 1.7.0 specified in the project's `build.gradle` file as `implementation 'org.pytorch:pytorch_android:1.7.0'`. If you use a different version of PyTorch to create your model by following the instructions below, make sure you specify the same PyTorch Android library version in the `build.gradle` file to avoid possible errors caused by the version mismatch. Furthermore, if you want to use the latest PyTorch master code to create the model, follow the steps at [Building PyTorch Android from Source](https://pytorch.org/mobile/android/#building-pytorch-android-from-source) and [Using the PyTorch Android Libraries Built](https://pytorch.org/mobile/android/#using-the-pytorch-android-libraries-built-from-source-or-nightly) on how to use the model in Android.
Be aware that the downloadable model file was created with PyTorch 1.9.0, matching the PyTorch Android library 1.9.0 specified in the project's `build.gradle` file as `implementation 'org.pytorch:pytorch_android_lite:1.9.0'`. If you use a different version of PyTorch to create your model by following the instructions below, make sure you specify the same PyTorch Android library version in the `build.gradle` file to avoid possible errors caused by the version mismatch. Furthermore, if you want to use the latest PyTorch master code to create the model, follow the steps at [Building PyTorch Android from Source](https://pytorch.org/mobile/android/#building-pytorch-android-from-source) and [Using the PyTorch Android Libraries Built](https://pytorch.org/mobile/android/#using-the-pytorch-android-libraries-built-from-source-or-nightly) on how to use the model in Android.

The Python script `export.py` in the `models` folder of the [YOLOv5 repo](https://github.com/ultralytics/yolov5) is used to generate a TorchScript-formatted YOLOv5 model named `yolov5s.torchscript.pt` for mobile apps.

Expand All @@ -31,20 +31,24 @@ cd yolov5
pip install -r requirements.txt
```

Then edit `models/export.py` to make two changes:
Then edit `models/export.py` to make the following four changes:

* Change the line 50 from `model.model[-1].export = True` to `model.model[-1].export = False`
* Change line 50 from `model.model[-1].export = True` to `model.model[-1].export = False`

* Change line 56 from `f = opt.weights.replace('.pt', '.torchscript.pt')` to `f = opt.weights.replace('.pt', '.torchscript.ptl')`

* Add the following two lines of model optimization code after line 57, between `ts = torch.jit.trace(model, img)` and `ts.save(f)`:

```
from torch.utils.mobile_optimizer import optimize_for_mobile
ts = optimize_for_mobile(ts)
ts = optimize_for_mobile(ts)
```

* Replace the line `ts.save(f)` with `ts._save_for_lite_interpreter(f)`.

If you ignore this step, you can still create a TorchScript model for mobile apps to use, but the inference on a non-optimized model can take twice as long as the inference on an optimized model - using the Android app test images, the average inference time on an optimized and non-optimized model is 0.6 seconds and 1.18 seconds, respectively. See [SCRIPT AND OPTIMIZE FOR MOBILE RECIPE](https://pytorch.org/tutorials/recipes/script_optimized.html) for more details.

Finally, run the script below to generate the optimized TorchScript model and copy the generated model file `yolov5s.torchscript.pt` to the `android-demo-app/ObjectDetection/app/src/main/assets` folder:
Now run the script below to generate the optimized TorchScript model and copy the generated model file `yolov5s.torchscript.ptl` to the `android-demo-app/ObjectDetection/app/src/main/assets` folder:

```
python models/export.py
Expand Down
4 changes: 2 additions & 2 deletions ObjectDetection/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ dependencies {
implementation "androidx.camera:camera-core:$camerax_version"
implementation "androidx.camera:camera-camera2:$camerax_version"

implementation 'org.pytorch:pytorch_android:1.7.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.7.0'
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import android.widget.ProgressBar;

import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

Expand All @@ -57,6 +57,25 @@ public class MainActivity extends AppCompatActivity implements Runnable {
private Module mModule = null;
private float mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY;

public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}

try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
Expand Down Expand Up @@ -162,7 +181,7 @@ public void onClick(View v) {
});

try {
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), "yolov5s.torchscript.pt");
mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "yolov5s.torchscript.ptl"));
BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open("classes.txt")));
String line;
List<String> classes = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import android.graphics.Rect;
import android.graphics.YuvImage;
import android.media.Image;
import android.os.Bundle;
import android.util.Log;
import android.view.TextureView;
import android.view.ViewStub;
Expand All @@ -17,8 +16,8 @@
import androidx.camera.core.ImageProxy;

import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

Expand Down Expand Up @@ -85,8 +84,13 @@ private Bitmap imgToBitmap(Image image) {
@WorkerThread
@Nullable
protected AnalysisResult analyzeImage(ImageProxy image, int rotationDegrees) {
if (mModule == null) {
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), "yolov5s.torchscript.pt");
try {
if (mModule == null) {
mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "yolov5s.torchscript.ptl"));
}
} catch (IOException e) {
Log.e("Object Detection", "Error reading assets", e);
return null;
}
Bitmap bitmap = imgToBitmap(image.getImage());
Matrix matrix = new Matrix();
Expand Down
31 changes: 7 additions & 24 deletions QuestionAnswering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

## Introduction

Question Answering (QA) is one of the common and challenging Natural Language Processing tasks. With the revolutionary transformed-based [Bert](https://arxiv.org/abs/1810.04805) model coming out in October 2018, question answering models have reached their state of art accuracy by fine-tuning Bert-like models on QA datasets such as [Squad](https://rajpurkar.github.io/SQuAD-explorer). [Huggingface](https://huggingface.co)'s [DistilBert](https://huggingface.co/transformers/model_doc/distilbert.html) is a smaller and faster version of BERT - DistilBert "has 40% less parameters than bert-base-uncased, runs 60% faster while preserving over 95% of BERT’s performances as measured on the GLUE language understanding benchmark."
Question Answering (QA) is one of the common and challenging Natural Language Processing tasks. With the revolutionary transformed-based [BERT](https://arxiv.org/abs/1810.04805) model coming out in October 2018, question answering models have reached their state of art accuracy by fine-tuning BERT-like models on QA datasets such as [Squad](https://rajpurkar.github.io/SQuAD-explorer). [Huggingface](https://huggingface.co)'s [DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html) is a smaller and faster version of BERT - DistilBERT "has 40% less parameters than bert-base-uncased, runs 60% faster while preserving over 95% of BERT’s performances as measured on the GLUE language understanding benchmark."

In this demo app, written in Kotlin, we'll show how to quantize and convert the Huggingface's DistilBert QA model to TorchScript and how to use the scripted model on an Android demo app to perform question answering.

## Prerequisites

* PyTorch 1.7 or later (Optional)
* PyTorch 1.9.0 or later (Optional)
* Python 3.8 (Optional)
* Android Pytorch library 1.7 or later
* Android Pytorch library org.pytorch:pytorch_android_lite:1.9.0
* Android Studio 4.0.1 or later

## Quick Start
Expand All @@ -19,32 +19,15 @@ To Test Run the Android QA App, run the following commands on a Terminal:

### 1. Prepare the Model

If you don't have PyTorch installed or want to have a quick try of the demo app, you can download the scripted QA model compressed in a zip file [here](https://drive.google.com/file/d/1RWZa_5oSQg5AfInkn344DN3FJ5WbbZbq/view?usp=sharing), then unzip it to the assets folder, and continue to Step 2.
If you don't have PyTorch installed or want to have a quick try of the demo app, you can download the scripted QA model `qa360_quantized.ptl` [here](https://drive.google.com/file/d/1PgD3pAEf0riUiT3BfwHOm6UEGk8FfJzI/view?usp=sharing) and save it to the `QuestionAnswering/app/src/main/assets` folder, then continue to Step 2.

Be aware that the downloadable model file was created with PyTorch 1.7.0, matching the PyTorch Android library 1.7.0 specified in the project's `build.gradle` file as `implementation 'org.pytorch:pytorch_android:1.7.0'`. If you use a different version of PyTorch to create your model by following the instructions below, make sure you specify the same PyTorch Android library version in the `build.gradle` file to avoid possible errors caused by the version mismatch. Furthermore, if you want to use the latest PyTorch master code to create the model, follow the steps at [Building PyTorch Android from Source](https://pytorch.org/mobile/android/#building-pytorch-android-from-source) and [Using the PyTorch Android Libraries Built](https://pytorch.org/mobile/android/#using-the-pytorch-android-libraries-built-from-source-or-nightly) on how to use the model in Android.
Be aware that the downloadable model file was created with PyTorch 1.9.0, matching the PyTorch Android library 1.9.0 specified in the project's `build.gradle` file as `implementation 'org.pytorch:pytorch_android:1.9.0'`. If you use a different version of PyTorch to create your model by following the instructions below, make sure you specify the same PyTorch Android library version in the `build.gradle` file to avoid possible errors caused by the version mismatch. Furthermore, if you want to use the latest PyTorch master code to create the model, follow the steps at [Building PyTorch Android from Source](https://pytorch.org/mobile/android/#building-pytorch-android-from-source) and [Using the PyTorch Android Libraries Built](https://pytorch.org/mobile/android/#using-the-pytorch-android-libraries-built-from-source-or-nightly) on how to use the model in Android.

With PyTorch 1.7 installed, first install the Huggingface `transformers` by running `pip install transformers` (the versions that have been tested are 4.0.0 and 4.1.1), then run `python convert_distilbert_qa.py`.
With PyTorch 1.9.0 installed, first install the Huggingface `transformers` by running `pip install transformers`, then run `python convert_distilbert_qa.py`.

Note that a pre-defined question and text, resulting in the size of the input tokens (of question and text) being 360, is used in the `convert_distilbert_qa.py`, and 360 is the maximum token size for the user text and question in the app. If the token size of the inputs of the text and question is less than 360, padding will be needed to make the model work correctly.

After the script completes, copy the model file qa360_quantized.pt to the Android app's assets folder. [Dynamic quantization](https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html) is used to quantize the model to reduce its size to half, without causing inference difference in question answering - you can verify this by changing the last 4 lines of code in `convert_distilbert_qa.py` from:

```
model_dynamic_quantized = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
traced_model = torch.jit.trace(model_dynamic_quantized, inputs['input_ids'], strict=False)
optimized_traced_model = optimize_for_mobile(traced_model)
torch.jit.save(optimized_traced_model, "qa360_quantized.pt")
```

to

```
traced_model = torch.jit.trace(model, inputs['input_ids'], strict=False)
optimized_traced_model = optimize_for_mobile(traced_model)
torch.jit.save(optimized_traced_model, "qa360.pt")
```

and rerun `python convert_distilbert_qa.py` to generate a non-quantized model `qa360.pt` and use it in the app to compare with the quantized version `qa360_quantized.pt`.
After the script completes, copy the model file `qa360_quantized.ptl` to the Android app's assets folder.


### 2. Build and run with Android Studio
Expand Down
2 changes: 1 addition & 1 deletion QuestionAnswering/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies {
implementation 'androidx.appcompat:appcompat:1.2.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.4'

implementation 'org.pytorch:pytorch_android:1.7.0'
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation "androidx.core:core-ktx:+"
implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version"

Expand Down
Loading

0 comments on commit 3d31a03

Please sign in to comment.