This project implements an image classification system using Convolutional Neural Networks (CNNs). The system is built with PyTorch and Flask, allowing for both training models on the CIFAR-10 dataset and serving a web interface for image classification.
models/
: Contains the model definitions for both Simple and Advanced CNNs.simple_cnn_model.py
: Defines theSimpleCNN
model.advanced_cnn_model.py
: Defines theAdvancedCNN
model.
controllers/
: Contains the Flask application.prediction_controller.py
: Defines the Flask app and routes.
services/
: Contains the prediction service.prediction_service.py
: Provides methods for image preprocessing and prediction.
templates/
: Contains the HTML templates for the web interface.index.html
: Main page for image upload and URL input.
-
Clone the repository:
git clone https://github.com/svbuh/showcase_architecture_3-layer.git cd showcase_architecture_3-layer
-
Create a virtual environment:
python3 -m venv venv source venv/bin/activate # On Windows use `venv\Scripts\activate`
-
Install the dependencies:
pip install -r requirements.txt
To train the model, run the train_model.py
script. You can choose between the SimpleCNN
and AdvancedCNN
models by uncommenting the desired model in the script.
python train_model.py
This script will:
- Load the CIFAR-10 dataset.
- Train the selected model.
- Save the trained model to a
.pth
file. - Evaluate the model on the test set and print the accuracy.
- Ensure the Flask app configuration points to the correct model path in
services/prediction_service.py
. - Run the Flask app:
python controllers/prediction_controller.py
- Open your browser and navigate to
http://127.0.0.1:5000/
to access the web interface.
The web interface allows users to classify images either by uploading a file or by providing an image URL.
- Provide an Image URL:
- Enter the image URL in the provided input field.
- Click on the "Classify" button to see the predicted class.
The train_model.py
script can be run to train the model and evaluate it on the CIFAR-10 test set.
torch==2.3.1
torchvision==0.18.1
numpy==1.26.4
Pillow==10.3.0
matplotlib==3.9.0
Werkzeug==3.0.3
Flask==3.0.3
requests==2.32.3
- The CIFAR-10 dataset is used for training and evaluation.
- The project uses pre-trained models from
torchvision.models
.