如果你是攝影愛好者,你可能對濾鏡并不陌生。它可以改變照片的色彩風格,使風景照片變得更清晰或肖像照片皮膚變白。但是,一個濾鏡通常只會改變照片的一個方面。要為照片應用理想的風格,您可能需要嘗試多種不同的濾鏡組合。這個過程與調整模型的超參數一樣復雜。
在本節中,我們將利用 CNN 的分層表示將一幅圖像的風格自動應用到另一幅圖像,即 風格遷移 (Gatys等人,2016 年)。此任務需要兩張輸入圖像:一張是內容圖像,另一張是風格圖像。我們將使用神經網絡修改內容圖像,使其在風格上接近風格圖像。例如 圖14.12.1中的內容圖片是我們在西雅圖郊區雷尼爾山國家公園拍攝的風景照,而風格圖是一幅以秋天的橡樹為主題的油畫。在輸出的合成圖像中,應用了樣式圖像的油畫筆觸,使顏色更加鮮艷,同時保留了內容圖像中對象的主要形狀。
圖 14.12.1給定內容和風格圖像,風格遷移輸出合成圖像。
14.12.1。方法
圖 14.12.2用一個簡化的例子說明了基于 CNN 的風格遷移方法。首先,我們將合成圖像初始化為內容圖像。這張合成圖像是風格遷移過程中唯一需要更新的變量,即訓練期間要更新的模型參數。然后我們選擇一個預訓練的 CNN 來提取圖像特征并在訓練期間凍結其模型參數。這種深度 CNN 使用多層來提取圖像的層次特征。我們可以選擇其中一些層的輸出作為內容特征或樣式特征。如圖14.12.2舉個例子。這里的預訓練神經網絡有 3 個卷積層,其中第二層輸出內容特征,第一層和第三層輸出風格特征。
圖 14.12.2基于 CNN 的風格遷移過程。實線表示正向傳播方向,虛線表示反向傳播。
接下來,我們通過正向傳播(實線箭頭方向)計算風格遷移的損失函數,并通過反向傳播(虛線箭頭方向)更新模型參數(輸出的合成圖像)。風格遷移中常用的損失函數由三部分組成:(i)內容損失使合成圖像和內容圖像在內容特征上接近;(ii)風格損失使得合成圖像和風格圖像在風格特征上接近;(iii) 總變差損失有助于減少合成圖像中的噪聲。最后,當模型訓練結束后,我們輸出風格遷移的模型參數,生成最終的合成圖像。
下面,我們將通過一個具體的實驗來解釋風格遷移的技術細節。
14.12.2。閱讀內容和樣式圖像
首先,我們閱讀內容和樣式圖像。從它們打印的坐標軸,我們可以看出這些圖像具有不同的尺寸。
%matplotlib inline import torch import torchvision from torch import nn from d2l import torch as d2l d2l.set_figsize() content_img = d2l.Image.open('../img/rainier.jpg') d2l.plt.imshow(content_img);
style_img = d2l.Image.open('../img/autumn-oak.jpg') d2l.plt.imshow(style_img);
%matplotlib inline from mxnet import autograd, gluon, image, init, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() d2l.set_figsize() content_img = image.imread('../img/rainier.jpg') d2l.plt.imshow(content_img.asnumpy());
style_img = image.imread('../img/autumn-oak.jpg') d2l.plt.imshow(style_img.asnumpy());
14.12.3。預處理和后處理
下面,我們定義了兩個用于預處理和后處理圖像的函數。該preprocess函數對輸入圖像的三個 RGB 通道中的每一個進行標準化,并將結果轉換為 CNN 輸入格式。該postprocess函數將輸出圖像中的像素值恢復為標準化前的原始值。由于圖像打印功能要求每個像素都有一個從0到1的浮點值,我們將任何小于0或大于1的值分別替換為0或1。
rgb_mean = torch.tensor([0.485, 0.456, 0.406]) rgb_std = torch.tensor([0.229, 0.224, 0.225]) def preprocess(img, image_shape): transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize(image_shape), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)]) return transforms(img).unsqueeze(0) def postprocess(img): img = img[0].to(rgb_std.device) img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1) return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))
rgb_mean = np.array([0.485, 0.456, 0.406]) rgb_std = np.array([0.229, 0.224, 0.225]) def preprocess(img, image_shape): img = image.imresize(img, *image_shape) img = (img.astype('float32') / 255 - rgb_mean) / rgb_std return np.expand_dims(img.transpose(2, 0, 1), axis=0) def postprocess(img): img = img[0].as_in_ctx(rgb_std.ctx) return (img.transpose(1, 2, 0) * rgb_std + rgb_mean).clip(0, 1)
14.12.4。提取特征
我們使用在 ImageNet 數據集上預訓練的 VGG-19 模型來提取圖像特征( Gatys et al. , 2016 )。
pretrained_net = torchvision.models.vgg19(pretrained=True)
pretrained_net = gluon.model_zoo.vision.vgg19(pretrained=True)
為了提取圖像的內容特征和風格特征,我們可以選擇VGG網絡中某些層的輸出。一般來說,越靠近輸入層越容易提取圖像的細節,反之越容易提取圖像的全局信息。為了避免在合成圖像中過度保留內容圖像的細節,我們選擇了一個更接近輸出的VGG層作為內容層來輸出圖像的內容特征。我們還選擇不同 VGG 層的輸出來提取局部和全局風格特征。這些圖層也稱為樣式圖層。如第 8.2 節所述,VGG 網絡使用 5 個卷積塊。在實驗中,我們選擇第四個卷積塊的最后一個卷積層作為內容層,每個卷積塊的第一個卷積層作為樣式層。這些層的索引可以通過打印pretrained_net實例來獲得。
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
當使用 VGG 層提取特征時,我們只需要使用從輸入層到最接近輸出層的內容層或樣式層的所有那些。讓我們構建一個新的網絡實例net,它只保留所有用于特征提取的 VGG 層。
net = nn.Sequential(*[pretrained_net.features[i] for i in range(max(content_layers + style_layers) + 1)])
net = nn.Sequential() for i in range(max(content_layers + style_layers) + 1): net.add(pretrained_net.features[i])
給定輸入X,如果我們簡單地調用前向傳播 net(X),我們只能得到最后一層的輸出。由于我們還需要中間層的輸出,因此我們需要逐層計算并保留內容層和樣式層的輸出。
def extract_features(X, content_layers, style_layers): contents = [] styles = [] for i in range(len(net)): X = net[i](X) if i in style_layers: styles.append(X) if i in content_layers: contents.append(X) return contents, styles
def extract_features(X, content_layers, style_layers): contents = [] styles = [] for i in range(len(net)): X = net[i](X) if i in style_layers: styles.append(X) if i in content_layers: contents.append(X) return contents, styles
下面定義了兩個函數:get_contents函數從內容圖像中提取內容特征,函數get_styles從風格圖像中提取風格特征。由于在訓練期間不需要更新預訓練 VGG 的模型參數,我們甚至可以在訓練開始之前提取內容和風格特征。由于合成圖像是一組需要更新的模型參數以進行風格遷移,因此我們只能extract_features 在訓練時通過調用函數來提取合成圖像的內容和風格特征。
def get_contents(image_shape, device): content_X = preprocess(content_img, image_shape).to(device) contents_Y, _ = extract_features(content_X, content_layers, style_layers) return content_X, contents_Y def get_styles(image_shape, device): style_X = preprocess(style_img, image_shape).to(device) _, styles_Y = extract_features(style_X, content_layers, style_layers) return style_X, styles_Y
def get_contents(image_shape, device): content_X = preprocess(content_img, image_shape).copyto(device) contents_Y, _ = extract_features(content_X, content_layers, style_layers) return content_X, contents_Y def get_styles(image_shape, device): style_X = preprocess(style_img, image_shape).copyto(device) _, styles_Y = extract_features(style_X, content_layers, style_layers) return style_X, styles_Y
14.12.5。定義損失函數
現在我們將描述風格遷移的損失函數。損失函數包括內容損失、風格損失和全變損失。
14.12.5.1。內容丟失
類似于線性回歸中的損失函數,內容損失通過平方損失函數衡量合成圖像和內容圖像之間內容特征的差異。平方損失函數的兩個輸入都是該extract_features函數計算的內容層的輸出。
def content_loss(Y_hat, Y): # We detach the target content from the tree used to dynamically compute # the gradient: this is a stated value, not a variable. Otherwise the loss # will throw an error. return torch.square(Y_hat - Y.detach()).mean()
def content_loss(Y_hat, Y): return np.square(Y_hat - Y).mean()
14.12.5.2。風格損失
風格損失與內容損失類似,也是使用平方損失函數來衡量合成圖像與風格圖像之間的風格差異。為了表達任何樣式層的樣式輸出,我們首先使用函數extract_features來計算樣式層輸出。假設輸出有 1 個示例,c渠道,高度 h, 和寬度w,我們可以將此輸出轉換為矩陣 X和c行和hw列。這個矩陣可以被認為是串聯c載體 x1,…,xc, 其中每一個的長度為hw. 在這里,矢量xi表示頻道的風格特征i.
在這些向量的 Gram 矩陣中XX?∈Rc×c, 元素 xij在排隊i和專欄j是向量的點積xi和xj. 表示渠道風格特征的相關性i和 j. 我們使用這個 Gram 矩陣來表示任何樣式層的樣式輸出。請注意,當值hw越大,它可能會導致 Gram 矩陣中的值越大。還要注意,Gram矩陣的高和寬都是通道數c. 為了讓風格損失不受這些值的影響,gram 下面的函數將 Gram 矩陣除以其元素的數量,即chw.
def gram(X): num_channels, n = X.shape[1], X.numel() // X.shape[1] X = X.reshape((num_channels, n)) return torch.matmul(X, X.T) / (num_channels * n)
def gram(X): num_channels, n = X.shape[1], d2l.size(X) // X.shape[1] X = X.reshape((num_channels, n)) return np.dot(X, X.T) / (num_channels * n)
顯然,風格損失的平方損失函數的兩個格拉姆矩陣輸入是基于合成圖像和風格圖像的風格層輸出。這里假設 gram_Y基于風格圖像的 Gram 矩陣已經預先計算好了。
def style_loss(Y_hat, gram_Y): return torch.square(gram(Y_hat) - gram_Y.detach()).mean()
def style_loss(Y_hat, gram_Y): return np.square(gram(Y_hat) - gram_Y).mean()
14.12.5.3。總變異損失
有時,學習到的合成圖像有很多高頻噪聲,即特別亮或特別暗的像素。一種常見的降噪方法是全變差去噪。表示為 xi,j坐標處的像素值(i,j). 減少總變異損失
(14.12.1)∑i,j|xi,j?xi+1,j|+|xi,j?xi,j+1|
使合成圖像上相鄰像素的值更接近。
def tv_loss(Y_hat): return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() + torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())
def tv_loss(Y_hat): return 0.5 * (np.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() + np.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())
14.12.5.4。損失函數
風格遷移的損失函數是內容損失、風格損失和總變異損失的加權和。通過調整這些權重超參數,我們可以在合成圖像的內容保留、風格遷移和降噪之間取得平衡。
content_weight, style_weight, tv_weight = 1, 1e4, 10 def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram): # Calculate the content, style, and total variance losses respectively contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip( contents_Y_hat, contents_Y)] styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip( styles_Y_hat, styles_Y_gram)] tv_l = tv_loss(X) * tv_weight # Add up all the losses l = sum(styles_l + contents_l + [tv_l]) return contents_l, styles_l, tv_l, l
content_weight, style_weight, tv_weight = 1, 1e4, 10 def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram): # Calculate the content, style, and total variance losses respectively contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip( contents_Y_hat, contents_Y)] styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip( styles_Y_hat, styles_Y_gram)] tv_l = tv_loss(X) * tv_weight # Add up all the losses l = sum(styles_l + contents_l + [tv_l]) return contents_l, styles_l, tv_l, l
14.12.6. 初始化合成圖像
在風格遷移中,合成圖像是訓練期間唯一需要更新的變量。因此,我們可以定義一個簡單的模型, SynthesizedImage并將合成圖像作為模型參數。在這個模型中,前向傳播只返回模型參數。
class SynthesizedImage(nn.Module): def __init__(self, img_shape, **kwargs): super(SynthesizedImage, self).__init__(**kwargs) self.weight = nn.Parameter(torch.rand(*img_shape)) def forward(self): return self.weight
class SynthesizedImage(nn.Block): def __init__(self, img_shape, **kwargs): super(SynthesizedImage, self).__init__(**kwargs) self.weight = self.params.get('weight', shape=img_shape) def forward(self): return self.weight.data()
接下來,我們定義get_inits函數。此函數創建一個合成圖像模型實例并將其初始化為 image X。styles_Y_gram在訓練之前計算各種樣式層的樣式圖像的 Gram 矩陣 。
def get_inits(X, device, lr, styles_Y): gen_img = SynthesizedImage(X.shape).to(device) gen_img.weight.data.copy_(X.data) trainer = torch.optim.Adam(gen_img.parameters(), lr=lr) styles_Y_gram = [gram(Y) for Y in styles_Y] return gen_img(), styles_Y_gram, trainer
def get_inits(X, device, lr, styles_Y): gen_img = SynthesizedImage(X.shape) gen_img.initialize(init.Constant(X), ctx=device, force_reinit=True) trainer = gluon.Trainer(gen_img.collect_params(), 'adam', {'learning_rate': lr}) styles_Y_gram = [gram(Y) for Y in styles_Y] return gen_img(), styles_Y_gram, trainer
14.12.7. 訓練
在訓練風格遷移模型時,我們不斷提取合成圖像的內容特征和風格特征,并計算損失函數。下面定義了訓練循環。
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch): X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y) scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[10, num_epochs], legend=['content', 'style', 'TV'], ncols=2, figsize=(7, 2.5)) for epoch in range(num_epochs): trainer.zero_grad() contents_Y_hat, styles_Y_hat = extract_features( X, content_layers, style_layers) contents_l, styles_l, tv_l, l = compute_loss( X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) l.backward() trainer.step() scheduler.step() if (epoch + 1) % 10 == 0: animator.axes[1].imshow(postprocess(X)) animator.add(epoch + 1, [float(sum(contents_l)), float(sum(styles_l)), float(tv_l)]) return X
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch): X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[10, num_epochs], ylim=[0, 20], legend=['content', 'style', 'TV'], ncols=2, figsize=(7, 2.5)) for epoch in range(num_epochs): with autograd.record(): contents_Y_hat, styles_Y_hat = extract_features( X, content_layers, style_layers) contents_l, styles_l, tv_l, l = compute_loss( X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) l.backward() trainer.step(1) if (epoch + 1) % lr_decay_epoch == 0: trainer.set_learning_rate(trainer.learning_rate * 0.8) if (epoch + 1) % 10 == 0: animator.axes[1].imshow(postprocess(X).asnumpy()) animator.add(epoch + 1, [float(sum(contents_l)), float(sum(styles_l)), float(tv_l)]) return X
現在我們開始訓練模型。我們將內容和樣式圖像的高度和寬度重新調整為 300 x 450 像素。我們使用內容圖像來初始化合成圖像。
device, image_shape = d2l.try_gpu(), (300, 450) # PIL Image (h, w) net = net.to(device) content_X, contents_Y = get_contents(image_shape, device) _, styles_Y = get_styles(image_shape, device) output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)
device, image_shape = d2l.try_gpu(), (450, 300) net.collect_params().reset_ctx(device) content_X, contents_Y = get_contents(image_shape, device) _, styles_Y = get_styles(image_shape, device) output = train(content_X, contents_Y, styles_Y, device, 0.9, 500, 50)
我們可以看到,合成圖保留了內容圖的景物和物體,同時傳遞了風格圖的顏色。例如,合成圖像具有風格圖像中的顏色塊。其中一些塊甚至具有筆觸的微妙紋理。
14.12.8。概括
風格遷移中常用的損失函數由三部分組成:(i)內容損失使合成圖像和內容圖像在內容特征上接近;(ii) 風格損失使得合成圖像和風格圖像在風格特征上接近;(iii) 總變差損失有助于減少合成圖像中的噪聲。
我們可以使用預訓練的 CNN 提取圖像特征并最小化損失函數,以在訓練期間不斷更新合成圖像作為模型參數。
我們使用 Gram 矩陣來表示樣式層的樣式輸出。
14.12.9。練習
當您選擇不同的內容和樣式層時,輸出如何變化?
調整損失函數中的權重超參數。輸出是保留更多內容還是噪音更少?
使用不同的內容和樣式圖像。你能創造出更有趣的合成圖像嗎?
我們可以對文本應用樣式轉換嗎?提示:您可以參考Hu等人的調查論文。( 2022 )。
-
cnn
+關注
關注
3文章
353瀏覽量
22246 -
pytorch
+關注
關注
2文章
808瀏覽量
13249
發布評論請先 登錄
相關推薦
評論