Skip to content

Commit

Permalink
add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
lanluo-nvidia committed Oct 4, 2024
1 parent 488e748 commit beda6a1
Showing 1 changed file with 35 additions and 21 deletions.
56 changes: 35 additions & 21 deletions tests/py/dynamo/partitioning/test_global_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,59 @@
import torch
import torch.nn.functional as F
import torch_tensorrt
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo import partitioning

from ..testing_utilities import lower_graph_testing


class TestGlobalPartitioning(TestCase):
def test_end2end_global_partition(self):
@parameterized.expand(
[
({}, 1),
({"torch.ops.aten.relu.default"}, 3),
]
)
def test_end2end_global_partition(self, torch_executed_ops, trt_mod_cnt):
class SimpleCNN(torch.nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.fc1 = torch.nn.Linear(32 * 134 * 134, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x

mod = SimpleCNN().to(dtype=torch.float16, device=torch.device("cuda"))
self.conv1 = torch.nn.Conv2d(3, 12, 3, padding=1)
self.bn = torch.nn.BatchNorm2d(12)
self.conv2 = torch.nn.Conv2d(12, 12, 3, padding=1)
self.fc1 = torch.nn.Linear(12 * 56 * 56, 10)

def forward(self, x, b=5):
x = self.conv1(x)
x = F.relu(x)
x = self.bn(x)
x = F.max_pool2d(x, (2, 2))
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = torch.flatten(x, 1)
x = x + b
return self.fc1(x)

mod = SimpleCNN().to("cuda")
mod.eval()
batch_size, tile_size = 1, 538
with torch.no_grad():
inputs = torch.randn(
batch_size, 3, tile_size, tile_size, device="cuda", dtype=torch.float16
)
inputs = torch.rand((1, 3, 224, 224)).to("cuda")
try:
torch_tensorrt.compile(
trt_mod = torch_tensorrt.compile(
mod,
ir="dynamo",
inputs=[inputs],
enabled_precisions={torch.float16},
min_block_size=1,
torch_executed_ops=torch_executed_ops,
use_fast_partitioner=False,
)
cnt = 0
for name, _ in trt_mod.named_children():
if "_run_on_acc" in name:
cnt += 1
self.assertEqual(cnt, trt_mod_cnt)
except Exception as e:
pytest.fail(f"unexpected exception raised: {e}")

Expand Down

0 comments on commit beda6a1

Please sign in to comment.