This repository demonstrates an end-to-end pipeline for image classification using a Vision Transformer (ViT) model built with PyTorch, optimized for Apple's MPS (Metal Performance Shaders) to leverage GPU acceleration on M1, M2, and M3 Macs. The project includes a Flask-based API for backend model inference, allowing users to interact with the model through a web client. It provides a seamless integration of deep learning, RESTful APIs, and web deployment for real-time image classification tasks.
- PyTorch-based ViT image classification model, optimized with Apple MPS.
- Flask server with RESTful API for model inference.
- A simple web interface for interacting with the model.
- Apple Silicon Mac (M1, M2, or M3) for MPS acceleration (or use CPU). Windows GPU (CUDA) support will be added soon.
- Python 3.8 or higher.
- If using a Conda environment, Conda needs to be installed.
-
Clone the repository:
git clone https://github.com/mehradnia/PyTorch-ViT-Image-Classification-with-Apple-MPS-Flask-APIs-Web-Client.git cd PyTorch-ViT-Image-Classification-with-Apple-MPS-Flask-APIs-Web-Client
-
Create and activate a virtual environment (Choose one of the methods below):
-
Using venv:
python3 -m venv venv source venv/bin/activate # On Windows use venv\Scripts\activate
-
Using Conda (make sure Conda is installed):
conda create --name myenv python=3.x conda activate myenv
-
-
Install dependencies:
pip install -r requirements.txt
-
Set up your dataset:
-
Create a
data/[YOUR_DATASET]
directory. -
Add your dataset structured in class-based folders:
data/[YOUR_DATASET]/ ├── class1/ │ ├── 1.jpg │ └── 2.jpg ├── class2/ │ ├── 1.jpg │ └── 2.jpg
An example for an animals dataset:
data/animals/ ├── cat/ │ ├── 1.jpg │ └── 2.jpg ├── dog/ │ ├── 1.jpg │ └── 2.jpg
-
-
Open the
config.yaml
file and replacepath/to/your/data
with your data directory (eg:data/animals
). -
Open the
notebooks/vit_image_classifier.ipynb
file and proceed through the instructions within the notebook to train your model. -
Once training is completed, you can find the trained model in the
/models
directory.
-
Start the Flask server using the command:
python3 app/server.py
-
Flask will run the server at
http://localhost:8000
. -
The API exposes:
POST /predict
: Upload an image and get the classification result.
Example POST request:
curl -X POST -F "file=@path_to_image.jpg" http://localhost:8000/predict
This technique monitors the model's performance on the validation set during training. If the model's validation loss stops improving for a specified number of epochs (patience), training is halted to prevent overfitting and save time.
Random transformations are applied to the training data to increase the variety of the dataset. Techniques like random resizing, rotations, and color jittering are used to help the model generalize better by learning from a broader range of input variations.
It refers to the systematic partitioning of datasets into training, validation, and test subsets in a manner that guarantees consistency across multiple runs. This method helps make model evaluations reliable and allows others to replicate the experiments. Reproducible data splitting is important for keeping machine learning workflows fair and ensuring that different models can be compared accurately.