Skip to content

Commit

Permalink
Add map_mul3_map_mul2_count_int32_int32_int32, map_mul2_map_mul2_coun…
Browse files Browse the repository at this point in the history
…t_float64_float32_int16 tests. map_mul* tests are flaky,TBD why.
  • Loading branch information
rwgk committed Nov 23, 2024
1 parent 3a232d8 commit 79a6dac
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions python/cuda_parallel/tests/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy
import pytest
import random
import re
import numba.cuda
import numba.types
import cuda.parallel.experimental as cudax
Expand Down Expand Up @@ -92,6 +93,10 @@ def mul2(val):
return 2 * val


def mul3(val):
return 3 * val


@pytest.mark.parametrize("use_numpy_array", [True, False])
@pytest.mark.parametrize("input_generator", ["raw_pointer_int16",
"raw_pointer_uint16",
Expand Down Expand Up @@ -138,6 +143,8 @@ def mul2(val):
"map_mul2_count_float64_int32",
"map_mul2_count_float64_int64",
"map_mul2_count_int64_float32",
"map_mul3_map_mul2_count_int32_int32_int32",
"map_mul2_map_mul2_count_float64_float32_int16",
])
def test_device_sum_iterators(use_numpy_array, input_generator, num_items=3, start_sum_with=10):
def add_op(a, b):
Expand Down Expand Up @@ -179,6 +186,21 @@ def dtype_ntype(ix):
mul2,
iterators.count(start_sum_with, ntype=ntype_inp),
op_return_ntype=ntype_out)
elif re.match(r"map_mul\d_map_mul\d_count_", input_generator):
fac_out = int(input_generator[7])
fac_mid = int(input_generator[16])
l_input = [fac_out * (fac_mid * (start_sum_with + distance)) for distance in range(num_items)]
dtype_inp, ntype_inp = dtype_ntype(-1)
dtype_mid, ntype_mid = dtype_ntype(-2)
dtype_out, ntype_out = dtype_ntype(-3)
mul_funcs = {2: mul2, 3: mul3}
i_input = iterators.cu_map(
mul_funcs[fac_out],
iterators.cu_map(
mul_funcs[fac_mid],
iterators.count(start_sum_with, ntype=ntype_inp),
op_return_ntype=ntype_mid),
op_return_ntype=ntype_out)
else:
raise RuntimeError("Unexpected input_generator")

Expand Down

0 comments on commit 79a6dac

Please sign in to comment.