Skip to content

Commit

Permalink
switch to pathlib instead of os.path
Browse files Browse the repository at this point in the history
  • Loading branch information
kalebphipps committed Sep 5, 2024
1 parent 7df8d73 commit be5253c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 17 deletions.
1 change: 0 additions & 1 deletion tests/target_cropper/stj_data/stj-tower-measurements.json

This file was deleted.

39 changes: 31 additions & 8 deletions tests/target_cropper/test_focal_spot_detection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import pathlib
import sys

import torch

from paint import target_cropper
from paint import PAINT_ROOT, target_cropper

lib_dir = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir, os.pardir))
sys.path.append(lib_dir)
Expand All @@ -15,9 +16,16 @@ def test_focal_spot_detection() -> None:
applied_k_means = 5

warped_image = target_cropper.util.load_image(
os.path.join(__file__, os.pardir, "focal_spot_image.png")
pathlib.Path(PAINT_ROOT)
/ "tests"
/ "target_cropper"
/ "test_data"
/ "focal_spot_image.png"
)

mask = target_cropper.util.load_image(
pathlib.Path(PAINT_ROOT) / "tests" / "target_cropper" / "test_data" / "mask.png"
)
mask = target_cropper.util.load_image(os.path.join(__file__, os.pardir, "mask.png"))
mask = torch.where(mask != 0, torch.tensor(1.0), mask)

target = target_cropper.dataclasses.Target(
Expand All @@ -26,31 +34,47 @@ def test_focal_spot_detection() -> None:
template_offset=torch.tensor([0, 0.5]),
enu_position=torch.tensor([-1, 0, 1]),
template_image=target_cropper.util.load_image(
os.path.join(__file__, os.pardir, "stj_data", "stj_center_left.png")
pathlib.Path(PAINT_ROOT)
/ "tests"
/ "target_cropper"
/ "test_data"
/ "stj_center_left.png"
),
),
marker_2=target_cropper.dataclasses.Marker(
image_position=torch.tensor([0, 400]),
template_offset=torch.tensor([0.5, 1.0]),
enu_position=torch.tensor([1, 0, 1]),
template_image=target_cropper.util.load_image(
os.path.join(__file__, os.pardir, "stj_data", "stj_center_right.png")
pathlib.Path(PAINT_ROOT)
/ "tests"
/ "target_cropper"
/ "test_data"
/ "stj_center_right.png"
),
),
marker_3=target_cropper.dataclasses.Marker(
image_position=torch.tensor([400, 0]),
template_offset=torch.tensor([1.0, 0.5]),
enu_position=torch.tensor([-1, 0, -1]),
template_image=target_cropper.util.load_image(
os.path.join(__file__, os.pardir, "stj_data", "stj_lower_left.png")
pathlib.Path(PAINT_ROOT)
/ "tests"
/ "target_cropper"
/ "test_data"
/ "stj_lower_left.png"
),
),
marker_4=target_cropper.dataclasses.Marker(
image_position=torch.tensor([400, 400]),
template_offset=torch.tensor([0.5, 0]),
enu_position=torch.tensor([1, 0, -1]),
template_image=target_cropper.util.load_image(
os.path.join(__file__, os.pardir, "stj_data", "stj_lower_right.png")
pathlib.Path(PAINT_ROOT)
/ "tests"
/ "target_cropper"
/ "test_data"
/ "stj_lower_right.png"
),
),
output_shape=torch.Size([400, 400]),
Expand All @@ -60,7 +84,6 @@ def test_focal_spot_detection() -> None:
focal_spot = target_cropper.detect_focal_spot(
image=warped_image,
num_k_means=num_k_means,
applied_k_means=applied_k_means,
target=target,
)

Expand Down
31 changes: 23 additions & 8 deletions tests/target_cropper/test_focal_spot_detection_from_dict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import os
import pathlib
import sys

import matplotlib.pyplot as plt
import torch

from paint import target_cropper
from paint import PAINT_ROOT, target_cropper

lib_dir = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir, os.pardir))
sys.path.append(lib_dir)
Expand All @@ -17,14 +18,24 @@ def test_focal_spot_detection_from_dict():
applied_k_means = 5

image = target_cropper.util.load_image(
os.path.join(__file__, os.pardir, "stj_data", "stj_target.png")
pathlib.Path(PAINT_ROOT)
/ "tests"
/ "target_cropper"
/ "test_data"
/ "stj_target.png"
)
mask = target_cropper.util.load_image(
pathlib.Path(PAINT_ROOT) / "tests" / "target_cropper" / "test_data" / "mask.png"
)
mask = target_cropper.util.load_image(os.path.join(__file__, os.pardir, "mask.png"))
# mask = torch.where(mask != 0, torch.tensor(1.0), mask)
mask = None

with open(
os.path.join(__file__, os.pardir, "stj_data", "stj-tower-measurements.json"),
pathlib.Path(PAINT_ROOT)
/ "tests"
/ "target_cropper"
/ "test_data"
/ "stj-tower-measurements.json",
"r",
) as file:
data_dict = json.load(file)
Expand All @@ -37,7 +48,11 @@ def test_focal_spot_detection_from_dict():
data_dict["solar_tower_juelich_lower"]["upper_left"]
),
template_image=target_cropper.util.load_image(
os.path.join(__file__, os.pardir, "stj_data", "stj_center_left.png")
pathlib.Path(PAINT_ROOT)
/ "tests"
/ "target_cropper"
/ "test_data"
/ "stj_center_left.png"
),
),
marker_2=target_cropper.dataclasses.Marker(
Expand All @@ -47,7 +62,7 @@ def test_focal_spot_detection_from_dict():
data_dict["solar_tower_juelich_lower"]["upper_right"]
),
template_image=target_cropper.util.load_image(
os.path.join(__file__, os.pardir, "stj_data", "stj_center_right.png")
os.path.join(__file__, os.pardir, "test_data", "stj_center_right.png")
),
),
marker_3=target_cropper.dataclasses.Marker(
Expand All @@ -57,7 +72,7 @@ def test_focal_spot_detection_from_dict():
data_dict["solar_tower_juelich_lower"]["lower_left"]
),
template_image=target_cropper.util.load_image(
os.path.join(__file__, os.pardir, "stj_data", "stj_lower_left.png")
os.path.join(__file__, os.pardir, "test_data", "stj_lower_left.png")
),
),
marker_4=target_cropper.dataclasses.Marker(
Expand All @@ -67,7 +82,7 @@ def test_focal_spot_detection_from_dict():
data_dict["solar_tower_juelich_lower"]["lower_right"]
),
template_image=target_cropper.util.load_image(
os.path.join(__file__, os.pardir, "stj_data", "stj_lower_right.png")
os.path.join(__file__, os.pardir, "test_data", "stj_lower_right.png")
),
),
output_shape=torch.Size([400, 400]),
Expand Down

0 comments on commit be5253c

Please sign in to comment.