Flower102 Dataset - 使用 Transfer Learning 訓練 + 使用 Batch Normalization 於 CNN
前言
最近選了一堂AI課程,這是第四個作業,主要教授內容為以下主題:
- Pick a dataset and train a model on it.
- Transfer Learning - Fine Tuning.
- Batch Normalization in CNN.
主要參考以下網站:
- Flower102 Dataset
- Transfer Learning
- DataSet of Pytorch
- Models for transfer learning
- Shannon’s Blog of Transfer Learning
- Resnet18
作業要求
Task:
- 選擇一個DataSet: Check out the torchvision DataSet of Pytorch and decide one dataset that you want to use (no
CIFAR, no ImageNet, no FashionMNIST). - 印出圖片和資料大小:Show some example images of the dataset in the notebook and print the dataset size.
- 建構使用Batch Normalization的CNN:Design a CNN to predict on the dataset. Use a similar architecture like last time, but this time
also include batch normalization layers. - 使用dataset訓練模型並印出Testing的準確率:Train the model on the dataset and measure the accuracy on hold out test data.
- 使用ResNet18來進行Transfer-Learning:Now use transfer learning to use a pre-trained ResNet18 on the dataset as follows:
- 不改變別人模型訓練好的權重:ResNet18 as fixed feature extractor.
- 使用RestNet進行Fineturned:ResNet18 finetuned on the training data.
- 使用EfficientNet_B5進行Fineturned:Repeat step 4 but now use EfficientNet_B5 instead of RestNet18.
- 比較這些不同的方法,並列印出準確度:Compare the accuracy of the different approaches on the test data and print out the training
times for each approach.
Task 0 - import package
先來導入所需的套件:
1 | # CNN |
Task 1 - 選擇一個DataSet
選擇一個DataSet: Check out the torchvision DataSet of Pytorch and decide one dataset that you want to use (no
CIFAR, no ImageNet, no FashionMNIST).
為了體驗 Transfer Learning,並且快速訓練。我們這邊使用 flower102 來作為我們的資料集。因為 flower102 沒有提供中文的 Label,我網路上找大部分都是讀取已經寫好的 .json
或是 .txt
檔案,該檔案會描述每一個 label index 對應的中文。
1 | # 指定你要下載的資料及路徑 和 btach size 一次訓練的量 |
1 | # 建立 dataset 的 classes_name |
結果如下
1 | {0: 'pink primrose', 1: 'hard-leaved pocket orchid', 2: 'canterbury bells', 3: 'sweet pea', 4: 'english marigold', 5: 'tiger lily', 6: 'moon orchid', 7: 'bird of paradise', 8: 'monkshood', 9: 'globe thistle', 10: 'snapdragon', 11: "colt's foot", 12: 'king protea', 13: 'spear thistle', 14: 'yellow iris', 15: 'globe-flower', 16: 'purple coneflower', 17: 'peruvian lily', 18: 'balloon flower', 19: 'giant white arum lily', 20: 'fire lily', 21: 'pincushion flower', 22: 'fritillary', 23: 'red ginger', 24: 'grape hyacinth', 25: 'corn poppy', 26: 'prince of wales feathers', 27: 'stemless gentian', 28: 'artichoke', 29: 'sweet william', 30: 'carnation', 31: 'garden phlox', 32: 'love in the mist', 33: 'mexican aster', 34: 'alpine sea holly', 35: 'ruby-lipped cattleya', 36: 'cape flower', 37: 'great masterwort', 38: 'siam tulip', 39: 'lenten rose', 40: 'barbeton daisy', 41: 'daffodil', 42: 'sword lily', 43: 'poinsettia', 44: 'bolero deep blue', 45: 'wallflower', 46: 'marigold', 47: 'buttercup', 48: 'oxeye daisy', 49: 'common dandelion', 50: 'petunia', 51: 'wild pansy', 52: 'primula', 53: 'sunflower', 54: 'pelargonium', 55: 'bishop of llandaff', 56: 'gaura', 57: 'geranium', 58: 'orange dahlia', 59: 'pink-yellow dahlia', 60: 'cautleya spicata', 61: 'japanese anemone', 62: 'black-eyed susan', 63: 'silverbush', 64: 'californian poppy', 65: 'osteospermum', 66: 'spring crocus', 67: 'bearded iris', 68: 'windflower', 69: 'tree poppy', 70: 'gazania', 71: 'azalea', 72: 'water lily', 73: 'rose', 74: 'thorn apple', 75: 'morning glory', 76: 'passion flower', 77: 'lotus lotus', 78: 'toad lily', 79: 'anthurium', 80: 'frangipani', 81: 'clematis', 82: 'hibiscus', 83: 'columbine', 84: 'desert-rose', 85: 'tree mallow', 86: 'magnolia', 87: 'cyclamen', 88: 'watercress', 89: 'canna lily', 90: 'hippeastrum', 91: 'bee balm', 92: 'ball moss', 93: 'foxglove', 94: 'bougainvillea', 95: 'camellia', 96: 'mallow', 97: 'mexican petunia', 98: 'bromelia', 99: 'blanket flower', 100: 'trumpet creeper', 101: 'blackberry lily'} |
這邊我主要是參考官方網站Transfer Learning的寫法,改成自己想要的 dataSet,並開始下載檔案:
1 | # Data augmentation and normalization for training |
這樣我們就完成了第一個Task,也就是下載好我們想要的 dataset。
Task 2 - 印出圖片和資料大小
- 印出圖片和資料大小:Show some example images of the dataset in the notebook and print the dataset size.
參考官方網站Transfer Learning的寫法,我們先建立
1 | def imshow(inp, title=None): |
結果如下:
Task 3 & 4 - CNN + Batch Normalization
- 建構使用Batch Normalization的CNN:Design a CNN to predict on the dataset. Use a similar architecture like last time, but this time also include batch normalization layers.
- 使用dataset訓練模型並印出Testing的準確率:Train the model on the dataset and measure the accuracy on hold out test data.
根據李鴻毅教授在 Transfer Learning 提到…
通常會在 Activation Function
之前執行 Batch Normalization
,有興趣可以參考這個章節。Batch Normalization
簡單來說就是以 Batch
的方式,執行 feature normalization
。
為什麼要做 feature normalization ?
他就是為了讓不同的 feature 有類似接近的數值範圍,這樣模型在執行Gradient Descent的時候,w1, w2 對 loss 的影響才不會太大,他們擁有相似的數值範圍,才能夠平均的影響 loss,而不是某個 w1 對 loss 的影響遠大於 w2。
大概是下圖這種效果。
建立 Network
請注意,根據不同的 dataset 其尺寸大小還有 hidden layer的數量,你要做兩個調整!!
- 在 fully connection layer 中,input 要根據你的 hidden layer 執行
max-pooling
跟convolution
的次數來決定。 - 然後你要根據你 dataset 的 categories 數量,調整最後一層 output layer 的output數量。
請注意程式碼中標示註解箭頭<====
的部分
所以我們這邊的 CNN 架構如下,你可以根據需求決定是否要執行 dropout,來解開註解
。
但是在我的情境中,我測試 dropout 並沒有帶來比較高的準確率。
1 | import torch.nn as nn |
並且指定 optimizer 和 loss function:
1 | import torch.optim as optim |
建立 Training Func
我們需要建立一個fucntion來執行訓練模型的動作如下:
1 | def train(epoch, start_time): |
建立 Testing Func
1 | def test(): |
執行 Training
為了讓訓練過程中,我們可以看到訓練的狀況,所以我們每 100 個 batch 就印出一次訓練的狀況,並且每 5 個 epoch 就印出一次 test 的狀況。
1 | num_epochs = 100 |
結果如下
1 | Accuracy on test data (top-1): 0.0019% |
Task 5 & 4 - Transfer Learning:Resnet18
- 使用dataset訓練模型並印出Testing的準確率:Train the model on the dataset and measure the accuracy on hold out test data.
- 使用ResNet18來進行Transfer-Learning:Now use transfer learning to use a pre-trained ResNet18 on the dataset as follows:
- 把參數fixed:ResNet18 as fixed feature extractor.
- 使用RestNet進行Fineturned:ResNet18 finetuned on the training data.
建立 Trainning & Testing Func
根據官方的範例Transfer Learning,我是直接複製過來的。
1 | def train_model(model, criterion, optimizer, scheduler, num_epochs=25): |
使用 Transfer Learning
根據老師的要求,要使用 resnet18 來進行 Transfer Learning,目前根據官方說明,resent18
如果不給予參數,則預設就是 IMAGENET1K_V1
,為了清楚我們到底使用哪一個 model 的參數,我們還是給予參數。
1 | model_ft = models.resnet18(weights='IMAGENET1K_V1') |
大概是這種感覺來進行Transfer Learning
為什麼要調整lr?
將學習率每隔一定的 epoch 進行調整是一種常見的學習率調整策略,稱為學習率衰減(learning rate decay)或學習率調度(learning rate scheduling)。這樣的效果是:
-
提高模型的穩定性:在訓練過程中,
一開始使用相對較大的學習率,有助於快速收斂
。但當訓練靠近最佳解時,較大的學習率可能導致模型在最佳解附近震盪或過度調整
。透過週期性地降低學習率,模型在訓練的後期會更穩定,更接近最佳解。 -
防止過度擬合:
週期性地降低學習率有助於防止模型在訓練集上過度擬合
。當學習率降低時,模型更謹慎地調整參數,不太容易陷入訓練集中的噪聲。
在實際應用中,學習率調整策略的具體設置(例如,step_size
和 gamma
的值)通常是根據試驗和經驗來調整的,以達到最佳性能。通常,這些參數的設置取決於你的數據集大小、模型架構、問題的難度和其他因素。
開始訓練
1 | model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, |
結果如下: 準確率 89.41% 挺好的
1 | Epoch 0/24 |
使用 ResNet18 作為 fixed feature extractor
因為作業有要求,要使用 ResNet18 作為 fixed feature extractor,所以我們要把所有的參數都設定為不可訓練,只有最後一層的參數是可以訓練的,簡單來說就是別人訓練好的 model 你就不要改人家的 weight 了拉。要改的地方就是,把 model 的每個 parameters 的 requires_grad
都設定為 False。這樣,我們就可以把 ResNet18 當作 fixed feature extractor 來使用。
1 | # 其他都老樣子 |
結果如下:準確率 79.11% 比較差一點,但是這樣的訓練速度會比較快
1 | Epoch 0/24 |
Task 6 & 4 - Transfer Learning:EfficientNet_B5
- 使用dataset訓練模型並印出Testing的準確率:Train the model on the dataset and measure the accuracy on hold out test data.
- 使用EfficientNet_B5進行Fineturned:Repeat step 4 but now use EfficientNet_B5 instead of RestNet18.
接下來,我們需要把 RestNet18 根據題目要求換成別的訓練好的模型,你可能會需要先透過 pip 安裝 efficientnet_pytorch
:
1 | pip install efficientnet_pytorch |
然後再執行下面的程式碼:
1 | from efficientnet_pytorch import EfficientNet |
結果如下:準確率 73.33% 好像沒有比較好。
1 | Epoch 0/24 |
Task 7 - 討論
- 比較這些不同的方法,並列印出準確度:Compare the accuracy of the different approaches on the test data and print out the training
從前面開始,我們測試了幾個方法:
他們的數據大概如下:
Model | Accuracy | Training Time | Result |
---|---|---|---|
自建 CNN | 35% |
超過1小時 |
最差 |
Resnet18 | 89.41% | 21分鐘 | 準確率最高 |
Resnet18 (fixed feature extractor) | 79.11% | 19分鐘 | 時間最短 |
EfficientNet-B5 | 82.15% | 80分鐘 | 普普 |
結論
- 如果使用 Transfer Learning 可明顯感受到,準確率明顯提高,並且訓練時間大幅縮短。
- 再者,以目前的案例來說,不要fixed model 的參數,準確率比較好,雖然相對的時間也會比較長,因為要做gradient descent。