Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GUI: Load pretrained models #97

Merged
47 changes: 47 additions & 0 deletions aydin/gui/_qt/custom_widgets/denoise_tab_pretrained_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from qtpy.QtCore import Qt
from qtpy.QtWidgets import QWidget, QHBoxLayout, QVBoxLayout, QLabel, QScrollArea

from aydin.gui._qt.custom_widgets.vertical_line_break_widget import (
QVerticalLineBreakWidget,
)


class DenoiseTabPretrainedMethodWidget(QWidget):
def __init__(self, parent, loaded_it):
super(QWidget, self).__init__(parent)

self.parent = parent
self.loaded_it = loaded_it
self.description = f"This is a pretrained model, namely uses the image translator: {loaded_it.__class__.__name__}, will not train anything new but will quickly infer on the images of your choice."

# Widget layout
self.layout = QHBoxLayout()
self.tab_method_layout = QVBoxLayout()
self.tab_method_layout.setAlignment(Qt.AlignTop)

# Description Label
self.description_scroll = QScrollArea()
self.description_scroll.setStyleSheet("QScrollArea {border: none;}")
self.description_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.description_scroll.setAlignment(Qt.AlignTop)
self.description_label = QLabel(self.description)
self.description_label.setWordWrap(True)

self.description_label.setTextFormat(Qt.RichText)
self.description_label.setOpenExternalLinks(True)

self.description_label.setAlignment(Qt.AlignTop)
self.description_scroll.setWidget(self.description_label)
self.description_scroll.setWidgetResizable(True)
self.description_scroll.setMinimumHeight(300)

self.tab_method_layout.addWidget(self.description_scroll)

self.right_side_vlayout = QVBoxLayout()
self.right_side_vlayout.setAlignment(Qt.AlignTop)

self.layout.addLayout(self.tab_method_layout, 35)
self.layout.addWidget(QVerticalLineBreakWidget(self))
self.layout.addLayout(self.right_side_vlayout, 50)

self.setLayout(self.layout)
7 changes: 7 additions & 0 deletions aydin/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ def setupMenubar(self):
)
runMenu.addAction(saveOptionsJSONButton)

loadPretrainedModelButton = QAction('Load Pretrained Model', self)
loadPretrainedModelButton.setStatusTip('Load Pretrained Model')
loadPretrainedModelButton.triggered.connect(
lambda: self.main_widget.load_pretrained_model()
)
runMenu.addAction(loadPretrainedModelButton)

# Preferences Menu
self.basicModeButton = QAction('Basic mode', self)
self.basicModeButton.setEnabled(False)
Expand Down
11 changes: 11 additions & 0 deletions aydin/gui/main_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
QVBoxLayout,
QHBoxLayout,
QPushButton,
QFileDialog,
QTabWidget,
QApplication,
QStyle,
Expand Down Expand Up @@ -317,6 +318,16 @@ def save_options_json(self, path=None):
for path in image_paths:
save_any_json(args_dict, path)

def load_pretrained_model(self):
options = QFileDialog.Options()
options |= QFileDialog.DontUseNativeDialog
files, _ = QFileDialog.getOpenFileNames(
self, "Open File(s)", "", "All Files (*)", options=options
)

if files:
self.tabs["Denoise"].load_pretrained_model(pretrained_model_files=files)

def filestab_changed(self):
self.tabs["File(s)"].on_data_model_update()

Expand Down
43 changes: 43 additions & 0 deletions aydin/gui/tabs/qt/denoise.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import shutil

from qtpy.QtCore import Qt
from qtpy.QtWidgets import (
QWidget,
Expand All @@ -8,10 +11,14 @@
)

from aydin.gui._qt.custom_widgets.denoise_tab_method import DenoiseTabMethodWidget
from aydin.gui._qt.custom_widgets.denoise_tab_pretrained_method import (
DenoiseTabPretrainedMethodWidget,
)
from aydin.gui._qt.custom_widgets.horizontal_line_break_widget import (
QHorizontalLineBreakWidget,
)
from aydin.gui._qt.custom_widgets.readmoreless_label import QReadMoreLessLabel
from aydin.it.base import ImageTranslatorBase
from aydin.restoration.denoise.util.denoise_utils import (
get_list_of_denoiser_implementations,
)
Expand Down Expand Up @@ -57,6 +64,8 @@ def __init__(self, parent):

self.leftlist = QListWidget()

self.loaded_backends = []

(
backend_options,
backend_options_descriptions,
Expand Down Expand Up @@ -157,3 +166,37 @@ def set_advanced_enabled(self, enable: bool = False):
widget_index
).constructor_arguments_widget_dict.items():
constructor_arguments_widget.set_advanced_enabled(enable=enable)

self.refresh_pretrained_backends()

def load_pretrained_model(self, pretrained_model_files):
"""

Parameters
----------
pretrained_model_files : list
list of paths to the loaded pretrained model files

"""
for file in pretrained_model_files:
shutil.unpack_archive(file, os.path.dirname(file), "zip")
self.loaded_backends.append(ImageTranslatorBase.load(file[:-4]))
shutil.rmtree(file[:-4])

self.refresh_pretrained_backends()

def refresh_pretrained_backends(self):

for index in range(self.leftlist.count() - 1, -1, -1):
if "pretrained" in self.leftlist.item(index).text():
self.leftlist.takeItem(index)
self.stacked_widget.removeWidget(self.stacked_widget.widget(index))

for index, option in enumerate(self.loaded_backends):
self.leftlist.insertItem(
self.leftlist.count() + index, f"pretrained-{index}"
)

self.stacked_widget.addWidget(
DenoiseTabPretrainedMethodWidget(self, option)
)
10 changes: 1 addition & 9 deletions aydin/gui/tabs/qt/images.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pathlib
import numpy
from qtpy.QtCore import Qt, Slot
from qtpy.QtWidgets import (
Expand Down Expand Up @@ -134,14 +133,7 @@ def on_data_model_update(self):

self.image_list_tree_widget.clear()

for (
filename,
array,
metadata,
denoise,
path,
output_folder,
) in imagelist:
for (filename, array, metadata, denoise, path, output_folder) in imagelist:
qtree_widget_item = QTreeWidgetItem(
self.image_list_tree_widget,
[
Expand Down
18 changes: 3 additions & 15 deletions aydin/nn/models/torch/test/test_torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,14 @@ def test_supervised_2D_n2t():
reload_best_model_period = 1024
best_val_loss_value = math.inf

dataset = TorchDataset(
input_image,
lizard_image,
64,
self_supervised=False,
)
dataset = TorchDataset(input_image, lizard_image, 64, self_supervised=False)

data_loader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
num_workers=0,
pin_memory=True,
dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True
)

model = UNetModel(
nb_unet_levels=2,
supervised=True,
spacetime_ndim=2,
residual=True,
nb_unet_levels=2, supervised=True, spacetime_ndim=2, residual=True
)

n2t_unet_train_loop(input_image, lizard_image, model, data_loader)
Expand Down
5 changes: 1 addition & 4 deletions aydin/nn/models/utils/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ def __init__(self, input_image, target_image, tilesize, self_supervised=False):

def extract(image):
return extract_tiles(
image,
tile_size=tilesize,
extraction_step=tilesize,
flatten=True,
image, tile_size=tilesize, extraction_step=tilesize, flatten=True
)

bc_flat_input_image = input_image.reshape(-1, *input_image.shape[2:])
Expand Down