Skip to content

Commit

Permalink
Merge pull request #43 from YuxinZou/dist
Browse files Browse the repository at this point in the history
fix mismatch bug in inference.py
  • Loading branch information
hxcai authored Oct 20, 2020
2 parents 180e158 + 607a028 commit 911c794
Showing 1 changed file with 34 additions and 23 deletions.
57 changes: 34 additions & 23 deletions tools/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vedaseg.runners import InferenceRunner
from vedaseg.utils import Config


CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
Expand All @@ -23,19 +24,23 @@
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]


def inverse_resize(pred, image_shape):
h, w, _ = image_shape
reisze_h, resized_w = pred.shape[0], pred.shape[1]
scale_factor = max(h / reisze_h, w / resized_w)
pred = cv2.resize(pred, (
int(reisze_h * scale_factor), int(reisze_h * scale_factor)),
interpolation=cv2.INTER_NEAREST)
return pred
def calc_resized_shape(target_shape, image_shape):
h, w = image_shape
size_h, size_w = target_shape
scale_factor = min(size_h / h, size_w / w)
resized_h, resized_w = int(h * scale_factor), int(w * scale_factor)
return resized_h, resized_w


def inverse_resize(output, image_shape):
h, w = image_shape
output = cv2.resize(output, (w, h), interpolation=cv2.INTER_NEAREST)
return output


def inverse_pad(pred, image_shape):
h, w, _ = image_shape
return pred[:h, :w]
def inverse_pad(output, image_shape):
h, w = image_shape
return output[:h, :w]


def plot_result(img, mask, cover):
Expand All @@ -45,10 +50,10 @@ def plot_result(img, mask, cover):
ax[0].set_title('image')
ax[0].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

ax[1].set_title(f'mask')
ax[1].set_title('mask')
ax[1].imshow(mask)

ax[2].set_title(f'cover')
ax[2].set_title('cover')
ax[2].imshow(cv2.cvtColor(cover, cv2.COLOR_BGR2RGB))
plt.show()

Expand Down Expand Up @@ -93,16 +98,16 @@ def result(fname,


def parse_args():
parser = argparse.ArgumentParser(description='Inference a segmentatation model')
parser = argparse.ArgumentParser(
description='Inference a segmentatation model')
parser.add_argument('config', type=str,
help='config file path')
parser.add_argument('checkpoint',
type=str, help='checkpoint file path')
parser.add_argument('image',
type=str,
parser.add_argument('checkpoint', type=str,
help='checkpoint file path')
parser.add_argument('image', type=str,
help='input image path')
parser.add_argument('--show', action='store_true',
help='show result')
help='show result images on screen')
parser.add_argument('--need_resize', action='store_true',
help='set true if there is LongestMaxSize in transform')
parser.add_argument('--out', default='./result',
Expand All @@ -123,17 +128,23 @@ def main():

runner = InferenceRunner(inference_cfg, common_cfg)
runner.load_checkpoint(args.checkpoint)

image = cv2.imread(args.image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w, c = image.shape
dummy_mask = np.zeros((h, w))
image_shape = image.shape[:2]
dummy_mask = np.zeros(image_shape)

output = runner(image, [dummy_mask])
if multi_label:
output = output.transpose((1, 2, 0))
output_shape = output.shape[:2]

if args.need_resize:
output = inverse_resize(output, image.shape)
output = inverse_pad(output, image.shape)
resized_shape = calc_resized_shape(output_shape, image_shape)
output = inverse_pad(output, resized_shape)
output = inverse_resize(output, image_shape)
else:
output = inverse_pad(output, image_shape)

result(args.image, output, multi_label=multi_label,
classes=CLASSES, palette=PALETTE, show=args.show,
Expand Down

0 comments on commit 911c794

Please sign in to comment.