Skip to content

Commit

Permalink
fix top_k
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Nov 19, 2024
1 parent c3fdc8a commit adb52c1
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -844,18 +844,18 @@ end
# return TracedRArray{T,N}((), res, size(x))
# end

function chlo.top_k(
function top_k(
x::TracedRArray{T,N},
k;
location=MLIR.IR.Location("chlo.top_k", MLIR.IR.Location(@__FILE__, @__LINE__, 0)),
) where {T,N}
rsize = [size(x)[1:(end - 1)]..., k]
values = MLIR.IR.TensorType(rsize, mlir_type(T))
indices = MLIR.IR.TensorType(rsize, mlir_type(Int))
op = chlo.top_k(x.mlir_data; values, indices, location)
values = mlir_type(TracedRArray{T,N}, rsize)
indices = mlir_type(TracedRArray{Int32,N}, rsize)
op = chlo.top_k(x.mlir_data; values, indices, k, location)
return (;
values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize),
indices=TracedRArray{Int,N}((), MLIR.IR.result(op, 2), rsize),
indices=TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize),
)
end

Expand Down

0 comments on commit adb52c1

Please sign in to comment.