diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py index 133e6612fd68..2291f27e32f2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py @@ -199,13 +199,14 @@ def _get_function_signature(self, function_kind: str, parameter_decls = list(map(parameter_decl_builder, self.arguments)) parameter_decls = list(filter(None, parameter_decls)) ret_decls = list(map(ret_decl_builder, self.returns)) + ret_decls = list(filter(None, ret_decls)) parameters = ", ".join(parameter_decls) result = ", ".join(ret_decls) + if len(ret_decls) == 0: + result = "None" if len(ret_decls) >= 2: result = f"Tuple[{result}]" - if function_kind == "has_value_semantics": - result = "None" return f"def {def_name}({parameters}) -> {result}:" def get_shape_function_signature(self): @@ -288,7 +289,7 @@ def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: return "" def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: - return "None" + return "" return self._get_function_signature( "has_value_semantics", parameter_decl_builder, ret_decl_builder)