Skip to content

Commit

Permalink
Torch tests demos refactor (#245)
Browse files Browse the repository at this point in the history
* initial commit

* test_torch_unet cleaned

* new torch unet demos

* jinet demos

* torch tests renamed

* unused function is removed

* reduce lr scheduler added to n2s loop

* black fix

Co-authored-by: acs-ws <asolak@ku.edu.tr>
  • Loading branch information
AhmetCanSolak and acs-ws committed Sep 27, 2022
1 parent 2c69073 commit a1a82cb
Show file tree
Hide file tree
Showing 17 changed files with 331 additions and 436 deletions.
48 changes: 0 additions & 48 deletions aydin/nn/models/torch/demo/demo_n2t.py

This file was deleted.

Empty file.
87 changes: 87 additions & 0 deletions aydin/nn/models/torch/demo/jinet/n2s_2D_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# flake8: noqa
import time
import numpy
import torch

from aydin.io.datasets import (
normalise,
add_noise,
camera,
)
from aydin.nn.models.torch.torch_jinet import JINetModel
from aydin.nn.models.torch.torch_unet import n2s_train
from aydin.util.log.log import Log


def demo(image, model_class, do_add_noise=True):
"""
Demo for self-supervised denoising using camera image with synthetic noise
"""
Log.enable_output = True
Log.set_log_max_depth(8)

image = normalise(image)
image = numpy.expand_dims(image, axis=0)
image = numpy.expand_dims(image, axis=0)
noisy = add_noise(image) if do_add_noise else image
print(noisy.shape)

# noisy = torch.tensor(noisy)
image = torch.tensor(image)

model = model_class(
nb_unet_levels=2,
spacetime_ndim=2,
)

print("training starts")

start = time.time()
n2s_train(noisy, model, nb_epochs=128)
stop = time.time()
print(f"Training: elapsed time: {stop - start} ")

noisy = torch.tensor(noisy)
model.eval()
model = model.cpu()
print(f"noisy tensor shape: {noisy.shape}")
# in case of batching we have to do this:
start = time.time()
denoised = model(noisy)
stop = time.time()
print(f"inference: elapsed time: {stop - start} ")

noisy = noisy.detach().numpy()[0, 0, :, :]
image = image.detach().numpy()[0, 0, :, :]
denoised = denoised.detach().numpy()[0, 0, :, :]

image = numpy.clip(image, 0, 1)
noisy = numpy.clip(noisy, 0, 1)
denoised = numpy.clip(denoised, 0, 1)
# psnr_noisy = psnr(image, noisy)
# ssim_noisy = ssim(image, noisy)
# psnr_denoised = psnr(image, denoised)
# ssim_denoised = ssim(image, denoised)
# print("noisy :", psnr_noisy, ssim_noisy)
# print("denoised:", psnr_denoised, ssim_denoised)

import napari

viewer = napari.Viewer() # no prior setup needed
viewer.add_image(image, name='image')
viewer.add_image(noisy, name='noisy')
viewer.add_image(denoised, name='denoised')
napari.run()


if __name__ == '__main__':
# image = newyork()
# image = lizard()
# image = characters()
image = camera()
# image = pollen()
# image = dots()

model_class = JINetModel

demo(image, model_class)
87 changes: 87 additions & 0 deletions aydin/nn/models/torch/demo/jinet/n2t_2D_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# flake8: noqa
import time
import numpy
import torch

from aydin.io.datasets import (
normalise,
add_noise,
camera,
)
from aydin.nn.models.torch.torch_jinet import JINetModel
from aydin.nn.models.torch.torch_unet import n2t_train
from aydin.util.log.log import Log


def demo(image, model_class, do_add_noise=True):
"""
Demo for self-supervised denoising using camera image with synthetic noise
"""
Log.enable_output = True
Log.set_log_max_depth(8)

image = normalise(image)
image = numpy.expand_dims(image, axis=0)
image = numpy.expand_dims(image, axis=0)
noisy = add_noise(image) if do_add_noise else image
print(noisy.shape)

# noisy = torch.tensor(noisy)
image = torch.tensor(image)

model = model_class(
nb_unet_levels=2,
spacetime_ndim=2,
)

print("training starts")

start = time.time()
n2t_train(noisy, model, nb_epochs=128)
stop = time.time()
print(f"Training: elapsed time: {stop - start} ")

noisy = torch.tensor(noisy)
model.eval()
model = model.cpu()
print(f"noisy tensor shape: {noisy.shape}")
# in case of batching we have to do this:
start = time.time()
denoised = model(noisy)
stop = time.time()
print(f"inference: elapsed time: {stop - start} ")

noisy = noisy.detach().numpy()[0, 0, :, :]
image = image.detach().numpy()[0, 0, :, :]
denoised = denoised.detach().numpy()[0, 0, :, :]

image = numpy.clip(image, 0, 1)
noisy = numpy.clip(noisy, 0, 1)
denoised = numpy.clip(denoised, 0, 1)
# psnr_noisy = psnr(image, noisy)
# ssim_noisy = ssim(image, noisy)
# psnr_denoised = psnr(image, denoised)
# ssim_denoised = ssim(image, denoised)
# print("noisy :", psnr_noisy, ssim_noisy)
# print("denoised:", psnr_denoised, ssim_denoised)

import napari

viewer = napari.Viewer() # no prior setup needed
viewer.add_image(image, name='image')
viewer.add_image(noisy, name='noisy')
viewer.add_image(denoised, name='denoised')
napari.run()


if __name__ == '__main__':
# image = newyork()
# image = lizard()
# image = characters()
image = camera()
# image = pollen()
# image = dots()

model_class = JINetModel

demo(image, model_class)
34 changes: 18 additions & 16 deletions aydin/nn/models/torch/demo/unet/n2s_2D_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
normalise,
add_noise,
camera,
newyork,
)
from aydin.nn.models.torch.torch_linear_scaling_unet import LinearScalingUNetModel
from aydin.nn.models.torch.torch_res_unet import ResidualUNetModel
from aydin.nn.models.torch.torch_unet import UNetModel, n2s_train
from aydin.util.log.log import Log


def demo(image, do_add_noise=True):
def demo(image, model_class, do_add_noise=True):
"""
Demo for self-supervised denoising using camera image with synthetic noise
"""
Expand All @@ -28,15 +31,15 @@ def demo(image, do_add_noise=True):
# noisy = torch.tensor(noisy)
image = torch.tensor(image)

model = UNetModel(
model = model_class(
nb_unet_levels=2,
spacetime_ndim=2,
)

print("training starts")

start = time.time()
n2s_train(noisy, model, nb_epochs=128)
n2s_train(noisy, model, nb_epochs=256)
stop = time.time()
print(f"Training: elapsed time: {stop - start} ")

Expand Down Expand Up @@ -74,16 +77,15 @@ def demo(image, do_add_noise=True):


if __name__ == '__main__':
# newyork_image = newyork()
# demo(newyork_image, "newyork")
# lizard_image = lizard()
# demo(lizard_image, "lizard")
# characters_image = characters()
# demo(characters_image, "characters")

camera_image = camera()
demo(camera_image, "camera")
# pollen_image = pollen()
# demo(pollen_image, "pollen")
# dots_image = dots()
# demo(dots_image, "dots")
image = newyork()[256 : 256 + 512, 256 : 256 + 512]
# image = lizard()
# image = characters()
# image = camera()
# image = pollen()
# image = dots()

model_class = UNetModel
# model_class = ResidualUNetModel
# model_class = LinearScalingUNetModel

demo(image, model_class)
63 changes: 34 additions & 29 deletions aydin/nn/models/torch/demo/unet/n2t_2D_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@
import time
import numpy
import torch
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

from aydin.io.datasets import (
normalise,
add_noise,
dots,
camera,
)
from aydin.nn.models.torch.torch_linear_scaling_unet import LinearScalingUNetModel
from aydin.nn.models.torch.torch_res_unet import ResidualUNetModel
from aydin.nn.models.torch.torch_unet import n2t_train
from aydin.nn.models.torch.torch_unet import UNetModel, n2t_train
from aydin.util.log.log import Log


def demo(image, do_add_noise=True):
def demo(image, model_class, do_add_noise=True):
"""
Demo for self-supervised denoising using camera image with synthetic noise
"""
Expand All @@ -26,19 +25,27 @@ def demo(image, do_add_noise=True):
image = numpy.expand_dims(image, axis=0)
image = numpy.expand_dims(image, axis=0)
noisy = add_noise(image) if do_add_noise else image
print(noisy.shape)

noisy = torch.tensor(noisy)
# noisy = torch.tensor(noisy)
image = torch.tensor(image)

model = ResidualUNetModel(nb_unet_levels=2, supervised=True, spacetime_ndim=2)
model = model_class(
nb_unet_levels=2,
spacetime_ndim=2,
)

print("training starts")

start = time.time()
n2t_train(noisy, image, model)
n2t_train(noisy, model, nb_epochs=128)
stop = time.time()
print(f"Training: elapsed time: {stop - start} ")

noisy = torch.tensor(noisy)
model.eval()
model = model.cpu()
print(f"noisy tensor shape: {noisy.shape}")
# in case of batching we have to do this:
start = time.time()
denoised = model(noisy)
Expand All @@ -52,12 +59,12 @@ def demo(image, do_add_noise=True):
image = numpy.clip(image, 0, 1)
noisy = numpy.clip(noisy, 0, 1)
denoised = numpy.clip(denoised, 0, 1)
psnr_noisy = psnr(image, noisy)
ssim_noisy = ssim(image, noisy)
psnr_denoised = psnr(image, denoised)
ssim_denoised = ssim(image, denoised)
print("noisy :", psnr_noisy, ssim_noisy)
print("denoised:", psnr_denoised, ssim_denoised)
# psnr_noisy = psnr(image, noisy)
# ssim_noisy = ssim(image, noisy)
# psnr_denoised = psnr(image, denoised)
# ssim_denoised = ssim(image, denoised)
# print("noisy :", psnr_noisy, ssim_noisy)
# print("denoised:", psnr_denoised, ssim_denoised)

import napari

Expand All @@ -68,18 +75,16 @@ def demo(image, do_add_noise=True):
napari.run()


# NOT Working
# newyork_image = newyork()
# demo(newyork_image, "newyork")
# lizard_image = lizard()
# demo(lizard_image, "lizard")
# characters_image = characters()
# demo(characters_image, "characters")

# Working
# camera_image = camera()
# demo(camera_image, "camera")
# pollen_image = pollen()
# demo(pollen_image, "pollen")
dots_image = dots()
demo(dots_image, "dots")
if __name__ == '__main__':
# image = newyork()
# image = lizard()
# image = characters()
image = camera()
# image = pollen()
# image = dots()

model_class = UNetModel
# model_class = ResidualUNetModel
# model_class = LinearScalingUNetModel

demo(image, model_class)
File renamed without changes.
Loading

0 comments on commit a1a82cb

Please sign in to comment.