This PyTorch project focuses on building and training a RESNET18 deep learning model on the CIFAR-10 dataset. It includes various features such as data augmentation using the Albumentations library, a custom dataset loader, plotting train and test loss curves, GradCam visualization of randomly sampled misclassified images, and visualization of misclassified images with labels and appropriate legends.
Follow the steps below to get started with the project:
-
Clone the repository:
git clone https://github.com/acharyasunil/cifar10_grad_cam.git cd cifar10_grad_cam
-
Install the required packages:
pip install -r requirements.txt
-
Download and preprocess the CIFAR-10 dataset:
Before training the model, you need to download the CIFAR-10 dataset. You can use the
torchvision
library to download and preprocess the dataset. Apply data augmentation using Albumentations to increase the diversity of the training data.
-
Train the Model:
To train the model on the CIFAR-10 dataset for 20 epochs, run the following command:
python main.py
OR Run S11.ipynb in jupyter notebook.
-
Evaluate the Model:
After training, the model's performance will be evaluated on the test dataset automatically.
In this project, data augmentation is performed using the Albumentations library. Augmentations like RandomCrop and Cutout are applied to the training data to enhance the model's generalization.
The project includes a custom dataset loader for CIFAR-10. The CustomCIFAR10
class extends the PyTorch Dataset
class and allows you to customize the data loading process.
GradCam is used to visualize the regions of interest in the misclassified images that contribute to the model's predictions. The gradcam.py
script can be used to apply GradCam to randomly sampled misclassified images.
This PyTorch CIFAR-10 project provides a comprehensive pipeline for training and evaluating deep learning models on the CIFAR-10 dataset. With data augmentation, custom dataset loading, loss curve visualization, GradCam visualization, and misclassified image visualization, you can gain insights into the model's performance and identify misclassifications. Feel free to explore and modify the code to suit your needs and experiment with different model architectures and hyperparameters.