Skip to content

Commit

Permalink
deprecate unary_einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Nov 19, 2024
1 parent 2d7eaf7 commit 4189cac
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 16 deletions.
38 changes: 23 additions & 15 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -649,21 +649,29 @@ function einsum(
return TracedRArray{T,length(rsize)}((), res, rsize)
end

function unary_einsum(
x::Union{TracedRNumber,TracedRArray};
equation::String,
location=MLIR.IR.Location(
"stablehlo.unary_einsum", MLIR.IR.Location(@__FILE__, @__LINE__, 0)
),
)
res = MLIR.IR.result(
stablehlo.unary_einsum(
x.mlir_data; einsum_config=MLIR.IR.Attribute(equation), location
),
)
# computing the result size is not trivial
return TracedRArray{Float64,1}((), res, (1,))
end
# function unary_einsum(
# x::TracedRArray{T};
# equation::String,
# location=MLIR.IR.Location(
# "stablehlo.unary_einsum", MLIR.IR.Location(@__FILE__, @__LINE__, 0)
# ),
# ) where {T}
# ia, ic = split(equation, "->")
# sizes = Dict(c => d for (c, d) in zip(ia, size(x)))
# rsize = Tuple(sizes[i] for i in ic)
# result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize)

# res = MLIR.IR.result(
# stablehlo.unary_einsum(
# x.mlir_data; result_0, einsum_config=MLIR.IR.Attribute(equation), location
# ),
# )
# if length(rsize) == 0
# return TracedRNumber{T}((), res)
# else
# return TracedRArray{T,length(rsize)}((), res, rsize)
# end
# end

# paralell ops
function partition_id(;
Expand Down
20 changes: 19 additions & 1 deletion test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,25 @@ end
] == @jit Ops.transpose(x, [3, 2, 1])
end

@testset "unary_einsum" begin end
# NOTE deprecated
# @testset "unary_einsum" begin
# f1(a) = Ops.unary_einsum(a; equation="i->")
# f4(a) = Ops.unary_einsum(a; equation="ij->")
# f3(a) = Ops.unary_einsum(a; equation="ij->ji")
# f4(a) = Ops.unary_einsum(a; equation="ij->j")
# f5(a) = Ops.unary_einsum(a; equation="ij->i")
# f6(a) = Ops.unary_einsum(a; equation="ii->i")

# x = ConcreteRArray([1, 2, 3, 4])
# @test sum(Array(x)) ≈ @jit f1(x)

# x = ConcreteRArray([1 2; 3 4])
# @test sum(Array(x)) ≈ @jit f4(x)
# @test Base.transpose(Array(x)) ≈ @jit f3(x)
# @test sum(Array(x); dims=1) ≈ @jit f4(x)
# @test sum(Array(x); dims=2) ≈ @jit f5(x)
# @test diag(Array(x)) ≈ @jit f6(x)
# end

@testset "xor" begin end

Expand Down

0 comments on commit 4189cac

Please sign in to comment.