在一般的 seq2seq 問題中,如機器翻譯(第 10.5 節),輸入和輸出的長度不同且未對齊。處理這類數據的標準方法是設計一個編碼器-解碼器架構(圖 10.6.1),它由兩個主要組件組成:一個 編碼器,它以可變長度序列作為輸入,以及一個 解碼器,作為一個條件語言模型,接收編碼輸入和目標序列的向左上下文,并預測目標序列中的后續標記。
讓我們以從英語到法語的機器翻譯為例。給定一個英文輸入序列:“They”、“are”、“watching”、“.”,這種編碼器-解碼器架構首先將可變長度輸入編碼為一個狀態,然后對該狀態進行解碼以生成翻譯后的序列,token通過標記,作為輸出:“Ils”、“regardent”、“.”。由于編碼器-解碼器架構構成了后續章節中不同 seq2seq 模型的基礎,因此本節將此架構轉換為稍后將實現的接口。
import tensorflow as tf
from d2l import tensorflow as d2l
10.6.1。編碼器
在編碼器接口中,我們只是指定編碼器將可變長度序列作為輸入X
。實現將由繼承此基類的任何模型提供Encoder
。
class Encoder(tf.keras.layers.Layer): #@save
"""The base encoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def call(self, X, *args):
raise NotImplementedError
10.6.2。解碼器
在下面的解碼器接口中,我們添加了一個額外的init_state
方法來將編碼器輸出 ( enc_all_outputs
) 轉換為編碼狀態。請注意,此步驟可能需要額外的輸入,例如輸入的有效長度,這在 第 10.5 節中有解釋。為了逐個令牌生成可變長度序列令牌,每次解碼器都可以將輸入(例如,在先前時間步生成的令牌)和編碼狀態映射到當前時間步的輸出令牌。
class Decoder(nn.Module): #@save
"""The base decoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
class Decoder(nn.Block): #@save
"""The base decoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
class Decoder(nn.Module): #@save
"""The base decoder interface for the encoder-decoder architecture."""
def setup(self):
raise NotImplementedError
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def __call__(self, X, state):
raise NotImplementedError
class Decoder(tf.keras.layers.Layer): #@save
"""The base decoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def call(self, X, state):
raise NotImplementedError
10.6.3。將編碼器和解碼器放在一起
在前向傳播中,編碼器的輸出用于產生編碼狀態,解碼器將進一步使用該狀態作為其輸入之一。
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder-decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state)[0]
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder-decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state)[0]
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder-decoder architecture."""
encoder: nn.Module
decoder: nn.Module
training: bool
def __call__(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args, training=self.training)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state, training=self.training)[0]
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder-decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def call(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args, training=True)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state, training=True)[0]
在下一節中,我們將看到如何應用 RNN 來設計基于這種編碼器-解碼器架構的 seq2seq 模型。
10.6.4。概括
編碼器-解碼器架構可以處理由可變長度序列組成的輸入和輸出,因此適用于機器翻譯等 seq2seq 問題。編碼器將可變長度序列作為輸入,并將其轉換為具有固定形狀的狀態。解碼器將固定形狀的編碼狀態映射到可變長度序列。
10.6.5。練習
-
假設我們使用神經網絡來實現編碼器-解碼器架構。編碼器和解碼器必須是同一類型的神經網絡嗎?
-
除了機器翻譯,你能想到另一個可以應用編碼器-解碼器架構的應用程序嗎?
評論
查看更多