diff --git a/aydin/gui/_qt/custom_widgets/denoise_tab_pretrained_method.py b/aydin/gui/_qt/custom_widgets/denoise_tab_pretrained_method.py new file mode 100644 index 00000000..f6fcdbc0 --- /dev/null +++ b/aydin/gui/_qt/custom_widgets/denoise_tab_pretrained_method.py @@ -0,0 +1,47 @@ +from qtpy.QtCore import Qt +from qtpy.QtWidgets import QWidget, QHBoxLayout, QVBoxLayout, QLabel, QScrollArea + +from aydin.gui._qt.custom_widgets.vertical_line_break_widget import ( + QVerticalLineBreakWidget, +) + + +class DenoiseTabPretrainedMethodWidget(QWidget): + def __init__(self, parent, loaded_it): + super(QWidget, self).__init__(parent) + + self.parent = parent + self.loaded_it = loaded_it + self.description = f"This is a pretrained model, namely uses the image translator: {loaded_it.__class__.__name__}, will not train anything new but will quickly infer on the images of your choice." + + # Widget layout + self.layout = QHBoxLayout() + self.tab_method_layout = QVBoxLayout() + self.tab_method_layout.setAlignment(Qt.AlignTop) + + # Description Label + self.description_scroll = QScrollArea() + self.description_scroll.setStyleSheet("QScrollArea {border: none;}") + self.description_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.description_scroll.setAlignment(Qt.AlignTop) + self.description_label = QLabel(self.description) + self.description_label.setWordWrap(True) + + self.description_label.setTextFormat(Qt.RichText) + self.description_label.setOpenExternalLinks(True) + + self.description_label.setAlignment(Qt.AlignTop) + self.description_scroll.setWidget(self.description_label) + self.description_scroll.setWidgetResizable(True) + self.description_scroll.setMinimumHeight(300) + + self.tab_method_layout.addWidget(self.description_scroll) + + self.right_side_vlayout = QVBoxLayout() + self.right_side_vlayout.setAlignment(Qt.AlignTop) + + self.layout.addLayout(self.tab_method_layout, 35) + self.layout.addWidget(QVerticalLineBreakWidget(self)) + self.layout.addLayout(self.right_side_vlayout, 50) + + self.setLayout(self.layout) diff --git a/aydin/gui/gui.py b/aydin/gui/gui.py index d8ee6c7d..5b5d7f61 100644 --- a/aydin/gui/gui.py +++ b/aydin/gui/gui.py @@ -90,6 +90,13 @@ def setupMenubar(self): ) runMenu.addAction(saveOptionsJSONButton) + loadPretrainedModelButton = QAction('Load Pretrained Model', self) + loadPretrainedModelButton.setStatusTip('Load Pretrained Model') + loadPretrainedModelButton.triggered.connect( + lambda: self.main_widget.load_pretrained_model() + ) + runMenu.addAction(loadPretrainedModelButton) + # Preferences Menu self.basicModeButton = QAction('Basic mode', self) self.basicModeButton.setEnabled(False) diff --git a/aydin/gui/main_page.py b/aydin/gui/main_page.py index 0d36995d..da03ad23 100644 --- a/aydin/gui/main_page.py +++ b/aydin/gui/main_page.py @@ -5,6 +5,7 @@ QVBoxLayout, QHBoxLayout, QPushButton, + QFileDialog, QTabWidget, QApplication, QStyle, @@ -317,6 +318,16 @@ def save_options_json(self, path=None): for path in image_paths: save_any_json(args_dict, path) + def load_pretrained_model(self): + options = QFileDialog.Options() + options |= QFileDialog.DontUseNativeDialog + files, _ = QFileDialog.getOpenFileNames( + self, "Open File(s)", "", "All Files (*)", options=options + ) + + if files: + self.tabs["Denoise"].load_pretrained_model(pretrained_model_files=files) + def filestab_changed(self): self.tabs["File(s)"].on_data_model_update() diff --git a/aydin/gui/tabs/qt/denoise.py b/aydin/gui/tabs/qt/denoise.py index 6511dc00..88fc4c29 100644 --- a/aydin/gui/tabs/qt/denoise.py +++ b/aydin/gui/tabs/qt/denoise.py @@ -1,3 +1,6 @@ +import os +import shutil + from qtpy.QtCore import Qt from qtpy.QtWidgets import ( QWidget, @@ -8,10 +11,14 @@ ) from aydin.gui._qt.custom_widgets.denoise_tab_method import DenoiseTabMethodWidget +from aydin.gui._qt.custom_widgets.denoise_tab_pretrained_method import ( + DenoiseTabPretrainedMethodWidget, +) from aydin.gui._qt.custom_widgets.horizontal_line_break_widget import ( QHorizontalLineBreakWidget, ) from aydin.gui._qt.custom_widgets.readmoreless_label import QReadMoreLessLabel +from aydin.it.base import ImageTranslatorBase from aydin.restoration.denoise.util.denoise_utils import ( get_list_of_denoiser_implementations, ) @@ -57,6 +64,8 @@ def __init__(self, parent): self.leftlist = QListWidget() + self.loaded_backends = [] + ( backend_options, backend_options_descriptions, @@ -157,3 +166,37 @@ def set_advanced_enabled(self, enable: bool = False): widget_index ).constructor_arguments_widget_dict.items(): constructor_arguments_widget.set_advanced_enabled(enable=enable) + + self.refresh_pretrained_backends() + + def load_pretrained_model(self, pretrained_model_files): + """ + + Parameters + ---------- + pretrained_model_files : list + list of paths to the loaded pretrained model files + + """ + for file in pretrained_model_files: + shutil.unpack_archive(file, os.path.dirname(file), "zip") + self.loaded_backends.append(ImageTranslatorBase.load(file[:-4])) + shutil.rmtree(file[:-4]) + + self.refresh_pretrained_backends() + + def refresh_pretrained_backends(self): + + for index in range(self.leftlist.count() - 1, -1, -1): + if "pretrained" in self.leftlist.item(index).text(): + self.leftlist.takeItem(index) + self.stacked_widget.removeWidget(self.stacked_widget.widget(index)) + + for index, option in enumerate(self.loaded_backends): + self.leftlist.insertItem( + self.leftlist.count() + index, f"pretrained-{index}" + ) + + self.stacked_widget.addWidget( + DenoiseTabPretrainedMethodWidget(self, option) + ) diff --git a/aydin/gui/tabs/qt/images.py b/aydin/gui/tabs/qt/images.py index f59fa7ad..8e08affe 100644 --- a/aydin/gui/tabs/qt/images.py +++ b/aydin/gui/tabs/qt/images.py @@ -1,4 +1,3 @@ -import pathlib import numpy from qtpy.QtCore import Qt, Slot from qtpy.QtWidgets import ( @@ -134,14 +133,7 @@ def on_data_model_update(self): self.image_list_tree_widget.clear() - for ( - filename, - array, - metadata, - denoise, - path, - output_folder, - ) in imagelist: + for (filename, array, metadata, denoise, path, output_folder) in imagelist: qtree_widget_item = QTreeWidgetItem( self.image_list_tree_widget, [ diff --git a/aydin/nn/models/torch/test/test_torch_models.py b/aydin/nn/models/torch/test/test_torch_models.py index 79812802..ea80cc4c 100644 --- a/aydin/nn/models/torch/test/test_torch_models.py +++ b/aydin/nn/models/torch/test/test_torch_models.py @@ -50,26 +50,14 @@ def test_supervised_2D_n2t(): reload_best_model_period = 1024 best_val_loss_value = math.inf - dataset = TorchDataset( - input_image, - lizard_image, - 64, - self_supervised=False, - ) + dataset = TorchDataset(input_image, lizard_image, 64, self_supervised=False) data_loader = DataLoader( - dataset, - batch_size=1, - shuffle=True, - num_workers=0, - pin_memory=True, + dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True ) model = UNetModel( - nb_unet_levels=2, - supervised=True, - spacetime_ndim=2, - residual=True, + nb_unet_levels=2, supervised=True, spacetime_ndim=2, residual=True ) n2t_unet_train_loop(input_image, lizard_image, model, data_loader) diff --git a/aydin/nn/models/utils/torch_dataset.py b/aydin/nn/models/utils/torch_dataset.py index 2574a24e..ff855d9b 100644 --- a/aydin/nn/models/utils/torch_dataset.py +++ b/aydin/nn/models/utils/torch_dataset.py @@ -13,10 +13,7 @@ def __init__(self, input_image, target_image, tilesize, self_supervised=False): def extract(image): return extract_tiles( - image, - tile_size=tilesize, - extraction_step=tilesize, - flatten=True, + image, tile_size=tilesize, extraction_step=tilesize, flatten=True ) bc_flat_input_image = input_image.reshape(-1, *input_image.shape[2:])