You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Could you explain the detailed implementation of clustered_sparse_dot_product in _topk_attention function of ImprovedClusteredAttention class.
I feel a little confused about how to compute QK with the below code snippet
class ImprovedClusteredAttention(Module):
......
def _topk_attention(self, Q, K, V,
clusters, counts,
topk, topk_values,
A_bottomk, softmax_temp,
query_lengths):
N, H, L, E = Q.shape
_, _, S, _ = K.shape
_, _, C, k = topk.shape
# We need to pass the output tensor to initialize to 0
QK = clustered_sparse_dot_product(
Q, K, topk,
clusters, counts,
query_lengths._lengths.int()
)
......
The text was updated successfully, but these errors were encountered:
Could you explain the detailed implementation of
clustered_sparse_dot_product
in_topk_attention
function ofImprovedClusteredAttention
class.I feel a little confused about how to compute
QK
with the below code snippetThe text was updated successfully, but these errors were encountered: