Skip to content
This repository has been archived by the owner on Aug 29, 2023. It is now read-only.

Question about transformer decoder #63

Open
CoinCheung opened this issue Mar 27, 2022 · 2 comments
Open

Question about transformer decoder #63

CoinCheung opened this issue Mar 27, 2022 · 2 comments

Comments

@CoinCheung
Copy link

Hi,

I am trying to learn about the code, and I find the following line:

tgt = torch.zeros_like(query_embed)

The input tgt of the decoder is all zeros, and I see the all-zeros-tensor is used as input in the decoder layer:
q = k = self.with_pos_embed(tgt, query_pos)

Here tgt is all-zeros and the query_pos is a learnable embedding, which causes q and k to be non-zero tensor (same tensor in value as query_pos, but the tgt is still all-zeros(used as v). According to the computation rule of qkv attention, if v is all-zeros, the output of qkv would be all-zeros. Thus the self-attention module does not contribute to the model. Am I correct on this?

@bowenc0221
Copy link
Contributor

This is correct only for the first self-attention layer. tgt is no longer zero vector after cross-attention.

@CoinCheung
Copy link
Author

CoinCheung commented Apr 5, 2022

Thanks for replying !!! There is another part of code that I cannot understand:

return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

If we use default settings of batch_first=False for nn.MultiheadAttention, the above hs tensor should be LNE, where L is sequence length(num of queries here), N is batchsize and E is feature dimension. After the transpose(1,2), hs will become LEN. The batchsize will be the last dimension.
However, according to this line:
outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features)

The output hs should be a 4d tensor ? Would you please tell me what did I miss here ?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants