Skip to content

Commit

Permalink
fix the globalpartitioner bug
Browse files Browse the repository at this point in the history
  • Loading branch information
lanluo-nvidia committed Sep 29, 2024
1 parent 43eb560 commit 81f8b6e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False

# If specified, try using the fast partitioner and fall back to the global one on failure
if settings.use_fast_partitioner:
try:
Expand Down Expand Up @@ -408,6 +407,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
# Generate the corresponding TRT Module for those
for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# filter on the GraphModule
if not isinstance(submodule, torch.fx.graph_module.GraphModule):
continue
# Criteria for a module to be convertible to TRT
if settings.use_fast_partitioner and "_run_on_acc" not in name:
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule))
Expand Down
18 changes: 16 additions & 2 deletions py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,22 @@ def partition(
# Determine partitions based on user specifications and operator support
# Then, fuse partitions and display overview of supported/unsupported operators
partitions = partitioner.propose_partitions()
fused_graph = partitioner.fuse_partitions(partitions)

# TODO: confirm with Naren whether this change is required or not
# tested both with and without this change, it both works
# the only difference is the graph node name, an example is as below:
# graph():
# %x : [num_users=1] = placeholder[target=x]
# %_run_on_acc_0 : [num_users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})
# return (_run_on_acc_0,)

# or

# graph():
# %x : [num_users=1] = placeholder[target=x]
# %fused_0 : [num_users=1] = call_module[target=fused_0](args = (%x,), kwargs = {})
# return (fused_0,)

fused_graph = partitioner.fuse_partitions(partitions, prefix="_run_on_acc_")
if verbose:
supported_ops.print_support_overview(len(partitions))

Expand Down

0 comments on commit 81f8b6e

Please sign in to comment.