-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement einstein summation notation kernel for 2 input arguments (#200
- Loading branch information
1 parent
72665fc
commit 38e71df
Showing
9 changed files
with
625 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from .base import * | ||
|
||
import torch | ||
|
||
__all__ = [ | ||
"einsum_2args_q4", | ||
] | ||
|
||
|
||
def einsum_util(einsum_str): | ||
es_in, es_out = einsum_str.split("->") | ||
es_in0, es_in1 = es_in.split(",") | ||
es_set = set(es_out) | ||
es_set = es_set.union(es_in0) | ||
es_set = es_set.union(es_in1) | ||
size = len(es_set) | ||
imap = dict() | ||
lmap = dict() | ||
for i in range(len(es_out)): | ||
imap[i] = es_out[i] | ||
lmap[es_out[i]] = i | ||
count = len(es_out) | ||
for c in es_set: | ||
if c not in lmap: | ||
imap[count] = c | ||
lmap[c] = count | ||
count += 1 | ||
|
||
assert count == len(es_set) | ||
|
||
in0_idx = [lmap[i] for i in es_in0] | ||
in1_idx = [lmap[i] for i in es_in1] | ||
out_idx = [lmap[i] for i in es_out] | ||
|
||
input_idx_str = ", ".join(["d" + str(i) for i in range(size)]) | ||
in0_idx_str = ", ".join(["d" + str(i) for i in in0_idx]) | ||
in1_idx_str = ", ".join(["d" + str(i) for i in in1_idx]) | ||
out_idx_str = ", ".join(["d" + str(i) for i in out_idx]) | ||
|
||
iterators = ", ".join( | ||
['"parallel"' if i in out_idx else '"reduction"' for i in range(size)] | ||
) | ||
|
||
affine_map_in0 = f"affine_map<({input_idx_str}) -> ({in0_idx_str})>" | ||
affine_map_in1 = f"affine_map<({input_idx_str}) -> ({in1_idx_str})>" | ||
affine_map_out = f"affine_map<({input_idx_str}) -> ({out_idx_str})>" | ||
|
||
indexing_maps = f"""{affine_map_in0}, | ||
{affine_map_in1}, | ||
{affine_map_out} | ||
""" | ||
|
||
out_dyn_dim_size_str = "" | ||
for c in es_out: | ||
if c in es_in0: | ||
out_dyn_dim_size_str += "%a" + str(es_in0.find(c)) + "," | ||
elif c in es_in1: | ||
if es_in1.find(c) == len(es_in1) - 1: | ||
out_dyn_dim_size_str += "%b_unblocked_dim," | ||
else: | ||
out_dyn_dim_size_str += "%b" + str(es_in1.find(c)) + "," | ||
else: | ||
raise Exception("Invalid einsum string") | ||
out_dyn_dim_size_str = out_dyn_dim_size_str[:-1] | ||
return ( | ||
(in0_idx, in1_idx, out_idx), | ||
iterators, | ||
indexing_maps, | ||
out_dyn_dim_size_str, | ||
) | ||
|
||
|
||
@CustomOp.register(library=LIBRARY) | ||
class einsum_2args_q4(CustomOp): | ||
"""Einsum that takes two tensor inputs and returns one tensor. | ||
The first input is expected to be a normal tensor. | ||
The second input corresponds to the BlockScaledLayout and operates on planar `d` | ||
and `qs` tensors as specified there: | ||
* `d`: `[..., K // BLOCK_SIZE, 1]` | ||
* `qs`: `[..., K // BLOCK_SIZE, BLOCK_SIZE // 2]` (of uint8) | ||
* `m`: `[..., K // BLOCK_SIZE, 1]` | ||
""" | ||
|
||
signature = ( | ||
"einsum_2args_q4(Tensor a, Tensor d, Tensor qs, Tensor m, str es) -> (Tensor)" | ||
) | ||
|
||
def select(self, ksel: KernelSelection): | ||
a_desc = ksel.arg_tensor(0) # Shape [b, ] m, k | ||
d_desc = ksel.arg_tensor(1) # Shape [N, K // BLOCK_SIZE, 1] | ||
qs_desc = ksel.arg_tensor(2) # Shape [N, K // BLOCK_SIZE, BLOCK_SIZE // 2] | ||
m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1] | ||
einsum_str = ksel.attr_str(4).v | ||
|
||
# a arg | ||
a_dims = a_desc.t.shape | ||
torch._check( | ||
a_desc.t.dtype.is_floating_point, | ||
lambda: f"einsum_2args_q4 arg 'a': Expected floating point (got {a_desc.t.dtype})", | ||
) | ||
|
||
# qs arg | ||
*qs_dims, qs_group0, qs_bs_div_2 = qs_desc.t.shape | ||
block_size = qs_bs_div_2 * 2 | ||
|
||
# d arg | ||
*d_dims, d_group0, d_one = d_desc.t.shape | ||
torch._check( | ||
d_group0 == qs_group0 and d_one == 1 and len(d_dims) == len(qs_dims), | ||
lambda: f"einsum_2args_q4 arg 'd': Incorrect shape (got {d_desc.t.shape})", | ||
) | ||
|
||
# m arg | ||
*m_dims, m_group0, m_one = m_desc.t.shape | ||
torch._check( | ||
m_desc.t.dtype == d_desc.t.dtype and len(m_dims) == len(qs_dims), | ||
lambda: f"einsum_2args_q4 arg 'm': Incorrect dtype (got {m_desc.t.dtype})", | ||
) | ||
# einsum_str | ||
torch._check( | ||
einsum_str.count(",") == 1 and einsum_str.count("->") == 1, | ||
lambda: f"einsum_2args_q4 arg 'einsum_str': Expected format '{{}},{{}}->{{}}' (got '{einsum_str}')", | ||
) | ||
|
||
es_in, es_out = einsum_str.split("->") | ||
es_in0, es_in1 = es_in.split(",") | ||
es_set = set(es_out) | ||
|
||
shp = qs_desc.t.shape | ||
b_dims = list(shp[:-2]) + [shp[-2] * block_size] | ||
torch._check( | ||
len(es_in0) == len(a_desc.t.shape) | ||
and len(es_in1) | ||
== len(qs_desc.t.shape) | ||
- 1, # The quantized shape is larger until the blocks are collapsed | ||
lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dimensions (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})", | ||
) | ||
torch._check( | ||
len(es_in0) == len(set(es_in0)) | ||
and len(es_in1) == len(set(es_in1)) | ||
and len(es_in0) != 0 | ||
and len(es_in1) != 0, | ||
lambda: f"einsum_2args_q4 arg 'einsum_str': Unsupported einsum str (got '{einsum_str}')", | ||
) | ||
|
||
# Check corresponding dimensions match | ||
for i in range(len(es_in0)): | ||
a_dim = a_dims[i] | ||
c = es_in0[i] | ||
pos = es_in1.find(c) | ||
if pos >= 0: | ||
b_dim = b_dims[pos] | ||
torch._check( | ||
a_dim == b_dim, | ||
lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dim for idx {c} (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})", | ||
) | ||
|
||
# Determine the output shape by referencing corresponding input shapes | ||
out_dims = [] | ||
for c in es_out: | ||
pos0 = es_in0.find(c) | ||
pos1 = es_in1.find(c) | ||
a_dim = a_dims[pos0] | ||
b_dim = b_dims[pos1] | ||
if pos0 >= 0: | ||
out_dims.append(a_dim) | ||
elif pos1 >= 0: | ||
out_dims.append(b_dim) | ||
else: | ||
torch._check( | ||
False, | ||
lambda: f"einsum_2args_q4 arg 'einsum_str': output indices must be in input indices (got '{einsum_str}')", | ||
) | ||
|
||
# Specialize on BS | ||
qs_desc.specialize_dims(-1) | ||
d_desc.specialize_dims(-1) | ||
m_desc.specialize_dims(-1) | ||
|
||
# Shape batch..., m, n | ||
c_desc = ksel.return_new_tensor(out_dims, dtype=a_desc.t.dtype) | ||
|
||
def generate(self, ksel: KernelSelection, kb: KernelBuilder): | ||
a = kb.arg_value(0) | ||
a_tensor_type = RankedTensorType(a.type) | ||
d = kb.arg_value(1) | ||
d_tensor_type = RankedTensorType(d.type) | ||
qs = kb.arg_value(2) | ||
qs_tensor_type = RankedTensorType(qs.type) | ||
einsum_str = ksel.arg_descs[4].v | ||
# einsum_str = "mek,menk->men" | ||
|
||
es_in, es_out = einsum_str.split("->") | ||
es_in0, es_in1 = es_in.split(",") | ||
|
||
es_name = "_".join([es_in0, es_in1, es_out]) | ||
|
||
( | ||
(es_0, es_1, es_2), | ||
einsum_iterators, | ||
einsum_indexing_maps, | ||
oddss, | ||
) = einsum_util(einsum_str) | ||
|
||
rank1 = len(es_1) | ||
dequant_iterators = ", ".join( | ||
['"parallel"' for i in range(rank1 + 1)] | ||
) # rank + 1 because of the group dimensions | ||
input_idx_str = ", ".join(["d" + str(i) for i in range(rank1 + 1)]) | ||
broadcast_idx_str = ", ".join( | ||
["d" + str(i) if i != rank1 else "0" for i in range(rank1 + 1)] | ||
) | ||
affine_map_parallel = f"affine_map<({input_idx_str}) -> ({input_idx_str})>" | ||
affine_map_broadcast = f"affine_map<({input_idx_str}) -> ({broadcast_idx_str})>" | ||
dequant_indexing_maps = f"""{affine_map_broadcast}, | ||
{affine_map_broadcast}, | ||
{affine_map_parallel}, | ||
{affine_map_parallel}""" | ||
|
||
size_str = "x".join("?" for i in range(rank1 - 2)) | ||
|
||
rank = a_tensor_type.rank | ||
*n_dims, group0, bs_i8 = qs_tensor_type.shape | ||
bs = bs_i8 * 2 # 2 nibbles per byte. | ||
group = group0 * bs | ||
a_type_str = str(a_tensor_type.element_type) | ||
scale_type_str = str(d_tensor_type.element_type) | ||
|
||
template_file = "einsum_2args_q4.mlir" | ||
target_function_name = f"sharktank_einsum_2args_q4_{es_name}_{bs}_{a_type_str}" | ||
|
||
target_function = inline_template_function( | ||
kb, | ||
template_file, | ||
target_function_name, | ||
bs=bs, | ||
bs_i8=bs_i8, | ||
a_type=a_type_str, | ||
scale_type=scale_type_str, | ||
dequant_indexing_maps=dequant_indexing_maps, | ||
dequant_iterator_types=dequant_iterators, | ||
einsum_indexing_maps=einsum_indexing_maps, | ||
einsum_iterator_types=einsum_iterators, | ||
es_name=es_name, | ||
a_size=len(es_in0), | ||
b_size=len(es_in1), | ||
c_size=len(es_out), | ||
out_dyn_dim_size_str=oddss, | ||
) | ||
kb.yield_results(*call_function(target_function, *kb.arg_bindings)) |
104 changes: 104 additions & 0 deletions
104
sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
// Copyright 2024 Advanced Micro Devices, Inc. | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
{% set accum_type = "f32" %} | ||
|
||
!lowp_type = i4 | ||
!a_type = {{a_type}} | ||
!scale_type = {{scale_type}} | ||
!accum_type = {{accum_type}} | ||
!a_tensor_type = tensor<{% for i in range(a_size) %}?x{% endfor %}!a_type> | ||
!qs_raw_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs_i8}}xi8> | ||
!qs_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!lowp_type> | ||
!d_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type> | ||
!m_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type> | ||
!accum_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!accum_type> | ||
!c_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!a_type> | ||
!b_grouped_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!a_type> | ||
!b_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}!a_type> | ||
|
||
module { | ||
|
||
util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( | ||
%a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type) | ||
-> !c_tensor_type { | ||
%debug = tensor.empty() : tensor<1xf32> | ||
%zero = arith.constant 0.0: !accum_type | ||
{% for i in range(a_size) %} | ||
%k{{i}} = arith.constant {{i}} : index | ||
{% endfor %} | ||
{% for i in range(a_size, b_size) %} | ||
%k{{i}} = arith.constant {{i}} : index | ||
{% endfor %} | ||
{% for i in range(a_size) %} | ||
%a{{i}} = tensor.dim %a, %k{{i}}: !a_tensor_type | ||
{% endfor %} | ||
{% for i in range(b_size) %} | ||
%b{{i}} = tensor.dim %qs_raw, %k{{i}}: !qs_raw_tensor_type | ||
{% endfor %} | ||
%bs = arith.constant {{bs}} : index | ||
%b_unblocked_dim = arith.muli %b{{b_size-1}}, %bs : index | ||
|
||
//%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type | ||
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} | ||
|
||
// Dequantize. | ||
%b_grouped = tensor.empty({% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}) : !b_grouped_tensor_type | ||
%b_grouped_dequant = linalg.generic { | ||
indexing_maps = [ | ||
{{dequant_indexing_maps}}], | ||
iterator_types = [{{dequant_iterator_types}}] } | ||
ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type) | ||
outs(%b_grouped : !b_grouped_tensor_type) { | ||
^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type): | ||
%q_element_ext = arith.extui %q_element : !lowp_type to i32 | ||
%q_element_fp = arith.uitofp %q_element_ext : i32 to !a_type | ||
{% if scale_type == a_type %} | ||
%q_element_scaled = arith.mulf %q_element_fp, %d_element : !a_type | ||
%q_element_offset = arith.addf %q_element_scaled, %m_element : !a_type | ||
{% else %} | ||
%d_element_ext = arith.extf %d_element : !scale_type to !a_type | ||
%m_element_ext = arith.extf %m_element : !scale_type to !a_type | ||
%q_element_scaled = arith.mulf %q_element_fp, %d_element_ext : !a_type | ||
%q_element_offset = arith.addf %q_element_scaled, %m_element_ext : !a_type | ||
{% endif %} | ||
linalg.yield %q_element_offset : !a_type | ||
} -> !b_grouped_tensor_type | ||
|
||
// Collapse %b to the same unblocked structure. | ||
%b_unblocked = tensor.collapse_shape %b_grouped_dequant [{% for i in range(b_size-1) %}[{{i}}], {% endfor %}[{{b_size-1}}, {{b_size}}]] : !b_grouped_tensor_type into !b_tensor_type | ||
|
||
// Einsum | ||
%result_empty = tensor.empty({{out_dyn_dim_size_str}}) : !accum_tensor_type | ||
%result_fill = linalg.fill ins(%zero: !accum_type) outs(%result_empty: !accum_tensor_type) -> !accum_tensor_type | ||
%result = linalg.generic { | ||
indexing_maps = [ | ||
{{einsum_indexing_maps}}], | ||
iterator_types = [{{einsum_iterator_types}}] } | ||
ins(%a, %b_unblocked : !a_tensor_type, !b_tensor_type) | ||
outs(%result_fill : !accum_tensor_type) { | ||
^bb0(%a_element: !a_type, %b_element: !a_type, %out: !accum_type): | ||
%bmm_mul = arith.mulf %a_element, %b_element : !a_type | ||
{% if accum_type == a_type %} | ||
%bmm_accum = arith.addf %bmm_mul, %out : !a_type | ||
{% else %} | ||
%bmm_mul_ext = arith.extf %bmm_mul : !a_type to !accum_type | ||
%bmm_accum = arith.addf %bmm_mul_ext, %out : !accum_type | ||
{% endif %} | ||
linalg.yield %bmm_accum : !accum_type | ||
} -> !accum_tensor_type | ||
|
||
// Cast. | ||
%result_cast_empty = tensor.empty({{out_dyn_dim_size_str}}) : !c_tensor_type | ||
%result_cast = linalg.copy | ||
ins(%result : !accum_tensor_type) | ||
outs(%result_cast_empty : !c_tensor_type) -> !c_tensor_type | ||
|
||
//iree_input.tensor.trace "foobar" = [%a : !a_tensor_type, %d : !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type, %b_grouped_dequant: !b_grouped_tensor_type] | ||
util.return %result_cast : !c_tensor_type | ||
} | ||
|
||
} |
Oops, something went wrong.