Skip to content

Latest commit

 

History

History
54 lines (36 loc) · 1.9 KB

README.md

File metadata and controls

54 lines (36 loc) · 1.9 KB

PytorchKotlinDemo

This is an Android project written in Kotlin to show a simple image classification application that uses Android PyTorch API and a trained PyTorch model.

In this demo application, user can either upload a picture or take photo. Then run the image analysis on the picture.

Architecture

I followed one of the android google architecture component sample.

This sample showcases the following Architecture Components:

Serialize a PyTorch model

In Android demo github, it describe in detail how the PyTorch model generated.

We cannot use the saved model directly in the notebook, we need to serialize that saved model.

In the Jupyter notebook where I trained the model, I can do the following

//load the model
ckp_path = './best_model.pt'
if(use_cuda):
  checkpoint = torch.load(ckp_path)
else:
  checkpoint = torch.load(ckp_path, map_location=torch.device('cpu'))
loaded_model.load_state_dict(checkpoint['state_dict'])
//serialize the model
loaded_model.eval()
example = torch.rand(1, 3, 224, 224)
if use_cuda:
  example = example.cuda()
traced_script_module = torch.jit.trace(loaded_model, example)
traced_script_module.save("./serialized_model.pt")

After this operations, we should have a usable model, serialized_model.pt.

References

Android demo github

PyTorch Mobile

android google architecture component sample