diff --git a/discoart/helper.py b/discoart/helper.py index de54b87..4f9f762 100644 --- a/discoart/helper.py +++ b/discoart/helper.py @@ -176,7 +176,7 @@ class NOP: def __call__(self, *args, **kwargs): return NOP() - __getattr__ = __enter__ = __exit__ = __call__ + __getattr__ = __enter__ = __exit__ = __iadd__ = __add__ = __call__ if is_jupyter(): from IPython import display as dp1 @@ -200,6 +200,8 @@ def __call__(self, *args, **kwargs): nondefault_config_handle = HTML() all_config_handle = HTML() + completed_handle = HTML() + completed_handle.value = '

Completed images will be displayed below

' code_snippet_handle = Textarea(rows=20) tab = Tab() tab.children = [ @@ -207,9 +209,16 @@ def __call__(self, *args, **kwargs): nondefault_config_handle, all_config_handle, code_snippet_handle, + completed_handle, ] for idx, j in enumerate( - ('Preview', 'Non-default config', 'Full config', 'Code snippet') + ( + 'Preview', + 'Non-default config', + 'Full config', + 'Code snippet', + 'Completed', + ) ): tab.set_title(idx, j) @@ -218,6 +227,7 @@ def __call__(self, *args, **kwargs): config=nondefault_config_handle, all_config=all_config_handle, code=code_snippet_handle, + completed=completed_handle, progress=pg_bar, ) diff --git a/discoart/persist.py b/discoart/persist.py index 3a49c01..697a39d 100644 --- a/discoart/persist.py +++ b/discoart/persist.py @@ -80,8 +80,12 @@ def _sample( _display_html.append(f'step {j} minibatch {k}') + if cur_t == -1: + _handlers.completed.value += f'
seed: {da[k].tags["seed"]}
step {j} minibatch {k}
' + if is_display_step: _handlers.preview.value = '
\n'.join(_display_html) + logger.debug('sample and plot is done') is_sampling_done.set() diff --git a/discoart/runner.py b/discoart/runner.py index 0c71d2c..631753e 100644 --- a/discoart/runner.py +++ b/discoart/runner.py @@ -358,8 +358,10 @@ def cond_fn(x, t, **kwargs): ) free_memory() - _da = [Document(tags=copy.deepcopy(vars(args))) for _ in range(args.batch_size)] - _da_gif = [Document() for _ in range(args.batch_size)] + _da = DocumentArray( + [Document(tags=copy.deepcopy(vars(args))) for _ in range(args.batch_size)] + ) + _da_gif = DocumentArray([Document() for _ in range(args.batch_size)]) da_batches.extend(_da) cur_t = diffusion.num_timesteps - skip_steps - 1