Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rfc] Use logging.getLogger for projects/pt1/e2e_testing #3173

Closed
wants to merge 12 commits into from
17 changes: 17 additions & 0 deletions projects/pt1/e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Also available under a BSD-style license. See LICENSE.

import argparse
import logging
renxida marked this conversation as resolved.
Show resolved Hide resolved
import re
import sys

Expand Down Expand Up @@ -75,6 +76,10 @@ def _get_argparse():
default=False,
action="store_true",
help="report test results with additional detail")
parser.add_argument("--print-ir",
default=False,
action="store_true",
help="Set logging level to DEBUG and causes the IR to be printed for each test.")
parser.add_argument("-s", "--sequential",
default=False,
action="store_true",
Expand All @@ -93,6 +98,18 @@ def _get_argparse():
def main():
args = _get_argparse().parse_args()

ir_printer = logging.getLogger("ir_printer")
if args.print_ir:
print("WARNING: --print-ir is a work in progress feature.")
print("print-ir: Setting logging level to DEBUG and enabling IR printing.")
print("print-ir: This currently only affects the Linalg-on-Tensors and onnx configs.")
print("print-ir: Work in progress. See https://github.com/llvm/torch-mlir/issues/3172")
ir_printer.setLevel(logging.DEBUG)
else:
# disable logging
ir_printer.setLevel(logging.CRITICAL+1)
renxida marked this conversation as resolved.
Show resolved Hide resolved


all_test_unique_names = set(
test.unique_name for test in GLOBAL_TEST_REGISTRY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import logging
ir_printer = logging.getLogger("ir_printer")

from typing import Any

Expand Down Expand Up @@ -32,7 +34,9 @@ def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torchscript.compile(
program, example_args, output_type="linalg-on-tensors")

ir_printer.debug("LinalgOnTensorsBackendTestConfig compiled module:")
ir_printer.debug(module)
ir_printer.debug("End LinalgOnTensorsBackendTestConfig compiled module")
return self.backend.compile(module)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from torch_mlir.extras import onnx_importer
from torch_mlir.dialects import torch as torch_d
from torch_mlir.ir import Context, Module

import logging
ir_printer = logging.getLogger("ir_printer")

def import_onnx(contents):
# Import the ONNX model proto from the file contents:
Expand All @@ -39,7 +40,7 @@ def import_onnx(contents):
return m


def convert_onnx(model, inputs):
def convert_onnx(model: torch.nn.Module, inputs):
buffer = io.BytesIO()

# Process the type information so we export with the dynamic shape information
Expand Down Expand Up @@ -82,6 +83,9 @@ def __init__(self, backend: OnnxBackend, use_make_fx: bool = False):
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
onnx_module = convert_onnx(program, example_args)
ir_printer.debug("OnnxBackendTestConfig imported module:")
renxida marked this conversation as resolved.
Show resolved Hide resolved
ir_printer.debug(onnx_module)
ir_printer.debug("End OnnxBackendTestConfig imported module")
compiled_module = self.backend.compile(onnx_module)
return compiled_module

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import logging
ir_printer = logging.getLogger("ir_printer")

from typing import Any

Expand Down Expand Up @@ -31,7 +33,9 @@ def __init__(self, backend: StablehloBackend):
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torchscript.compile(program, example_args, output_type="stablehlo")

ir_printer.debug("StablehloBackendTestConfig compiled module:")
renxida marked this conversation as resolved.
Show resolved Hide resolved
ir_printer.debug(module)
ir_printer.debug("End StablehloBackendTestConfig compiled module")
return self.backend.compile(module)

def run(self, artifact: Any, trace: Trace) -> Trace:
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/python/torch_mlir_e2e_test/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def compile(self, program: torch.nn.Module) -> CompiledArtifact:
"""Compile the provided torch.nn.Module into a compiled artifact"""
pass


renxida marked this conversation as resolved.
Show resolved Hide resolved
# Any should match result of `compile`.

@abc.abstractmethod
Expand Down
Loading