Skip to content

Commit

Permalink
Fix up
Browse files Browse the repository at this point in the history
  • Loading branch information
GuanLuo committed May 10, 2023
1 parent 62b0ebc commit 202a545
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions qa/L0_device_memory_tracker/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,39 @@

import unittest
import time
from functools import partial

import tritonclient.http as tritonclient
import tritonclient.http as httpclient
import tritonclient.grpc as grpcclient

import nvidia_smi


class UnifiedClientProxy:

def __init__(self, client):
self.client_ = client

def __getattr__(self, attr):
forward_attr = getattr(self.client_, attr)
if type(self.client_) == grpcclient.InferenceServerClient:
if attr == "get_model_config":
return lambda *args, **kwargs: forward_attr(
*args, **kwargs, as_json=True)["config"]
elif attr == "get_inference_statistics":
return partial(forward_attr, as_json=True)
return forward_attr


class MemoryUsageTest(unittest.TestCase):

def setUp(self):
nvidia_smi.nvmlInit()
self.gpu_handle_ = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
self.client_ = tritonclient.InferenceServerClient(url="localhost:8000")
self.http_client_ = httpclient.InferenceServerClient(
url="localhost:8000")
self.grpc_client_ = grpcclient.InferenceServerClient(
url="localhost:8001")

def tearDown(self):
nvidia_smi.nvmlShutdown()
Expand All @@ -55,7 +76,7 @@ def verify_recorded_usage(self, model_stat):
recorded_gpu_usage = 0
for usage in model_stat["memory_usage"]:
if usage["type"] == "GPU":
recorded_gpu_usage += usage["byte_size"]
recorded_gpu_usage += int(usage["byte_size"])
# unload and verify recorded usage
before_total_usage = self.report_used_gpu_memory()
self.client_.unload_model(model_stat["name"])
Expand All @@ -71,13 +92,15 @@ def verify_recorded_usage(self, model_stat):
.format(model_stat["name"], usage_delta * 0.9, usage_delta * 1.1,
recorded_gpu_usage))

def test_onnx(self):
def test_onnx_http(self):
self.client_ = UnifiedClientProxy(self.http_client_)
model_stats = self.client_.get_inference_statistics()["model_stats"]
for model_stat in model_stats:
if self.is_testing_backend(model_stat["name"], "onnxruntime"):
self.verify_recorded_usage(model_stat)

def test_plan(self):
def test_plan_grpc(self):
self.client_ = UnifiedClientProxy(self.grpc_client_)
model_stats = self.client_.get_inference_statistics()["model_stats"]
for model_stat in model_stats:
if self.is_testing_backend(model_stat["name"], "tensorrt"):
Expand Down

0 comments on commit 202a545

Please sign in to comment.