Flower102 Dataset - Using Transfer Learning to train + Using Batch Normalization in CNN
Preface
I recently took an ai course, this is the fourth assignment and the main topics taught are the following.
- selecting a dataset and training a model on it.
- migration learning - fine tuning.
- batch normalization in CNN.
The main references are the following websites: 1.
- Flower102 dataset
- Migration Learning
- Pytorch dataset
- Migration Learning Model
- Shannon’s Transfer Learning Blog
- Resnet18
Assignment Requirements
Tasks
- Choose a dataset*: Look at torchvision Pytorch’s dataset and decide which dataset you want to use (excluding
CIFAR, ImageNet, FashionMNIST). - Print images and profile sizes: show some sample images of the dataset in your notebook and print the size of the dataset.
- Construct a CNN using batch normalization: design a CNN to make predictions on the dataset. Use a similar architecture as last time, but this time
also includes a batch normalization layer. - Train a model using a dataset and print out the accuracy of the test: train a model on a dataset and measure the accuracy on retained test data.
- Use ResNet18 for Migration Learning: now use migration learning to use a pre-trained ResNet18 on the dataset as follows:
- Without changing the trained weights of other people’s models: ResNet18 is used as a fixed feature extractor.
- Fine-tuning using RestNet : ResNet18 is fine-tuned on the training data.
- Fine-tuning using EfficientNet_B5: Repeat step 4 but now use EfficientNet_B5 instead of RestNet18.
Compare these different methods and print out the accuracy: Compare the accuracy of the different methods on the test data and print out the training time for each method.
Training time for each method.
Task 0 - Importing Packages
Let’s start by importing the required package: # Task 0 - import package
1 | # CNN |
Task 1 - Selecting a DataSet
Select a DataSet: Check out the torchvision DataSet of Pytorch and decide one dataset that you want to use (no
CIFAR, no ImageNet, no FashionMNIST).
In order to experience Transfer Learning and train it quickly, we use flower102 here. We use flower102 as our dataset. Since flower102 doesn’t provide Chinese labels, most of my searching on the web is done by reading the .json
or .txt
files that are already written, which describes each label index in Chinese.
1 | # Specify the data you want to download, the path and btach size, and the amount of training to do at once. |
1 | # Build the classes_name of the dataset |
Result
1 | {0: 'pink primrose', 1: 'hard-leaved pocket orchid', 2: 'canterbury bells', 3: 'sweet pea', 4: 'english marigold', 5: 'tiger lily', 6: 'moon orchid', 7: 'bird of paradise', 8: 'monkshood', 9: 'globe thistle', 10: 'snapdragon', 11: "colt's foot", 12: 'king protea', 13: 'spear thistle', 14: 'yellow iris', 15: 'globe-flower', 16: 'purple coneflower', 17: 'peruvian lily', 18: 'balloon flower', 19: 'giant white arum lily', 20: 'fire lily', 21: 'pincushion flower', 22: 'fritillary', 23: 'red ginger', 24: 'grape hyacinth', 25: 'corn poppy', 26: 'prince of wales feathers', 27: 'stemless gentian', 28: 'artichoke', 29: 'sweet william', 30: 'carnation', 31: 'garden phlox', 32: 'love in the mist', 33: 'mexican aster', 34: 'alpine sea holly', 35: 'ruby-lipped cattleya', 36: 'cape flower', 37: 'great masterwort', 38: 'siam tulip', 39: 'lenten rose', 40: 'barbeton daisy', 41: 'daffodil', 42: 'sword lily', 43: 'poinsettia', 44: 'bolero deep blue', 45: 'wallflower', 46: 'marigold', 47: 'buttercup', 48: 'oxeye daisy', 49: 'common dandelion', 50: 'petunia', 51: 'wild pansy', 52: 'primula', 53: 'sunflower', 54: 'pelargonium', 55: 'bishop of llandaff', 56: 'gaura', 57: 'geranium', 58: 'orange dahlia', 59: 'pink-yellow dahlia', 60: 'cautleya spicata', 61: 'japanese anemone', 62: 'black-eyed susan', 63: 'silverbush', 64: 'californian poppy', 65: 'osteospermum', 66: 'spring crocus', 67: 'bearded iris', 68: 'windflower', 69: 'tree poppy', 70: 'gazania', 71: 'azalea', 72: 'water lily', 73: 'rose', 74: 'thorn apple', 75: 'morning glory', 76: 'passion flower', 77: 'lotus lotus', 78: 'toad lily', 79: 'anthurium', 80: 'frangipani', 81: 'clematis', 82: 'hibiscus', 83: 'columbine', 84: 'desert-rose', 85: 'tree mallow', 86: 'magnolia', 87: 'cyclamen', 88: 'watercress', 89: 'canna lily', 90: 'hippeastrum', 91: 'bee balm', 92: 'ball moss', 93: 'foxglove', 94: 'bougainvillea', 95: 'camellia', 96: 'mallow', 97: 'mexican petunia', 98: 'bromelia', 99: 'blanket flower', 100: 'trumpet creeper', 101: 'blackberry lily'} |
Here I mainly refer to the official website Transfer Learning to change the writing style to the dataSet I want, and start to download the file:
1 | # Data augmentation and normalization for training |
This completes the first Task, which is to download the dataset we want.
Task 2 - Printing out images and profile sizes
- Print images and profile size: display some sample images of the dataset in the notebook and print the dataset size.
Referring to the official website Transfer Learning for the writeup, we first create the
1 | def imshow(inp, title=None): |
Result
Task 3 & 4 - CNN + Batch Normalization
- Construct a CNN using Batch Normalization: Design a CNN to predict on the dataset. Use a similar architecture like last time, but this time also include batch normalization layers.
- **Train the model on the dataset and measure the accuracy on hold out test data.
According to Prof. Hongyi Li in Transfer Learning, he mentioned…
Usually, Batch Normalization
is performed before Activation Function
, you can refer to this section if you are interested. Batch Normalization
is simply to run feature normalization
in the same way as Batch
.
Why do we need to do feature normalization? Why do we do feature normalization?
It is to let different features have similar value ranges, so that when the model performs Gradient Descent, the effect of w1 and w2 on the loss will not be too big, and they have similar value ranges, so that they can affect the loss evenly, instead of a certain w1 affecting the loss much more than w2.
Instead of a w1 having a much larger effect on loss than w2, the effect will be something like the following.
! Origin
Build the Network
Note that depending on the size of the dataset and the number of hidden layers, you have to make two adjustments!
- In the fully connection layer, the input is determined by the number of times your hidden layer performs
max-pooling
andconvolution
. 2. - Then you have to adjust the number of outputs in the last output layer according to the number of categories in your dataset.
Please note the part of the code labeled with the comment arrow <====
.
So our CNN architecture is as follows, You can decide whether you want to run a dropout to unpack the annotations or not
.
But in my case, I tested that the dropout didn’t result in higher accuracy.
1 | import torch.nn as nn |
and specify the optimizer and loss function:
1 | import torch.optim as optim |
Create a Training Func
We need to create a funcntion to execute the training model as follows:
1 | def train(epoch, start_time): |
Build Testing Func
1 | def test(): |
Run Training
In order for us to see the status of the training during the training process, we print the status of the training every 100 batches and the status of the test every 5 epochs.
1 | num_epochs = 100 |
Result
1 | Accuracy on test data (top-1): 0.0019% |
Task 5 & 4 - Transfer Learning:Resnet18
- Train a model using a dataset and print out the accuracy of the test: train a model on a dataset and measure the accuracy on retained test data.
- Use ResNet18 for Migration Learning: now use migration learning to use a pre-trained ResNet18 on the dataset as follows:
Without changing the trained weights of other people’s models: ResNet18 is used as a fixed feature extractor.- Fix the parameters: Fix the parameters of the ResNet18 model and only train the last layer.
- Fine-tuning: Fine-tune the ResNet18 model.
Build up Trainning & Testing Func
Refer to the official website Transfer Learning for the writeup, we first create the
1 | # Set the start time for logging to monitor the training time for each epoch |
Use Transfer Learning
According to the teacher’s request, I want to use resnet18 for Transfer Learning, currently according to the [official description](https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet18. html#torchvision.models.ResNet18_Weights), resent18
is IMAGENET1K_V1
by default if we don’t give the parameter, in order to make it clear which model’s parameter we are using, we still give the parameter.
1 | model_ft = models.resnet18(weights='IMAGENET1K_V1') |
Why need to adjust lr?
Adjusting the learning rate every certain epoch is a common learning rate adjustment strategy called learning rate decay or learning rate scheduling. The effect of this is to:
-
Improve model stability: During training, ``using a relatively large learning rate at the beginning helps to converge quickly.‘’ However, when training
near the optimal solution, a larger learning rate may cause the model to oscillate or over-adjust near the optimal solution
. By periodically decreasing the learning rate, the model will be more stable and closer to the optimal solution in the later stages of training. -
Preventing overfitting: `Periodically decreasing the learning rate helps prevent the model from overfitting on the training set.’ When the learning rate is reduced, the model adjusts its parameters more carefully and is less likely to fall into the noise in the training set.
In practice, the specific settings of the learning rate tuning strategy (e.g., the values of step_size
and gamma
) are usually adjusted based on trials and experience to achieve optimal performance. Typically, the settings of these parameters depend on the size of your dataset, the model architecture, the difficulty of the problem, and other factors.
Start Training
1 | model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, |
Result: Accuracy on test data (top-1): 89.41%
1 | Epoch 0/24 |
Using ResNet18 as a fixed feature extractor
Due to the requirements of the homework, ResNet18 is required to be used as a fixed feature extractor, so we need to set all the parameters to be untrainable, and only the parameters of the last layer can be trained. Simply put, don’t change the weights of other people’s models. The only thing we need to change is to set each parameter of the model’s requires_grad
to False. This way, we can use ResNet18 as a fixed feature extractor.
1 | model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1') |
Result: Accuracy on test data: 79.11%
1 | Epoch 0/24 |
Task 6 & 4 - Transfer Learning:EfficientNet_B5
- Train a model using a dataset and print out the accuracy of the test: train a model on a dataset and measure the accuracy on retained test data.
- Fine-tuning using RestNet : ResNet18 is fine-tuned on the training data.
We need to install the efficientnet_pytorch
package first:
1 | pip install efficientnet_pytorch |
Then we can import the package and use it:
1 | from efficientnet_pytorch import EfficientNet |
Result: Accuracy on test data: 82.15%
1 | Epoch 0/24 |
Task 7 - Discussion
Compare these different methods and print out the accuracy: Compare the accuracy of the different methods on the test data and print out the training time for each method.
From the results of the above experiments, we can see that the accuracy of the model is as follows:
- Use the CNN build by ourselves
- [Using Transfer Learning Resnet18](#Task-5-4-Transfer Learning: Resnet18)
- [Using Transfer Learning EfficientNet-B5](#Task-6-4-Transfer Learning: EfficientNet-B5)
Their data are as follows:
Model | Accuracy | Training Time | Result |
---|---|---|---|
Self-built CNN | 35% |
more than 1 hour |
Worst |
Resnet18 | 89.41% | 21 minutes | highest accuracy |
Resnet18 (fixed feature extractor) | 79.11% | 19 minutes | Shortest |
EfficientNet-B5 | 82.15% | 80 mins | PuPu |
Conclusion
- If we use Transfer Learning, we can obviously feel that the accuracy rate is significantly improved and the training time is greatly reduced.
- Moreover, in the current case, the accuracy is better without fixed model parameters, although the matching time is longer because of the gradient descent.