The Transformer Decoder
Contents
The Transformer Decoder#
As for the encoder, the decoder is a stack of identical blocks. Again, then, let’s first define the single decoder block. You can refer to the image (right) in Section 1 to see all the components we need to code to make the decoder work.
The Decoder block is composed of:
a Masked Multi-Head Attention layer. Here, “masked” refers to the autoregressive property of self-attention in the decoder. Specifically, we want to force an arbitrary token at position i to express a non-zero attention weight to preceding tokens. But we don’t have to take care of it now: the trick 💡 is to use a particular attention mask, which we will see later.
a Cross-Attention layer. Here is where the magic happens. The decoder receives the Queries and Keys from the encoder. We will call them “memory.”
a Feed-Forward layer with element-wise non-linear activation.
Skip connections and Layer Normalization after each sub-layers.
Decoder Block#
class DecoderBlock(hk.Module):
"""
Transformer decoder block.
:param d_model: dimension of the model.
:param num_heads: number of attention heads.
:param d_ff: dimension of the feedforward network model.
:param p_dropout: dropout rate.
"""
def __init__(self, d_model, num_heads, d_ff, p_dropout, name=None):
super().__init__(name=name)
self.d_model = d_model
self.num_heads = num_heads
self.d_ff = d_ff
self.p_dropout = p_dropout
# self-attention sub-layer
self.self_attn = MultiheadAttention(
d_model=self.d_model, num_heads=self.num_heads
)
# src-target cross-attention sub-layer
self.cross_attn = MultiheadAttention(
d_model=self.d_model, num_heads=self.num_heads
)
# positionwise feedforward sub-layer
self.ff = PositionwiseFeedForward(
d_model=self.d_model, d_ff=self.d_ff, p_dropout=self.p_dropout
)
self.norm1 = hk.LayerNorm(
axis=-1, param_axis=-1, create_scale=True, create_offset=True
)
self.norm2 = hk.LayerNorm(
axis=-1, param_axis=-1, create_scale=True, create_offset=True
)
self.norm3 = hk.LayerNorm(
axis=-1, param_axis=-1, create_scale=True, create_offset=True
)
def __call__(self, x, memory, src_mask, tgt_mask, is_train):
"""
The forward pass of the decoder block.
:param x: the input sequence for the decoder block.
:param memory: the memory from the encoder.
:param src_mask: the mask for the src sequence.
:param tgt_mask: the mask for the tgt sequence.
:param is_train: boolean flag to indicate training mode.
:return: the output of the decoder block.
"""
# self-attention sub-layer
sub_x, _ = self.self_attn(x, x, x, tgt_mask)
if is_train:
sub_x = hk.dropout(hk.next_rng_key(), self.p_dropout, sub_x)
x = self.norm1(x + sub_x) # residual conn
# cross-attention sub-layer
sub_x, _ = self.cross_attn(x, memory, memory, src_mask)
if is_train:
sub_x = hk.dropout(hk.next_rng_key(), self.p_dropout, sub_x)
x = self.norm2(x + sub_x)
# feedforward sub-layer
sub_x = self.ff(x, is_train=is_train)
if is_train:
sub_x = hk.dropout(hk.next_rng_key(), self.p_dropout, sub_x)
x = self.norm3(x + sub_x)
return x
class TransformerDecoder(hk.Module):
"""
The Transformer decoder model.
:param num_layers: number of decoder layers.
:param num_heads: number of attention heads.
:param d_model: dimension of the model.
:param d_ff: dimension of the feedforward network model.
:param p_dropout: dropout rate.
"""
def __init__(self, num_layers, num_heads, d_model, d_ff, p_dropout, name=None):
super().__init__(name=name)
self.num_layers = num_layers
self.num_heads = num_heads
self.d_model = d_model
self.d_ff = d_ff
self.p_dropout = p_dropout
self.layers = [
DecoderBlock(self.d_model, self.num_heads, self.d_ff, self.p_dropout)
for _ in range(self.num_layers)
]
def __call__(self, x, memory, src_mask, tgt_mask, is_train):
"""
The forward pass of the decoder.
:param x: the input sequence for the decoder.
:param memory: the memory from the encoder.
:param src_mask: the mask for the src sequence.
:param tgt_mask: the mask for the tgt sequence.
:param is_train: boolean flag to indicate training mode.
:return: the output of the transformer decoder.
"""
for l in self.layers:
x = l(x, memory, src_mask, tgt_mask, is_train)
return x
class Transformer(hk.Module):
"""
Complete Transformer model including encoder and decoder.
:param d_model: dimension of the model.
:param d_ff: dimension of the feedforward network model.
:param src_vocab_size: size of the source vocabulary.
:param tgt_vocab_size: size of the target vocabulary.
:param num_layers: number of encoder and decoder layers.
:param num_heads: number of attention heads.
:param p_dropout: dropout rate.
:param max_seq_len: maximum sequence length.
"""
def __init__(
self,
d_model,
d_ff,
src_vocab_size,
tgt_vocab_size,
num_layers,
num_heads,
p_dropout,
max_seq_len,
name=None,
tie_embeddings=False,
):
super().__init__(name)
self.d_model = d_model
self.d_ff = d_ff
self.src_vocab_size = src_vocab_size
self.tgt_vocab_size = tgt_vocab_size
self.num_layers = num_layers
self.num_heads = num_heads
self.p_dropout = p_dropout
self.max_seq_len = max_seq_len
self.src_emb = Embeddings(d_model, src_vocab_size)
if tie_embeddings:
self.tgt_emb = self.src_emb
else:
self.tgt_emb = Embeddings(d_model, tgt_vocab_size)
self.encoder = TransformerEncoder(
num_layers, num_heads, d_model, d_ff, p_dropout
)
self.decoder = TransformerDecoder(
num_layers, num_heads, d_model, d_ff, p_dropout
)
def encode(self, src, src_mask, is_train):
"""
The forward pass for the encoder.
:param src: the source sequence.
:param src_mask: the mask for the src sequence.
:param is_train: boolean flag to indicate training mode.
:return: the encoded sequence.
"""
pe = PositionalEncoding(self.d_model, self.max_seq_len, self.p_dropout)
src = self.src_emb(src)
src = src[None, :, :] if len(src.shape) == 2 else src
src = pe(src, is_train=is_train)
return self.encoder(src, src_mask, is_train)
def decode(self, memory, src_mask, tgt, tgt_mask, is_train):
"""
The forward pass for the decoder.
:param memory: the memory from the encoder.
:param src_mask: the mask for the src sequence.
:param tgt: the target sequence.
:param tgt_mask: the mask for the tgt sequence.
:param is_train: boolean flag to indicate training mode.
:return: the output of the decoder.
"""
pe = PositionalEncoding(self.d_model, self.max_seq_len, self.p_dropout)
tgt = self.tgt_emb(tgt)
tgt = tgt[None, :, :] if len(tgt.shape) == 2 else tgt
tgt = pe(tgt, is_train=is_train)
return self.decoder(tgt, memory, src_mask, tgt_mask, is_train)
def __call__(self, src, src_mask, tgt, tgt_mask, is_train):
"""
The forward pass of the whole transformer model.
:param src: the source sequence.
:param src_mask: the mask for the src sequence.
:param tgt: the target sequence.
:param tgt_mask: the mask for the tgt sequence.
:param is_train: boolean flag to indicate training mode.
:return: the output of the transformer model (encoder + decoder).
"""
memory = self.encode(src, src_mask, is_train)
return self.decode(memory, src_mask, tgt, tgt_mask, is_train)