From 46a4954f7eed24bb780211a2fbeffd5570d21f3a Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Fri, 3 Nov 2023 12:26:04 -0700 Subject: [PATCH] allow gpu detection (#1261) --- .../src/main/java/ai/djl/python/engine/TrtLlmUtils.java | 3 +++ serving/docker/partition/trt_llm_partition.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/engines/python/src/main/java/ai/djl/python/engine/TrtLlmUtils.java b/engines/python/src/main/java/ai/djl/python/engine/TrtLlmUtils.java index cf5cda12169..d66acbcd10f 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/TrtLlmUtils.java +++ b/engines/python/src/main/java/ai/djl/python/engine/TrtLlmUtils.java @@ -13,6 +13,7 @@ package ai.djl.python.engine; import ai.djl.engine.EngineException; +import ai.djl.util.cuda.CudaUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -79,6 +80,8 @@ private static List getStrings(PyModel model, Path trtLlmRepoDir, String commandList.add(model.getModelPath().toAbsolutePath().toString()); commandList.add("--trt_llm_model_repo"); commandList.add(trtLlmRepoDir.toAbsolutePath().toString()); + commandList.add("--gpu_count"); + commandList.add(String.valueOf(CudaUtils.getGpuCount())); if (modelId != null) { commandList.add("--model_path"); commandList.add(modelId); diff --git a/serving/docker/partition/trt_llm_partition.py b/serving/docker/partition/trt_llm_partition.py index 86719b81c08..45fcbc2a719 100644 --- a/serving/docker/partition/trt_llm_partition.py +++ b/serving/docker/partition/trt_llm_partition.py @@ -30,6 +30,8 @@ def create_trt_llm_repo(properties, args): kwargs['trt_llm_model_repo'] = args.trt_llm_model_repo kwargs = update_kwargs_with_env_vars(kwargs) model_id_or_path = args.model_path or kwargs['model_id'] + if 'max' == kwargs.get('tensor_parallel_degree', -1): + kwargs['tensor_parallel_degree'] = int(args.gpu_count) create_model_repo(model_id_or_path, **kwargs) @@ -60,6 +62,11 @@ def main(): type=str, required=True, help='local path where trt llm model repo will be created') + parser.add_argument( + '--gpu_count', + type=str, + required=True, + help='The total number of gpus in the system') parser.add_argument('--model_path', type=str, required=False,