diff --git a/src/Ops.jl b/src/Ops.jl index e8dd3252..a55b8777 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -844,18 +844,18 @@ end # return TracedRArray{T,N}((), res, size(x)) # end -function chlo.top_k( +function top_k( x::TracedRArray{T,N}, k; location=MLIR.IR.Location("chlo.top_k", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), ) where {T,N} rsize = [size(x)[1:(end - 1)]..., k] - values = MLIR.IR.TensorType(rsize, mlir_type(T)) - indices = MLIR.IR.TensorType(rsize, mlir_type(Int)) - op = chlo.top_k(x.mlir_data; values, indices, location) + values = mlir_type(TracedRArray{T,N}, rsize) + indices = mlir_type(TracedRArray{Int32,N}, rsize) + op = chlo.top_k(x.mlir_data; values, indices, k, location) return (; values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize), - indices=TracedRArray{Int,N}((), MLIR.IR.result(op, 2), rsize), + indices=TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize), ) end