forked from jrzech/reproduce-chexnet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_model.py
58 lines (46 loc) · 1.89 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import pandas as pd
import cxr_dataset as CXR
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import sklearn
import sklearn.metrics as sklm
from torch.autograd import Variable
import numpy as np
import datetime
def make_pred_multilabel(data_transforms, model, PATH_TO_IMAGES, epoch_loss, CHROMOSOME):
"""
Gives predictions for test fold and calculates AUCs using previously trained model
Args:
data_transforms: torchvision transforms to preprocess raw images; same as validation transforms
model: densenet-121 from torchvision previously fine tuned to training data
PATH_TO_IMAGES: path at which NIH images can be found
Returns:
pred_df: dataframe containing individual predictions and ground truth for each test image
auc_df: dataframe containing aggregate AUCs by train/test tuples
"""
# calc preds in batches of 16, can reduce if your GPU has less RAM
BATCH_SIZE = 32
# set model to eval mode; required for proper predictions given use of batchnorm
model.train(False)
# create dataloader
dataset = CXR.CXRDataset(
path_to_images=PATH_TO_IMAGES,
fold="test",
transform=data_transforms['val'])
dataloader = torch.utils.data.DataLoader(
dataset, BATCH_SIZE, shuffle=False, num_workers=0)
size = len(dataset)
# create empty dfs
pred_df = pd.DataFrame(columns=["Image Index"])
true_df = pd.DataFrame(columns=["Image Index"])
# iterate over dataloader
for i, data in enumerate(dataloader):
inputs, labels, _ = data
inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
true_labels = labels.cpu().data.numpy()
batch_size = true_labels.shape
outputs = model(inputs)
probs = outputs.cpu().data.numpy()
return BATCH_SIZE