Skip to content

Commit

Permalink
[FORK][FIX] Fixed behavior for unaligned src and weights ic groups
Browse files Browse the repository at this point in the history
[FORK][FEATURE] InnerProduct primitive: src dynamic quantization
  • Loading branch information
dmitry-gorokhov committed Apr 1, 2024
1 parent 1a30d38 commit 26633ae
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
22 changes: 11 additions & 11 deletions src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2499,13 +2499,6 @@ void jit_brgemm_kernel_t<isa, Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
mov(reg_rdb_loop, brg.rdb);
L_aligned(rdb_loop_label, 64);
{
const bool is_rd_tail = false;
gemm_microkernel(bd_block2, is_bdb_tail, ld_block2,
is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail);

add(reg_aux_A, rdb_A_offset());
add(reg_aux_B, rdb_B_offset());

if (brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 ||
brg.wei_decomp_zero_points_stride != 0)) {
auto reg_local_ic = reg_aux_D;
Expand All @@ -2529,10 +2522,6 @@ void jit_brgemm_kernel_t<isa, Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop);
mov(ptr[rsp + reg_reg_a_offset_offs_], reg_a_offset); // preserve rdx for idiv

mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]);
add(reg_local_ic, brg.rd_block);
mov(ptr[rsp + reg_aux_ic_offs_], reg_local_ic);

if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) {
ic_group_shift(reg_aux_wei_scales_offs_, reg_aux2_wei_scales_offs_,
brg.wei_decomp_scales_group_size, brg.wei_decomp_scales_stride * sizeof(float));
Expand All @@ -2548,12 +2537,23 @@ void jit_brgemm_kernel_t<isa, Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
brg.src_scales_group_size, sizeof(float));
}

mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]);
add(reg_local_ic, brg.rd_block);
mov(ptr[rsp + reg_aux_ic_offs_], reg_local_ic);

mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
mov(reg_aux_D, ptr[rsp + reg_aux2_D_offs_]);
mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
mov(reg_a_offset, ptr[rsp + reg_reg_a_offset_offs_]);
}

const bool is_rd_tail = false;
gemm_microkernel(bd_block2, is_bdb_tail, ld_block2,
is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail);

add(reg_aux_A, rdb_A_offset());
add(reg_aux_B, rdb_B_offset());

dec(reg_rdb_loop);
cmp(reg_rdb_loop, 0);
}
Expand Down
20 changes: 10 additions & 10 deletions src/cpu/x64/jit_brgemm_inner_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
int wei_zero_points_offset = 0;
int src_scales_offset = 0;
if (jbgp.weights_decompression) {
wei_scales_offset = (ic / jbgp.wei_scales_ic_group_size) * wei_scales_d.dims()[0] + wei_scales_oc_stride * oc;
wei_zero_points_offset = ((ic / jbgp.wei_zero_points_ic_group_size) * wei_zero_points_d.dims()[0] + wei_zero_points_oc_stride * oc) * wei_zero_points_dt_size;
src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size) + (ic / jbgp.src_quant_group_size);
wei_scales_offset = wei_scales_oc_stride * oc;
wei_zero_points_offset = wei_zero_points_oc_stride * oc * wei_zero_points_dt_size;
src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size);
}

auto ptr_D = dst + dst_off;
Expand All @@ -423,10 +423,10 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(

brgemm_kernel_execute_postops(brg_kernel, gemm_batch,
addr_batch, (void *)ptr_C, (void *)ptr_D, post_ops_data,
scratch, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, 0);
scratch, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic);
} else {
brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch,
(void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, 0);
(void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic);
}
}

Expand Down Expand Up @@ -500,9 +500,9 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
int wei_zero_points_offset = 0;
int src_scales_offset = 0;
if (jbgp.weights_decompression) {
wei_scales_offset = (ic / jbgp.wei_scales_ic_group_size) * wei_scales_d.dims()[0] + wei_scales_oc_stride * oc;
wei_zero_points_offset = ((ic / jbgp.wei_zero_points_ic_group_size) * wei_zero_points_d.dims()[0] + wei_zero_points_oc_stride * oc) * wei_zero_points_dt_size;
src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size) + (ic / jbgp.src_quant_group_size);
wei_scales_offset = wei_scales_oc_stride * oc;
wei_zero_points_offset = wei_zero_points_oc_stride * oc * wei_zero_points_dt_size;
src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size);
}

auto brg_kernel_ic_tail = brg_kernels_[brg_ker_ic_tail_idx].get();
Expand All @@ -524,10 +524,10 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
nullptr, false, 1, false, false, dst_scales};

brgemm_kernel_execute_postops(brg_kernel_ic_tail, 1, addr_batch,
(void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, 0);
(void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic);
} else {
brgemm_kernel_execute(brg_kernel_ic_tail, 1, addr_batch,
(void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, 0);
(void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic);
}
}
};
Expand Down

0 comments on commit 26633ae

Please sign in to comment.