Skip to content

Commit

Permalink
allow gpu detection (deepjavalibrary#1261)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan committed Nov 3, 2023
1 parent a6fa50e commit 46a4954
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.python.engine;

import ai.djl.engine.EngineException;
import ai.djl.util.cuda.CudaUtils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -79,6 +80,8 @@ private static List<String> getStrings(PyModel model, Path trtLlmRepoDir, String
commandList.add(model.getModelPath().toAbsolutePath().toString());
commandList.add("--trt_llm_model_repo");
commandList.add(trtLlmRepoDir.toAbsolutePath().toString());
commandList.add("--gpu_count");
commandList.add(String.valueOf(CudaUtils.getGpuCount()));
if (modelId != null) {
commandList.add("--model_path");
commandList.add(modelId);
Expand Down
7 changes: 7 additions & 0 deletions serving/docker/partition/trt_llm_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def create_trt_llm_repo(properties, args):
kwargs['trt_llm_model_repo'] = args.trt_llm_model_repo
kwargs = update_kwargs_with_env_vars(kwargs)
model_id_or_path = args.model_path or kwargs['model_id']
if 'max' == kwargs.get('tensor_parallel_degree', -1):
kwargs['tensor_parallel_degree'] = int(args.gpu_count)
create_model_repo(model_id_or_path, **kwargs)


Expand Down Expand Up @@ -60,6 +62,11 @@ def main():
type=str,
required=True,
help='local path where trt llm model repo will be created')
parser.add_argument(
'--gpu_count',
type=str,
required=True,
help='The total number of gpus in the system')
parser.add_argument('--model_path',
type=str,
required=False,
Expand Down

0 comments on commit 46a4954

Please sign in to comment.