大家好我們今天來講一講如何用Matlab做一個新的遷移學習您可能需要的基礎知識
Matlab編程Deep learning的基礎知識
一、什么是遷移學習?
以圖像識別為例。如果你想構建一個神經網絡,讓它能夠識別馬匹,但是手上又沒有任何公開的算法可以完成這項任務。這時,借助遷移學習,你可以從一個原本是用來識別其它動物的現成的卷積神經網絡(CNN)入手,對其進行調整并訓練它識別馬匹。深度學習應用中常常用到遷移學習??梢圆捎妙A訓練的網絡,基于它學習新任務。與使用隨機初始化的權重從頭訓練網絡相比,通過遷移學習微調網絡要更快更簡單。我們可以使用較少數量的訓練圖像快速地將已學習的特征遷移到新任務。
二、網絡的創建和數據的導入
加載數據
解壓縮新圖像并加載這些圖像作為圖像數據存儲。imageDatastore 根據文件夾名稱自動標注圖像,并將數據存儲為 ImageDatastore 對象。通過圖像數據存儲可以存儲大圖像數據,包括無法放入內存的數據,并在卷積神經網絡的訓練過程中高效分批讀取圖像。
unzip(‘MerchData.zip’);imds = imageDatastore(‘MerchData’, 。.. ‘IncludeSubfolders’,true, 。.. ‘LabelSource’,‘foldernames’);
將數據劃分為訓練數據集和驗證數據集。將 70% 的圖像用于訓練,30% 的圖像用于驗證。splitEachLabel 將 images 數據存儲拆分為兩個新的數據存儲。
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,‘randomized’);
這個非常小的數據集現在包含 55 個訓練圖像和 20 個驗證圖像。
numTrainImages = numel(imdsTrain.Labels);idx = randperm(numTrainImages,16);figurefor i = 1:16 subplot(4,4,i) I = readimage(imdsTrain,idx(i)); imshow(I)end
加載預訓練網絡
加載預訓練的 AlexNet 神經網絡。如果未安裝 Deep Learning Toolbox Model for AlexNet Network,則軟件會提供下載鏈接。AlexNet 已基于超過一百萬個圖像進行訓練,可以將圖像分為 1000 個對象類別(例如鍵盤、鼠標、鉛筆和多種動物)。因此,該模型已基于大量圖像學習了豐富的特征表示。
net = alexnet;
使用 analyzeNetwork 可以交互可視方式呈現網絡架構以及有關網絡層的詳細信息。
analyzeNetwork(net)
第一層(圖像輸入層)需要大小為 227×227×3 的輸入圖像
其中 3 是顏色通道數
inputSize = 1×3 227 227 3
三、網絡的訓練
替換最終層
預訓練網絡 net 的最后三層針對 1000 個類進行配置。必須針對新分類問題微調這三個層。從預訓練網絡中提取除最后三層之外的所有層。
layersTransfer = net.Layers(1:end-3);
通過將最后三層替換為全連接層、softmax 層和分類輸出層,將層遷移到新分類任務。根據新數據指定新的全連接層的選項。將全連接層設置為大小與新數據中的類數相同。要使新層中的學習速度快于遷移的層,請增大全連接層的 WeightLearnRateFactor 和 BiasLearnRateFactor 值。
numClasses = numel(categories(imdsTrain.Labels))numClasses = 5
layers = [ layersTransfer fullyConnectedLayer(numClasses,‘WeightLearnRateFactor’,20,‘BiasLearnRateFactor’,20) softmaxLayer classificationLayer];
訓練網絡
網絡要求輸入圖像的大小為 227×227×3,但圖像數據存儲中的圖像具有不同大小。使用增強的圖像數據存儲可自動調整訓練圖像的大小。指定要對訓練圖像額外執行的增強操作:沿垂直軸隨機翻轉訓練圖像,以及在水平和垂直方向上隨機平移訓練圖像最多 30 個像素。數據增強有助于防止網絡過擬合和記憶訓練圖像的具體細節。
pixelRange = [-30 30];imageAugmenter = imageDataAugmenter( 。.. ‘RandXReflection’,true, 。.. ‘RandXTranslation’,pixelRange, 。.. ‘RandYTranslation’,pixelRange);augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, 。.. ‘DataAugmentation’,imageAugmenter);
對驗證圖像進行分類
使用經過微調的網絡對驗證圖像進行分類
[YPred,scores] = classify(netTransfer,augimdsValidation);
顯示四個示例驗證圖像及預測的標簽。
idx = randperm(numel(imdsValidation.Files),4);figurefor i = 1:4 subplot(2,2,i) I = readimage(imdsValidation,idx(i)); imshow(I) label = YPred(idx(i)); title(string(label));end
計算針對驗證集的分類準確度。準確度是網絡預測正確的標簽的比例
YValidation = imdsValidation.Labels;accuracy = mean(YPred == YValidation)
accuracy = 1
今天你學廢了嗎???
編輯:lyn
-
matlab
+關注
關注
185文章
2977瀏覽量
230601 -
神經網絡
+關注
關注
42文章
4772瀏覽量
100855 -
圖像識別
+關注
關注
9文章
520瀏覽量
38282 -
遷移學習
+關注
關注
0文章
74瀏覽量
5566
原文標題:【圖像識別】基于Matlab的遷移學習的圖像分類案例
文章出處:【微信號:vision263com,微信公眾號:新機器視覺】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論