Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: It cnn file refactor #78

Merged
merged 11 commits into from
Jan 11, 2022
37 changes: 22 additions & 15 deletions aydin/it/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,18 @@ def _load_internals(self, path: str):
self.infmodel = keras.models.load_model(join(path, "tf_inf_model"))

def get_receptive_field_radius(self, nb_unet_levels, shiftconv=False):
"""TODO: add proper docstrings here

Parameters
----------
nb_unet_levels : int
shiftconv : bool

Returns
-------
int

"""
if shiftconv:
rf = 7 if nb_unet_levels == 0 else 36 * 2 ** (nb_unet_levels - 1) - 6
else:
Expand All @@ -185,10 +197,6 @@ def stop_training(self):
"""
self.stop_fitting = True

def retrain(self, input_image, target_image, training_architecture=None):
self.training_architecture = training_architecture
self.train(input_image, target_image)

def _train(
self,
input_image,
Expand Down Expand Up @@ -283,7 +291,7 @@ def _train(
+ self.batch_size
)

lprint("Available mem: ", available_device_memory())
lprint(f"Available mem: {available_device_memory()}")
lprint(f"Batch size for training: {self.batch_size}")

# Decide whether to use validation pixels or patches
Expand Down Expand Up @@ -342,7 +350,7 @@ def _train(

# Last check of input size espetially for shiftconv
if 'shiftconv' == self.training_architecture and self.self_supervised:
# TODO: Hirofumi what is going on the conditional below :D <-- check input dim is compatible w/ shiftconv
# TODO: Hirofumi what is going on the conditional below <-- check input dim is compatible w/ shiftconv
if (
numpy.mod(
img_train.shape[1:][:-1],
Expand Down Expand Up @@ -404,7 +412,7 @@ def _train(
lprint(f'Batch normalization: {self.batch_norm}')
lprint(f'Training input size: {img_train.shape[1:]}')

# End of train function and beginning of _train from legacy implmentation
# End of train function and beginning of _train from legacy implementation
input_image = img_train

with lsection(
Expand Down Expand Up @@ -492,21 +500,20 @@ def _train(
self.loss_history = self.model.fit(
input_image=input_image,
target_image=target_image,
max_epochs=self.max_epochs,
callbacks=callbacks,
train_valid_ratio=train_valid_ratio,
verbose=self.verbose,
batch_size=self.batch_size,
total_num_patches=self.total_num_patches,
img_val=self.validation_images,
create_patches_for_validation=self._create_patches_for_validation,
train_valid_ratio=train_valid_ratio,
val_marker=self.validation_markers,
training_architecture=self.training_architecture,
create_patches_for_validation=self._create_patches_for_validation,
total_num_patches=self.total_num_patches,
batch_size=self.batch_size,
random_mask_ratio=self.random_mask_ratio,
patch_size=self.patch_size,
mask_size=self.mask_size,
verbose=self.verbose,
max_epochs=self.max_epochs,
ReduceLR_patience=self.ReduceLR_patience,
parent=self,
)

def _translate(self, input_image, image_slice=None, whole_image_shape=None):
Expand Down Expand Up @@ -558,12 +565,12 @@ def _translate(self, input_image, image_slice=None, whole_image_shape=None):

# Change the batch_size in split layer or input dimensions accordingly
kwargs_for_infmodel = {
'spacetime_ndim': self.spacetime_ndim,
'mini_batch_size': 1,
'nb_unet_levels': self.nb_unet_levels,
'normalization': self.batch_norm,
'activation': self.activation_fun,
'shiftconv': 'shiftconv' == self.training_architecture,
'spacetime_ndim': self.spacetime_ndim,
}

if len(input_image.shape[1:-1]) == 2:
Expand Down
2 changes: 0 additions & 2 deletions aydin/it/test/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from skimage.exposure import rescale_intensity
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# from tensorflow_core.python.keras.backend import clear_session
from tensorflow.python.keras.backend import clear_session

from aydin.io import io
Expand Down
35 changes: 35 additions & 0 deletions aydin/nn/models/jinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,41 @@ def fit(

return loss_history

def predict(
self,
x,
batch_size=None,
verbose=0,
steps=None,
callbacks=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
):
"""Overwritten model predict method.

Parameters
----------
x
batch_size
verbose
steps
callbacks
max_queue_size
workers
use_multiprocessing

Returns
-------

"""
# TODO: move as much as you can from it cnn _translate
return super().predict(
x,
batch_size=batch_size,
verbose=verbose,
)

def jinet_core(self, input_lyr):
dilated_conv_list = []
total_num_features = 0
Expand Down
39 changes: 37 additions & 2 deletions aydin/nn/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
self,
input_layer_size,
spacetime_ndim,
training_architecture: str = 'random',
mini_batch_size: int = 1,
nb_unet_levels: int = 4,
normalization: str = 'batch', # None, # 'instance',
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
"""
self.compiled = False

self.training_architecture = training_architecture
self.rot_batch_size = mini_batch_size
self.num_lyr = nb_unet_levels
self.normalization = normalization
Expand Down Expand Up @@ -424,7 +426,6 @@ def fit(
max_epochs=None,
ReduceLR_patience=None,
reduce_lr_factor=0.3,
parent=None,
):
"""

Expand All @@ -448,7 +449,6 @@ def fit(
max_epochs
ReduceLR_patience
reduce_lr_factor
parent

Returns
-------
Expand Down Expand Up @@ -593,3 +593,38 @@ def fit(
loss_history = -1

return loss_history

def predict(
self,
x,
batch_size=None,
verbose=0,
steps=None,
callbacks=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
):
"""Overwritten model predict method.

Parameters
----------
x
batch_size
verbose
steps
callbacks
max_queue_size
workers
use_multiprocessing

Returns
-------

"""
# TODO: move as much as you can from it cnn _translate
return super().predict(
x,
batch_size=batch_size,
verbose=verbose,
)
4 changes: 1 addition & 3 deletions aydin/restoration/denoise/noise2selfcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def __init__(
"""
super().__init__()
self.lower_level_args = lower_level_args
self.backend_it, self.backend_or_model = (
variant.split("-") if variant is not None else ("cnn", "jinet")
)
self.backend_it, self.backend_or_model = ("cnn", "jinet") if variant is None else variant.split("-")

self.input_model_path = input_model_path
self.use_model_flag = use_model
Expand Down