Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Allow extra kwargs in to_onnx() and/or engine.export() to be passed to torch.onnx.export #2415

Open
1 task done
mbignotti opened this issue Nov 12, 2024 · 0 comments

Comments

@mbignotti
Copy link

Describe the bug

When trying to export a model larger than 2GB, onnx throws the following error:

RuntimeError: The serialized model is larger than the 2GiB limit imposed by the protobuf library. Therefore the output file must be a file path, so that the ONNX external data can be written to the same directory. Please specify the output file name.
File <command-4284593085184757>, line 1
----> 1 engine.export(model=model, export_root=TMP, export_type=ExportType.ONNX, transform=transform)

where model is a Padim with image size 512x512, and a wide_resnet50_2 as a backbone with default parameters (layers and features). Note that TMP already is a path to a directory, and not a path to an .onnx file.

Large onnx models are exported with external data (that is, with multiple files, not everything is included in the .onnx file), as specified here. And, indeed, torch.onnx.export accepts a external_data argument. That argument is is not exposed in Anomalib, since the to_onnx() method defined in Pytorch Lightning , that accepts **kwargs passed to torch.onnx.export, is overridden by the Anomalib ExportMixin (used in the based AnomalyModule).

My guess is that, exposing **kwargs passed to torch.onnx.export should solve the issue and allow exporting large onnx models.

Dataset

Other (please specify in the text field below)

Model

PADiM

Steps to reproduce the behavior

Train a large enough Padim model on any dataset, for example:

  • image size 512x512
  • backbone: wide_resnet50_2 with default layers and features

Try to export the model to onnx with

engine.export(model=model, export_root=TMP, export_type=ExportType.ONNX, transform=transform)

where TMP is the path to a local directory.

OS information

OS information:

  • OS: Ubuntu 20.04
  • Python version: 3.10
  • Anomalib version: 1.2.0
  • PyTorch version: 2.5.1
  • CUDA/cuDNN version: 12.4
  • Any other relevant information: training on a custom dataset, but any dataset with image size 512x512 and the specified model parameters should be enough to reproduce the error.

Expected behavior

The methods engine.export() and/or model.to_onnx() should accept extra kwargs that will be passed to torch.onnx.export(), in order to simplify the export process of large models.

Screenshots

image image

Pip/GitHub

pip

What version/branch did you use?

No response

Configuration YAML

No configuration used

Logs

RuntimeError: The serialized model is larger than the 2GiB limit imposed by the protobuf library. Therefore the output file must be a file path, so that the ONNX external data can be written to the same directory. Please specify the output file name.
File <command-4284593085184757>, line 1
----> 1 engine.export(model=model, export_root=TMP, export_type=ExportType.ONNX, transform=transform)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-cecb19f8-6e0a-4bc0-ba3d-c636bf461e18/lib/python3.10/site-packages/anomalib/engine/engine.py:958, in Engine.export(self, model, export_type, export_root, input_size, transform, compression_type, datamodule, metric, ov_args, ckpt_path)
    952     exported_model_path = model.to_torch(
    953         export_root=export_root,
    954         transform=transform,
    955         task=self.task,
    956     )
    957 elif export_type == ExportType.ONNX:
--> 958     exported_model_path = model.to_onnx(
    959         export_root=export_root,
    960         input_size=input_size,
    961         transform=transform,
    962         task=self.task,
    963     )
    964 elif export_type == ExportType.OPENVINO:
    965     exported_model_path = model.to_openvino(
    966         export_root=export_root,
    967         input_size=input_size,
   (...)
    973         ov_args=ov_args,
    974     )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-cecb19f8-6e0a-4bc0-ba3d-c636bf461e18/lib/python3.10/site-packages/anomalib/models/components/base/export_mixin.py:151, in ExportMixin.to_onnx(self, export_root, input_size, transform, task)
    149 _write_metadata_to_json(self._get_metadata(task), export_root)
    150 onnx_path = export_root / "model.onnx"
--> 151 torch.onnx.export(
    152     inference_model,
    153     input_shape.to(self.device),
    154     str(onnx_path),
    155     opset_version=14,
    156     dynamic_axes=dynamic_axes,
    157     input_names=["input"],
    158     output_names=["output"],
    159 )
    161 return onnx_path
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-cecb19f8-6e0a-4bc0-ba3d-c636bf461e18/lib/python3.10/site-packages/torch/onnx/__init__.py:375, in export(model, args, f, kwargs, export_params, verbose, input_names, output_names, opset_version, dynamic_axes, keep_initializers_as_inputs, dynamo, external_data, dynamic_shapes, report, verify, profile, dump_exported_program, artifacts_dir, fallback, training, operator_export_type, do_constant_folding, custom_opsets, export_modules_as_functions, autograd_inlining, **_)
    369 if dynamic_shapes:
    370     raise ValueError(
    371         "The exporter only supports dynamic shapes "
    372         "through parameter dynamic_axes when dynamo=False."
    373     )
--> 375 export(
    376     model,
    377     args,
    378     f,  # type: ignore[arg-type]
    379     kwargs=kwargs,
    380     export_params=export_params,
    381     verbose=verbose is True,
    382     input_names=input_names,
    383     output_names=output_names,
    384     opset_version=opset_version,
    385     dynamic_axes=dynamic_axes,
    386     keep_initializers_as_inputs=keep_initializers_as_inputs,
    387     training=training,
    388     operator_export_type=operator_export_type,
    389     do_constant_folding=do_constant_folding,
    390     custom_opsets=custom_opsets,
    391     export_modules_as_functions=export_modules_as_functions,
    392     autograd_inlining=autograd_inlining,
    393 )
    394 return None
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-cecb19f8-6e0a-4bc0-ba3d-c636bf461e18/lib/python3.10/site-packages/torch/onnx/utils.py:502, in export(model, args, f, kwargs, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, autograd_inlining)
    499 if kwargs is not None:
    500     args = args + (kwargs,)
--> 502 _export(
    503     model,
    504     args,
    505     f,
    506     export_params,
    507     verbose,
    508     training,
    509     input_names,
    510     output_names,
    511     operator_export_type=operator_export_type,
    512     opset_version=opset_version,
    513     do_constant_folding=do_constant_folding,
    514     dynamic_axes=dynamic_axes,
    515     keep_initializers_as_inputs=keep_initializers_as_inputs,
    516     custom_opsets=custom_opsets,
    517     export_modules_as_functions=export_modules_as_functions,
    518     autograd_inlining=autograd_inlining,
    519 )
    521 return None
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-cecb19f8-6e0a-4bc0-ba3d-c636bf461e18/lib/python3.10/site-packages/torch/onnx/utils.py:1564, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions, autograd_inlining)
   1561     dynamic_axes = {}
   1562 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> 1564 graph, params_dict, torch_out = _model_to_graph(
   1565     model,
   1566     args,
   1567     verbose,
   1568     input_names,
   1569     output_names,
   1570     operator_export_type,
   1571     val_do_constant_folding,
   1572     fixed_batch_size=fixed_batch_size,
   1573     training=training,
   1574     dynamic_axes=dynamic_axes,
   1575 )
   1577 # TODO: Don't allocate a in-memory string for the protobuf
   1578 defer_weight_export = (
   1579     export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
   1580 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-cecb19f8-6e0a-4bc0-ba3d-c636bf461e18/lib/python3.10/site-packages/torch/onnx/utils.py:1117, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
   1114 params_dict = _get_named_param_dict(graph, params)
   1116 try:
-> 1117     graph = _optimize_graph(
   1118         graph,
   1119         operator_export_type,
   1120         _disable_torch_constant_prop=_disable_torch_constant_prop,
   1121         fixed_batch_size=fixed_batch_size,
   1122         params_dict=params_dict,
   1123         dynamic_axes=dynamic_axes,
   1124         input_names=input_names,
   1125         module=module,
   1126     )
   1127 except Exception as e:
   1128     torch.onnx.log("Torch IR graph at exception: ", graph)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-cecb19f8-6e0a-4bc0-ba3d-c636bf461e18/lib/python3.10/site-packages/torch/onnx/utils.py:663, in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)
    661 _C._jit_pass_lint(graph)
    662 if GLOBALS.onnx_shape_inference:
--> 663     _C._jit_pass_onnx_graph_shape_type_inference(
    664         graph, params_dict, GLOBALS.export_onnx_opset_version
    665     )
    667 return graph

Code of Conduct

  • I agree to follow this project's Code of Conduct
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant