Skip to content

Commit

Permalink
Implement einstein summation notation kernel for 2 input arguments (#200
Browse files Browse the repository at this point in the history
)
  • Loading branch information
KyleHerndon authored Oct 1, 2024
1 parent 72665fc commit 38e71df
Show file tree
Hide file tree
Showing 9 changed files with 625 additions and 12 deletions.
1 change: 1 addition & 0 deletions sharktank/sharktank/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .attention import *
from .einsum_2args_q4 import *
from .mmtfp import *
from .mmt_block_scaled_offset_q4 import *
from .mmt_block_scaled_q8 import *
Expand Down
259 changes: 259 additions & 0 deletions sharktank/sharktank/kernels/einsum_2args_q4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .base import *

import torch

__all__ = [
"einsum_2args_q4",
]


def einsum_util(einsum_str):
es_in, es_out = einsum_str.split("->")
es_in0, es_in1 = es_in.split(",")
es_set = set(es_out)
es_set = es_set.union(es_in0)
es_set = es_set.union(es_in1)
size = len(es_set)
imap = dict()
lmap = dict()
for i in range(len(es_out)):
imap[i] = es_out[i]
lmap[es_out[i]] = i
count = len(es_out)
for c in es_set:
if c not in lmap:
imap[count] = c
lmap[c] = count
count += 1

assert count == len(es_set)

in0_idx = [lmap[i] for i in es_in0]
in1_idx = [lmap[i] for i in es_in1]
out_idx = [lmap[i] for i in es_out]

input_idx_str = ", ".join(["d" + str(i) for i in range(size)])
in0_idx_str = ", ".join(["d" + str(i) for i in in0_idx])
in1_idx_str = ", ".join(["d" + str(i) for i in in1_idx])
out_idx_str = ", ".join(["d" + str(i) for i in out_idx])

iterators = ", ".join(
['"parallel"' if i in out_idx else '"reduction"' for i in range(size)]
)

affine_map_in0 = f"affine_map<({input_idx_str}) -> ({in0_idx_str})>"
affine_map_in1 = f"affine_map<({input_idx_str}) -> ({in1_idx_str})>"
affine_map_out = f"affine_map<({input_idx_str}) -> ({out_idx_str})>"

indexing_maps = f"""{affine_map_in0},
{affine_map_in1},
{affine_map_out}
"""

out_dyn_dim_size_str = ""
for c in es_out:
if c in es_in0:
out_dyn_dim_size_str += "%a" + str(es_in0.find(c)) + ","
elif c in es_in1:
if es_in1.find(c) == len(es_in1) - 1:
out_dyn_dim_size_str += "%b_unblocked_dim,"
else:
out_dyn_dim_size_str += "%b" + str(es_in1.find(c)) + ","
else:
raise Exception("Invalid einsum string")
out_dyn_dim_size_str = out_dyn_dim_size_str[:-1]
return (
(in0_idx, in1_idx, out_idx),
iterators,
indexing_maps,
out_dyn_dim_size_str,
)


@CustomOp.register(library=LIBRARY)
class einsum_2args_q4(CustomOp):
"""Einsum that takes two tensor inputs and returns one tensor.
The first input is expected to be a normal tensor.
The second input corresponds to the BlockScaledLayout and operates on planar `d`
and `qs` tensors as specified there:
* `d`: `[..., K // BLOCK_SIZE, 1]`
* `qs`: `[..., K // BLOCK_SIZE, BLOCK_SIZE // 2]` (of uint8)
* `m`: `[..., K // BLOCK_SIZE, 1]`
"""

signature = (
"einsum_2args_q4(Tensor a, Tensor d, Tensor qs, Tensor m, str es) -> (Tensor)"
)

def select(self, ksel: KernelSelection):
a_desc = ksel.arg_tensor(0) # Shape [b, ] m, k
d_desc = ksel.arg_tensor(1) # Shape [N, K // BLOCK_SIZE, 1]
qs_desc = ksel.arg_tensor(2) # Shape [N, K // BLOCK_SIZE, BLOCK_SIZE // 2]
m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1]
einsum_str = ksel.attr_str(4).v

# a arg
a_dims = a_desc.t.shape
torch._check(
a_desc.t.dtype.is_floating_point,
lambda: f"einsum_2args_q4 arg 'a': Expected floating point (got {a_desc.t.dtype})",
)

# qs arg
*qs_dims, qs_group0, qs_bs_div_2 = qs_desc.t.shape
block_size = qs_bs_div_2 * 2

# d arg
*d_dims, d_group0, d_one = d_desc.t.shape
torch._check(
d_group0 == qs_group0 and d_one == 1 and len(d_dims) == len(qs_dims),
lambda: f"einsum_2args_q4 arg 'd': Incorrect shape (got {d_desc.t.shape})",
)

# m arg
*m_dims, m_group0, m_one = m_desc.t.shape
torch._check(
m_desc.t.dtype == d_desc.t.dtype and len(m_dims) == len(qs_dims),
lambda: f"einsum_2args_q4 arg 'm': Incorrect dtype (got {m_desc.t.dtype})",
)
# einsum_str
torch._check(
einsum_str.count(",") == 1 and einsum_str.count("->") == 1,
lambda: f"einsum_2args_q4 arg 'einsum_str': Expected format '{{}},{{}}->{{}}' (got '{einsum_str}')",
)

es_in, es_out = einsum_str.split("->")
es_in0, es_in1 = es_in.split(",")
es_set = set(es_out)

shp = qs_desc.t.shape
b_dims = list(shp[:-2]) + [shp[-2] * block_size]
torch._check(
len(es_in0) == len(a_desc.t.shape)
and len(es_in1)
== len(qs_desc.t.shape)
- 1, # The quantized shape is larger until the blocks are collapsed
lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dimensions (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})",
)
torch._check(
len(es_in0) == len(set(es_in0))
and len(es_in1) == len(set(es_in1))
and len(es_in0) != 0
and len(es_in1) != 0,
lambda: f"einsum_2args_q4 arg 'einsum_str': Unsupported einsum str (got '{einsum_str}')",
)

# Check corresponding dimensions match
for i in range(len(es_in0)):
a_dim = a_dims[i]
c = es_in0[i]
pos = es_in1.find(c)
if pos >= 0:
b_dim = b_dims[pos]
torch._check(
a_dim == b_dim,
lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dim for idx {c} (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})",
)

# Determine the output shape by referencing corresponding input shapes
out_dims = []
for c in es_out:
pos0 = es_in0.find(c)
pos1 = es_in1.find(c)
a_dim = a_dims[pos0]
b_dim = b_dims[pos1]
if pos0 >= 0:
out_dims.append(a_dim)
elif pos1 >= 0:
out_dims.append(b_dim)
else:
torch._check(
False,
lambda: f"einsum_2args_q4 arg 'einsum_str': output indices must be in input indices (got '{einsum_str}')",
)

# Specialize on BS
qs_desc.specialize_dims(-1)
d_desc.specialize_dims(-1)
m_desc.specialize_dims(-1)

# Shape batch..., m, n
c_desc = ksel.return_new_tensor(out_dims, dtype=a_desc.t.dtype)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
a = kb.arg_value(0)
a_tensor_type = RankedTensorType(a.type)
d = kb.arg_value(1)
d_tensor_type = RankedTensorType(d.type)
qs = kb.arg_value(2)
qs_tensor_type = RankedTensorType(qs.type)
einsum_str = ksel.arg_descs[4].v
# einsum_str = "mek,menk->men"

es_in, es_out = einsum_str.split("->")
es_in0, es_in1 = es_in.split(",")

es_name = "_".join([es_in0, es_in1, es_out])

(
(es_0, es_1, es_2),
einsum_iterators,
einsum_indexing_maps,
oddss,
) = einsum_util(einsum_str)

rank1 = len(es_1)
dequant_iterators = ", ".join(
['"parallel"' for i in range(rank1 + 1)]
) # rank + 1 because of the group dimensions
input_idx_str = ", ".join(["d" + str(i) for i in range(rank1 + 1)])
broadcast_idx_str = ", ".join(
["d" + str(i) if i != rank1 else "0" for i in range(rank1 + 1)]
)
affine_map_parallel = f"affine_map<({input_idx_str}) -> ({input_idx_str})>"
affine_map_broadcast = f"affine_map<({input_idx_str}) -> ({broadcast_idx_str})>"
dequant_indexing_maps = f"""{affine_map_broadcast},
{affine_map_broadcast},
{affine_map_parallel},
{affine_map_parallel}"""

size_str = "x".join("?" for i in range(rank1 - 2))

rank = a_tensor_type.rank
*n_dims, group0, bs_i8 = qs_tensor_type.shape
bs = bs_i8 * 2 # 2 nibbles per byte.
group = group0 * bs
a_type_str = str(a_tensor_type.element_type)
scale_type_str = str(d_tensor_type.element_type)

template_file = "einsum_2args_q4.mlir"
target_function_name = f"sharktank_einsum_2args_q4_{es_name}_{bs}_{a_type_str}"

target_function = inline_template_function(
kb,
template_file,
target_function_name,
bs=bs,
bs_i8=bs_i8,
a_type=a_type_str,
scale_type=scale_type_str,
dequant_indexing_maps=dequant_indexing_maps,
dequant_iterator_types=dequant_iterators,
einsum_indexing_maps=einsum_indexing_maps,
einsum_iterator_types=einsum_iterators,
es_name=es_name,
a_size=len(es_in0),
b_size=len(es_in1),
c_size=len(es_out),
out_dyn_dim_size_str=oddss,
)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
104 changes: 104 additions & 0 deletions sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright 2024 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

{% set accum_type = "f32" %}

!lowp_type = i4
!a_type = {{a_type}}
!scale_type = {{scale_type}}
!accum_type = {{accum_type}}
!a_tensor_type = tensor<{% for i in range(a_size) %}?x{% endfor %}!a_type>
!qs_raw_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs_i8}}xi8>
!qs_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!lowp_type>
!d_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type>
!m_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type>
!accum_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!accum_type>
!c_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!a_type>
!b_grouped_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!a_type>
!b_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}!a_type>

module {

util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}(
%a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type)
-> !c_tensor_type {
%debug = tensor.empty() : tensor<1xf32>
%zero = arith.constant 0.0: !accum_type
{% for i in range(a_size) %}
%k{{i}} = arith.constant {{i}} : index
{% endfor %}
{% for i in range(a_size, b_size) %}
%k{{i}} = arith.constant {{i}} : index
{% endfor %}
{% for i in range(a_size) %}
%a{{i}} = tensor.dim %a, %k{{i}}: !a_tensor_type
{% endfor %}
{% for i in range(b_size) %}
%b{{i}} = tensor.dim %qs_raw, %k{{i}}: !qs_raw_tensor_type
{% endfor %}
%bs = arith.constant {{bs}} : index
%b_unblocked_dim = arith.muli %b{{b_size-1}}, %bs : index

//%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}}

// Dequantize.
%b_grouped = tensor.empty({% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}) : !b_grouped_tensor_type
%b_grouped_dequant = linalg.generic {
indexing_maps = [
{{dequant_indexing_maps}}],
iterator_types = [{{dequant_iterator_types}}] }
ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type)
outs(%b_grouped : !b_grouped_tensor_type) {
^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type):
%q_element_ext = arith.extui %q_element : !lowp_type to i32
%q_element_fp = arith.uitofp %q_element_ext : i32 to !a_type
{% if scale_type == a_type %}
%q_element_scaled = arith.mulf %q_element_fp, %d_element : !a_type
%q_element_offset = arith.addf %q_element_scaled, %m_element : !a_type
{% else %}
%d_element_ext = arith.extf %d_element : !scale_type to !a_type
%m_element_ext = arith.extf %m_element : !scale_type to !a_type
%q_element_scaled = arith.mulf %q_element_fp, %d_element_ext : !a_type
%q_element_offset = arith.addf %q_element_scaled, %m_element_ext : !a_type
{% endif %}
linalg.yield %q_element_offset : !a_type
} -> !b_grouped_tensor_type

// Collapse %b to the same unblocked structure.
%b_unblocked = tensor.collapse_shape %b_grouped_dequant [{% for i in range(b_size-1) %}[{{i}}], {% endfor %}[{{b_size-1}}, {{b_size}}]] : !b_grouped_tensor_type into !b_tensor_type

// Einsum
%result_empty = tensor.empty({{out_dyn_dim_size_str}}) : !accum_tensor_type
%result_fill = linalg.fill ins(%zero: !accum_type) outs(%result_empty: !accum_tensor_type) -> !accum_tensor_type
%result = linalg.generic {
indexing_maps = [
{{einsum_indexing_maps}}],
iterator_types = [{{einsum_iterator_types}}] }
ins(%a, %b_unblocked : !a_tensor_type, !b_tensor_type)
outs(%result_fill : !accum_tensor_type) {
^bb0(%a_element: !a_type, %b_element: !a_type, %out: !accum_type):
%bmm_mul = arith.mulf %a_element, %b_element : !a_type
{% if accum_type == a_type %}
%bmm_accum = arith.addf %bmm_mul, %out : !a_type
{% else %}
%bmm_mul_ext = arith.extf %bmm_mul : !a_type to !accum_type
%bmm_accum = arith.addf %bmm_mul_ext, %out : !accum_type
{% endif %}
linalg.yield %bmm_accum : !accum_type
} -> !accum_tensor_type

// Cast.
%result_cast_empty = tensor.empty({{out_dyn_dim_size_str}}) : !c_tensor_type
%result_cast = linalg.copy
ins(%result : !accum_tensor_type)
outs(%result_cast_empty : !c_tensor_type) -> !c_tensor_type

//iree_input.tensor.trace "foobar" = [%a : !a_tensor_type, %d : !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type, %b_grouped_dequant: !b_grouped_tensor_type]
util.return %result_cast : !c_tensor_type
}

}
Loading

0 comments on commit 38e71df

Please sign in to comment.