-
Notifications
You must be signed in to change notification settings - Fork 0
/
dual_attention_transformer.py
1554 lines (1310 loc) · 69.3 KB
/
dual_attention_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
This file contains a self-contained single-file implementation of DualAttention and the DualAttention Transformer
as proposed in the paper:
"Disentangling and Integrating Relational and Sensory Information in Transformer Architectures"
Awni Altabaa, John Lafferty (2024). https://arxiv.org/abs/2405.16727
Author: Awni Altabaa
License: MIT License
"""
import torch
import torch.nn as nn
import math
from einops import rearrange
# An implementation of Dual Attention as proposed in the paper
# "Disentangling and Integrating Relational and Sensory Information in Transformer Architectures"
# Awni Altabaa, John Lafferty (2024). https://arxiv.org/abs/2405.16727
# The DualAttention module is a form of multi-head attention involving a composition of two distinct types of attention heads.
# The first type is standard self-attention, which captures object-level (i.e., sensory) features, and
# the second type is relational attention, which captures relational features.
# DualAttention is a concatenation of self-attention and relational attention heads.
class DualAttention(nn.Module):
def __init__(self,
d_model: int,
n_heads_sa: int,
n_heads_ra: int,
dropout: float,
sa_kwargs: dict = None,
ra_kwargs: dict = None,
share_attn_params: bool = False,
ra_type: str = 'relational_attention'
):
"""An implementation of Dual Attention.
The DualAttention module is a form of multi-head attention involving a composition of two distinct types of attention heads.
The first type is standard self-attention, which captures object-level (i.e., sensory) features, and
the second type is relational attention, which captures relational features.
Parameters
----------
d_model : int
model dimension
n_heads_sa : int
number of self-attention heads
n_heads_ra : int
number of relational attention heads
dropout : float
dropout rate
sa_kwargs : dict, optional
self-attention kwargs, by default None
ra_kwargs : dict, optional
relational attention kwargs, by default None
share_attn_params : bool, optional
whether to share attention parameters between self-attention and relational attention.
If True, w{q,k} in sensory attention and w{q,k}_attn in relational attention are shared.
number of heads in each must be the same. By default False
ra_type : str, optional
type of relational attention module (e.g., whether to use RCA for an ablation experiment).
by default 'relational_attention'.
"""
super(DualAttention, self).__init__()
self.d_model = d_model
self.n_heads_sa = n_heads_sa
self.n_heads_ra = n_heads_ra
self.dropout = dropout
self.sa_kwargs = sa_kwargs if sa_kwargs is not None else {}
self.ra_kwargs = ra_kwargs if ra_kwargs is not None else {}
self.ra_type = ra_type
self.share_attn_params = share_attn_params
if self.share_attn_params and n_heads_sa != n_heads_ra:
raise ValueError("Number of heads in self-attention and relational attention must be the same if sharing attention parameters")
self.use_self_attn = n_heads_sa > 0
self.use_rel_attn = n_heads_ra > 0
self.total_n_heads = n_heads_sa + n_heads_ra
if not (self.use_self_attn or self.use_rel_attn):
raise ValueError("At least one of self-attention or relational attention must be used")
if self.use_self_attn:
self.self_attention = Attention(
d_model=d_model, n_heads=n_heads_sa,
total_n_heads=self.total_n_heads, dropout=dropout,
**self.sa_kwargs)
if self.use_rel_attn and ra_type=='relational_attention':
self.relational_attention = RelationalAttention(
d_model=d_model, n_heads=n_heads_ra,
total_n_heads=self.total_n_heads, dropout=dropout,
**self.ra_kwargs)
# elif self.use_rel_attn and ra_type=='rca':
# self.relational_attention = RelationalCrossAttention(
# d_model=d_model, n_heads=n_heads_ra,
# total_n_heads=self.total_n_heads, dropout=dropout,
# **self.ra_kwargs)
# elif self.use_rel_attn and ra_type=='disrca':
# self.relational_attention = DisentangledRelationalCrossAttention(
# d_model=d_model, n_heads=n_heads_ra,
# total_n_heads=self.total_n_heads, dropout=dropout,
# **self.ra_kwargs)
else:
raise ValueError(f"Invalid relational attention type: {ra_type}")
if self.share_attn_params:
self.self_attention.wq = self.relational_attention.wq_attn
self.self_attention.wk = self.relational_attention.wk_attn
def forward(
self,
x: torch.Tensor,
symbols: torch.Tensor,
attn_mask: torch.Tensor = None, # boolean attention mask: True indicates corresponding position *should* be attended to
is_causal: bool = False, # indicates causal mask; should only set one of is_causal and attn_mask
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
need_weights: bool = False # applies only to self-attention; determines whether FlashAttention is used or not
):
# self-attention
if self.use_self_attn:
self_attn_out, self_attn_scores = self.self_attention(
query=x, key=x, value=x,
freqs_cos=freqs_cos, freqs_sin=freqs_sin,
attn_mask=attn_mask, is_causal=is_causal,
need_weights=need_weights)
# relational cross-attention
if self.use_rel_attn:
rel_attn_out, *rel_attn_scores = self.relational_attention(
x, symbols,
attn_mask=attn_mask, is_causal=is_causal,
freqs_cos=freqs_cos, freqs_sin=freqs_sin)
# combine self-attention and relational cross-attention
if self.use_rel_attn and self.use_self_attn:
# concat self-attention output (E) and relational cross-attention output (A)
out = torch.concat((self_attn_out, rel_attn_out), dim=-1)
elif self.use_rel_attn:
out = rel_attn_out # only use relational cross-attention
self_attn_scores = None
elif self.use_self_attn:
out = self_attn_out # only use standard self-attention
rel_attn_scores = None
return out, self_attn_scores, rel_attn_scores
# Implementation of RelationalAttention as proposed in
# > "Disentangling and Integrating Relational and Sensory Information in Transformer Architectures"
# > Awni Altabaa, John Lafferty (2024). https://arxiv.org/abs/2405.16727
# Relational attention defines a differentiable information-retrieval operation where the information retrieved
# is the relations between objects. The "message" sent from one object to another is the relations between the
# sender and the receiver, tagged with a symbol identifying the sender. These messages are aggregated based on the
# receiver's features via softmax attention scores.
# Relational attention takes the form
# Math: \mathrm{RelAttn}(x_1, ..., x_n) = \sum_{j} \alpha_{ij} (r(x_i, x_j) W_r + s_j W_s)
# Math: \alpha = \mathrm{Softmax}((x W_q^{attn}) (x W_k^{attn})^\intercal)
# Math: r(x_i, x_j) = (\langle x_i W_{q, \ell}^{rel}, x_j W_{k, \ell}^{rel}\rangle)_{\ell \in [d_r]}
# Math: (s_1, ..., s_n) = \mathrm{SymbolRetriever}(x_1, ..., x_n)
# TODO: add support for sharing single key-proj for all relations (similar to MQA)
# TODO: should default rel_proj_dim be s.t. rel_proj_dim = head_dim * n_h^ra // n_relations?
# (so that param count is constant as n_relations varies)
class RelationalAttention(nn.Module):
def __init__(self,
d_model: int,
n_heads: int,
n_relations: int = None,
dropout: float = 0.0,
key_dim: int = None,
n_kv_heads: int = None,
rel_activation: str = 'identity',
rel_proj_dim: int = None,
add_bias_kv: bool = False,
add_bias_out: bool = False,
total_n_heads: int = None,
symmetric_rels: bool = False,
use_relative_positional_symbols: bool = False
):
"""
An implementation of Relational Attention (RA).
Relational attention defines a differentiable information-retrieval operation where the information retrieved
is the relations between objects. The "message" sent from one object to another is the relations between the
sender and the receiver, tagged with a symbol identifying the sender. These messages are aggregated based on the
receiver's features via softmax attention scores.
The learnable parameters include a set of query/key projections which determine the attention scores, and hence
the ``selection criteria'', as well as a set of query/key projections for computing relations between objects.
They also include per-head projections for the symbols and relations, as well as a final output projection.
This module supports symmetric relations, position-relative symbolic embeddings,
multi-query attention/grouped query attention, and control over total number of heads (for use with "dual attention").
Parameters
----------
d_model : int
model dimension
n_heads : int
number of attention heads (query heads if n_kv_heads is set)
n_relations : int, optional
number of relations. If None, n_relations = n_heads. By default None
dropout : float, optional
dropout rate. By default 0.0
n_kv_heads : int, optional
number of key/value heads. used to implement multi-query attention or grouped query attention.
n_kv_heads=1 corresponds to MQA, n_kv_heads > 1 corresponsd to grouped query attention.
n_kv_heads=n_heads is standard MHA. uses MHA when None. By default None
rel_activation : str, optional
name of activation function applied to relations. By default 'identity'.
rel_proj_dim : int, optional
dimension of relation projections. If None, rel_proj_dim = d_model // n_relations. By default None.
add_bias_kv : bool, optional
whether to use bias in key/value projections, by default False
add_bias_out : bool, optional
whether to use bias in out projection, by default False
total_n_heads : int, optional
total number of heads in dual attention (if using dual attention).
used to ensure that concat(A, E) is of dimension d_model after concatentation.
hence, output dimension is (d_model // total_heads) * n_heads.
if None, total_heads = n_heads and output dimension is d_model
"""
super().__init__()
self.d_model = d_model # model dimension
self.n_heads = n_heads # number of heads (for query)
self.n_relations = n_relations if n_relations is not None else n_heads # number of relations
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads # n_kv_heads = 1 corresponds to multi-query attn
self.rel_activation = rel_activation # "relation activation function"
self.rel_activation_ = get_activation_function(rel_activation)
self.symmetric_rels = symmetric_rels # whether to use symmetric relations
self.dropout = dropout
self.add_bias_kv = add_bias_kv # whether to add bias to key/value projections
self.add_bias_out = add_bias_out # whether to add bias to output projection
self.use_relative_positional_symbols = use_relative_positional_symbols # whether to use relative positional symbols
self.total_n_heads = n_heads if total_n_heads is None else total_n_heads # total number of heads in abstract attention
self.head_dim = self.d_model // self.total_n_heads # dim of projections
self.n_rep_kv = self.n_heads // self.n_kv_heads # use same kv heads for several query heads
self.key_dim = key_dim if key_dim is not None else self.head_dim # key dimension
self.rel_proj_dim = rel_proj_dim if rel_proj_dim is not None else self.head_dim # dimension of relation projections
# make relative size of parameters and dimensions makes sense
assert self.n_heads % self.n_kv_heads == 0, f"n_heads={self.n_heads}, n_kv_heads = {self.n_kv_heads}"
assert self.n_rep_kv * self.n_kv_heads == self.n_heads, f"n_rep_kv={self.n_rep_kv}, n_kv_heads={self.n_kv_heads}, n_heads={self.n_heads}"
assert self.total_n_heads * self.head_dim == self.d_model, f"total_n_heads={self.total_n_heads}, head_dim={self.head_dim}, d_model={self.d_model}"
assert self.rel_proj_dim * self.n_relations == self.head_dim * self.n_heads, f"rel_proj_dim={self.rel_proj_dim}, n_relations={self.n_relations}, head_dim={self.head_dim}"
self.attn_scale = 1 / math.sqrt(self.head_dim) # for scaled dot product attention
self.rel_scale = 1 / math.sqrt(self.rel_proj_dim) # for relations
# Wq, Wk projections for attention
self.wq_attn = nn.Linear(self.d_model, self.n_heads * self.key_dim, bias=False)
self.wk_attn = nn.Linear(self.d_model, self.n_kv_heads * self.key_dim, bias=self.add_bias_kv)
# Wq, Wk projections for relation
self.wq_rel = nn.Linear(self.d_model, self.n_relations * self.rel_proj_dim, bias=False)
if self.symmetric_rels:
self.wk_rel = self.wq_rel
else:
self.wk_rel = nn.Linear(self.d_model, self.n_relations * self.rel_proj_dim, bias=False)
# NOTE: attn Wq, Wk have n_kv_heads param for multi-query/grouped-query attention
# but rel Wq, Wk do not have this param. TODO: think about whether we want to adjust implementation?
# W_r = (W_r^h)_h projection mapping r(x_i, x_j) to common dimension with symbols
self.wr = nn.Parameter(torch.empty(self.n_heads, self.head_dim, self.n_relations))
torch.nn.init.kaiming_uniform_(self.wr, a=math.sqrt(5))
# W_s = (W_s^h)_h = W_v projection for symbols
self.wv = nn.Linear(self.d_model, self.n_kv_heads * self.head_dim, bias=self.add_bias_kv)
# Final output projection
self.wo = nn.Linear(self.n_heads * self.head_dim, self.n_heads * self.head_dim, bias=self.add_bias_out)
self.attn_dropout = nn.Dropout(self.dropout) # dropout for attention scores
self.resid_dropout = nn.Dropout(self.dropout) # dropout for output
def forward(
self,
x: torch.Tensor,
symbols: torch.Tensor,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
attn_mask: torch.Tensor = None, # boolean attention mask: True indicates corresponding position *should* be attended to
is_causal: bool = False # indicates causal mask (will be computed automatically); should only set one of is_causal and attn_mask
):
"""
compute attention with given query, key, value.
if freqs_cos and freqs_sin are given, apply rotary positional embeddings.
if attn_mask is given, apply attention mask.
if is_causal is True, apply causal mask (attn_mask must be None).
if use_relative_positional_symbols is True, the symbols are treated as relative positional embeddings.
assumed to be of shape [len, len, dim] where len is the length of the sequence x.
Parameters
----------
x : torch.Tensor
input tensor of shape [bsz, len, d_model]
symbols : torch.Tensor
input tensor of shape [bsz, len, d_model] or [len, len, d_model] if use_relative_positional_symbols is True
freqs_cos : torch.Tensor, optional
cosine of frequencies for RoPE. RoPE is applied if given. By default None
freqs_sin : torch.Tensor, optional
cosine of frequencies for RoPE. RoPE is applied if given. By default None
attn_mask : torch.Tensor, optional
boolean attention mask of shape [len, len]. True at [i,j] indicates i is allowed to attend to j.
By default None
is_causal : bool, optional
whether to apply a causal mask. If True, attn_mask must be None. By default False
Returns
-------
tuple[torch.Tensor]
outputs [bsz, len, d_model], attention scores [bsz, n_heads, len, len], relations [bsz, len, len, n_relations]
"""
bsz, seqlen, _ = x.shape
## compute attention scores
# apply query/key projections for attention and reshape to split into different heads
xq_attn, xk_attn = self.wq_attn(x), self.wk_attn(x) # shape: [bsz, seqlen, d_model] (d_model = n_heads * head_dim)
xq_attn = rearrange(xq_attn, 'b l (nh hd) -> b l nh hd', nh=self.n_heads) # shape: [bsz, seqlen, n_heads, head_dim]
xk_attn = rearrange(xk_attn, 'b l (nh hd) -> b l nh hd', nh=self.n_kv_heads) # shape: [bsz, seqlen, n_kv_heads, head_dim]
# apply RoPE relative positional embeddings (if given)
if freqs_cos is not None and freqs_sin is not None:
xq_attn, xk_attn = apply_rotary_emb(xq_attn, xk_attn, freqs_cos, freqs_sin)
# grouped multiquery attention: expand out keys and values
if self.n_rep_kv != 1:
xk_attn = repeat_kv(xk_attn, self.n_rep_kv) # (bs, seqlen, n_heads, head_dim)
# transpose for matmul, make heads into a batch dimension
xq_attn = xq_attn.transpose(1, 2) # (bs, n_heads, seqlen, head_dim)
xk_attn = xk_attn.transpose(1, 2) # (bs, n_heads, seqlen, head_dim)
assert not (attn_mask is not None and is_causal) # attn_mask must not be given if is_causal
# if is_causal create attn_mask
if is_causal and attn_mask is None:
attn_mask = compute_causal_mask(seqlen, device=xq_attn.device)
# compute dot product for attn scores
# Math: \alpha_{ij}^h = \langle W_q^{attn,h} x_i, W_k^{attn,h} x_j \rangle
attn_scores = torch.matmul(xq_attn, xk_attn.transpose(2, 3)) * self.attn_scale # (bs, n_heads, seqlen, seqlen)
# TODO: instead of creating a mask each time, it can be added to the buffer using a max_seq_len argument
# e.g., see: https://github.com/karpathy/llama2.c/blob/master/model.py
if attn_mask is not None:
attn_mask_ = torch.zeros(seqlen, seqlen, dtype=xq_attn.dtype, device=xq_attn.device).masked_fill(attn_mask.logical_not(), float('-inf'))
attn_scores = attn_scores + attn_mask_
# apply (relation) activation to inner products
attn_scores = nn.functional.softmax(attn_scores, dim=-1) # (bs, n_heads, seqlen, seqlen)
attn_scores = self.attn_dropout(attn_scores)
# NOTE: does it make sense to dropout attention scores?
# it's done in Vaswani et al's original implementation and continues to be used, but standard dropout is not "closed under" simplex...
## compute relations
# apply query/key projections for relation and reshape to split into different heads
xq_rel, xk_rel = self.wq_rel(x), self.wk_rel(x) # shape: [bsz, seqlen, n_rels * rel_proj_dim]
xq_rel = rearrange(xq_rel, 'b l (nr rd) -> b l nr rd', nr=self.n_relations) # shape: [bsz, seqlen, n_relations, rel_proj_dim]
xk_rel = rearrange(xk_rel, 'b l (nr rd) -> b l nr rd', nr=self.n_relations) # shape: [bsz, seqlen, n_relations, rel_proj_dim]
# apply value projection to symbols
sv = self.wv(symbols) # shape: [bsz, seqlen, d_model] or [seqlen, seqlen, d_model] if use_relative_positional_symbols
# grouped multiquery attention: expand out keys and values
if self.use_relative_positional_symbols:
# make sure symbols are of shape [len, len, dim]
assert symbols.shape[0] == symbols.shape[1] == seqlen, f"symbols must be of shape [len, len, dim], received {symbols.shape}"
sv = rearrange(sv, 'l1 l2 (nh hd) -> l1 l2 nh hd', nh=self.n_kv_heads) # shape: [seqlen, seqlen, n_kv_heads, head_dim]
else:
sv = rearrange(sv, 'b l (nh hd) -> b l nh hd', nh=self.n_kv_heads) # shape: [bsz, seqlen, n_kv_heads, head_dim]
if self.n_rep_kv != 1:
sv = repeat_kv(sv, self.n_rep_kv) # (bs, seqlen, n_heads, head_dim)
xq_rel = xq_rel.transpose(1, 2) # (bs, n_heads, seqlen, head_dim)
xk_rel = xk_rel.transpose(1, 2) # (bs, n_heads, seqlen, head_dim)
# sv: (seq_len, seq_len, n_heads, head_dim) or (bs, seq_len, n_heads, head_dim)
# compute relations
# Math: r(x_i, x_j) = (\langle W_q^{rel,\ell} x_i, W_k^{rel,\ell} x_j \rangle)_{\ell \in [d_r]}
relations = torch.matmul(xq_rel, xk_rel.transpose(2, 3)) * self.rel_scale
relations = self.rel_activation_(relations) # (bs, n_rels, seqlen, seqlen)
# transpose to put "heads"/"relations" in final dim
relations = rearrange(relations, 'b nr i j -> b i j nr') # (bs, seqlen, seqlen, n_rels)
# NOTE: in a previous version of this implementation, the relations were mapped to head_dim-dimensional space with W_r^h
# *before* the attention operation. However, this requires manifesting a large (N * N * D)- dimensional tensor instead of
# an (N * N * R)-dimensional tensor (where R << D; R usually equals n_heads). This is a significant memory bottleneck.
# This caused the memory requirement to scale quadratically with the sequence length which was prohibitive
# Here, instead, we only manifest the (N * N * R)-dimensional tensor, compute attention over the relations to obtain an (N * H * R)-dimensional tensor,
# then project to the final (N * H * head_dim)-dimensional tensor. This is much more memory efficient.
# compute disentangled relational cross attention
if not self.use_relative_positional_symbols:
# sv: (bs, seqlen, n_heads, head_dim)
# attn_scores: (bs, n_heads, seqlen, seqlen)
# relations: (bs, seqlen, seqlen, n_heads, head_dim)
# Math: A_i^h = \sum_j \alpha_{ij}^h (r(x_i, x_j) W_r^h + s_j W_s^h)
# attend to symbols for each head
attended_symbols = torch.einsum('bhij,bjhd->bihd', attn_scores, sv) # (bs, seqlen, n_heads, head_dim)
# attend to relations for each head
# Math: \sum_j \alpha_{ij}^h r(x_i, x_j)
attended_relations = torch.einsum('bhij,bijr->bihr', attn_scores, relations) # (bs, seqlen, n_heads, n_relations)
# project relations to common dimension with symbols (per-head)
# Math: W_r^h (attended_relations)
attended_relations = torch.einsum('bihr,hdr->bihd', attended_relations, self.wr) # (bs, seqlen, n_heads, head_dim)
output = attended_symbols + attended_relations # (bs, seqlen, n_heads, head_dim)
else:
# sv: (seqlen, seqlen, n_heads, head_dim)
# attn_scores: (bs, n_heads, seqlen, seqlen)
# relations: (bs, seqlen, seqlen, n_heads, head_dim)
# Math: A_i^h = \sum_j \alpha_{ij}^h (r(x_i, x_j) W_r^h + s_{j-i} W_s)
# attend to symbols for each head
attended_symbols = torch.einsum('bhij,ijhd->bihd', attn_scores, sv) # (bs, seqlen, n_heads, head_dim)
# Math: \sum_j \alpha_{ij}^h r(x_i, x_j)
attended_relations = torch.einsum('bhij,bijr->bihr', attn_scores, relations) # (bs, seqlen, n_heads, n_relations)
# project relations to common dimension with symbols (per-head)
# Math: W_r^h (attended_relations)
attended_relations = torch.einsum('bihr,hdr->bihd', attended_relations, self.wr) # (bs, seqlen, n_heads, head_dim)
output = attended_symbols + attended_relations # (bs, seqlen, n_heads, head_dim)
# concat heads
# Math: A_i = \mathrm{concat}(A_i^1, ..., A_i^{n_h})
output = rearrange(output, 'b l nh hd -> b l (nh hd)') # (bs, seqlen, n_heads * head_dim)
# final projection into the residual stream
# Math: A_i \gets W_o A_i
output = self.wo(output)
output = self.resid_dropout(output)
return output, attn_scores, relations
# NOTE: position-relative symbol variant is very memory hungry because it involves a large (N * N * D)-dimensional tensor
# TODO: can we obtain a more memory-efficient implementation?
# we don't really need N*N since there are only 2*N possible position-relative symbols (actually just N in causal case)
# can we improve this?
# region Symbol Assignment Mechanisms
class SymbolicAttention(nn.Module):
def __init__(self,
d_model: int,
n_heads: int,
n_symbols: int,
dropout: float = 0.0,
scale: float = None,
trainable_symbols: bool = True):
"""
Symbolic Attention.
Learns a library of "symbols" and corresponding template features.
For a given input, retrieves a symbol from the symbol library via attention.
Parameters
----------
d_model : int
model dimension. this is the dimension of the input and the dimension of the symbols and template features.
n_heads : int
number of heads in symbolic attention.
n_symbols : int
number of symbols in the symbol library.
dropout : float, optional
dropout probability, by default 0.0
scale : float, optional
scaling factor in scaled_dot_product_attention, by default None
trainable_symbols: bool, optional
whether to make the symbol library trainable, by default True
"""
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.n_symbols = n_symbols
self.dropout = dropout
self.scale = scale
self.trainable_symbols = trainable_symbols
self.q_proj = nn.Linear(self.d_model, self.d_model)
self.template_features = nn.Parameter(torch.empty(self.n_symbols, self.d_model))
self.symbol_library = nn.Parameter(torch.empty(self.n_symbols, self.d_model), requires_grad=trainable_symbols)
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.normal_(self.template_features)
torch.nn.init.normal_(self.symbol_library)
def forward(self, x):
batch_size, seq_len, dim = x.size()
# create query from input
query = self.q_proj(x)
query = query.view(batch_size, seq_len, self.n_heads, dim // self.n_heads).transpose(1, 2)
# create keys from template features
key = self.template_features.view(self.n_symbols, self.n_heads, self.d_model // self.n_heads).transpose(0, 1)
key = self._repeat_kv(key, batch_size)
# create values from symbol library
value = self.symbol_library.view(self.n_symbols, self.n_heads, self.d_model // self.n_heads).transpose(0, 1)
value = self._repeat_kv(value, batch_size)
retrieved_symbols = torch.nn.functional.scaled_dot_product_attention(
query, key, value,
scale=self.scale, dropout_p=self.dropout, attn_mask=None, is_causal=False)
retrieved_symbols = retrieved_symbols.transpose(1, 2).contiguous().view(batch_size, seq_len, dim)
return retrieved_symbols
def _repeat_kv(self, x, batch_size):
"""
template_features and symbol_library are of shape (n_heads, n_s, d_s//n_heads).
repeat for each input and add a batch dimension of size batch_size.
"""
return x.unsqueeze(0).repeat(batch_size, 1, 1, 1)
class PositionalSymbolRetriever(nn.Module):
def __init__(self, symbol_dim, max_length, sinusoidal=False):
"""
Postional Symbol Retriever.
Learns a library of "symbols".
Retrieves a symbol for each object based on its position.
Parameters
----------
symbol_dim : int
dimension of the symbols.
max_symbols : int
maximum number of symbols.
"""
super().__init__()
self.symbol_dim = symbol_dim
self.max_length = max_length
self.sinusoidal = sinusoidal
self.symbol_library = nn.Embedding(self.max_length, self.symbol_dim)
# TODO: implement sinusoidal symbols?
def forward(self, x):
device = x.device
batch_size, seq_len, dim = x.size()
pos = torch.arange(0, seq_len, dtype=torch.long, device=device)
retrieved_symbols = self.symbol_library(pos).unsqueeze(0).repeat(batch_size, 1, 1)
return retrieved_symbols
# TODO: add support for causal-only position-relative symbols?
# cuts param count by half
class PositionRelativeSymbolRetriever(nn.Module):
def __init__(self, symbol_dim, max_rel_pos):
"""
Position-Relative Symbol Retriever.
For i -> j, the symbol s_{ij} encodes the relative position j - i.
Parameters
----------
symbol_dim : int
dimension of the symbols.
max_rel_pos : int
maximum relative position encoded by symbols.
Positions exceeding this will be truncated.
"""
super().__init__()
self.symbol_dim = symbol_dim
self.max_rel_pos = max_rel_pos
self.rel_pos_enc = RelativePositionalEncoding(dim=symbol_dim, max_rel_pos=max_rel_pos)
def forward(self, x):
length = x.shape[1]
return self.rel_pos_enc(length, device=x.device)
class RelativePositionalEncoding(nn.Module):
def __init__(self, dim: int, max_rel_pos: int):
"""
module which returns relative positional embeddings for a given pair of sequences.
I.e., returns tensor whose [i,j]-th entry is the embedding of the relative position "j-i"
Parameters
----------
dim : int
dimension of embeddings
max_rel_pos : int
maximum relative position in either direction (used for clipping)
"""
super().__init__()
self.num_units = dim
self.max_relative_position = max_rel_pos
self.rel_pos_embeddings = nn.Parameter(torch.Tensor(max_rel_pos * 2 + 1, dim))
nn.init.xavier_uniform_(self.rel_pos_embeddings)
def forward(self, length_q, length_k=None, device=None):
"""
Parameters
----------
length_q : int
length of query sequence
length_k : _type_, optional
length of key sequence, by default None
Returns
-------
Tensor
tensor of shape [length_q, length_k, dim] where [i,j] is the embedding of the relative position "j-i"
"""
if length_k is None:
length_k = length_q
range_q = torch.arange(length_q, device=device)
range_k = torch.arange(length_k, device=device)
distance_mat = range_k[None, :] - range_q[:, None]
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
final_mat = distance_mat_clipped + self.max_relative_position
# final_mat = torch.LongTensor(final_mat).cuda()
embeddings = self.rel_pos_embeddings[final_mat] #.cuda()
return embeddings
# endregion
# region Dual-Attention Blocks
class DualAttnEncoderBlock(nn.Module):
def __init__(self,
d_model: int,
n_heads_sa: int,
n_heads_ra: int,
dff: int,
activation: str,
dropout_rate: float,
norm_first: bool,
norm_type: str = 'layernorm',
sa_kwargs: dict = None,
ra_kwargs: dict = None,
ra_type: str = 'relational_attention',
share_attn_params: bool = False,
bias: bool = True,
causal: bool = False):
"""
Dual Attention Encoder Block.
A Dual Attention Encoder is a variant of the Transformer Encoder that uses a combination of two distinct types of attention heads.
The first type is standard self-attention, which captures object-level (i.e., sensory) features, and
the second type is relational attention, which captures relational features.
Parameters
----------
d_model : int
model dimension.
n_heads_sa : int
number of standard self-attention heads.
n_heads_ra : int
number of relational attention heads.
dff : int
intermediate dimension of feed-forward block.
activation : str
name of activation function to use in feedforward block.
dropout_rate : float
dropout rate.
norm_first : bool
whether to apply normalization before or after attention. norm_first=True means pre-norm otherwise post-norm.
norm_type : 'layernorm' or 'rmsnorm, optional
type of normalization to use, by default 'layernorm'
sa_kwargs : dict, optional
self-attention kwargs, by default None
ra_kwargs : dict, optional
relational attention kwargs, by default None
ra_type : str, optional
type of relational attention module (e.g., whether to use RCA for an ablation experiment), by default 'relational_attention'
share_attn_params : bool, optional
whether to share attention parameters between self-attention and relational attention.
If True, w{q,k} in sensory attention and w{q,k}_attn in relational attention are shared.
number of heads in each must be the same. By default False
bias : bool, optional
whether to use bias in multi-head attention, by default True
causal : bool, optional
whether attention operations should be causal, by default False
"""
super().__init__()
self.d_model = d_model
self.n_heads_sa = n_heads_sa
self.n_heads_ra = n_heads_ra
self.dff = dff
self.dropout_rate = dropout_rate
self.activation = activation
self.norm_first = norm_first
self.norm_type = norm_type
self.ra_type = ra_type
self.share_attn_params = share_attn_params
self.bias = bias
self.causal = causal
self.dropout = nn.Dropout(self.dropout_rate)
self.norm1 = create_norm(self.d_model, self.norm_type)
self.dual_attn = DualAttention(
d_model=d_model, n_heads_sa=n_heads_sa, n_heads_ra=n_heads_ra,
dropout=dropout_rate, sa_kwargs=sa_kwargs, ra_kwargs=ra_kwargs,
ra_type=ra_type, share_attn_params=share_attn_params)
self.norm2 = create_norm(self.d_model, self.norm_type)
self.ff_block = FeedForwardBlock(self.d_model, dff=self.dff, activation=self.activation, use_bias=self.bias)
# TODO: make attn_mask input so it only needs to be computed once?
def forward(self, x, symbols, freqs_cos=None, freqs_sin=None):
if self.norm_first:
x = x + self._compute_dual_attn(self.norm1(x), symbols, freqs_cos=freqs_cos, freqs_sin=freqs_sin)
x = x + self._apply_ff_block(self.norm2(x))
else:
x = self.norm1(x + self._compute_dual_attn(x, symbols, freqs_cos=freqs_cos, freqs_sin=freqs_sin))
x = self.dropout(x)
x = self.norm2(x + self._apply_ff_block(x))
return x
def _compute_dual_attn(self, x, symbols, freqs_cos=None, freqs_sin=None):
x, *_ = self.dual_attn(x, symbols,
need_weights=False, is_causal=self.causal,
freqs_cos=freqs_cos, freqs_sin=freqs_sin)
x = self.dropout(x) # dropout
return x
def _apply_ff_block(self, x):
x = self.ff_block(x)
x = self.dropout(x)
return x
class DualAttnDecoderBlock(nn.Module):
def __init__(self,
d_model: int,
n_heads_sa: int,
n_heads_ra: int,
n_heads_cross: int,
dff: int,
activation: str,
dropout_rate: float,
norm_first: bool,
norm_type: str = 'layernorm',
sa_kwargs: dict = None,
ra_kwargs: dict = None,
cross_kwargs: dict = None,
ra_type: str = 'relational_attention',
share_attn_params: bool = False,
bias: bool = True,
causal: bool = True):
"""
Dual Attention Decoder Block.
A Dual Attention Decoder is a variant of the Transformer Decoder that uses a combination of two distinct types of attention heads.
The first type is standard self-attention, which captures object-level (i.e., sensory) features, and
the second type is relational attention, which captures relational features.
Parameters
----------
d_model : int
model dimension.
n_heads_sa : int
number of standard self-attention heads.
n_heads_ra : int
number of relational attention heads.
n_heads_cross : int
number of cross-attention heads.
dff : int
intermediate dimension of feed-forward block.
activation : str
name of activation function to use in feedforward block.
dropout_rate : float
dropout rate.
norm_first : bool
whether to apply normalization before or after attention. norm_first=True means pre-norm otherwise post-norm.
norm_type : 'layernorm' or 'rmsnorm, optional
type of normalization to use, by default 'layernorm'
sa_kwargs : dict, optional
self-attention kwargs, by default None
ra_kwargs : dict, optional
relational attention kwargs, by default None
cross_kwargs : dict, optional
cross-attention kwargs, by default None
ra_type : str, optional
type of relational attention module (e.g., whether to use RCA for an ablation experiment), by default 'relational_attention'
share_attn_params : bool, optional
whether to share attention parameters between self-attention and relational attention.
If True, w{q,k} in sensory attention and w{q,k}_attn in relational attention are shared.
number of heads in each must be the same. By default False
bias : bool, optional
whether to use bias in multi-head attention, by default True
causal : bool, optional
whether attention operations should be causal, by default False
"""
super().__init__()
self.d_model = d_model
self.n_heads_sa = n_heads_sa
self.n_heads_ra = n_heads_ra
self.n_heads_cross = n_heads_cross
self.dff = dff
self.dropout_rate = dropout_rate
self.activation = activation
self.norm_first = norm_first
self.norm_type = norm_type
self.ra_type = ra_type
self.share_attn_params = share_attn_params
self.bias = bias
self.causal = causal
self.use_self_attn = n_heads_sa > 0
self.use_rel_attn = n_heads_ra > 0
self.dropout = nn.Dropout(self.dropout_rate)
self.norm1 = create_norm(self.d_model, self.norm_type)
self.dual_attn = DualAttention(
d_model=d_model, n_heads_sa=n_heads_sa, n_heads_ra=n_heads_ra,
dropout=dropout_rate, sa_kwargs=sa_kwargs, ra_kwargs=ra_kwargs,
ra_type=ra_type, share_attn_params=share_attn_params)
self.norm2 = create_norm(self.d_model, self.norm_type)
cross_kwargs = cross_kwargs if cross_kwargs is not None else {}
self.cross_attn = Attention(
self.d_model, self.n_heads_cross, dropout=self.dropout_rate,
**cross_kwargs)
self.norm3 = create_norm(self.d_model, self.norm_type)
self.ff_block = FeedForwardBlock(self.d_model, dff=self.dff, activation=self.activation, use_bias=self.bias)
def forward(self, x, context, symbols):
if self.norm_first:
x = x + self._compute_dual_attn(self.norm1(x), symbols)
x = x + self._compute_cross_attn(self.norm2(x), context)
x = x + self.ff_block(self.norm3(x))
else:
x = self.norm1(x + self._compute_dual_attn(x, symbols))
x = self.norm2(x + self._compute_cross_attn(x, context))
x = self.norm3(x + self.ff_block(x))
return x
def _compute_dual_attn(self, x, symbols):
x, *_ = self.dual_attn(x, symbols, need_weights=False, is_causal=self.causal)
x = self.dropout(x) # dropout
return x
def _compute_cross_attn(self, x, context):
x = self.cross_attn(query=x, key=context, value=context, need_weights=False, is_causal=False)[0]
x = self.dropout(x)
return x
def _apply_ff_block(self, x):
x = self.ff_block(x)
x = self.dropout(x)
return x
# endregion
# region Dual-Attention language model
class DualAttnTransformerLM(nn.Module):
"""Dual Attention Transformer Language Model"""
def __init__(self,
vocab_size: int,
d_model: int,
n_layers: int,
n_heads_sa: int,
n_heads_ra: int,
symbol_retrieval_kwargs: dict,
dff: int,
dropout_rate: float,
activation: str,
norm_first: bool,
max_block_size: int,
norm_type: str = 'layernorm',
sa_kwargs: dict = None,
ra_kwargs: dict = None,
ra_type: str = 'relational_attention',
share_attn_params: bool = False,
symbol_retrieval: str = 'symbolic_attention',
symbol_retriever_config: dict = None, # dict with keys: shared_symbol_retriever, weight_tie_symbol_library
pos_enc_type: str = 'pos_emb',
bias: bool = True):
"""
Dual Attention Transformer Language Model.
Parameters
----------
vocab_size : int
vocabulary size.
d_model : int
model dimension.
n_layers : int
number of layers.
n_heads_sa : int
number of self-attention heads in dual-attention.
n_heads_ra : int
number of relational attention heads in dual-attention.
symbol_retrieval_kwargs : dict
keyword arguments for symbol retrieval module.
dff : int
size of intermediate layer in feedforward blocks.
dropout_rate : float
dropout rate.
activation : str
name of activation function (e.g., 'relu', 'gelu', or 'swiglu').
norm_first : bool
whether to apply layer normalization before or after attention.
max_block_size : int
maximum context size.
sa_kwargs : dict, optional
keyword arguments for self-attention, by default None
ra_kwargs : dict, optional
keyword arguments for relational attention, by default None
ra_type : 'relational_attention', 'rca', or 'disrca', optional
type of relational attention module (e.g., whether to use RCA for an ablation experiment), by default 'relational_attention'
share_attn_params : bool, optional
whether to share attention parameters between self-attention and relational attention.
If True, w{q,k} in sensory attention and w{q,k}_attn in relational attention are shared.
number of heads in each must be the same. By default False
symbol_retrieval : 'symbolic_attention', 'position_relative', 'positional_symbols', optional
type of symbol retrieval module to use. this is shared across layers, by default 'symbolic_attention'
pos_enc_type : 'pos_emb' or 'RoPE', optional
type of positional encoding to use, by default 'pos_emb'
bias : bool, optional
whether to use bias in attention, by default True
"""
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.n_heads_sa = n_heads_sa
self.n_heads_ra = n_heads_ra
self.dff = dff
self.dropout_rate = dropout_rate
self.activation = activation
self.norm_first = norm_first
self.norm_type = norm_type
self.block_size = max_block_size
self.ra_type = ra_type
self.share_attn_params = share_attn_params
self.symbol_retriever = symbol_retrieval
self.pos_enc_type = pos_enc_type
self.bias = bias
self.symbol_retriever_config = symbol_retriever_config if symbol_retriever_config is not None else {}
shared_symbol_retriever = self.symbol_retriever_config.setdefault('shared_symbol_retriever', True)
weight_tie_symbol_library = self.symbol_retriever_config.setdefault('weight_tie_symbol_library', False)
self.n_heads = n_heads_sa + n_heads_ra
if symbol_retrieval == 'symbolic_attention':
if shared_symbol_retriever:
symbol_retrievers = [SymbolicAttention(**symbol_retrieval_kwargs)] * n_layers
else:
symbol_retrievers = [SymbolicAttention(**symbol_retrieval_kwargs) for _ in range(n_layers)]
# elif symbol_retrieval == 'rel_sym_attn':
# symbol_retriever = RelationalSymbolicAttention(**symbol_retrieval_kwargs)
elif symbol_retrieval == 'positional_symbols':
if shared_symbol_retriever:
symbol_retrievers = [PositionalSymbolRetriever(**symbol_retrieval_kwargs)] * n_layers