diff --git a/src/gr_mat/mul_strassen.c b/src/gr_mat/mul_strassen.c index 6aaaef06c2..15c2a97dff 100644 --- a/src/gr_mat/mul_strassen.c +++ b/src/gr_mat/mul_strassen.c @@ -13,12 +13,11 @@ #include "gr_mat.h" -#include "fmpz_mat.h" - /* todo: optimize for small matrices */ -/* todo: bodrato squaring */ /* todo: use fused add-mul operations when supported by the matrix interface in the future */ +/* todo: when squaring, pretransform A12, A21, X2 which are + used twice in the recursive multiplications */ /* The implemented sequence is not Strassen's nor Winograd's, but the sequence proposed by Bodrato, which is equivalent to Winograd's, and can be easily @@ -27,14 +26,8 @@ int gr_mat_mul_strassen(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) { slong ar, ac, br, bc; - slong anr, anc, bnr, bnc; int status = GR_SUCCESS; - gr_mat_t A11, A12, A21, A22; - gr_mat_t B11, B12, B21, B22; - gr_mat_t C11, C12, C21, C22; - gr_mat_t X1, X2; - ar = A->r; ac = A->c; br = B->r; @@ -60,132 +53,247 @@ int gr_mat_mul_strassen(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t return status; } - anr = ar / 2; - anc = ac / 2; - bnr = anc; - bnc = bc / 2; + if (A == B) + { + slong anr; - gr_mat_window_init(A11, A, 0, 0, anr, anc, ctx); - gr_mat_window_init(A12, A, 0, anc, anr, 2 * anc, ctx); - gr_mat_window_init(A21, A, anr, 0, 2 * anr, anc, ctx); - gr_mat_window_init(A22, A, anr, anc, 2 * anr, 2 * anc, ctx); + gr_mat_t A11, A12, A21, A22; + gr_mat_t C11, C12, C21, C22; + gr_mat_t X1, X2; - gr_mat_window_init(B11, B, 0, 0, bnr, bnc, ctx); - gr_mat_window_init(B12, B, 0, bnc, bnr, 2 * bnc, ctx); - gr_mat_window_init(B21, B, bnr, 0, 2 * bnr, bnc, ctx); - gr_mat_window_init(B22, B, bnr, bnc, 2 * bnr, 2 * bnc, ctx); + anr = ar / 2; - gr_mat_window_init(C11, C, 0, 0, anr, bnc, ctx); - gr_mat_window_init(C12, C, 0, bnc, anr, 2 * bnc, ctx); - gr_mat_window_init(C21, C, anr, 0, 2 * anr, bnc, ctx); - gr_mat_window_init(C22, C, anr, bnc, 2 * anr, 2 * bnc, ctx); + gr_mat_window_init(A11, A, 0, 0, anr, anr, ctx); + gr_mat_window_init(A12, A, 0, anr, anr, 2 * anr, ctx); + gr_mat_window_init(A21, A, anr, 0, 2 * anr, anr, ctx); + gr_mat_window_init(A22, A, anr, anr, 2 * anr, 2 * anr, ctx); - gr_mat_init(X1, anr, FLINT_MAX(bnc, anc), ctx); - gr_mat_init(X2, anc, bnc, ctx); + gr_mat_window_init(C11, C, 0, 0, anr, anr, ctx); + gr_mat_window_init(C12, C, 0, anr, anr, 2 * anr, ctx); + gr_mat_window_init(C21, C, anr, 0, 2 * anr, anr, ctx); + gr_mat_window_init(C22, C, anr, anr, 2 * anr, 2 * anr, ctx); - X1->c = anc; + gr_mat_init(X2, anr, anr, ctx); - status |= gr_mat_add(X1, A22, A12, ctx); - status |= gr_mat_add(X2, B22, B12, ctx); - status |= gr_mat_mul(C21, X1, X2, ctx); + status |= gr_mat_add(X2, A22, A12, ctx); + status |= gr_mat_mul(C21, X2, X2, ctx); + status |= gr_mat_sub(X2, A22, A21, ctx); + status |= gr_mat_mul(C22, X2, X2, ctx); + status |= gr_mat_add(X2, X2, A12, ctx); + status |= gr_mat_mul(C11, X2, X2, ctx); - status |= gr_mat_sub(X1, A22, A21, ctx); - status |= gr_mat_sub(X2, B22, B21, ctx); - status |= gr_mat_mul(C22, X1, X2, ctx); + status |= gr_mat_sub(X2, X2, A11, ctx); + status |= gr_mat_mul(C12, X2, A12, ctx); - status |= gr_mat_add(X1, X1, A12, ctx); - status |= gr_mat_add(X2, X2, B12, ctx); - status |= gr_mat_mul(C11, X1, X2, ctx); + gr_mat_init(X1, anr, anr, ctx); - status |= gr_mat_sub(X1, X1, A11, ctx); - status |= gr_mat_mul(C12, X1, B12, ctx); + status |= gr_mat_mul(X1, A12, A21, ctx); + status |= gr_mat_add(C11, C11, X1, ctx); + status |= gr_mat_sub(C12, C11, C12, ctx); + status |= gr_mat_sub(C11, C21, C11, ctx); + status |= gr_mat_mul(C21, A21, X2, ctx); - X1->c = bnc; - status |= gr_mat_mul(X1, A12, B21, ctx); - status |= gr_mat_add(C11, C11, X1, ctx); - status |= gr_mat_add(C12, C12, C22, ctx); - status |= gr_mat_sub(C12, C11, C12, ctx); - status |= gr_mat_sub(C11, C21, C11, ctx); - status |= gr_mat_sub(X2, X2, B11, ctx); - status |= gr_mat_mul(C21, A21, X2, ctx); + gr_mat_clear(X2, ctx); - gr_mat_clear(X2, ctx); + status |= gr_mat_sub(C21, C11, C21, ctx); + status |= gr_mat_sub(C12, C12, C22, ctx); + status |= gr_mat_add(C22, C22, C11, ctx); + status |= gr_mat_mul(C11, A11, A11, ctx); + status |= gr_mat_add(C11, X1, C11, ctx); - status |= gr_mat_sub(C21, C11, C21, ctx); - status |= gr_mat_add(C22, C22, C11, ctx); - status |= gr_mat_mul(C11, A11, B11, ctx); + gr_mat_clear(X1, ctx); - status |= gr_mat_add(C11, X1, C11, ctx); + gr_mat_window_clear(A11, ctx); + gr_mat_window_clear(A12, ctx); + gr_mat_window_clear(A21, ctx); + gr_mat_window_clear(A22, ctx); - X1->c = FLINT_MAX(bnc, anc); - gr_mat_clear(X1, ctx); + gr_mat_window_clear(C11, ctx); + gr_mat_window_clear(C12, ctx); + gr_mat_window_clear(C21, ctx); + gr_mat_window_clear(C22, ctx); - gr_mat_window_clear(A11, ctx); - gr_mat_window_clear(A12, ctx); - gr_mat_window_clear(A21, ctx); - gr_mat_window_clear(A22, ctx); + if (ar > 2 * anr) + { + { + gr_mat_t Ac, Cc; + gr_mat_window_init(Ac, A, 0, 2 * anr, ar, ar, ctx); + gr_mat_window_init(Cc, C, 0, 2 * anr, ar, ar, ctx); - gr_mat_window_clear(B11, ctx); - gr_mat_window_clear(B12, ctx); - gr_mat_window_clear(B21, ctx); - gr_mat_window_clear(B22, ctx); + status |= gr_mat_mul(Cc, A, Ac, ctx); - gr_mat_window_clear(C11, ctx); - gr_mat_window_clear(C12, ctx); - gr_mat_window_clear(C21, ctx); - gr_mat_window_clear(C22, ctx); + gr_mat_window_clear(Ac, ctx); + gr_mat_window_clear(Cc, ctx); + } - if (bc > 2 * bnc) - { - gr_mat_t Bc, Cc; - gr_mat_window_init(Bc, B, 0, 2 * bnc, ac, bc, ctx); - gr_mat_window_init(Cc, C, 0, 2 * bnc, ar, bc, ctx); - status |= gr_mat_mul(Cc, A, Bc, ctx); - gr_mat_window_clear(Bc, ctx); - gr_mat_window_clear(Cc, ctx); - } + { + gr_mat_t Ar, Cr; + gr_mat_t As; + + gr_mat_window_init(Ar, A, 2 * anr, 0, ar, ar, ctx); + gr_mat_window_init(Cr, C, 2 * anr, 0, ar, 2 * anr, ctx); + gr_mat_window_init(As, A, 0, 0, ar, 2 * anr, ctx); + + status |= gr_mat_mul(Cr, Ar, As, ctx); + + gr_mat_window_clear(As, ctx); + gr_mat_window_clear(Ar, ctx); + gr_mat_window_clear(Cr, ctx); + } + + { + gr_mat_t Ac, Ar, Cb, tmp; + + gr_mat_window_init(Ac, A, 0, 2 * anr, 2 * anr, ar, ctx); + gr_mat_window_init(Ar, A, 2 * anr, 0, ar, 2 * anr, ctx); + gr_mat_window_init(Cb, C, 0, 0, 2 * anr, 2 * anr, ctx); + gr_mat_init(tmp, 2 * anr, 2 * anr, ctx); + + status |= gr_mat_mul(tmp, Ac, Ar, ctx); + status |= gr_mat_add(Cb, Cb, tmp, ctx); - if (ar > 2 * anr) + gr_mat_clear(tmp, ctx); + gr_mat_window_clear(Ac, ctx); + gr_mat_window_clear(Ar, ctx); + gr_mat_window_clear(Cb, ctx); + } + } + } + else { - gr_mat_t Ar, Br, Cr; - gr_mat_window_init(Ar, A, 2 * anr, 0, ar, ac, ctx); - gr_mat_window_init(Cr, C, 2 * anr, 0, ar, 2 * bnc, ctx); + slong anr, anc, bnr, bnc; + gr_mat_t A11, A12, A21, A22; + gr_mat_t B11, B12, B21, B22; + gr_mat_t C11, C12, C21, C22; + gr_mat_t X1, X2; + + anr = ar / 2; + anc = ac / 2; + bnr = anc; + bnc = bc / 2; + + gr_mat_window_init(A11, A, 0, 0, anr, anc, ctx); + gr_mat_window_init(A12, A, 0, anc, anr, 2 * anc, ctx); + gr_mat_window_init(A21, A, anr, 0, 2 * anr, anc, ctx); + gr_mat_window_init(A22, A, anr, anc, 2 * anr, 2 * anc, ctx); + + gr_mat_window_init(B11, B, 0, 0, bnr, bnc, ctx); + gr_mat_window_init(B12, B, 0, bnc, bnr, 2 * bnc, ctx); + gr_mat_window_init(B21, B, bnr, 0, 2 * bnr, bnc, ctx); + gr_mat_window_init(B22, B, bnr, bnc, 2 * bnr, 2 * bnc, ctx); + + gr_mat_window_init(C11, C, 0, 0, anr, bnc, ctx); + gr_mat_window_init(C12, C, 0, bnc, anr, 2 * bnc, ctx); + gr_mat_window_init(C21, C, anr, 0, 2 * anr, bnc, ctx); + gr_mat_window_init(C22, C, anr, bnc, 2 * anr, 2 * bnc, ctx); + + gr_mat_init(X1, anr, FLINT_MAX(bnc, anc), ctx); + gr_mat_init(X2, anc, bnc, ctx); + + X1->c = anc; + + status |= gr_mat_add(X1, A22, A12, ctx); + status |= gr_mat_add(X2, B22, B12, ctx); + status |= gr_mat_mul(C21, X1, X2, ctx); + + status |= gr_mat_sub(X1, A22, A21, ctx); + status |= gr_mat_sub(X2, B22, B21, ctx); + status |= gr_mat_mul(C22, X1, X2, ctx); + + status |= gr_mat_add(X1, X1, A12, ctx); + status |= gr_mat_add(X2, X2, B12, ctx); + status |= gr_mat_mul(C11, X1, X2, ctx); + + status |= gr_mat_sub(X1, X1, A11, ctx); + status |= gr_mat_mul(C12, X1, B12, ctx); + + X1->c = bnc; + status |= gr_mat_mul(X1, A12, B21, ctx); + status |= gr_mat_add(C11, C11, X1, ctx); + status |= gr_mat_add(C12, C12, C22, ctx); + status |= gr_mat_sub(C12, C11, C12, ctx); + status |= gr_mat_sub(C11, C21, C11, ctx); + status |= gr_mat_sub(X2, X2, B11, ctx); + status |= gr_mat_mul(C21, A21, X2, ctx); + + gr_mat_clear(X2, ctx); + + status |= gr_mat_sub(C21, C11, C21, ctx); + status |= gr_mat_add(C22, C22, C11, ctx); + status |= gr_mat_mul(C11, A11, B11, ctx); + + status |= gr_mat_add(C11, X1, C11, ctx); + + X1->c = FLINT_MAX(bnc, anc); + gr_mat_clear(X1, ctx); + + gr_mat_window_clear(A11, ctx); + gr_mat_window_clear(A12, ctx); + gr_mat_window_clear(A21, ctx); + gr_mat_window_clear(A22, ctx); + + gr_mat_window_clear(B11, ctx); + gr_mat_window_clear(B12, ctx); + gr_mat_window_clear(B21, ctx); + gr_mat_window_clear(B22, ctx); + + gr_mat_window_clear(C11, ctx); + gr_mat_window_clear(C12, ctx); + gr_mat_window_clear(C21, ctx); + gr_mat_window_clear(C22, ctx); - /* don't compute the overlapping entries twice */ if (bc > 2 * bnc) { - gr_mat_window_init(Br, B, 0, 0, ac, 2 * bnc, ctx); - status |= gr_mat_mul(Cr, Ar, Br, ctx); - gr_mat_window_clear(Br, ctx); + gr_mat_t Bc, Cc; + gr_mat_window_init(Bc, B, 0, 2 * bnc, ac, bc, ctx); + gr_mat_window_init(Cc, C, 0, 2 * bnc, ar, bc, ctx); + status |= gr_mat_mul(Cc, A, Bc, ctx); + gr_mat_window_clear(Bc, ctx); + gr_mat_window_clear(Cc, ctx); } - else + + if (ar > 2 * anr) { - status |= gr_mat_mul(Cr, Ar, B, ctx); + gr_mat_t Ar, Br, Cr; + gr_mat_window_init(Ar, A, 2 * anr, 0, ar, ac, ctx); + gr_mat_window_init(Cr, C, 2 * anr, 0, ar, 2 * bnc, ctx); + + /* don't compute the overlapping entries twice */ + if (bc > 2 * bnc) + { + gr_mat_window_init(Br, B, 0, 0, ac, 2 * bnc, ctx); + status |= gr_mat_mul(Cr, Ar, Br, ctx); + gr_mat_window_clear(Br, ctx); + } + else + { + status |= gr_mat_mul(Cr, Ar, B, ctx); + } + + gr_mat_window_clear(Ar, ctx); + gr_mat_window_clear(Cr, ctx); } - gr_mat_window_clear(Ar, ctx); - gr_mat_window_clear(Cr, ctx); - } + if (ac > 2 * anc) + { + gr_mat_t Ac, Br, Cb, tmp; + slong mt, nt; - if (ac > 2 * anc) - { - gr_mat_t Ac, Br, Cb, tmp; - slong mt, nt; - - gr_mat_window_init(Ac, A, 0, 2 * anc, 2 * anr, ac, ctx); - gr_mat_window_init(Br, B, 2 * bnr, 0, ac, 2 * bnc, ctx); - gr_mat_window_init(Cb, C, 0, 0, 2 * anr, 2 * bnc, ctx); - - mt = Ac->r; - nt = Br->c; - - gr_mat_init(tmp, mt, nt, ctx); - status |= gr_mat_mul(tmp, Ac, Br, ctx); - status |= gr_mat_add(Cb, Cb, tmp, ctx); - gr_mat_clear(tmp, ctx); - gr_mat_window_clear(Ac, ctx); - gr_mat_window_clear(Br, ctx); - gr_mat_window_clear(Cb, ctx); + gr_mat_window_init(Ac, A, 0, 2 * anc, 2 * anr, ac, ctx); + gr_mat_window_init(Br, B, 2 * bnr, 0, ac, 2 * bnc, ctx); + gr_mat_window_init(Cb, C, 0, 0, 2 * anr, 2 * bnc, ctx); + + mt = Ac->r; + nt = Br->c; + + gr_mat_init(tmp, mt, nt, ctx); + status |= gr_mat_mul(tmp, Ac, Br, ctx); + status |= gr_mat_add(Cb, Cb, tmp, ctx); + gr_mat_clear(tmp, ctx); + gr_mat_window_clear(Ac, ctx); + gr_mat_window_clear(Br, ctx); + gr_mat_window_clear(Cb, ctx); + } } return status; diff --git a/src/gr_mat/test/t-mul_strassen.c b/src/gr_mat/test/t-mul_strassen.c index 0fc092faca..8829ebcbeb 100644 --- a/src/gr_mat/test/t-mul_strassen.c +++ b/src/gr_mat/test/t-mul_strassen.c @@ -30,8 +30,8 @@ TEST_FUNCTION_START(gr_mat_mul_strassen, state) gr_ctx_init_nmod(ctx, n_randtest_not_zero(state)); a = n_randint(state, 8); - b = n_randint(state, 8); - c = n_randint(state, 8); + b = n_randint(state, 2) ? a : n_randint(state, 8); + c = n_randint(state, 2) ? a : n_randint(state, 8); gr_mat_init(A, a, b, ctx); gr_mat_init(B, b, c, ctx); @@ -43,7 +43,12 @@ TEST_FUNCTION_START(gr_mat_mul_strassen, state) status |= gr_mat_randtest(C, state, ctx); status |= gr_mat_randtest(D, state, ctx); - if (b == c && n_randint(state, 2)) + if (a == b && b == c && n_randint(state, 2)) + { + status |= gr_mat_set(B, A, ctx); + status |= gr_mat_mul_strassen(C, A, A, ctx); + } + else if (b == c && n_randint(state, 2)) { status |= gr_mat_set(C, A, ctx); status |= gr_mat_mul_strassen(C, C, B, ctx); @@ -60,7 +65,7 @@ TEST_FUNCTION_START(gr_mat_mul_strassen, state) status |= gr_mat_mul_classical(D, A, B, ctx); - if (status != GR_SUCCESS && gr_mat_equal(C, D, ctx) == T_FALSE) + if (status != GR_SUCCESS || gr_mat_equal(C, D, ctx) != T_TRUE) { flint_printf("FAIL:\n"); gr_ctx_println(ctx);