Skip to content

Commit

Permalink
repair gm_reg mode for resnet/densenet/alexnet
Browse files Browse the repository at this point in the history
  • Loading branch information
delphieritas authored Mar 14, 2021
1 parent 22d08bc commit c41b458
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion singa_easy/models/TorchModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,19 @@ def local_explain(self, org_imgs: Image,
traceback.print_exc(file=sys.stdout)

if enable_gradcam:
if 'densenet' in self._knobs.get("model_class"):
model_arch = 'densenet'
elif 'alexnet' in self._knobs.get("model_class"):
model_arch = 'alexnet'
elif 'resnet' in self._knobs.get("model_class"):
model_arch = 'resnet'
elif 'vgg' in self._knobs.get("model_class"):
model_arch = 'vgg'
else:
raise NameError()
try:
gc = GradCam(model=self._model,
model_arch='vgg',
model_arch=model_arch,
target_layer=None,
device=self.device)
(images, _, _) = utils.dataset.normalize_images(
Expand Down

0 comments on commit c41b458

Please sign in to comment.