From 5947ca231ee3cb5ef0fe80d1fe21a126d2865dee Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 20 Aug 2024 18:21:07 -0700 Subject: [PATCH] [WIP] TMA Version of HSTU (Autotuned) Based on #57, this version uses the autotuned to toggle use of TMA. --- ops/triton/triton_ragged_hstu_attention.py | 266 ++++++++++++++++----- 1 file changed, 208 insertions(+), 58 deletions(-) diff --git a/ops/triton/triton_ragged_hstu_attention.py b/ops/triton/triton_ragged_hstu_attention.py index 9d5a63f..bb6a5a6 100644 --- a/ops/triton/triton_ragged_hstu_attention.py +++ b/ops/triton/triton_ragged_hstu_attention.py @@ -17,6 +17,7 @@ # pyre-unsafe from typing import List, Optional +import numpy as np import torch @@ -57,153 +58,160 @@ def _get_fw_configs() -> List[triton.Config]: # noqa: C901 ) ) else: - configs = [ + base_configs = [ triton.Config( - {"BLOCK_M": 16, "BLOCK_N": 32}, + {"BLOCK_M": 16, "BLOCK_N": 32, "enable_tma": False}, num_stages=2, num_warps=2, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32}, + {"BLOCK_M": 32, "BLOCK_N": 32, "enable_tma": False}, num_stages=2, num_warps=2, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32}, + {"BLOCK_M": 32, "BLOCK_N": 32, "enable_tma": False}, num_stages=4, num_warps=2, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32}, + {"BLOCK_M": 32, "BLOCK_N": 32, "enable_tma": False}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32}, + {"BLOCK_M": 32, "BLOCK_N": 32, "enable_tma": False}, num_stages=4, num_warps=4, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 64}, + {"BLOCK_M": 32, "BLOCK_N": 64, "enable_tma": False}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 64}, + {"BLOCK_M": 32, "BLOCK_N": 64, "enable_tma": False}, num_stages=4, num_warps=4, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 64}, + {"BLOCK_M": 32, "BLOCK_N": 64, "enable_tma": False}, num_stages=4, num_warps=8, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 128}, + {"BLOCK_M": 32, "BLOCK_N": 128, "enable_tma": False}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 128}, + {"BLOCK_M": 32, "BLOCK_N": 128, "enable_tma": False}, num_stages=2, num_warps=8, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32}, + {"BLOCK_M": 64, "BLOCK_N": 32, "enable_tma": False}, num_stages=4, num_warps=2, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32}, + {"BLOCK_M": 64, "BLOCK_N": 32, "enable_tma": False}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32}, + {"BLOCK_M": 64, "BLOCK_N": 32, "enable_tma": False}, num_stages=4, num_warps=4, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32}, + {"BLOCK_M": 64, "BLOCK_N": 32, "enable_tma": False}, num_stages=2, num_warps=8, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, + {"BLOCK_M": 64, "BLOCK_N": 64, "enable_tma": False}, num_stages=2, num_warps=2, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, + {"BLOCK_M": 64, "BLOCK_N": 64, "enable_tma": False}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, + {"BLOCK_M": 64, "BLOCK_N": 64, "enable_tma": False}, num_stages=4, num_warps=4, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, + {"BLOCK_M": 64, "BLOCK_N": 64, "enable_tma": False}, num_stages=4, num_warps=8, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32}, + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tma": False}, num_stages=2, num_warps=2, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32}, + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tma": False}, num_stages=4, num_warps=2, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32}, + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tma": False}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32}, + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tma": False}, num_stages=4, num_warps=4, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32}, + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tma": False}, num_stages=2, num_warps=8, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32}, + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tma": False}, num_stages=4, num_warps=8, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64}, + {"BLOCK_M": 128, "BLOCK_N": 64, "enable_tma": False}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64}, + {"BLOCK_M": 128, "BLOCK_N": 64, "enable_tma": False}, num_stages=2, num_warps=8, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64}, + {"BLOCK_M": 128, "BLOCK_N": 64, "enable_tma": False}, num_stages=4, num_warps=8, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128}, + {"BLOCK_M": 128, "BLOCK_N": 128, "enable_tma": False}, num_stages=4, num_warps=4, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128}, + {"BLOCK_M": 128, "BLOCK_N": 128, "enable_tma": False}, num_stages=2, num_warps=8, ), ] + + for config in base_configs: + config["enable_tma"] = False + configs.append(config) + config["enable_tma"] = True + configs.append(config) + return configs @@ -218,6 +226,10 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901 q, K_block_ptr, V_block_ptr, + K_desc_ptr, + V_desc_ptr, + offset_h, + seq_start, n_targets, ts_1_ptrs, ts_0, @@ -245,13 +257,27 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901 HAS_MAX_ATTN_LEN: tl.constexpr, IS_DELTA_Q: tl.constexpr, ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + enable_tma: tl.constexpr, ): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero") - qk = tl.dot(q, k, allow_tf32=ALLOW_TF32) * alpha + k = None + qk = None + if enable_tma: + k = tl._experimental_descriptor_load( + K_desc_ptr, + [(seq_start + start_n).to(tl.int32), offset_h.to(tl.int32)], + [BLOCK_N, BLOCK_D_Q], + tl.bfloat16, + ) + qk = tl.dot(q, tl.trans(k), allow_tf32=ALLOW_TF32) * alpha + else: + k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero") + qk = tl.dot(q, k, allow_tf32=ALLOW_TF32) * alpha invalid_mask = offs_m[:, None] == offs_n[None, :] if HAS_MULTIPLE_TARGETS: if INVALID_MASK_TYPE == "lower_triangular": @@ -335,7 +361,17 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901 silu = tl.where(invalid_mask, silu, 0) if HAS_ATTN_SCALE: silu = silu * attn_scale[:, None] - v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero") + + v = None + if enable_tma: + v = tl._experimental_descriptor_load( + V_desc_ptr, + [(seq_start + start_n).to(tl.int32), offset_h.to(tl.int32)], + [BLOCK_N, BLOCK_D_V], + tl.bfloat16, + ) + else: + v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero") silu = silu.to(v.dtype) return tl.dot(silu, v, allow_tf32=ALLOW_TF32) @@ -359,6 +395,8 @@ def _ragged_hstu_attn_fwd( # noqa C901 Q, K, V, + desc_k, + desc_v, seq_offsets, TS, TW, @@ -409,6 +447,7 @@ def _ragged_hstu_attn_fwd( # noqa C901 BLOCK_N: tl.constexpr, max_attn_len: tl.constexpr, HAS_MAX_ATTN_LEN: tl.constexpr, + enable_tma: tl.constexpr, ): # M_CTX == N_CTX off_hz = tl.program_id(1) @@ -452,22 +491,25 @@ def _ragged_hstu_attn_fwd( # noqa C901 block_shape=(BLOCK_M, BLOCK_D_Q), order=(1, 0), ) - K_block_ptr = tl.make_block_ptr( - base=K + off_h * stride_kh + seq_start * stride_kn, - shape=(BLOCK_D_Q, seq_len), - strides=(1, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_D_Q, BLOCK_N), - order=(0, 1), - ) - V_block_ptr = tl.make_block_ptr( - base=V + off_h * stride_vh + seq_start * stride_vn, - shape=(seq_len, BLOCK_D_V), - strides=(stride_vn, 1), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_D_V), - order=(1, 0), - ) + K_block_ptr = None + V_block_ptr = None + if not enable_tma: + K_block_ptr = tl.make_block_ptr( + base=K + off_h * stride_kh + seq_start * stride_kn, + shape=(BLOCK_D_Q, seq_len), + strides=(1, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_D_Q, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V + off_h * stride_vh + seq_start * stride_vn, + shape=(seq_len, BLOCK_D_V), + strides=(stride_vn, 1), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_D_V), + order=(1, 0), + ) mask_m = offs_m < seq_len if ATTN_BIAS_TYPE == "fused" and USE_TIME_BIAS: ts_0_ptrs = TS + off_z * stride_ts + offs_m @@ -486,6 +528,7 @@ def _ragged_hstu_attn_fwd( # noqa C901 scale_ptrs = Scale + off_z * stride_sz attn_scale = tl.load(scale_ptrs + offs_m * stride_sm, mask=offs_m < seq_len) + # convert q to tma q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") acc = tl.zeros([BLOCK_M, BLOCK_D_V], dtype=tl.float32) if INVALID_MASK_TYPE == "lower_triangular": @@ -511,8 +554,13 @@ def _ragged_hstu_attn_fwd( # noqa C901 elif INVALID_MASK_TYPE == "upper_triangular": low = start_m high = seq_len + if enable_tma: + tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", + [K], dtype=tl.int32, is_pure=False, pack=1) + tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", + [V], dtype=tl.int32, is_pure=False, pack=1) # pyre-ignore[61] - if low > 0: + if low > 0 and not enable_tma: # pyre-ignore[61] K_block_ptr = tl.advance(K_block_ptr, (0, low)) # pyre-ignore[61] @@ -531,6 +579,11 @@ def _ragged_hstu_attn_fwd( # noqa C901 q=q, K_block_ptr=K_block_ptr, V_block_ptr=V_block_ptr, + K_desc_ptr=desc_k, + V_desc_ptr=desc_v, + offset_h=off_h * stride_vh, + seq_start=seq_start, + # pyre-ignore[61] n_targets=n_targets if HAS_MULTIPLE_TARGETS else None, ts_1_ptrs=( # pyre-ignore[61] @@ -568,18 +621,23 @@ def _ragged_hstu_attn_fwd( # noqa C901 ALLOW_TF32=ALLOW_TF32, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + enable_tma=enable_tma, ) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = None + V_block_ptr = None + if not enable_tma: + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) if HAS_MULTIPLE_TARGETS and INVALID_MASK_TYPE == "lower_triangular": # pyre-ignore[61] if uih_end < start_m: low_delta = start_m high_delta = start_m + BLOCK_M - offset = (low_delta - uih_end).to(tl.int32) # pyre-ignore [61] - K_block_ptr = tl.advance(K_block_ptr, (0, offset)) - V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) + if not enable_tma: + offset = (low_delta - uih_end).to(tl.int32) # pyre-ignore [61] + K_block_ptr = tl.advance(K_block_ptr, (0, offset)) + V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) for start_delta in range(low_delta, high_delta, BLOCK_N): cur_offs_n = offs_n + start_delta mask_n = cur_offs_n < seq_len @@ -593,6 +651,10 @@ def _ragged_hstu_attn_fwd( # noqa C901 q=q, K_block_ptr=K_block_ptr, V_block_ptr=V_block_ptr, + K_desc_ptr=desc_k, + V_desc_ptr=desc_v, + offset_h=off_h * stride_vh, + seq_start=seq_start, n_targets=n_targets if HAS_MULTIPLE_TARGETS else None, ts_1_ptrs=( # pyre-ignore[61] @@ -634,9 +696,13 @@ def _ragged_hstu_attn_fwd( # noqa C901 ALLOW_TF32=ALLOW_TF32, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + enable_tma=enable_tma, ) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = None + V_block_ptr = None + if not enable_tma: + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) if IS_DELTA_Q: start_m_delta = tl.program_id(0) * BLOCK_M @@ -648,6 +714,7 @@ def _ragged_hstu_attn_fwd( # noqa C901 + offs_v_d[None, :] ) out_ptrs = Out + off_o + # todo: convert out to tma tl.store(out_ptrs, acc, mask=(offs_m_delta < DeltaSize)[:, None]) else: # rematerialize offsets to save registers @@ -691,6 +758,32 @@ def triton_ragged_attention( has_attn_scale = attn_scale is not None has_max_attn_len = max_attn_len is not None + # TMA SETUP: + TMA_SIZE = 128 + BLOCK_N, BLOCK_D_V, BLOCK_D_Q = 64, DimV, DimQ + desc_k = np.empty(TMA_SIZE, dtype=np.int8) + desc_v = np.empty(TMA_SIZE, dtype=np.int8) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor( + k.data_ptr(), + L, + H * DimQ, + BLOCK_N, + BLOCK_D_Q, + k.element_size(), + desc_k, + ) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor( + v.data_ptr(), + L, + H * DimV, + BLOCK_N, + BLOCK_D_V, + v.element_size(), + desc_v, + ) + desc_k = torch.tensor(desc_k, device=v.device) + desc_v = torch.tensor(desc_v, device=v.device) + grid = lambda meta: ( # noqa E731 triton.cdiv(N, meta["BLOCK_M"]), Z * H, @@ -709,6 +802,8 @@ def triton_ragged_attention( Q=q, K=k, V=v, + desc_k=desc_k, + desc_v=desc_v, seq_offsets=seq_offsets, TS=None, TW=None, @@ -789,13 +884,65 @@ def triton_ragged_attention_relative_bias( has_multiple_targets = num_targets is not None has_max_pos_id = max_pos_ind is not None has_max_attn_len = max_attn_len is not None - _, H, DimQ = q.shape + L, H, DimQ = q.shape _, _, DimV = v.shape out = torch.empty_like(v) grid = lambda meta: ( # noqa E731 triton.cdiv(N, meta["BLOCK_M"]), Z * H, ) + + TMA_SIZE = 128 + BLOCK_D_V, BLOCK_D_Q = DimV, DimQ + desc_k = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8) + desc_v = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8) + ''' + desc_k = np.empty(TMA_SIZE, dtype=np.int8) + desc_v = np.empty(TMA_SIZE, dtype=np.int8) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor( + k.data_ptr(), + L, + H * DimQ, + BLOCK_N, + BLOCK_D_Q, + k.element_size(), + desc_k, + ) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor( + v.data_ptr(), + L, + H * DimV, + BLOCK_N, + BLOCK_D_V, + v.element_size(), + desc_v, + ) + desc_k = torch.tensor(desc_k, device=v.device) + desc_v = torch.tensor(desc_v, device=v.device) + ''' + + # TODO: ??? + def grid2(META): + nonlocal desc_k + nonlocal desc_v + #a_buf = torch.empty(TMA_SIZE, dtype=torch.int8) + k_buf = torch.empty_like(desc_k, device="cpu") + v_buf = torch.empty_like(desc_v, device="cpu") + #desc_a = desc_a.numpy() # if start with cuda, will need cpu() here + #desc_b = desc_b.numpy() + #desc_c = desc_c.numpy() + #print("enter grid2", META['BLOCK_M'], META['BLOCK_K']) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(k.data_ptr(), L, H * DimQ, META['BLOCK_N'], BLOCK_D_Q, k.element_size(), + k_buf.numpy()) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(v.data_ptr(), L, H * DimV, META['BLOCK_N'], BLOCK_D_V, v.element_size(), + v_buf.numpy()) + #desc_a = torch.tensor(desc_a, device="cuda") + #desc_b = torch.tensor(desc_b, device="cuda") + #desc_c = torch.tensor(desc_c, device="cuda") + desc_k.copy_(k_buf) + desc_v.copy_(v_buf) + return (triton.cdiv(N, META["BLOCK_M"]), Z * H, 1) + stride_sz = 0 stride_sm = 0 if attn_scale is not None: @@ -807,10 +954,13 @@ def triton_ragged_attention_relative_bias( use_time_bias = relative_bias_type == "TIME" or relative_bias_type == "ALL" use_pos_bias = relative_bias_type == "POSITION" or relative_bias_type == "ALL" + # TODO: grid ??? _ragged_hstu_attn_fwd[grid]( Q=q, K=k, V=v, + desc_k=desc_k, + desc_v=desc_v, seq_offsets=seq_offsets, TS=timestamps, TW=ts_weights,