diff --git a/discoart/helper.py b/discoart/helper.py
index de54b87..4f9f762 100644
--- a/discoart/helper.py
+++ b/discoart/helper.py
@@ -176,7 +176,7 @@ class NOP:
def __call__(self, *args, **kwargs):
return NOP()
- __getattr__ = __enter__ = __exit__ = __call__
+ __getattr__ = __enter__ = __exit__ = __iadd__ = __add__ = __call__
if is_jupyter():
from IPython import display as dp1
@@ -200,6 +200,8 @@ def __call__(self, *args, **kwargs):
nondefault_config_handle = HTML()
all_config_handle = HTML()
+ completed_handle = HTML()
+ completed_handle.value = '
Completed images will be displayed below
'
code_snippet_handle = Textarea(rows=20)
tab = Tab()
tab.children = [
@@ -207,9 +209,16 @@ def __call__(self, *args, **kwargs):
nondefault_config_handle,
all_config_handle,
code_snippet_handle,
+ completed_handle,
]
for idx, j in enumerate(
- ('Preview', 'Non-default config', 'Full config', 'Code snippet')
+ (
+ 'Preview',
+ 'Non-default config',
+ 'Full config',
+ 'Code snippet',
+ 'Completed',
+ )
):
tab.set_title(idx, j)
@@ -218,6 +227,7 @@ def __call__(self, *args, **kwargs):
config=nondefault_config_handle,
all_config=all_config_handle,
code=code_snippet_handle,
+ completed=completed_handle,
progress=pg_bar,
)
diff --git a/discoart/persist.py b/discoart/persist.py
index 3a49c01..697a39d 100644
--- a/discoart/persist.py
+++ b/discoart/persist.py
@@ -80,8 +80,12 @@ def _sample(
_display_html.append(f'')
+ if cur_t == -1:
+ _handlers.completed.value += f'
seed: {da[k].tags["seed"]}
'
+
if is_display_step:
_handlers.preview.value = '
\n'.join(_display_html)
+
logger.debug('sample and plot is done')
is_sampling_done.set()
diff --git a/discoart/runner.py b/discoart/runner.py
index 0c71d2c..631753e 100644
--- a/discoart/runner.py
+++ b/discoart/runner.py
@@ -358,8 +358,10 @@ def cond_fn(x, t, **kwargs):
)
free_memory()
- _da = [Document(tags=copy.deepcopy(vars(args))) for _ in range(args.batch_size)]
- _da_gif = [Document() for _ in range(args.batch_size)]
+ _da = DocumentArray(
+ [Document(tags=copy.deepcopy(vars(args))) for _ in range(args.batch_size)]
+ )
+ _da_gif = DocumentArray([Document() for _ in range(args.batch_size)])
da_batches.extend(_da)
cur_t = diffusion.num_timesteps - skip_steps - 1