From 607a02864c73765ee1fdcb6b67f09e32a11cd11b Mon Sep 17 00:00:00 2001 From: YuxinZou Date: Tue, 20 Oct 2020 16:42:53 +0800 Subject: [PATCH] fix mismatch bug in inference.py --- tools/inference.py | 57 +++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/tools/inference.py b/tools/inference.py index f582c9c..9e6890c 100644 --- a/tools/inference.py +++ b/tools/inference.py @@ -11,6 +11,7 @@ from vedaseg.runners import InferenceRunner from vedaseg.utils import Config + CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', @@ -23,19 +24,23 @@ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] -def inverse_resize(pred, image_shape): - h, w, _ = image_shape - reisze_h, resized_w = pred.shape[0], pred.shape[1] - scale_factor = max(h / reisze_h, w / resized_w) - pred = cv2.resize(pred, ( - int(reisze_h * scale_factor), int(reisze_h * scale_factor)), - interpolation=cv2.INTER_NEAREST) - return pred +def calc_resized_shape(target_shape, image_shape): + h, w = image_shape + size_h, size_w = target_shape + scale_factor = min(size_h / h, size_w / w) + resized_h, resized_w = int(h * scale_factor), int(w * scale_factor) + return resized_h, resized_w + + +def inverse_resize(output, image_shape): + h, w = image_shape + output = cv2.resize(output, (w, h), interpolation=cv2.INTER_NEAREST) + return output -def inverse_pad(pred, image_shape): - h, w, _ = image_shape - return pred[:h, :w] +def inverse_pad(output, image_shape): + h, w = image_shape + return output[:h, :w] def plot_result(img, mask, cover): @@ -45,10 +50,10 @@ def plot_result(img, mask, cover): ax[0].set_title('image') ax[0].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - ax[1].set_title(f'mask') + ax[1].set_title('mask') ax[1].imshow(mask) - ax[2].set_title(f'cover') + ax[2].set_title('cover') ax[2].imshow(cv2.cvtColor(cover, cv2.COLOR_BGR2RGB)) plt.show() @@ -93,16 +98,16 @@ def result(fname, def parse_args(): - parser = argparse.ArgumentParser(description='Inference a segmentatation model') + parser = argparse.ArgumentParser( + description='Inference a segmentatation model') parser.add_argument('config', type=str, help='config file path') - parser.add_argument('checkpoint', - type=str, help='checkpoint file path') - parser.add_argument('image', - type=str, + parser.add_argument('checkpoint', type=str, + help='checkpoint file path') + parser.add_argument('image', type=str, help='input image path') parser.add_argument('--show', action='store_true', - help='show result') + help='show result images on screen') parser.add_argument('--need_resize', action='store_true', help='set true if there is LongestMaxSize in transform') parser.add_argument('--out', default='./result', @@ -123,17 +128,23 @@ def main(): runner = InferenceRunner(inference_cfg, common_cfg) runner.load_checkpoint(args.checkpoint) + image = cv2.imread(args.image) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - h, w, c = image.shape - dummy_mask = np.zeros((h, w)) + image_shape = image.shape[:2] + dummy_mask = np.zeros(image_shape) + output = runner(image, [dummy_mask]) if multi_label: output = output.transpose((1, 2, 0)) + output_shape = output.shape[:2] if args.need_resize: - output = inverse_resize(output, image.shape) - output = inverse_pad(output, image.shape) + resized_shape = calc_resized_shape(output_shape, image_shape) + output = inverse_pad(output, resized_shape) + output = inverse_resize(output, image_shape) + else: + output = inverse_pad(output, image_shape) result(args.image, output, multi_label=multi_label, classes=CLASSES, palette=PALETTE, show=args.show,