Skip to content

Commit

Permalink
Re #49: add loop to test inference time
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscyyeung committed Jun 1, 2023
1 parent 250d366 commit fc366df
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions UltrasoundSegmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import matplotlib.pyplot as plt

from tqdm import tqdm
from time import perf_counter
from datetime import datetime
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
Expand Down Expand Up @@ -361,14 +362,17 @@ def main(args):
iou_metric.reset()
confusion_matrix_metric.reset()

logging.info(f"Val loss: {val_loss}")
logging.info(f"Dice: {dice}\n"
f"IoU: {iou}\n"
f"Accuracy: {(acc := cm[0].item())}\n"
f"Precision: {(pre := cm[1].item())}\n"
f"Sensitivity: {(sen := cm[2].item())}\n"
f"Specificity: {(spe := cm[3].item())}\n"
f"F1 score: {(f1 := cm[4].item())}")
logging.info(
f"Validation results:\n"
f"\tLoss: {val_loss}\n"
f"\tDice: {dice}\n"
f"\tIoU: {iou}\n"
f"\tAccuracy: {(acc := cm[0].item())}\n"
f"\tPrecision: {(pre := cm[1].item())}\n"
f"\tSensitivity: {(sen := cm[2].item())}\n"
f"\tSpecificity: {(spe := cm[3].item())}\n"
f"\tF1 score: {(f1 := cm[4].item())}"
)

# Log a random sample of 3 test images along with their ground truth and predictions
random.seed(config["seed"])
Expand Down Expand Up @@ -424,7 +428,25 @@ def main(args):
and (epoch + 1) < config["num_epochs"]):
ckpt_model_path = os.path.join(ckpt_dir, f"model_{epoch:03d}.pt")
torch.save(model.state_dict(), ckpt_model_path)
logging.info(f"Saved model checkpoint to {ckpt_model_path}")
logging.info(f"Saved model checkpoint to {ckpt_model_path}.")

# Test inference time (load images before loop to exclude from time measurement)
logging.info("Measuring inference time...")
num_test_images = 100
inputs = torch.stack([val_dataset[i]["image"] for i in range(num_test_images)])
model.eval()
with torch.no_grad():
start = perf_counter()
for i in range(num_test_images):
model(inputs[i, :, :, :].unsqueeze(0).to(device=device))
end = perf_counter()
avg_inf_time = (end - start) / num_test_images
avg_inf_fps = 1 / avg_inf_time
logging.info(f"Average inference time per image: {avg_inf_time:.4f}s ({avg_inf_fps:.2f} FPS)")
run.log({
"avg_inference_time": avg_inf_time,
"avg_inference_fps": avg_inf_fps
})

# Save the final model also under the name "model.pt" so that we can easily find it later.
# This is useful if we want to use the model for inference without having to specify the model filename.
Expand All @@ -436,7 +458,6 @@ def main(args):
if args.save_torchscript:
ts_model_path = os.path.join(run_dir, "model_traced.pt")
model = model.to("cpu")
model.eval()
example_input = torch.rand(1, config["in_channels"], config["image_size"], config["image_size"])
traced_script_module = torch.jit.trace(model, example_input)
d = {"shape": example_input.shape}
Expand Down

0 comments on commit fc366df

Please sign in to comment.