Skip to content

Commit

Permalink
BUG: GUI pretrained model usage (#102)
Browse files Browse the repository at this point in the history
* initial commit

* to keep changes

* to keep changes

* tested, working

* flake8 fix
  • Loading branch information
AhmetCanSolak committed Feb 12, 2022
1 parent 5f2b975 commit 2530bf5
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 28 deletions.
25 changes: 24 additions & 1 deletion aydin/gui/_qt/custom_widgets/denoise_tab_pretrained_method.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from qtpy.QtCore import Qt
from qtpy.QtWidgets import QWidget, QHBoxLayout, QVBoxLayout, QLabel, QScrollArea
from qtpy.QtWidgets import (
QWidget,
QHBoxLayout,
QVBoxLayout,
QLabel,
QScrollArea,
QCheckBox,
)

from aydin.gui._qt.custom_widgets.vertical_line_break_widget import (
QVerticalLineBreakWidget,
Expand All @@ -12,6 +19,7 @@ def __init__(self, parent, loaded_it):

self.parent = parent
self.loaded_it = loaded_it
self.name = loaded_it.__class__.__name__
self.description = f"This is a pretrained model, namely uses the image translator: {loaded_it.__class__.__name__}, will not train anything new but will quickly infer on the images of your choice."

# Widget layout
Expand Down Expand Up @@ -40,6 +48,21 @@ def __init__(self, parent, loaded_it):
self.right_side_vlayout = QVBoxLayout()
self.right_side_vlayout.setAlignment(Qt.AlignTop)

# Checkboxes
self.save_json_and_model_layout = QHBoxLayout()
self.save_json_and_model_layout.setAlignment(Qt.AlignLeft)

self.save_json_checkbox = QCheckBox("Save denoising options (JSON)")
self.save_json_checkbox.setChecked(True)
self.save_json_and_model_layout.addWidget(self.save_json_checkbox)
self.save_json_and_model_layout.addWidget(QVerticalLineBreakWidget(self))

self.save_model_checkbox = QCheckBox("Save the trained model")
self.save_model_checkbox.setChecked(True)
self.save_json_and_model_layout.addWidget(self.save_model_checkbox)

self.right_side_vlayout.addLayout(self.save_json_and_model_layout)

self.layout.addLayout(self.tab_method_layout, 35)
self.layout.addWidget(QVerticalLineBreakWidget(self))
self.layout.addLayout(self.right_side_vlayout, 50)
Expand Down
70 changes: 46 additions & 24 deletions aydin/gui/_qt/job_runners/denoise_job_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from qtpy.QtWidgets import QWidget, QHBoxLayout

from aydin.gui._qt.custom_widgets.denoise_tab_pretrained_method import (
DenoiseTabPretrainedMethodWidget,
)
from aydin.gui._qt.output_wrapper import OutputWrapper
from aydin.gui._qt.job_runners.worker import Worker
from aydin.io.io import imwrite
Expand All @@ -8,7 +11,10 @@
get_options_json_path,
get_save_model_path,
)
from aydin.restoration.denoise.util.denoise_utils import get_denoiser_class_instance
from aydin.restoration.denoise.util.denoise_utils import (
get_denoiser_class_instance,
get_pretrained_denoiser_class_instance,
)
from aydin.util.log.log import Log, lprint


Expand Down Expand Up @@ -48,9 +54,12 @@ def start_func(self, progress_callback):
self.image_paths,
self.output_folders,
):
self.denoiser.train(
training_image, batch_axes=self.batch_axes, chan_axes=self.channel_axes
)
if not self.pretrained:
self.denoiser.train(
training_image,
batch_axes=self.batch_axes,
chan_axes=self.channel_axes,
)

if self.denoiser.it:
denoised = self.denoiser.denoise(
Expand Down Expand Up @@ -127,29 +136,42 @@ def prep_and_run(self):
self.channel_axes = self.parent.tabs["Dimensions"].channel_axes
self.denoise_backend = self.parent.tabs["Denoise"].selected_backend

try:
self.it_transforms = self.parent.tabs["Pre/Post-Processing"].transforms
self.lower_level_args = self.parent.tabs["Denoise"].lower_level_args
except Exception:
self.parent.status_bar.showMessage(
"There is a mistake with given parameter values..."
)
return
self.pretrained = (
self.parent.tabs["Denoise"].current_backend_widget.__class__
is DenoiseTabPretrainedMethodWidget
)

self.save_options_json = self.parent.tabs[
"Denoise"
].current_backend_widget.save_json_checkbox.isChecked()
self.it_transforms = self.parent.tabs["Pre/Post-Processing"].transforms

self.save_model = self.parent.tabs[
"Denoise"
].current_backend_widget.save_model_checkbox.isChecked()
if self.pretrained:
self.denoiser = get_pretrained_denoiser_class_instance(
self.parent.tabs["Denoise"].current_backend_widget.loaded_it
)
self.save_options_json = False
self.save_model = False
else:
self.save_options_json = self.parent.tabs[
"Denoise"
].current_backend_widget.save_json_checkbox.isChecked()

self.save_model = self.parent.tabs[
"Denoise"
].current_backend_widget.save_model_checkbox.isChecked()

try:
lower_level_args = self.parent.tabs["Denoise"].lower_level_args
except Exception:
self.parent.status_bar.showMessage(
"There is a mistake with given parameter values..."
)
return

# Initialize required restoration instances
self.denoiser = get_denoiser_class_instance(
variant=self.denoise_backend,
lower_level_args=self.lower_level_args,
it_transforms=self.it_transforms,
)
# Initialize required restoration instances
self.denoiser = get_denoiser_class_instance(
variant=self.denoise_backend,
lower_level_args=lower_level_args,
it_transforms=self.it_transforms,
)

Log.gui_statusbar = self.parent.parent.statusBar

Expand Down
2 changes: 0 additions & 2 deletions aydin/restoration/denoise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from abc import abstractmethod, ABC
from pathlib import Path

import jsonpickle

from aydin.it.base import ImageTranslatorBase
from aydin.util.log.log import lprint

Expand Down
1 change: 0 additions & 1 deletion aydin/restoration/denoise/test/test_saveload.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from aydin import Classic
from aydin.io.datasets import normalise, add_noise
from aydin.io.folders import get_temp_folder
from aydin.restoration.denoise.noise2selfcnn import Noise2SelfCNN
from aydin.restoration.denoise.noise2selffgr import Noise2SelfFGR


Expand Down
21 changes: 21 additions & 0 deletions aydin/restoration/denoise/util/denoise_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
import importlib

from aydin import Classic
from aydin.restoration import denoise
from aydin.restoration.denoise.base import DenoiseRestorationBase
from aydin.restoration.denoise.noise2selfcnn import Noise2SelfCNN
from aydin.restoration.denoise.noise2selffgr import Noise2SelfFGR


def get_pretrained_denoiser_class_instance(loaded_model_it):
if "Classic" in loaded_model_it.__class__.__name__:
denoiser_class = Classic
elif "FGR" in loaded_model_it.__class__.__name__:
denoiser_class = Noise2SelfFGR
elif "CNN" in loaded_model_it.__class__.__name__:
denoiser_class = Noise2SelfCNN
else:
raise ValueError(
"Loaded model is not supported on restoration level implementations."
)

denoiser = denoiser_class()
denoiser.it = loaded_model_it

return denoiser


def get_denoiser_class_instance(variant, lower_level_args=None, it_transforms=None):
Expand Down

0 comments on commit 2530bf5

Please sign in to comment.