In [1]:
1
2
3
import os, torch, sys, pickle, random
from torch import nn
from typing import List, Tuple

Multi-head Attention - Focusing on Mask

pytorch 1.4.0 version

I followed the notations in offical document of pytorch

Basically, multi-head attention mechanism is multiple scaled-dot attention version.
Scaled-dot attention means as follows.
Given [query, key, value],

name <div style="width:100px">dimension</div> how to do
query $T \times N \times d_q$ embeded padded target sequence for a mini-batch
key $S \times N \times d_k$ embeded padded input sequence for a mini-batch
value $S \times N \times d_v$ embeded padded inpt sequence for a mini-batch

Warning:: constraints of multi-head attention inputs

  • $d_k = d_v$
  • $\frac{0}{0}$ makes nan value,
    1. src_len or trg_len should be upper than 1 because all masking in an example lead to nan value.
    2. All float('inf') in one row or one column of[src, tgt, memory] lead to nan value

Find similairity scores between query, key! And then, apply the attention scores to value.

\[\begin{align} &B: \text{batch size} \\ &E: \text{embedding dimension} \\ &S: \text{max source sequence length} \\ &T: \text{max target sequence length} \\ \end{align}\]

For complicated cases, I will use all masking conditions

mask name <div style="width:60px">dimension</div> how to do
key_padding_mask $N \times S$ masking padding source sequence and target sequence for each example
attn_mask $T \times S$ masking attention weights each positions for all batch
In [2]:
1
2
3
4
5
6
B = 5
V = 100
N, E, S, T = 5, 16, 10, 20
nhead = 2
attn = nn.MultiheadAttention(embed_dim=E, num_heads=nhead)
emb = nn.Embedding(num_embeddings=V, embedding_dim=E, padding_idx=0)

Let’s prepare a toy example as follows.

In [34]:
1
2
3
4
seq = torch.LongTensor([[random.randint(1, V - 1) for _ in range(S)] for _ in range(N)])  # [bsz, srclen]
for b in range(N):
    seq[b][random.randint(S//5, S - 5):] = 0 
print(seq)
1
2
3
4
5
6
tensor([[75, 79, 46, 44, 68,  0,  0,  0,  0,  0],
        [84,  7, 75, 57,  0,  0,  0,  0,  0,  0],
        [65, 30, 41,  0,  0,  0,  0,  0,  0,  0],
        [88, 78,  0,  0,  0,  0,  0,  0,  0,  0],
        [98, 58, 34,  0,  0,  0,  0,  0,  0,  0]])

Step 1. Generate Masking

Q. Why use look-ahead masking?

ans. In orginal paper
We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions.
This masking, combined with fact that the output embeddings are offset by one position, ensures that the
predictions for position i can depend only on the known outputs at positions less than i.

In [67]:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# for self-attention masking
def sequence_mask(seq:torch.LongTensor, padding_idx:int=None) -> torch.BoolTensor:
    """ seq: [bsz, slen], which is padded, so padding_idx might be exist.     
    if True, '-inf' will be applied before applied scaled-dot attention"""
    return seq == padding_idx

# for decoder's look-ahead masking 
# (I think the reason to use this masking is taking into consideration that predict to next words only using preceding words in the decoder)
def look_ahead_mask(tgt_len:int, src_len:int) -> torch.FloatTensor:  
    """ this will be applied before sigmoid function, so '-inf' for proper positions needed. 
    look-ahead masking is used for decoder in transformer, 
    which prevents future target label affecting past-step target labels. """
    mask = torch.triu(torch.ones(tgt_len, src_len), diagonal=1)
    mask[mask.bool()] = -float('inf')
    return mask
  1. key pad masking
mask name <div style="width:60px">dimension</div> how to do
key_padding_mask $N \times S$ masking padding source sequence and target sequence for each example

This masking contrains the scope of self-attention for each examples.
Therefore, the model can apply attention scores to only real sequences by avoiding padding index.

Notice that
In the scaled-dot attention function …
another dimension(dim=1) of key_padding_mask are expanded with size $T$(so, it becomes $N \times T \times S$)
in order to broadcast masking values to target sequences.

In [68]:
1
2
3
4
5
6
7
8
9
10
11
key_padding_mask = sequence_mask(seq, padding_idx=0)
print(key_padding_mask)
print(key_padding_mask.shape)  # [bsz, src_len]
print()
padding_mask = key_padding_mask.unsqueeze(1).expand(-1, T, -1)  # [bsz, trg_len, src_len], which is mask for scaled-dot attention
print(padding_mask[0]), print()  # for example 0
# print(padding_mask[1]), print()  # for example 1

before_softmax = torch.rand(B, T, S).masked_fill(padding_mask, -float('inf'))
print(before_softmax[0])  # for example 0
# print(before_softmax[1])  # for example 1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
tensor([[False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True]])
torch.Size([5, 10])

tensor([[False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True]])

tensor([[0.8794, 0.8458, 0.8516, 0.0447, 0.3812,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.1812, 0.4549, 0.7835, 0.8322, 0.6611,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.4658, 0.6365, 0.2442, 0.9905, 0.5118,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.5298, 0.9668, 0.2158, 0.1174, 0.8734,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.9967, 0.6162, 0.7594, 0.6391, 0.3663,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.8592, 0.8073, 0.2411, 0.2279, 0.6485,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.0724, 0.3111, 0.6000, 0.6570, 0.0180,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.7280, 0.6654, 0.3910, 0.3444, 0.9638,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.3494, 0.8876, 0.0506, 0.3111, 0.7949,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.9763, 0.7667, 0.4739, 0.5561, 0.2022,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.4080, 0.7799, 0.9461, 0.6745, 0.3950,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.2704, 0.6876, 0.6525, 0.1416, 0.9703,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.9985, 0.7673, 0.8472, 0.0541, 0.1273,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.5152, 0.4203, 0.3577, 0.1673, 0.4000,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.8417, 0.8949, 0.8560, 0.0523, 0.0803,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.0683, 0.0212, 0.1341, 0.6816, 0.7462,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.9225, 0.5103, 0.5363, 0.0149, 0.8378,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.6542, 0.2225, 0.4332, 0.9008, 0.9145,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.6001, 0.9155, 0.5877, 0.2048, 0.9796,   -inf,   -inf,   -inf,   -inf,
           -inf],
        [0.5510, 0.4078, 0.2633, 0.6340, 0.7613,   -inf,   -inf,   -inf,   -inf,
           -inf]])

  1. look ahead masking
mask name <div style="width:100px">dimension</div> how to do
attn_mask $T \times S$ masking attention weights each positions for all batch

This masking contrains the scope of self-attention for all examples.

look-ahead masking is used for decoder in transformer, which prevents future target label affecting the past-step target labels.

In [69]:
1
2
attn_mask = look_ahead_mask(S, S)  # it is used for decorder 
print(attn_mask)  # [trg_len, src_len]
1
2
3
4
5
6
7
8
9
10
11
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

Step 2. Multi-head Attention

Q. why contiguous() for [q, k, v] are called?

  • I think if they are from same embedding, the algorithm needs to detach the gradient update for each other.
    This is just my opionion. So, later on, I will find the reason.

please attention that
if key_padding_mask is input of forward propagation, each example’s attention weights after sequence length becomes zeros.
However, out does not have zeros, this is because linear layer is propagated after applying attention weights.

In [38]:
1
attn.train(mode=False)
1
2
3
MultiheadAttention(
  (out_proj): Linear(in_features=16, out_features=16, bias=True)
)
In [39]:
1
key_padding_mask
1
2
3
4
5
tensor([[False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True]])

If $q, k, v$ are projected by same embedding (just same embedding or linear projected),
then sequence length of q, k, v are same.
In this circumstance, multi-head attention is called self-attention.

Please note that if we use mask options, some attention weights become zeros

  • key-padding mask: positions corresponding to True parts of key_padding mask for each example are zeros
  • attn_mask: -inf parts of attn_mask are zeros

I will show you that using masking options lead to zeros of attention weights as follows.

In [40]:
1
2
3
4
5
6
7
8
9
10
Eseq = emb(seq) # [bsz, src_len, E]
q, k, v = Eseq.transpose(0, 1).contiguous(), Eseq.transpose(0, 1).contiguous(), Eseq.transpose(0, 1).contiguous()
print(q.shape, k.shape, v.shape)  # [T, ]
out, weights = attn(q, k, v, key_padding_mask=key_padding_mask, attn_mask=attn_mask)  # [src_len, bsz, E], [tar_len, bszsrc_len]
# out, weights = attn(q, k, v, key_padding_mask=None, attn_mask=attn_mask)  # [src_len, bsz, E], [tar_len, bszsrc_len]
print(out.shape, weights.shape), print()

print(seq.bool().sum(-1))
print(out[0])  
weights[0] # after sequence lengths vector values and upper triangle parts are zeros
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
torch.Size([10, 5, 16]) torch.Size([10, 5, 16]) torch.Size([10, 5, 16])
torch.Size([10, 5, 16]) torch.Size([5, 10, 10])

tensor([5, 4, 3, 2, 3])
tensor([[ 0.3376, -0.7838,  0.1124,  0.3856,  1.0881, -0.1623,  0.1072,  0.3590,
         -0.6419, -0.0237, -0.2214,  0.0281,  0.8012, -0.0309,  0.0032, -0.2124],
        [ 0.4690,  0.8251, -0.6962, -0.1230,  0.3652, -0.0564,  0.1260,  0.7676,
         -0.2868,  0.9124,  0.0811,  0.8795, -0.4150, -0.9839, -0.0982,  0.4887],
        [-0.1854,  0.4487,  0.1684, -0.3133, -0.7565,  0.0864, -0.7580, -0.2905,
          0.5740, -0.4998, -0.0651, -0.2894, -0.4343,  0.2809, -0.3145,  0.1744],
        [-0.1030,  0.2608,  0.1130,  0.0331, -0.4297, -0.2702, -0.7032, -0.4559,
         -0.0314, -0.8184,  0.4495, -0.9816,  0.1393,  1.1237, -0.1796, -0.0038],
        [-0.3896, -0.4187, -0.5751,  0.6582, -0.3250, -0.4222, -0.2628, -0.5191,
          0.0186,  0.2831, -0.6457,  0.3321,  0.0746,  1.0277, -0.3817, -0.0720]],
       grad_fn=<SelectBackward>)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.5078, 0.4922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2588, 0.4714, 0.2697, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1556, 0.3744, 0.1879, 0.2821, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2304, 0.1647, 0.2204, 0.1221, 0.2624, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000]], grad_fn=<SelectBackward>)

However, output corresponding to zeros of attention weights are not zeros because last layer is linear projection.
Therefore, [memory_key_padding_mask, memory_mask] is needed to padding the output of encoder to decoder in translation model.
I will explain these masks in decoder part.


Transformer model

How to use in the transformer model?

Transformer layers document describe this, but the examples in there are ambiguous to understand when using maskings.

Let’s do a simple example, but use all maskings.

Use Encoder module

For simplicity, I use only one layer of transformer encoder, so I used nn.TransformerEncoderLayer

Warning: dropout is used, so use model.train(mode=False) before forward propagation for test purposes.

Please note that transformer encoder’s src and trg are same, so scaled-dot attention for the same sequence means self-attention!.

Eseq is used for input sequence, where the dimension of Eseq is [B, S, E].

In [8]:
1
2
encoder = torch.nn.TransformerEncoderLayer(d_model=E, nhead=nhead, dim_feedforward=4*E)
encoder.train(mode=False)
1
2
3
4
5
6
7
8
9
10
11
12
TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (linear1): Linear(in_features=16, out_features=64, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=64, out_features=16, bias=True)
  (norm1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)

My view: lthough it is possible to obtain embedding with self-attention, it is a pity that the model does not output the self-attention score.

In [9]:
1
2
3
4
5
6
src_mask = torch.zeros(S, S)  # [S, S], which means self-attention
print(src_mask.shape)
print(key_padding_mask.shape)  # [bsz, S]
out = encoder(Eseq.transpose(0, 1), src_mask=src_mask, src_key_padding_mask=key_padding_mask)  # [S, bsz, E] 
print(out.shape)
out[0]
1
2
3
4
torch.Size([10, 10])
torch.Size([5, 10])
torch.Size([10, 5, 16])

1
2
3
4
5
6
7
8
9
10
11
tensor([[-1.2512, -1.1156, -1.0019,  0.0576, -0.5861, -1.0735, -0.1735,  0.5430,
          0.0823,  1.6950,  1.4143, -0.3691,  1.7459,  0.2939,  0.9126, -1.1740],
        [-0.6578, -0.4739,  0.8454, -1.3896, -0.5582,  0.8627,  0.2286,  0.2346,
          1.4226,  0.1236, -0.5403, -0.1841,  2.4831, -0.9266, -0.0057, -1.4644],
        [ 0.2509, -1.4570,  2.3970, -1.3271, -1.6661,  0.6993, -0.1136, -0.2787,
         -0.2379,  0.0055,  0.8749,  0.2460, -0.9369,  0.9882,  0.1042,  0.4514],
        [ 1.8043, -0.9929,  0.6597, -0.9512,  0.8177, -0.4477,  0.3470,  0.9663,
         -0.3772, -0.2731, -2.0912, -0.2372,  1.4346, -1.1369,  0.6771, -0.1993],
        [ 0.1023, -0.5971,  0.5727, -1.4589,  0.5111,  0.4181,  0.0157, -2.3575,
          1.7297, -0.1671,  1.0217, -0.6159, -0.6257, -0.4689,  0.5827,  1.3372]],
       grad_fn=<SelectBackward>)
In [44]:
1
torch.isnan(out).sum()  # sanity check
1
tensor(0)

Use Decoder module

Q. What is the memory mask? see this discuss in pytorch community
ans. It’s an attention mask working on the second input of transformer decoder layer.
To be more specific, the memory mask applies to second part of multi-head attention in the decoder,
which means look-ahead mask between input sequence and target sequence.
Within the encoder-decoder architecture, it works on the output of transformer encoder, which we call it “memory”.

Encoder outputs embeddings with self-attention, which called memory(or context) vectors and then pass the memory vectors to decoder.
However, in the memory, which position should the decoder focus on?
Therefore, [memory_key_padding_mask, memory_mask] needed.

To be more specific, they are used in second part of decoder’s multi-head attention

  • memory_key_padding_mask is just encoder’s key_padding_mask, of shape [N, S].
  • memory_mask is look-ahead mask to encoder key, value to decoder, of shape [T, S].
In [70]:
1
2
mem_mask = look_ahead_mask(T, S)
mem_mask  # [T, S]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
In [71]:
1
2
decoder = nn.TransformerDecoderLayer(d_model=E, nhead=nhead, dim_feedforward=4*E)
decoder.train(mode=False)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TransformerDecoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (multihead_attn): MultiheadAttention(
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (linear1): Linear(in_features=16, out_features=64, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=64, out_features=16, bias=True)
  (norm1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (norm3): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
  (dropout3): Dropout(p=0.1, inplace=False)
)

Generate target sequence

In [72]:
1
2
3
4
5
6
7
trg = torch.LongTensor([[random.randint(1, V - 1) for _ in range(T)] for _ in range(N)])  # [bsz, srclen]
for b in range(N):
    trg[b][random.randint(T//5, T - 3):] = 0 
print(trg)  # [bsz, T]

Etrg = emb(trg)  # [bsz, T, E]
Etrg.shape
1
2
3
4
5
6
7
8
9
10
11
tensor([[34, 71, 38, 71,  3, 74, 24, 30, 55, 30, 56, 20,  0,  0,  0,  0,  0,  0,
          0,  0],
        [97, 88, 23, 98, 23, 98,  3,  5,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [95,  2, 45, 31, 40, 10, 11, 71, 96, 32,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [47, 10, 62, 76, 24, 36, 93,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [95, 37, 57, 75,  3, 53, 24, 85,  4, 33, 23, 41, 83, 74, 18, 26, 51,  0,
          0,  0]])

1
torch.Size([5, 20, 16])
Follow pytorch community document
name dimension
src $(S, N, E)$
tgt $(T, N, E)$
src_mask $(S, S)$
tgt_mask $(T, T)$
memory_mask $(T,S)$
src_key_padding_mask $(N, S)$
tgt_key_padding_mask $(N, T)$
memory_key_padding_mask $(N, S)$

Note that memory_key_padding_mask is same with src_key_padding_mask.
This is because both query and key of the decoder’s second part of multi-head attention is the output of encoder.
Therefore, the encoder’s memory(or context) information will be delivered to connected decoder as [key, value and key_padding_mask].

In [75]:
1
2
3
4
5
6
7
memory = encoder(Eseq.transpose(0, 1), src_mask=src_mask, src_key_padding_mask=key_padding_mask)  # [S, bsz, E]
print(memory.shape)
output = decoder(Etrg.transpose(0, 1), memory=memory, 
                tgt_mask=look_ahead_mask(T, T), memory_mask=mem_mask, 
                 tgt_key_padding_mask=sequence_mask(trg, padding_idx=0), 
                memory_key_padding_mask=key_padding_mask) 
print(output.shape)  # [T, bsz, E]
1
2
3
torch.Size([10, 5, 16])
torch.Size([20, 5, 16])

In [76]:
1
2
out = output.transpose(0, 1)  # [bsz, T, E]
out[0]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
tensor([[ 2.0530e-01,  1.5013e+00,  1.0501e+00,  3.5751e-01,  6.3936e-01,
          9.2592e-01,  6.6488e-01, -2.5640e-01, -4.9700e-01,  6.8812e-01,
         -1.7615e+00,  1.0349e+00, -8.9191e-01, -1.6123e+00, -6.2900e-01,
         -1.4193e+00],
        [-4.4284e-01,  1.9421e+00,  6.1783e-01, -4.9344e-01,  2.1229e+00,
          4.7583e-02,  1.6725e-01,  6.9770e-01, -1.1447e+00, -1.4100e+00,
         -7.2436e-01, -1.3630e+00,  6.1289e-01, -3.1280e-01, -4.5737e-01,
          1.4021e-01],
        [ 3.5275e-01, -5.5859e-01, -9.4811e-01,  3.3188e-01, -1.1998e-03,
         -6.0703e-01, -7.5305e-01,  2.1518e+00, -1.2262e+00, -9.4202e-02,
         -1.0129e+00,  4.2685e-01, -1.4186e+00,  1.6577e+00,  1.0287e+00,
          6.7026e-01],
        [-5.0308e-01,  1.9444e+00,  6.1748e-01, -3.7436e-01,  1.9657e+00,
          1.6569e-01,  3.2078e-02,  8.3904e-01, -1.5390e+00, -1.4493e+00,
         -6.0518e-01, -1.2885e+00,  4.4276e-01, -2.1376e-01, -2.5178e-01,
          2.1766e-01],
        [ 1.0492e+00,  1.5197e+00,  6.3454e-01,  1.0607e-01, -6.4794e-01,
          1.0796e+00, -5.2467e-01, -1.7117e-01,  1.5651e-01, -2.3230e+00,
         -1.2304e+00,  2.5204e-01, -2.5927e-01, -3.0521e-01, -8.7363e-01,
          1.5376e+00],
        [ 1.1820e+00, -1.0573e+00,  1.7224e+00,  1.1034e-01,  3.5638e-01,
          2.0074e-01, -2.0558e+00,  5.5258e-01,  7.8305e-01, -6.4929e-01,
         -1.4155e+00, -1.1916e-02, -3.8352e-01,  1.5048e+00, -4.7025e-01,
         -3.6871e-01],
        [ 1.8280e+00,  3.6645e-01,  1.6182e-01, -1.4546e+00,  3.5153e-01,
         -1.2442e+00, -1.1728e+00, -1.1383e+00,  1.3857e+00,  3.4732e-01,
         -3.3480e-01, -3.3509e-01,  1.7011e+00, -1.5128e-01,  4.8057e-01,
         -7.9143e-01],
        [ 8.4219e-01, -2.2292e-01,  1.6207e+00, -1.5880e-01, -1.9941e+00,
         -2.6402e-01, -1.0795e-01, -1.2542e+00, -3.7832e-01, -1.3619e+00,
         -3.0050e-01,  6.4807e-01,  8.8309e-01,  7.8289e-01, -4.2921e-01,
          1.6950e+00],
        [-5.2335e-01, -6.4032e-01, -6.4664e-01,  6.1068e-01, -5.7863e-02,
          2.1731e+00, -1.5652e+00,  3.1220e-01,  2.7384e-01, -1.1730e+00,
         -1.1549e-01,  1.0999e-02, -1.4816e+00,  1.5529e+00,  1.0676e+00,
          2.0224e-01],
        [ 8.0632e-01, -2.7476e-01,  1.7004e+00, -3.7164e-01, -1.9823e+00,
         -3.1552e-01, -8.3502e-02, -1.1603e+00, -4.2286e-01, -1.2626e+00,
         -1.8362e-01,  7.0035e-01,  9.3423e-01,  6.5968e-01, -4.8287e-01,
          1.7390e+00],
        [-1.2699e-01,  1.3019e+00,  2.5125e-01, -4.2549e-01,  1.7312e+00,
         -1.7435e+00,  1.8060e-01,  5.0046e-01, -8.4960e-01, -2.8272e-01,
         -1.3480e+00,  2.2478e-01,  2.4194e-01,  8.5712e-01, -1.6936e+00,
          1.1806e+00],
        [ 3.9400e-01,  1.7380e+00, -7.8590e-01,  4.9670e-01,  4.0187e-01,
          1.5195e+00, -1.9596e+00, -8.9633e-01, -1.1844e+00, -1.5570e-01,
         -6.3172e-01, -8.9354e-02, -1.0397e+00,  1.0618e+00,  8.5174e-01,
          2.7912e-01],
        [ 9.2153e-01,  4.3841e-02,  3.8642e-01,  7.7932e-01, -6.8301e-02,
          3.4165e-01, -1.3100e+00, -1.2595e+00, -9.2050e-01, -2.1033e+00,
          3.6177e-01,  1.3866e+00,  5.1212e-01,  1.5935e+00, -8.8966e-01,
          2.2454e-01],
        [ 9.2153e-01,  4.3841e-02,  3.8642e-01,  7.7932e-01, -6.8301e-02,
          3.4165e-01, -1.3100e+00, -1.2595e+00, -9.2050e-01, -2.1033e+00,
          3.6177e-01,  1.3866e+00,  5.1212e-01,  1.5935e+00, -8.8966e-01,
          2.2454e-01],
        [ 9.2153e-01,  4.3841e-02,  3.8642e-01,  7.7932e-01, -6.8301e-02,
          3.4165e-01, -1.3100e+00, -1.2595e+00, -9.2050e-01, -2.1033e+00,
          3.6177e-01,  1.3866e+00,  5.1212e-01,  1.5935e+00, -8.8966e-01,
          2.2454e-01],
        [ 9.2153e-01,  4.3841e-02,  3.8642e-01,  7.7932e-01, -6.8301e-02,
          3.4165e-01, -1.3100e+00, -1.2595e+00, -9.2050e-01, -2.1033e+00,
          3.6177e-01,  1.3866e+00,  5.1212e-01,  1.5935e+00, -8.8966e-01,
          2.2454e-01],
        [ 9.2153e-01,  4.3841e-02,  3.8642e-01,  7.7932e-01, -6.8301e-02,
          3.4165e-01, -1.3100e+00, -1.2595e+00, -9.2050e-01, -2.1033e+00,
          3.6177e-01,  1.3866e+00,  5.1212e-01,  1.5935e+00, -8.8966e-01,
          2.2454e-01],
        [ 9.2153e-01,  4.3841e-02,  3.8642e-01,  7.7932e-01, -6.8301e-02,
          3.4165e-01, -1.3100e+00, -1.2595e+00, -9.2050e-01, -2.1033e+00,
          3.6177e-01,  1.3866e+00,  5.1212e-01,  1.5935e+00, -8.8966e-01,
          2.2454e-01],
        [ 9.2153e-01,  4.3841e-02,  3.8642e-01,  7.7932e-01, -6.8301e-02,
          3.4165e-01, -1.3100e+00, -1.2595e+00, -9.2050e-01, -2.1033e+00,
          3.6177e-01,  1.3866e+00,  5.1212e-01,  1.5935e+00, -8.8966e-01,
          2.2454e-01],
        [ 9.2153e-01,  4.3841e-02,  3.8642e-01,  7.7932e-01, -6.8301e-02,
          3.4165e-01, -1.3100e+00, -1.2595e+00, -9.2050e-01, -2.1033e+00,
          3.6177e-01,  1.3866e+00,  5.1212e-01,  1.5935e+00, -8.8966e-01,
          2.2454e-01]], grad_fn=<SelectBackward>)

Warnning: decoder does not output attention weights, hard to figure out what parts are masked

In [77]:
1
2
3
trg_lengths = trg.bool().sum(-1)
print(trg_lengths)  # trg sequence lengths
out[0][:trg_lengths[0] + 3]  # after trg sequence lengths vector values are same, buy not zeros.
1
2
tensor([12,  8, 10,  7, 17])

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
tensor([[ 2.0530e-01,  1.5013e+00,  1.0501e+00,  3.5751e-01,  6.3936e-01,
          9.2592e-01,  6.6488e-01, -2.5640e-01, -4.9700e-01,  6.8812e-01,
         -1.7615e+00,  1.0349e+00, -8.9191e-01, -1.6123e+00, -6.2900e-01,
         -1.4193e+00],
        [-4.4284e-01,  1.9421e+00,  6.1783e-01, -4.9344e-01,  2.1229e+00,
          4.7583e-02,  1.6725e-01,  6.9770e-01, -1.1447e+00, -1.4100e+00,
         -7.2436e-01, -1.3630e+00,  6.1289e-01, -3.1280e-01, -4.5737e-01,
          1.4021e-01],
        [ 3.5275e-01, -5.5859e-01, -9.4811e-01,  3.3188e-01, -1.1998e-03,
         -6.0703e-01, -7.5305e-01,  2.1518e+00, -1.2262e+00, -9.4202e-02,
         -1.0129e+00,  4.2685e-01, -1.4186e+00,  1.6577e+00,  1.0287e+00,
          6.7026e-01],
        [-5.0308e-01,  1.9444e+00,  6.1748e-01, -3.7436e-01,  1.9657e+00,
          1.6569e-01,  3.2078e-02,  8.3904e-01, -1.5390e+00, -1.4493e+00,
         -6.0518e-01, -1.2885e+00,  4.4276e-01, -2.1376e-01, -2.5178e-01,
          2.1766e-01],
        [ 1.0492e+00,  1.5197e+00,  6.3454e-01,  1.0607e-01, -6.4794e-01,
          1.0796e+00, -5.2467e-01, -1.7117e-01,  1.5651e-01, -2.3230e+00,
         -1.2304e+00,  2.5204e-01, -2.5927e-01, -3.0521e-01, -8.7363e-01,
          1.5376e+00],
        [ 1.1820e+00, -1.0573e+00,  1.7224e+00,  1.1034e-01,  3.5638e-01,
          2.0074e-01, -2.0558e+00,  5.5258e-01,  7.8305e-01, -6.4929e-01,
         -1.4155e+00, -1.1916e-02, -3.8352e-01,  1.5048e+00, -4.7025e-01,
         -3.6871e-01],
        [ 1.8280e+00,  3.6645e-01,  1.6182e-01, -1.4546e+00,  3.5153e-01,
         -1.2442e+00, -1.1728e+00, -1.1383e+00,  1.3857e+00,  3.4732e-01,
         -3.3480e-01, -3.3509e-01,  1.7011e+00, -1.5128e-01,  4.8057e-01,
         -7.9143e-01],
        [ 8.4219e-01, -2.2292e-01,  1.6207e+00, -1.5880e-01, -1.9941e+00,
         -2.6402e-01, -1.0795e-01, -1.2542e+00, -3.7832e-01, -1.3619e+00,
         -3.0050e-01,  6.4807e-01,  8.8309e-01,  7.8289e-01, -4.2921e-01,
          1.6950e+00],
        [-5.2335e-01, -6.4032e-01, -6.4664e-01,  6.1068e-01, -5.7863e-02,
          2.1731e+00, -1.5652e+00,  3.1220e-01,  2.7384e-01, -1.1730e+00,
         -1.1549e-01,  1.0999e-02, -1.4816e+00,  1.5529e+00,  1.0676e+00,
          2.0224e-01],
        [ 8.0632e-01, -2.7476e-01,  1.7004e+00, -3.7164e-01, -1.9823e+00,
         -3.1552e-01, -8.3502e-02, -1.1603e+00, -4.2286e-01, -1.2626e+00,
         -1.8362e-01,  7.0035e-01,  9.3423e-01,  6.5968e-01, -4.8287e-01,
          1.7390e+00],
        [-1.2699e-01,  1.3019e+00,  2.5125e-01, -4.2549e-01,  1.7312e+00,
         -1.7435e+00,  1.8060e-01,  5.0046e-01, -8.4960e-01, -2.8272e-01,
         -1.3480e+00,  2.2478e-01,  2.4194e-01,  8.5712e-01, -1.6936e+00,
          1.1806e+00],
        [ 3.9400e-01,  1.7380e+00, -7.8590e-01,  4.9670e-01,  4.0187e-01,
          1.5195e+00, -1.9596e+00, -8.9633e-01, -1.1844e+00, -1.5570e-01,
         -6.3172e-01, -8.9354e-02, -1.0397e+00,  1.0618e+00,  8.5174e-01,
          2.7912e-01],
        [ 9.2153e-01,  4.3841e-02,  3.8642e-01,  7.7932e-01, -6.8301e-02,
          3.4165e-01, -1.3100e+00, -1.2595e+00, -9.2050e-01, -2.1033e+00,
          3.6177e-01,  1.3866e+00,  5.1212e-01,  1.5935e+00, -8.8966e-01,
          2.2454e-01],
        [ 9.2153e-01,  4.3841e-02,  3.8642e-01,  7.7932e-01, -6.8301e-02,
          3.4165e-01, -1.3100e+00, -1.2595e+00, -9.2050e-01, -2.1033e+00,
          3.6177e-01,  1.3866e+00,  5.1212e-01,  1.5935e+00, -8.8966e-01,
          2.2454e-01],
        [ 9.2153e-01,  4.3841e-02,  3.8642e-01,  7.7932e-01, -6.8301e-02,
          3.4165e-01, -1.3100e+00, -1.2595e+00, -9.2050e-01, -2.1033e+00,
          3.6177e-01,  1.3866e+00,  5.1212e-01,  1.5935e+00, -8.8966e-01,
          2.2454e-01]], grad_fn=<SliceBackward>)
In [78]:
1
torch.isnan(out).sum()
1
tensor(0)

Summary

masking is used for helping multi-head-attention properly. summarize all masking in transformer module.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# for self-attention masking
def sequence_mask(seq:torch.LongTensor, padding_idx:int=None) -> torch.BoolTensor:
    """ seq: [bsz, slen], which is padded, so padding_idx might be exist.     
    if True, '-inf' will be applied before applied scaled-dot attention"""
    return seq == padding_idx

# for decoder's look-ahead masking 
def look_ahead_mask(tgt_len:int, src_len:int) -> torch.FloatTensor:  
    """ this will be applied before sigmoid function, so '-inf' for proper positions needed. 
    look-ahead masking is used for decoder in transformer, 
    which prevents future target label affecting past-step target labels. """
    mask = torch.triu(torch.ones(tgt_len, src_len), diagonal=1)
    mask[mask.bool()] = -float('inf')
    return mask

Masking Summary for Translation Model

used for key_padding_mask attn_mask
Encoder (self-attn) src_key_padding_mask $(N, S)$
sequence_mask(seq, padding_idx=0)
src_mask $(S, S)$
torch.ones(S, S)
Decoder (cross-attn, 2nd part) memory_key_padding_mask $(N, S)$
sequence_mask(seq, padding_idx=0)
memory_mask $(T, S)$
look_ahead_mask(T, S)
Decoder (self-attn, 1st part) tgt_key_padding_mask $(N, T)$
sequence_mask(tgt, padding_idx=0)
tgt_mask $(T, T)$
look_ahead_mask(T, T)

Note that cell of (Encoder - key_padding_mask) and cell of (Decoder, 2nd part - key_padding_mask) are the same acutally.
This is because Encoder outputs context vectors.
In the context vector, memory_key_padding_mask helps true sequence of context vectors to applied to decoder.

Leave a comment