Skip to content

Commit

Permalink
GUI: Load pretrained models (#97)
Browse files Browse the repository at this point in the history
* initial commit

* main_page level implemented

* to keep changes

* loading trained it instances directly

* removing the unzipped model after loading added

* populating gui implemented, working

* deleting older pretrained models before populating new ones implemented, tested, working

* bug fixed with mode toggle

* black fixes

* flake8 fixes
  • Loading branch information
AhmetCanSolak committed Feb 9, 2022
1 parent 8280af6 commit 6fcbcc6
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 28 deletions.
47 changes: 47 additions & 0 deletions aydin/gui/_qt/custom_widgets/denoise_tab_pretrained_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from qtpy.QtCore import Qt
from qtpy.QtWidgets import QWidget, QHBoxLayout, QVBoxLayout, QLabel, QScrollArea

from aydin.gui._qt.custom_widgets.vertical_line_break_widget import (
QVerticalLineBreakWidget,
)


class DenoiseTabPretrainedMethodWidget(QWidget):
def __init__(self, parent, loaded_it):
super(QWidget, self).__init__(parent)

self.parent = parent
self.loaded_it = loaded_it
self.description = f"This is a pretrained model, namely uses the image translator: {loaded_it.__class__.__name__}, will not train anything new but will quickly infer on the images of your choice."

# Widget layout
self.layout = QHBoxLayout()
self.tab_method_layout = QVBoxLayout()
self.tab_method_layout.setAlignment(Qt.AlignTop)

# Description Label
self.description_scroll = QScrollArea()
self.description_scroll.setStyleSheet("QScrollArea {border: none;}")
self.description_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.description_scroll.setAlignment(Qt.AlignTop)
self.description_label = QLabel(self.description)
self.description_label.setWordWrap(True)

self.description_label.setTextFormat(Qt.RichText)
self.description_label.setOpenExternalLinks(True)

self.description_label.setAlignment(Qt.AlignTop)
self.description_scroll.setWidget(self.description_label)
self.description_scroll.setWidgetResizable(True)
self.description_scroll.setMinimumHeight(300)

self.tab_method_layout.addWidget(self.description_scroll)

self.right_side_vlayout = QVBoxLayout()
self.right_side_vlayout.setAlignment(Qt.AlignTop)

self.layout.addLayout(self.tab_method_layout, 35)
self.layout.addWidget(QVerticalLineBreakWidget(self))
self.layout.addLayout(self.right_side_vlayout, 50)

self.setLayout(self.layout)
7 changes: 7 additions & 0 deletions aydin/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ def setupMenubar(self):
)
runMenu.addAction(saveOptionsJSONButton)

loadPretrainedModelButton = QAction('Load Pretrained Model', self)
loadPretrainedModelButton.setStatusTip('Load Pretrained Model')
loadPretrainedModelButton.triggered.connect(
lambda: self.main_widget.load_pretrained_model()
)
runMenu.addAction(loadPretrainedModelButton)

# Preferences Menu
self.basicModeButton = QAction('Basic mode', self)
self.basicModeButton.setEnabled(False)
Expand Down
11 changes: 11 additions & 0 deletions aydin/gui/main_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
QVBoxLayout,
QHBoxLayout,
QPushButton,
QFileDialog,
QTabWidget,
QApplication,
QStyle,
Expand Down Expand Up @@ -317,6 +318,16 @@ def save_options_json(self, path=None):
for path in image_paths:
save_any_json(args_dict, path)

def load_pretrained_model(self):
options = QFileDialog.Options()
options |= QFileDialog.DontUseNativeDialog
files, _ = QFileDialog.getOpenFileNames(
self, "Open File(s)", "", "All Files (*)", options=options
)

if files:
self.tabs["Denoise"].load_pretrained_model(pretrained_model_files=files)

def filestab_changed(self):
self.tabs["File(s)"].on_data_model_update()

Expand Down
43 changes: 43 additions & 0 deletions aydin/gui/tabs/qt/denoise.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import shutil

from qtpy.QtCore import Qt
from qtpy.QtWidgets import (
QWidget,
Expand All @@ -8,10 +11,14 @@
)

from aydin.gui._qt.custom_widgets.denoise_tab_method import DenoiseTabMethodWidget
from aydin.gui._qt.custom_widgets.denoise_tab_pretrained_method import (
DenoiseTabPretrainedMethodWidget,
)
from aydin.gui._qt.custom_widgets.horizontal_line_break_widget import (
QHorizontalLineBreakWidget,
)
from aydin.gui._qt.custom_widgets.readmoreless_label import QReadMoreLessLabel
from aydin.it.base import ImageTranslatorBase
from aydin.restoration.denoise.util.denoise_utils import (
get_list_of_denoiser_implementations,
)
Expand Down Expand Up @@ -57,6 +64,8 @@ def __init__(self, parent):

self.leftlist = QListWidget()

self.loaded_backends = []

(
backend_options,
backend_options_descriptions,
Expand Down Expand Up @@ -157,3 +166,37 @@ def set_advanced_enabled(self, enable: bool = False):
widget_index
).constructor_arguments_widget_dict.items():
constructor_arguments_widget.set_advanced_enabled(enable=enable)

self.refresh_pretrained_backends()

def load_pretrained_model(self, pretrained_model_files):
"""
Parameters
----------
pretrained_model_files : list
list of paths to the loaded pretrained model files
"""
for file in pretrained_model_files:
shutil.unpack_archive(file, os.path.dirname(file), "zip")
self.loaded_backends.append(ImageTranslatorBase.load(file[:-4]))
shutil.rmtree(file[:-4])

self.refresh_pretrained_backends()

def refresh_pretrained_backends(self):

for index in range(self.leftlist.count() - 1, -1, -1):
if "pretrained" in self.leftlist.item(index).text():
self.leftlist.takeItem(index)
self.stacked_widget.removeWidget(self.stacked_widget.widget(index))

for index, option in enumerate(self.loaded_backends):
self.leftlist.insertItem(
self.leftlist.count() + index, f"pretrained-{index}"
)

self.stacked_widget.addWidget(
DenoiseTabPretrainedMethodWidget(self, option)
)
10 changes: 1 addition & 9 deletions aydin/gui/tabs/qt/images.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pathlib
import numpy
from qtpy.QtCore import Qt, Slot
from qtpy.QtWidgets import (
Expand Down Expand Up @@ -134,14 +133,7 @@ def on_data_model_update(self):

self.image_list_tree_widget.clear()

for (
filename,
array,
metadata,
denoise,
path,
output_folder,
) in imagelist:
for (filename, array, metadata, denoise, path, output_folder) in imagelist:
qtree_widget_item = QTreeWidgetItem(
self.image_list_tree_widget,
[
Expand Down
18 changes: 3 additions & 15 deletions aydin/nn/models/torch/test/test_torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,14 @@ def test_supervised_2D_n2t():
reload_best_model_period = 1024
best_val_loss_value = math.inf

dataset = TorchDataset(
input_image,
lizard_image,
64,
self_supervised=False,
)
dataset = TorchDataset(input_image, lizard_image, 64, self_supervised=False)

data_loader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
num_workers=0,
pin_memory=True,
dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True
)

model = UNetModel(
nb_unet_levels=2,
supervised=True,
spacetime_ndim=2,
residual=True,
nb_unet_levels=2, supervised=True, spacetime_ndim=2, residual=True
)

n2t_unet_train_loop(input_image, lizard_image, model, data_loader)
Expand Down
5 changes: 1 addition & 4 deletions aydin/nn/models/utils/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ def __init__(self, input_image, target_image, tilesize, self_supervised=False):

def extract(image):
return extract_tiles(
image,
tile_size=tilesize,
extraction_step=tilesize,
flatten=True,
image, tile_size=tilesize, extraction_step=tilesize, flatten=True
)

bc_flat_input_image = input_image.reshape(-1, *input_image.shape[2:])
Expand Down

0 comments on commit 6fcbcc6

Please sign in to comment.