Skip to content

Commit

Permalink
Merge pull request #296 from Ash-Lee233/master
Browse files Browse the repository at this point in the history
use GE backend for graph mode
  • Loading branch information
zhanghuiyao committed Jun 18, 2024
2 parents a7a720f + aaccfd0 commit e69d1ef
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
3 changes: 3 additions & 0 deletions demo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import mindspore as ms
from mindspore import Tensor, context, nn
from mindspore._c_expression import ms_ctx_param

from mindyolo.data import COCO80_TO_COCO91_CLASS
from mindyolo.models import create_model
Expand Down Expand Up @@ -53,6 +54,8 @@ def get_parser_infer(parents=None):
def set_default_infer(args):
# Set Context
context.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if "jit_config" in ms_ctx_param.__members__ and args.mode == 0:
ms.set_context(jit_config={"jit_level": "O2"})
if args.device_target == "Ascend":
context.set_context(device_id=int(os.getenv("DEVICE_ID", 0)))
elif args.device_target == "GPU" and args.ms_enable_graph_kernel:
Expand Down
17 changes: 10 additions & 7 deletions mindyolo/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import numpy as np

import mindspore as ms
from mindspore import context, ops, Tensor, nn
from mindspore import ops, Tensor, nn
from mindspore.communication.management import get_group_size, get_rank, init
from mindspore.context import ParallelMode
from mindspore import ParallelMode
from mindspore._c_expression import ms_ctx_param

from mindyolo.utils import logger

Expand All @@ -21,22 +22,24 @@ def set_seed(seed=2):

def set_default(args):
# Set Context
context.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
ms.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if "jit_config" in ms_ctx_param.__members__ and args.mode == 0:
ms.set_context(jit_config={"jit_level": "O2"})
if args.device_target == "Ascend":
device_id = int(os.getenv("DEVICE_ID", 0))
context.set_context(device_id=device_id)
ms.set_context(device_id=device_id)
elif args.device_target == "GPU" and args.ms_enable_graph_kernel:
context.set_context(enable_graph_kernel=True)
ms.set_context(enable_graph_kernel=True)
# Set Parallel
if args.is_parallel:
init()
args.rank, args.rank_size, parallel_mode = get_rank(), get_group_size(), ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(device_num=args.rank_size, parallel_mode=parallel_mode, gradients_mean=True)
ms.set_auto_parallel_context(device_num=args.rank_size, parallel_mode=parallel_mode, gradients_mean=True)
else:
args.rank, args.rank_size = 0, 1
# Set Default
args.total_batch_size = args.per_batch_size * args.rank_size
args.sync_bn = args.sync_bn and context.get_context("device_target") == "Ascend" and args.rank_size > 1
args.sync_bn = args.sync_bn and ms.get_context("device_target") == "Ascend" and args.rank_size > 1
args.accumulate = max(1, np.round(args.nbs / args.total_batch_size)) if args.auto_accumulate else args.accumulate
# optimizer
args.optimizer.warmup_epochs = args.optimizer.get("warmup_epochs", 0)
Expand Down
3 changes: 3 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import mindspore as ms
from mindspore import Tensor, context, nn, ParallelMode
from mindspore.communication import init, get_rank, get_group_size
from mindspore._c_expression import ms_ctx_param

from mindyolo.data import COCO80_TO_COCO91_CLASS, COCODataset, create_loader
from mindyolo.models.model_factory import create_model
Expand Down Expand Up @@ -71,6 +72,8 @@ def get_parser_test(parents=None):
def set_default_test(args):
# Set Context
context.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if "jit_config" in ms_ctx_param.__members__ and args.mode == 0:
ms.set_context(jit_config={"jit_level": "O2"})
if args.device_target == "Ascend":
context.set_context(device_id=int(os.getenv("DEVICE_ID", 0)))
elif args.device_target == "GPU" and args.ms_enable_graph_kernel:
Expand Down

0 comments on commit e69d1ef

Please sign in to comment.