Vision Transformer on Android
Introduction
We use 2 well known vision models:
-
Facebook DeiT model, a ViT model pre-trained on ImageNet, for image classification on Android;
-
ViT model on MNIST and convert it to TorchScript to use on Android for handwritten digit recognition.
Prerequisites
- PyTorch 1.7 or later (Optional)
- Python 3.8 (Optional)
- Android Pytorch library 1.7 or later
- Android Studio 4.0.1 or later
Quick Start on Using Facebook DeiT
1. Prepare the Model (Optional)
To use a pre-trained Facebook DeiT model and convert it to TorchScript, first install PyTorch 1.7 or later, then install timm using pip install timm==0.3.2
, and finally run the following script:
python convert_deit.py
This will generate the quantized scripted model named fbdeit.pt
, which can also be downloaded here. Note that the quantization code in the script reduces the model size from 346MB to 89MB.
To train and convert your own DeiT model on ImageNet, first follow the instructions under Data Preparation and Training at the DeiT repo, then simply run the following code after model
is trained:
from torch.utils.mobile_optimizer import optimize_for_mobile
ts_model = torch.jit.script(model)
optimized_torchscript_model = optimize_for_mobile(ts_model)
optimized_torchscript_model.save("fbdeit.pt")
2. Run the Model on Android
Changes in MainActivity.java
file from:
module = Module.load(assetFilePath(this, "model.pt"));
```py
to
```py
module = Module.load(assetFilePath(this, "fbdeit.pt"));
Run the app in Android Studio and you'll see the same image classification result.
Quick Start on Using ViT for MNIST
To Test Run the Android ViT4MNIST demo app, follow the steps below:
1. Prepare the Model (Optional)
On a Terminal, with PyTorch 1.7.0 and einops installed, run :
python mnist_vit.py
The model definition in vit_pytorch.py
and training code in mnist_vit.py
are mostly taken from the blog here.
2. Build and run with Android Studio
Run on your AVD or real Android device.