diff --git a/tests/test_pytorch/cmp_output.py b/tests/test_pytorch/cmp_output.py index 543f4ed1..63196dd1 100644 --- a/tests/test_pytorch/cmp_output.py +++ b/tests/test_pytorch/cmp_output.py @@ -1,9 +1,15 @@ import glob +def exclude_files(files, keys): + return [x for x in files if not any(key in x for key in keys)] + output_full_code = sorted(glob.glob("depyf_output/*/full_code_*.py")) expected_full_code = sorted(glob.glob("tests/depyf_output/*/full_code_*.py")) expected_full_code = [x[len("tests/"):] for x in expected_full_code] +output_full_code = exclude_files(output_full_code, ["insert_deferred_runtime_asserts"]) +expected_full_code = exclude_files(expected_full_code, ["insert_deferred_runtime_asserts"]) + msg = "Unexpected files:\n" for x in set(output_full_code) - set(expected_full_code): msg += x + "\n" @@ -23,6 +29,9 @@ expected_files.sort() expected_files = [x[len("tests/"):] for x in expected_files] +output_files = exclude_files(output_files, ["insert_deferred_runtime_asserts"]) +expected_files = exclude_files(expected_files, ["insert_deferred_runtime_asserts"]) + msg = f"len(output_files)={len(output_files)}, len(expected_files)={len(expected_files)}.\n" msg += "Unexpected files:\n" for x in set(output_files) - set(expected_files):