Skip to content

Commit

Permalink
Fix anomaly map shapes to work with tiling
Browse files Browse the repository at this point in the history
Signed-off-by: blaz-r <blaz.rolih@gmail.com>
  • Loading branch information
blaz-r committed Apr 7, 2024
1 parent 6dffa5f commit 140fb91
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/anomalib/models/image/padim/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
torch.Size([32, 128, 28, 28]),
torch.Size([32, 256, 14, 14])]
"""
output_size = input_tensor.shape[-2:]
if self.tiler:
input_tensor = self.tiler.tile(input_tensor)

Expand All @@ -143,7 +144,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
embedding=embeddings,
mean=self.gaussian.mean,
inv_covariance=self.gaussian.inv_covariance,
image_size=input_tensor.shape[-2:],
image_size=output_size,
)
return output

Expand Down
3 changes: 2 additions & 1 deletion src/anomalib/models/image/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | dict[str, torch.
Returns:
Tensor | dict[str, torch.Tensor]: Embedding for training, anomaly map and anomaly score for testing.
"""
output_size = input_tensor.shape[-2:]
if self.tiler:
input_tensor = self.tiler.tile(input_tensor)

Expand Down Expand Up @@ -98,7 +99,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | dict[str, torch.
# reshape to w, h
patch_scores = patch_scores.reshape((batch_size, 1, width, height))
# get anomaly map
anomaly_map = self.anomaly_map_generator(patch_scores, input_tensor.shape[-2:])
anomaly_map = self.anomaly_map_generator(patch_scores, output_size)

output = {"anomaly_map": anomaly_map, "pred_score": pred_score}

Expand Down
3 changes: 2 additions & 1 deletion src/anomalib/models/image/stfpm/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def forward(self, images: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor
Returns:
Teacher and student features when in training mode, otherwise the predicted anomaly maps.
"""
output_size = images.shape[-2:]
if self.tiler:
images = self.tiler.tile(images)
teacher_features: dict[str, torch.Tensor] = self.teacher_model(images)
Expand All @@ -78,7 +79,7 @@ def forward(self, images: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor
output = self.anomaly_map_generator(
teacher_features=teacher_features,
student_features=student_features,
image_size=images.shape[-2:],
image_size=output_size,
)

return output

0 comments on commit 140fb91

Please sign in to comment.