diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 97aa2ec443..d213cca638 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -366,7 +366,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False - # If specified, try using the fast partitioner and fall back to the global one on failure if settings.use_fast_partitioner: try: @@ -408,6 +407,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: # Generate the corresponding TRT Module for those for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) + # filter on the GraphModule + if not isinstance(submodule, torch.fx.graph_module.GraphModule): + continue # Criteria for a module to be convertible to TRT if settings.use_fast_partitioner and "_run_on_acc" not in name: dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule)) diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 823a43beb8..6086d7c707 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -228,8 +228,22 @@ def partition( # Determine partitions based on user specifications and operator support # Then, fuse partitions and display overview of supported/unsupported operators partitions = partitioner.propose_partitions() - fused_graph = partitioner.fuse_partitions(partitions) - + # TODO: confirm with Naren whether this change is required or not + # tested both with and without this change, it both works + # the only difference is the graph node name, an example is as below: + # graph(): + # %x : [num_users=1] = placeholder[target=x] + # %_run_on_acc_0 : [num_users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {}) + # return (_run_on_acc_0,) + + # or + + # graph(): + # %x : [num_users=1] = placeholder[target=x] + # %fused_0 : [num_users=1] = call_module[target=fused_0](args = (%x,), kwargs = {}) + # return (fused_0,) + + fused_graph = partitioner.fuse_partitions(partitions, prefix="_run_on_acc_") if verbose: supported_ops.print_support_overview(len(partitions))