Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] [Dynamic Shapes] Encountered bug when using Torch-TensorRT #3140

Open
yjjinjie opened this issue Sep 3, 2024 · 7 comments
Open
Assignees
Labels
bug Something isn't working

Comments

@yjjinjie
Copy link

yjjinjie commented Sep 3, 2024

Bug Description

when I use dynamic shape in trt, will raise error,

ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: Internal Error (Tensor [SLICE]-[aten_ops.expand.default]-[__/expand]_output has axis 0 with inherently negative length. Proven upper bound is -1. Network must have an instance where axis has non-negative length.)
ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: Internal Error (Output shape can not be computed for node [SLICE]-[aten_ops.expand.default]-[__/expand].)
ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: Internal Error (Output shape can not be computed for node [SLICE]-[aten_ops.expand.default]-[__/expand].)
Traceback (most recent call last):
  File "/larec/tzrec/tests/test3.py", line 73, in <module>
    trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 230, in compile
    trt_gm = compile_module(gm, inputs, settings)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 418, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 106, in convert_module
    interpreter_result = interpret_module_to_result(module, inputs, settings)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 87, in interpret_module_to_result
    interpreter_result = interpreter.run()
                         ^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 327, in run
    super().run()
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 372, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
                              ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 487, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1937, in aten_ops_sub
    return impl.elementwise.sub(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py", line 492, in sub
    return convert_binary_elementwise(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py", line 154, in convert_binary_elementwise
    lhs_val, rhs_val = broadcast(
                       ^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/fx/converters/converter_utils.py", line 404, in broadcast
    a_shape = tuple(a.shape)
              ^^^^^^^^^^^^^^
ValueError: __len__() should return >= 0

While executing %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%expand, %args1_1), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7fe317191230>: ((s0, 41), torch.float32, False, (41, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7fe3170105b0>: ((s0, 1, 41), torch.float32, False, (41, 41, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7fe3174f3c70>: ((s0, 50, 41), torch.float32, False, (41, 0, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7fe317026cb0>: ((s0, 50, 41), torch.float32, False, (2050, 41, 1), torch.contiguous_format, False, {})}})
Original traceback:
  File "<eval_with_key>.0 from /larec/tzrec/tests/test3.py:32 in forward", line 22, in forward
    sub = expand - getitem_1

the static shape is ok.just delete these

torch._dynamo.mark_dynamic(a, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
torch._dynamo.mark_dynamic(c, 0, min=1, max=8196)

To Reproduce

Steps to reproduce the behavior:

@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
    if len(grouped_features_keys) != len(args):
            raise ValueError(
                "The number of grouped_features_keys must match "
                "the number of arguments."
            )
    grouped_features = {
        key: value for key, value in zip(grouped_features_keys, args)
    }
    return grouped_features

@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
    return torch.arange(end, device=device)

class MatMul(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.keys = ["query","sequence","sequence_length"]
        attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
        self.mlp = MLP(in_features=41 * 4, **attn_mlp)
        self.linear = nn.Linear(self.mlp.hidden_units[-1], 1)

    def forward(self, *args1: List[torch.Tensor]):
        """Forward the module."""
        # use predict to avoid trace error in self._output_to_prediction(y)
        return self.predict(args1)
    
    def predict(self, args: List[torch.Tensor]):
        grouped_features= _get_dict(self.keys, args)
        query = grouped_features["query"]
        sequence = grouped_features["sequence"]
        sequence_length = grouped_features["sequence_length"]
        max_seq_length = sequence.size(1)
        sequence_mask = _arange(
            max_seq_length, device=sequence_length.device
        ).unsqueeze(0) < sequence_length.unsqueeze(1)

       
        queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)

        attn_input = torch.cat(
            [queries, sequence, queries - sequence, queries * sequence], dim=-1
        )
        
        return attn_input
       

model = MatMul().eval().cuda()
a=torch.randn(8196, 41).cuda()
b=torch.randn(8196, 50,41).cuda()
c=torch.randn(8196).cuda()
torch._dynamo.mark_dynamic(a, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
torch._dynamo.mark_dynamic(c, 0, min=1, max=8196)
inputs = [a, b,c]
print(model(*inputs)[0][0][0])
# seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
from torchrec.fx import symbolic_trace
model = symbolic_trace(model)

exp_program = torch.export.export(model, (*inputs,))
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
# Run inference
print(trt_gm(*inputs)[0][0][0])
# trt_gm = symbolic_trace(trt_gm)
trt_gm = torch.jit.trace(trt_gm,
                        example_inputs=(a,b,c), 
                        strict=False)
    
scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")

model_gpu = torch.jit.load(
    "./scripted_model_trt.pt", map_location="cuda:0"
)
print("load:",model_gpu(*inputs)[0][0][0])

the env:

CPU(s):                          104
On-line CPU(s) list:             0-103
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Platinum 8269CY CPU @ 2.50GHz
CPU family:                      6
Model:                           85
Thread(s) per core:              2
Core(s) per socket:              26
Socket(s):                       2
Stepping:                        7
CPU max MHz:                     3800.0000
CPU min MHz:                     1200.0000
BogoMIPS:                        5000.00
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization:                  VT-x
L1d cache:                       1.6 MiB (52 instances)
L1i cache:                       1.6 MiB (52 instances)
L2 cache:                        52 MiB (52 instances)
L3 cache:                        71.5 MiB (2 instances)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-103
Vulnerability Itlb multihit:     KVM: Mitigation: Split huge pages
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Tsx async abort:   Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] optree==0.12.1
[pip3] torch==2.4.0
[pip3] torch_tensorrt==2.4.0
[pip3] torchaudio==2.4.0
[pip3] torchelastic==0.2.2
[pip3] torchmetrics==1.0.3
[pip3] torchrec==0.8.0+cu121
[pip3] torchvision==0.19.0
[pip3] triton==3.0.0
[conda] blas                      1.0                         mkl
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344
[conda] mkl-service               2.4.0           py311h5eee18b_1
[conda] mkl_fft                   1.3.8           py311h5eee18b_0
[conda] mkl_random                1.2.4           py311hdb19cb5_0
[conda] numpy                     1.26.4          py311h08b1b3b_0
[conda] numpy-base                1.26.4          py311hf175353_0
[conda] optree                    0.12.1                   pypi_0    pypi
[conda] pytorch                   2.4.0           py3.11_cuda12.1_cudnn9.1.0_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch-tensorrt            2.4.0                    pypi_0    pypi
[conda] torchaudio                2.4.0               py311_cu121    pytorch
[conda] torchelastic              0.2.2                    pypi_0    pypi
[conda] torchmetrics              1.0.3                    pypi_0    pypi
[conda] torchrec                  0.8.0+cu121              pypi_0    pypi
[conda] torchtriton               3.0.0                     py311    pytorch
[conda] torchvision               0.19.0              py311_cu121    pytorch
@yjjinjie yjjinjie added the bug Something isn't working label Sep 3, 2024
@yjjinjie
Copy link
Author

yjjinjie commented Sep 3, 2024

@narendasan can you help me slove these problem? I want to set the dynamic shape in batch size & seq_len

@yjjinjie
Copy link
Author

yjjinjie commented Sep 9, 2024

@narendasan when to support torch_executed_modules in dynamo mode?

@apbose
Copy link
Collaborator

apbose commented Sep 24, 2024

Hi @yjjinjie you can set the dynamic shapes and pass in the dynamic inputs using torch_tensorrt.Input
something like

compile_spec = {
    "inputs": [
        torch_tensorrt.Input(
            min_shape=(1, 3, 224, 224),
            opt_shape=(8, 3, 224, 224),
            max_shape=(16, 3, 224, 224),
            dtype=torch.half,
        )
    ],
    "enabled_precisions": enabled_precisions,
    "ir": "dynamo",
}
trt_model = torch_tensorrt.compile(model, **compile_spec)

where model is your torch trt compiled module. You can refer to the example- https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/torch_compile_resnet_example.py
Since you want to set batch_size and seq_len as dynamic, you need to pass their range. eg:

        torch_tensorrt.Input(
            min_shape=(1, 1, 224, 224),
            opt_shape=(8, 2, 224, 224),
            max_shape=(16, 3, 224, 224),
            dtype=torch.half,
        )

where the first two (1, 8, 16) and (1, 2, 3) denote the batch_size and seq_len respectively. Can you try with this and see if you get the same error as above?

@yjjinjie
Copy link
Author

yes,I have tried the torch_tensorrt.Input. but it encountered a new bug

import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from tzrec.modules.mlp import MLP
from torch import nn


@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
    if len(grouped_features_keys) != len(args):
            raise ValueError(
                "The number of grouped_features_keys must match "
                "the number of arguments."
            )
    grouped_features = {
        key: value for key, value in zip(grouped_features_keys, args)
    }
    return grouped_features

@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
    return torch.arange(end, device=device)

class MatMul(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.keys = ["query","sequence","sequence_length"]
        attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
        self.mlp = MLP(in_features=41 * 4, **attn_mlp)
        self.linear = nn.Linear(self.mlp.hidden_units[-1], 1)

    def forward(self, *args1: List[torch.Tensor]):
        """Forward the module."""
        # use predict to avoid trace error in self._output_to_prediction(y)
        return self.predict(args1)
    
    def predict(self, args: List[torch.Tensor]):
        grouped_features= _get_dict(self.keys, args)
        query = grouped_features["query"]
        sequence = grouped_features["sequence"]
        sequence_length = grouped_features["sequence_length"]
    
        max_seq_length = sequence.size(1)
        sequence_mask = _arange(
            max_seq_length, device=sequence_length.device
        ).unsqueeze(0) < sequence_length.unsqueeze(1)

        
        queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)

        attn_input = torch.cat(
            [queries, sequence, queries - sequence, queries * sequence], dim=-1
        )
        
        return attn_input
       

model = MatMul().eval().cuda()
a=torch.randn(2, 41).cuda()
b=torch.randn(2, 2,41).cuda()
c=torch.randn(2).cuda()
d=torch.tensor(2)
# torch._dynamo.mark_dynamic(a, 0, min=1, max=8196)
# torch._dynamo.mark_dynamic(b, 0, min=1, max=8196)
# torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
# torch._dynamo.mark_dynamic(c, 0, min=1, max=8196)
# torch._dynamo.mark_dynamic(d, 0, min=1, max=8196)
inputs = [a, b,c]
print(model(*inputs)[0][0][0])
# seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
from torchrec.fx import symbolic_trace
model = symbolic_trace(model)

inputs_dy = []
inputs_dy.append(
        torch_tensorrt.Input(
            min_shape=[1, 41],
            opt_shape=[512, 41],
            max_shape=[8196, 41],
            name="query",
        )
    )
inputs_dy.append(
        torch_tensorrt.Input(
            min_shape=[1, 1,41],
            opt_shape=[512, 2, 41],
            max_shape=[8196,50, 41],
            name="sequence",
        )
    )
inputs_dy.append(
        torch_tensorrt.Input(
            min_shape=[1],
            opt_shape=[512],
            max_shape=[8196],
            name="sequence_length",
        )
    )

trt_gm = torch_tensorrt.compile(
           model,
                ir="dynamo",
                inputs=[inputs_dy],min_block_size=1,
                torch_executed_ops=["aten.expand"],)
print(trt_gm)
# exp_program = torch.export.export(model, (*inputs,))
# trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs,assume_dynamic_shape_support=True, 
#                                         allow_shape_tensors=True,min_block_size=2)
# Run inference
print(trt_gm(*inputs)[0][0][0])
# trt_gm = symbolic_trace(trt_gm)
trt_gm = torch.jit.trace(trt_gm,
                        example_inputs=(a,b,c), 
                        strict=False)
    
scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")

model_gpu = torch.jit.load(
    "./scripted_model_trt.pt", map_location="cuda:0"
)
from torch.profiler import ProfilerActivity, profile, record_function
    
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
) as prof:
    with record_function("model_inference"):
        model_gpu(*inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))

print("load:",model_gpu(*inputs)[0][0][0])

the error is:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1266, in RAISE_VARARGS
    raise exc.ObservedException(f"raised exception {val}")
torch._dynamo.exc.ObservedException: raised exception ExceptionVariable()

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
    self.call_function(fn, args, kwargs)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 808, in step
    self.exception_handler()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1304, in exception_handler
    raise exc.ObservedException
torch._dynamo.exc.ObservedException

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/larec/tzrec/tests/test_dy2.py", line 103, in <module>
    trt_gm = torch_tensorrt.compile(
             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 248, in compile
    exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_tracer.py", line 81, in trace
    exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/__init__.py", line 174, in export
    return _export(
           ^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 945, in wrapper
    raise e
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 928, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/exported_program.py", line 89, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1455, in _export
    aten_export_artifact = export_func(
                           ^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1060, in _strict_export
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 512, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1379, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 738, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 316, in __call__
    raise e
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 303, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
    return _compile(
           ^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
    tracer.run()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
    super().run()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 808, in step
    self.exception_handler()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1303, in exception_handler
    raise Unsupported("Observed exception")
torch._dynamo.exc.Unsupported: Observed exception

from user code:
   File "<eval_with_key>.0 from /larec/tzrec/tests/test_dy2.py:33 in forward", line 7, in forward
    _get_dict = __main____get_dict(['query', 'sequence', 'sequence_length'], _args1);  _args1 = None

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

@yjjinjie
Copy link
Author

I also tried the dynamic_shapes: https://pytorch.org/TensorRT/user_guide/dynamic_shapes.html

import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from tzrec.modules.mlp import MLP
from torch import nn

@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
    if len(grouped_features_keys) != len(args):
            raise ValueError(
                "The number of grouped_features_keys must match "
                "the number of arguments."
            )
    grouped_features = {
        key: value for key, value in zip(grouped_features_keys, args)
    }
    return grouped_features

@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
    return torch.arange(end, device=device)

class MatMul(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.keys = ["query","sequence","sequence_length"]
        attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
        self.mlp = MLP(in_features=41 * 4, **attn_mlp)
        self.linear = nn.Linear(self.mlp.hidden_units[-1], 1)

    def forward(self, *args1: List[torch.Tensor]):
        """Forward the module."""
        # use predict to avoid trace error in self._output_to_prediction(y)
        return self.predict(args1)
    
    def predict(self, args: List[torch.Tensor]):
        grouped_features= _get_dict(self.keys, args)
        query = grouped_features["query"]
        sequence = grouped_features["sequence"]
        sequence_length = grouped_features["sequence_length"]
        max_seq_length = sequence.size(1)
        sequence_mask = _arange(
            max_seq_length, device=sequence_length.device
        ).unsqueeze(0) < sequence_length.unsqueeze(1)

       
        queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)

        attn_input = torch.cat(
            [queries, sequence, queries - sequence, queries * sequence], dim=-1
        )
        
        return attn_input
       

model = MatMul().eval().cuda()
a=torch.randn(2, 41).cuda()
b=torch.randn(2, 2,41).cuda()
c=torch.randn(2).cuda()


# torch._dynamo.mark_dynamic(a, 0,min=1,max=8196)
# torch._dynamo.mark_dynamic(b, 0,min=1,max=8196)
# # torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
# torch._dynamo.mark_dynamic(c, 0,min=1,max=8196)
inputs = [a, b,c]
print(model(*inputs)[0][0][0])


batch = torch.export.Dim("batch",min=1,max=8196)
seq_len = torch.export.Dim("seq_len",min=1,max=50)
dynamic_shapes={"args1": ({0:batch},{0:batch,1:seq_len},{0:batch})}
# Export the model first with custom dynamic shape constraints
from torchrec.fx import symbolic_trace
model = symbolic_trace(model)
print(model.code)
exp_program = torch.export.export(model, (*inputs,),dynamic_shapes=dynamic_shapes)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs, assume_dynamic_shape_support=True, 
                                       allow_shape_tensors=True,min_block_size=2)

it has the same problem as the torch._dynamo.mark_dynamic(a, 0,min=1,max=8196)

@yjjinjie
Copy link
Author

@apbose can you help me?

@apbose
Copy link
Collaborator

apbose commented Oct 1, 2024

Yeah sure, let me take a look and get back on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants