diff --git a/projects/pt1/python/CMakeLists.txt b/projects/pt1/python/CMakeLists.txt index 642b86b50490..443fcc809e2c 100644 --- a/projects/pt1/python/CMakeLists.txt +++ b/projects/pt1/python/CMakeLists.txt @@ -20,7 +20,6 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel SOURCES torchscript.py _dynamo_fx_importer.py - compiler_utils.py dynamo.py _version.py ) diff --git a/projects/pt1/python/torch_mlir/compiler_utils.py b/projects/pt1/python/torch_mlir/compiler_utils.py deleted file mode 100644 index 7792006032af..000000000000 --- a/projects/pt1/python/torch_mlir/compiler_utils.py +++ /dev/null @@ -1,75 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -from io import StringIO -import os -import sys -import tempfile - -from torch_mlir.passmanager import PassManager -from torch_mlir.ir import StringAttr - - -def get_module_name_for_debug_dump(module): - """Gets a name suitable for a debug dump. - - The name is not guaranteed to be unique. - """ - if not "torch.debug_module_name" in module.operation.attributes: - return "UnnammedModule" - return StringAttr(module.operation.attributes["torch.debug_module_name"]).value - - -class TorchMlirCompilerError(Exception): - pass - -def run_pipeline_with_repro_report(module, - pipeline: str, - description: str, - enable_ir_printing: bool = False): - """Runs `pipeline` on `module`, with a nice repro report if it fails.""" - module_name = get_module_name_for_debug_dump(module) - try: - original_stderr = sys.stderr - sys.stderr = StringIO() - asm_for_error_report = module.operation.get_asm( - large_elements_limit=10, enable_debug_info=True) - # Lower module in place to make it ready for compiler backends. - with module.context as ctx: - pm = PassManager.parse(pipeline) - if enable_ir_printing: - ctx.enable_multithreading(False) - pm.enable_ir_printing() - pm.run(module.operation) - except Exception as e: - # TODO: More robust. - # - don't arbitrarily clutter up /tmp. When a test suite has many - # tests, this can be a big disk cost (also, /tmp/ is frequently a - # RAM fs, which increases worries about capacity). - # - don't have colliding filenames (hard to do without cluttering - # up /tmp) - # - if we do have have colliding filenames, writes should at least - # avoid being racy. - filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir") - with open(filename, 'w') as f: - f.write(asm_for_error_report) - debug_options="-mlir-print-ir-after-all -mlir-disable-threading" - # Put something descriptive here even if description is empty. - description = description or f"{module_name} compile" - - message = f"""\ - {description} failed with the following diagnostics: - {sys.stderr.getvalue()} - - python exception: {e} - - For Torch-MLIR developers, the error can be reproduced with: - $ torch-mlir-opt -pass-pipeline='{pipeline}' {filename} - Add '{debug_options}' to get the IR dump for debugging purpose. - """ - trimmed_message = '\n'.join([m.lstrip() for m in message.split('\n')]) - raise TorchMlirCompilerError(trimmed_message) from None - finally: - sys.stderr = original_stderr diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py index acb487319ae9..508297cfe8f0 100644 --- a/projects/pt1/python/torch_mlir/torchscript.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -17,65 +17,15 @@ from torch_mlir.dynamo import _get_decomposition_table from torch.fx.experimental.proxy_tensor import make_fx -from .compiler_utils import run_pipeline_with_repro_report +from torch_mlir.compiler_utils import ( + run_pipeline_with_repro_report, + OutputType, + lower_mlir_module +) from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library -class OutputType(Enum): - """The kind of output that `torchscript.compile` can produce. - - In MLIR terminology, this describes the mix of dialects that will be - produced by the conversion process. - - In user-facing API's, this type can always be passed interchangeably with an - appropriate string specifying the output type. The allowed strings are - the set of enum vales, allowed to be case insensitive and with `-` allowed - in place of `_`. The `OutputType.get` static method can be used to convert - from a string to an `OutputType` instance. - """ - - # This output type consists of `torch` dialect ops that have been converted - # maximally to value semantics, decomposed, and shapes have been inferred. - TORCH = "torch" - - # The output type contains a mix of `linalg`-on-tensors ops, `scf`, and - # `arith` ops (and also `math` and `tm_tensor`). It can be thought of - # as taking the `TORCH` output type and lowering it so that tensor - # computations are done with `linalg`-on-tensors ops. - LINALG_ON_TENSORS = "linalg-on-tensors" - - # This output type consists of `tosa` dialect ops. It can be thought of - # as taking the `TORCH` output type and lowering it to TOSA. - TOSA = "tosa" - - # This output type consists of `stablehlo` dialect ops. It can be thought of - # as taking the `TORCH` output type and lowering it to StableHLO. - STABLEHLO = "stablehlo" - - # Raw output of the JIT IR importer. This is not expected to be useful - # for end-users, but can be convenient for development or reporting bugs. - RAW = "raw" - - @staticmethod - def get(spec: Union[str, "OutputType"]) -> "OutputType": - """Gets an OutputType from allowed way to specify one. - - Args: - spec: An OutputType instance or the case-insensitive name of one of the - enum values. - Returns: - An OutputType instance. - """ - if isinstance(spec, OutputType): - return spec - spec = spec.upper().replace("-", "_") - if spec not in OutputType.__members__: - raise ValueError(f"For output_type= argument, expected one of: " - f"{', '.join(OutputType.__members__.keys())}") - return OutputType[spec] - - class TensorPlaceholder: """A class that represents a formal parameter of a given shape and dtype. @@ -270,49 +220,6 @@ def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra return "" -def _lower_mlir_module(verbose, output_type, module): - if verbose: - print("\n====================") - print("Torch Backend IR") - print(module) - - if output_type == OutputType.TORCH: - return module - - if output_type == OutputType.TOSA: - run_pipeline_with_repro_report( - module, "builtin.module(torch-backend-to-tosa-backend-pipeline)", - "Lowering Torch Backend IR -> TOSA Backend IR") - if verbose: - print("\n====================") - print("TOSA Backend IR") - print(module) - return module - - if output_type == OutputType.LINALG_ON_TENSORS: - run_pipeline_with_repro_report( - module, - "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", - "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR") - if verbose: - print("\n====================") - print("LINALG Backend IR") - print(module) - return module - - elif output_type == OutputType.STABLEHLO: - run_pipeline_with_repro_report( - module, - "builtin.module(torch-backend-to-stablehlo-backend-pipeline)", - "Lowering Torch Backend IR -> StableHLO Backend IR") - if verbose: - print("\n====================") - print("StableHLO Backend IR") - print(module) - return module - raise Exception(f"Unknown OutputType: {output_type}") - - def compile(model: torch.nn.Module, example_args: _example_args, output_type: Union[str, "OutputType"] = OutputType.TORCH, @@ -464,4 +371,4 @@ def compile(model: torch.nn.Module, enable_ir_printing=enable_ir_printing, ) - return _lower_mlir_module(verbose, output_type, mb.module) + return lower_mlir_module(verbose, output_type, mb.module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index 0d75fe2ad3f0..e45c7b18bb7a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -12,12 +12,13 @@ from torch.export import ExportedProgram from torch_mlir import fx -from torch_mlir.torchscript import ( - _example_args, +from torch_mlir.compiler_utils import ( + run_pipeline_with_repro_report, + lower_mlir_module, OutputType, +) +from torch_mlir.torchscript import ( BACKEND_LEGAL_OPS, - run_pipeline_with_repro_report, - _lower_mlir_module, _canon_extra_library, ) from torch_mlir_e2e_test.configs.utils import ( @@ -76,7 +77,7 @@ def jit( "Lowering TorchFX IR -> Torch Backend IR", ) - return _lower_mlir_module(verbose, output_type, mlir_module) + return lower_mlir_module(verbose, output_type, mlir_module) class FxImporterTestConfig(TestConfig): diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py index bdc410741cae..13f4d3df863f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -15,14 +15,16 @@ set_model_name, ) +from torch_mlir.compiler_utils import ( + run_pipeline_with_repro_report, + lower_mlir_module, + OutputType, +) from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func from torch_mlir.dynamo import _get_decomposition_table from torch_mlir.torchscript import ( _example_args, - OutputType, BACKEND_LEGAL_OPS, - run_pipeline_with_repro_report, - _lower_mlir_module, _canon_extra_library, ) from torch_mlir_e2e_test.configs.utils import ( @@ -148,7 +150,7 @@ def my_aot_autograd_backend(gm: torch.fx.GraphModule, "Lowering TorchFX IR -> Torch Backend IR", ) - return _lower_mlir_module(verbose, output_type, mlir_module) + return lower_mlir_module(verbose, output_type, mlir_module) class TorchDynamoTestConfig(TestConfig): diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py index 449e6bb40f01..fcd1efb3f4d6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py @@ -4,11 +4,13 @@ # Also available under a BSD-style license. See LICENSE. -from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir.compiler_utils import ( + run_pipeline_with_repro_report, + lower_mlir_module, + OutputType, +) from torch_mlir.ir import * from torch_mlir.passmanager import * -from torch_mlir.torchscript import OutputType -from torch_mlir.torchscript import _lower_mlir_module from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend @@ -58,7 +60,7 @@ def compile(self, imported_module: Module): "Lowering TorchFX IR -> Torch Backend IR", ) - imported_module = _lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module) + imported_module = lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module) compiled_module = self.refbackend.compile(imported_module) return compiled_module diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index e52135599864..76cdbcca41eb 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -43,6 +43,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" ADD_TO_PARENT TorchMLIRPythonSources SOURCES + compiler_utils.py fx.py extras/fx_decomp_util.py ) diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py new file mode 100644 index 000000000000..6416b88aab5f --- /dev/null +++ b/python/torch_mlir/compiler_utils.py @@ -0,0 +1,166 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +from enum import Enum +from io import StringIO +import os +import sys +import tempfile +from typing import Union + +from torch_mlir.passmanager import PassManager +from torch_mlir.ir import StringAttr + + +def get_module_name_for_debug_dump(module): + """Gets a name suitable for a debug dump. + + The name is not guaranteed to be unique. + """ + if not "torch.debug_module_name" in module.operation.attributes: + return "UnnammedModule" + return StringAttr(module.operation.attributes["torch.debug_module_name"]).value + + +class TorchMlirCompilerError(Exception): + pass + +def run_pipeline_with_repro_report(module, + pipeline: str, + description: str, + enable_ir_printing: bool = False): + """Runs `pipeline` on `module`, with a nice repro report if it fails.""" + module_name = get_module_name_for_debug_dump(module) + original_stderr = sys.stderr + try: + sys.stderr = StringIO() + asm_for_error_report = module.operation.get_asm( + large_elements_limit=10, enable_debug_info=True) + # Lower module in place to make it ready for compiler backends. + with module.context as ctx: + pm = PassManager.parse(pipeline) + if enable_ir_printing: + ctx.enable_multithreading(False) + pm.enable_ir_printing() + pm.run(module.operation) + except Exception as e: + # TODO: More robust. + # - don't arbitrarily clutter up /tmp. When a test suite has many + # tests, this can be a big disk cost (also, /tmp/ is frequently a + # RAM fs, which increases worries about capacity). + # - don't have colliding filenames (hard to do without cluttering + # up /tmp) + # - if we do have have colliding filenames, writes should at least + # avoid being racy. + filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir") + with open(filename, 'w') as f: + f.write(asm_for_error_report) + debug_options="-mlir-print-ir-after-all -mlir-disable-threading" + # Put something descriptive here even if description is empty. + description = description or f"{module_name} compile" + + message = f"""\ + {description} failed with the following diagnostics: + {sys.stderr.getvalue()} + + python exception: {e} + + For Torch-MLIR developers, the error can be reproduced with: + $ torch-mlir-opt -pass-pipeline='{pipeline}' {filename} + Add '{debug_options}' to get the IR dump for debugging purpose. + """ + trimmed_message = '\n'.join([m.lstrip() for m in message.split('\n')]) + raise TorchMlirCompilerError(trimmed_message) from None + finally: + sys.stderr = original_stderr + + +class OutputType(Enum): + + # Output torch dialect. When converting from FX, this will be immediately + # after the import from FX to MLIR. When converting from torchscript, + # this will come after some cleanup passes which attempt to de-alias, + # decompose and infer shapes. These should be roughly the same level of + # abstraction since those steps are done within PyTorch itself + # when coming directly from Dynamo/FX. + TORCH = "torch" + + # The output type contains a mix of `linalg`-on-tensors ops, `scf`, and + # `arith` ops (and also `math` and `tm_tensor`). It can be thought of + # as taking the `TORCH` output type and lowering it so that tensor + # computations are done with `linalg`-on-tensors ops. + LINALG_ON_TENSORS = "linalg-on-tensors" + + # This output type consists of `tosa` dialect ops. It can be thought of + # as taking the `TORCH` output type and lowering it to TOSA. + TOSA = "tosa" + + # This output type consists of `stablehlo` dialect ops. It can be thought of + # as taking the `TORCH` output type and lowering it to StableHLO. + STABLEHLO = "stablehlo" + + # Raw output of the JIT IR importer. This is not expected to be useful + # for end-users, but can be convenient for development or reporting bugs. + RAW = "raw" + + @staticmethod + def get(spec: Union[str, "OutputType"]) -> "OutputType": + """Gets an OutputType from allowed way to specify one. + + Args: + spec: An OutputType instance or the case-insensitive name of one of the + enum values. + Returns: + An OutputType instance. + """ + if isinstance(spec, OutputType): + return spec + spec = spec.upper().replace("-", "_") + if spec not in OutputType.__members__: + raise ValueError(f"For output_type= argument, expected one of: " + f"{', '.join(OutputType.__members__.keys())}") + return OutputType[spec] + + +def lower_mlir_module(verbose, output_type, module): + if verbose: + print("\n====================") + print("Torch Backend IR") + print(module) + + if output_type == OutputType.TORCH: + return module + + if output_type == OutputType.TOSA: + run_pipeline_with_repro_report( + module, "builtin.module(torch-backend-to-tosa-backend-pipeline)", + "Lowering Torch Backend IR -> TOSA Backend IR") + if verbose: + print("\n====================") + print("TOSA Backend IR") + print(module) + return module + + if output_type == OutputType.LINALG_ON_TENSORS: + run_pipeline_with_repro_report( + module, + "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR") + if verbose: + print("\n====================") + print("LINALG Backend IR") + print(module) + return module + + elif output_type == OutputType.STABLEHLO: + run_pipeline_with_repro_report( + module, + "builtin.module(torch-backend-to-stablehlo-backend-pipeline)", + "Lowering Torch Backend IR -> StableHLO Backend IR") + if verbose: + print("\n====================") + print("StableHLO Backend IR") + print(module) + return module + raise Exception(f"Unknown OutputType: {output_type}")