Wednesday, January 17, 2024

FastAI Course Part 1: Making an Anime Recognizer with Pytorch and FastAI

Beginning the FastAI Practical Deep Learning course

I have recently set out to complete the FastAI course "Practical Deep Learning", which covers many topics from image recognition to stable diffusion. The course is available as a free e-book which is published as a series of Jupyter Notebooks.

In fact, both the course and the FastAI library itself are built on Jupyter Notebooks, a fitting environment for casual deep learning experimentation. However, this can lead to some interesting quirks of the FastAI library, such as their many monkey-patched classes which make debugging difficult or required use of "import *" which obscures source modules. But these annoyances are well worth the benefit of FastAI's many premade and boilerplate systems for loading and tagging data, training models, and previewing results in Jupyter Notebooks.

The course's first chapter covers introductions to machine learning, model training, and convolutional neural networks. At the end of the chapter, the course asks you find your own application for CNN image recognizers to test your new knowledge. I opted to create an Anime Recognizer which will use a pre-trained Resnet34 CNN to distinguish between screenshots from 4 different animes. While much more advanced anime recognition systems already exist online, I thought this would still be a good application for learning on a subject that is fun.

Some other students of the class had created impressive applications of CNNs through techniques such as:

  • Converting audio waveforms to images and recognizing sounds
  • Converting mouse movement to images and recognizing humans vs bot users of websites
  • Recognizing cities by areal photographs

Project plan

My assumption was that I could use a CNN image classifier to identify a handful of anime shows based on their stylistic differences. For instance a show like Samurai Champloo has different edges, gradients, and contours from the animation styles in Blue Eye Samurai. My hope was that the CNN would be able to recognize these features and make an accurate guess of which show a screenshot is taken from. The faces of characters in the shows may also appear as layers in the CNN.

My implementation plan was as follows:

  1. Torrent video files of 3 shows
  2. Create a system to capture standardized screenshots from the video files using ffmpeg-python and Pillow
  3. Load the screenshots as a labeled training and validation data set
  4. Train both a blank and pre-trained resnet34 model on the training data
  5. Evaluate its performance

Creating a training data set

As with any AI project, the most difficult step of this process would be aquiring and creating the training data for the CNN. I was able to find complete video files for all 3 animes online, and after some troubleshooting with ffmpeg, I was able to reliably capture square screenshots from the center of each video file. I opted to take a screenshot for every 10 seconds of video to ensure a variety of scenes. I also set a maximum number of screenshots from a single video file, ensuring that the training data captured at least 6 video files for each show. Overall, my training data set was 1800 items, or 600 images per show.

Training the model

While training the model, I was interested in understanding the difference between a "blank" resnet34 CNN vs a pre-trained CNN. (In this case, Pytorch has a resnet34 model pre-trained on the ImageNet-1K dataset. I trained 10 epochs on the blank model and fine-tuned 5 epochs on top of the pre-trained model.



Results

As you can see from the results in the screenshot below, the pretrained model (pretrained on IMAGENET1K_V2) out-performed the blank model, even when additional epochs were trained on the blank model. This is likely due to either over-fitting on the blank model, transfer learning with the pretrained model, or a combination of both factors.

The fine-tuned model achieved 95% accuracy after 5 epochs and the blank model achieved 81% accuracy after 10 epochs.