Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX dynamic sized model export with torch.onnx.dynamo_export fails when torch.nn.functional.interpolate is used #124884

Open
pinkhamr-fb opened this issue Apr 24, 2024 · 0 comments
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@pinkhamr-fb
Copy link

pinkhamr-fb commented Apr 24, 2024

馃悰 Describe the bug

I'm trying to export a model from pytorch to onnx that needs to support dynamic input/output sizing at runtime. The model uses an interpolate step to resize features at a few points in the model. However, when dynamic_shapes option is set to True, this results in a very long error that seems to indicate the Onnx export fails at the interpolate step. Here's a short example that declares an extremely simple model, verifies it will run a forward pass, but then fails on export.

import torch
import torch.nn as nn
import torch.nn.functional as F


class DummyModel(torch.nn.Module):

    def __init__(self):
        super(DummyModel, self).__init__()

        # Basic set of filters
        self.conv_down1 = nn.Conv2d(1, 16, 3, 1, 0)
        self.activation = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        self.conv_up1 = nn.Conv2d(16, 1, 3, 1, 1)

    def forward(self, x):
        x = self.conv_down1(x)
        x = self.activation(x)

        # Upsample
        x = F.interpolate(x, scale_factor=2, mode="nearest")  # Onnx export fails here
        x = self.conv_up1(x)
        x = self.activation(x)

        return x


if __name__ == "__main__":
    # Declare the dummy model and confirm it works
    dummy_model = DummyModel()
    dummy_model = dummy_model.cuda()
    dummy_model.eval()

    fake_input = torch.randn(1, 1, 100, 100).cuda()
    fake_output = dummy_model(fake_input)

    # Now try to export it to onnx with dynamic sizing
    export_options = torch.onnx.ExportOptions(
        dynamic_shapes=True,  # Export works when set to False
    )  # Allow dynamic sizing
    onnx_program = torch.onnx.dynamo_export(
        dummy_model, fake_input, export_options=export_options
    )
    onnx_program.save("dummy_model.onnx")

The error printout is quite long, but the most informative part is the following (Full error printout at end):

RuntimeError: aten\src\ATen\RegisterCompositeExplicitAutograd.cpp:2252: SymIntArrayRef expected to contain only concrete integers

While executing %_unsafe_index : [num_users=1] = call_function[target=torch.ops.aten._unsafe_index.Tensor](args = (%copy_, [None, None, %unsqueeze, %_to_copy_1]), kwargs = {})
Original traceback:
  File "C:\Users\pinkhamr\working_dir\dummy_model.py", line 23, in forward
    x = F.interpolate(x, scale_factor=2, mode="nearest")  # Onnx export fails here

I've tried various scale_factor sizes, as well as different interpolation modes. However, they all seem to cause a similar export error. However, when I disable dynamic_shapes in the export options, they work.

Is there hope to get the interpolate operation supported for dynamic export?

Full error printout:

(mock_env) PS C:\Users\pinkhamr\working_dir> python .\dummy_model.py
C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
Traceback (most recent call last):
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\exporter.py", line 1433, in dynamo_export
    ).export()
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\exporter.py", line 1175, in export
    graph_module = self.options.fx_tracer.generate_fx(
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\fx\dynamo_graph_extractor.py", line 232, in generate_fx
    return self.pre_export_passes(options, model, graph_module, updated_model_args)  # type: ignore[return-value]
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\fx\dynamo_graph_extractor.py", line 242, in pre_export_passes
    return exporter.common_pre_export_passes(
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\exporter.py", line 1472, in common_pre_export_passes
    module = passes.Functionalize(
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\diagnostics\infra\decorator.py", line 151, in wrapper
    ctx.log_and_raise_if_error(diag)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\diagnostics\infra\context.py", line 366, in log_and_raise_if_error
    raise diagnostic.source_exception
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\diagnostics\infra\decorator.py", line 135, in wrapper
    return_values = fn(*args, **kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\fx\_pass.py", line 275, in run
    module = self._run(*args, **kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\fx\passes\functionalization.py", line 123, in _run
    graph_module = proxy_tensor.make_fx(
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 871, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_dynamo\eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_dynamo\external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 483, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_dynamo\eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_dynamo\external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\fx\_symbolic_trace.py", line 821, in trace
    (self.create_arg(fn(*args)),),
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 519, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\fx\passes\functionalization.py", line 86, in wrapped
    out = function(*inputs_functional)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\fx\passes\_utils.py", line 30, in wrapped
    return torch.fx.Interpreter(graph_module).run(*args)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\fx\interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\fx\interpreter.py", line 195, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\fx\interpreter.py", line 267, in call_function
    return target(*args, **kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_ops.py", line 513, in __call__
    return self._op(*args, **(kwargs or {}))
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\utils\_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 596, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 631, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 376, in proxy_call
    out = func(*args, **kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_ops.py", line 513, in __call__
    return self._op(*args, **(kwargs or {}))
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\utils\_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1712, in dispatch
    r = func(*args, **kwargs)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_ops.py", line 513, in __call__
    return self._op(*args, **(kwargs or {}))
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_meta_registrations.py", line 2908, in meta_index_Tensor
    indices = list(refs._maybe_broadcast(*indices))
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_refs\__init__.py", line 434, in _maybe_broadcast
    return tuple(__maybe_broadcast(x, common_shape) for x in args)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_refs\__init__.py", line 434, in <genexpr>
    return tuple(__maybe_broadcast(x, common_shape) for x in args)
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\_refs\__init__.py", line 426, in __maybe_broadcast
    return x.expand(common_shape)
RuntimeError: aten\src\ATen\RegisterCompositeExplicitAutograd.cpp:2252: SymIntArrayRef expected to contain only concrete integers

While executing %_unsafe_index : [num_users=1] = call_function[target=torch.ops.aten._unsafe_index.Tensor](args = (%copy_, [None, None, %unsqueeze, %_to_copy_1]), kwargs = {})
Original traceback:
  File "C:\Users\pinkhamr\working_dir\dummy_model.py", line 23, in forward
    x = F.interpolate(x, scale_factor=2, mode="nearest")  # Onnx export fails here


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\Users\pinkhamr\working_dir\dummy_model.py", line 43, in <module>
    onnx_program = torch.onnx.dynamo_export(
  File "C:\Users\pinkhamr\Anaconda3\envs\mock_env\lib\site-packages\torch\onnx\_internal\exporter.py", line 1444, in dynamo_export
    raise OnnxExporterError(
torch.onnx.OnnxExporterError: Failed to export the model to ONNX. Generating SARIF report at 'report_dynamo_export.sarif'. SARIF is a standard format for the output of static analysis tools. SARIF logs can be loaded in VS Code SARIF viewer extension, or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). Please report a bug on PyTorch Github: https://github.com/pytorch/pytorch/issues

Versions

Collecting environment information...
PyTorch version: 2.2.2+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Pro
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: N/A

Python version: 3.10.14 | packaged by Anaconda, Inc. | (main, Mar 21 2024, 16:20:14) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19045-SP0
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080
Nvidia driver version: 516.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
'wmic' is not recognized as an internal or external command,
operable program or batch file.

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] onnx==1.16.0
[pip3] onnxscript==0.1.0.dev20240418
[pip3] torch==2.2.2+cu118
[pip3] torchaudio==2.2.2+cu118
[pip3] torchvision==0.17.2+cu118
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] torch                     2.2.2+cu118              pypi_0    pypi
[conda] torchaudio                2.2.2+cu118              pypi_0    pypi
[conda] torchvision               0.17.2+cu118             pypi_0    pypi
@malfet malfet added module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants