如果你是攝影愛好者,你可能對濾鏡并不陌生。它可以改變照片的色彩風格,使風景照片變得更清晰或肖像照片皮膚變白。但是,一個濾鏡通常只會改變照片的一個方面。要為照片應用理想的風格,您可能需要嘗試多種不同的濾鏡組合。這個過程與調整模型的超參數一樣復雜。
在本節中,我們將利用 CNN 的分層表示將一幅圖像的風格自動應用到另一幅圖像,即 風格遷移 (Gatys等人,2016 年)。此任務需要兩張輸入圖像:一張是內容圖像,另一張是風格圖像。我們將使用神經網絡修改內容圖像,使其在風格上接近風格圖像。例如 圖14.12.1中的內容圖片是我們在西雅圖郊區雷尼爾山國家公園拍攝的風景照,而風格圖是一幅以秋天的橡樹為主題的油畫。在輸出的合成圖像中,應用了樣式圖像的油畫筆觸,使顏色更加鮮艷,同時保留了內容圖像中對象的主要形狀。
14.12.1。方法
圖 14.12.2用一個簡化的例子說明了基于 CNN 的風格遷移方法。首先,我們將合成圖像初始化為內容圖像。這張合成圖像是風格遷移過程中唯一需要更新的變量,即訓練期間要更新的模型參數。然后我們選擇一個預訓練的 CNN 來提取圖像特征并在訓練期間凍結其模型參數。這種深度 CNN 使用多層來提取圖像的層次特征。我們可以選擇其中一些層的輸出作為內容特征或樣式特征。如圖14.12.2舉個例子。這里的預訓練神經網絡有 3 個卷積層,其中第二層輸出內容特征,第一層和第三層輸出風格特征。
接下來,我們通過正向傳播(實線箭頭方向)計算風格遷移的損失函數,并通過反向傳播(虛線箭頭方向)更新模型參數(輸出的合成圖像)。風格遷移中常用的損失函數由三部分組成:(i)內容損失使合成圖像和內容圖像在內容特征上接近;(ii)風格損失使得合成圖像和風格圖像在風格特征上接近;(iii) 總變差損失有助于減少合成圖像中的噪聲。最后,當模型訓練結束后,我們輸出風格遷移的模型參數,生成最終的合成圖像。
下面,我們將通過一個具體的實驗來解釋風格遷移的技術細節。
14.12.2。閱讀內容和樣式圖像
首先,我們閱讀內容和樣式圖像。從它們打印的坐標軸,我們可以看出這些圖像具有不同的尺寸。
%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
d2l.set_figsize()
content_img = d2l.Image.open('../img/rainier.jpg')
d2l.plt.imshow(content_img);
%matplotlib inline
from mxnet import autograd, gluon, image, init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
d2l.set_figsize()
content_img = image.imread('../img/rainier.jpg')
d2l.plt.imshow(content_img.asnumpy());
14.12.3。預處理和后處理
下面,我們定義了兩個用于預處理和后處理圖像的函數。該preprocess
函數對輸入圖像的三個 RGB 通道中的每一個進行標準化,并將結果轉換為 CNN 輸入格式。該postprocess
函數將輸出圖像中的像素值恢復為標準化前的原始值。由于圖像打印功能要求每個像素都有一個從0到1的浮點值,我們將任何小于0或大于1的值分別替換為0或1。
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])
def preprocess(img, image_shape):
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(image_shape),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
return transforms(img).unsqueeze(0)
def postprocess(img):
img = img[0].to(rgb_std.device)
img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))
rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])
def preprocess(img, image_shape):
img = image.imresize(img, *image_shape)
img = (img.astype('float32') / 255 - rgb_mean) / rgb_std
return np.expand_dims(img.transpose(2, 0, 1), axis=0)
def postprocess(img):
img = img[0].as_in_ctx(rgb_std.ctx)
return (img.transpose(1, 2, 0) * rgb_std + rgb_mean).clip(0, 1)
14.12.4。提取特征
我們使用在 ImageNet 數據集上預訓練的 VGG-19 模型來提取圖像特征( Gatys et al. , 2016 )。
為了提取圖像的內容特征和風格特征,我們可以選擇VGG網絡中某些層的輸出。一般來說,越靠近輸入層越容易提取圖像的細節,反之越容易提取圖像的全局信息。為了避免在合成圖像中過度保留內容圖像的細節,我們選擇了一個更接近輸出的VGG層作為內容層來輸出圖像的內容特征。我們還選擇不同 VGG 層的輸出來提取局部和全局風格特征。這些圖層也稱為樣式圖層。如第 8.2 節所述,VGG 網絡使用 5 個卷積塊。在實驗中,我們選擇第四個卷積塊的最后一個卷積層作為內容層,每個卷積塊的第一個卷積層作為樣式層。這些層的索引可以通過打印pretrained_net
實例來獲得。
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
當使用 VGG 層提取特征時,我們只需要使用從輸入層到最接近輸出層的內容層或樣式層的所有那些。讓我們構建一個新的網絡實例net
,它只保留所有用于特征提取的 VGG 層。
net = nn.Sequential(*[pretrained_net.features[i] for i in
range(max(content_layers + style_layers) + 1)])
給定輸入X
,如果我們簡單地調用前向傳播 net(X)
,我們只能得到最后一層的輸出。由于我們還需要中間層的輸出,因此我們需要逐層計算并保留內容層和樣式層的輸出。
def extract_features(X, content_layers, style_layers):
contents = []
styles = []
for i in range(len(net)):
X = net[i](X)
if i in style_layers:
styles.append(X)
if i in content_layers:
contents.append(X)
return contents, styles
下面定義了兩個函數:get_contents
函數從內容圖像中提取內容特征,函數get_styles
從風格圖像中提取風格特征。由于在訓練期間不需要更新預訓練 VGG 的模型參數,我們甚至可以在訓練開始之前提取內容和風格特征。由于合成圖像是一組需要更新的模型參數以進行風格遷移,因此我們只能extract_features
在訓練時通過調用函數來提取合成圖像的內容和風格特征。
def get_contents(image
評論
查看更多