Skip to content

Commit

Permalink
animatediff reaping logic
Browse files Browse the repository at this point in the history
  • Loading branch information
kabachuha committed Nov 13, 2023
1 parent d9b6c46 commit 719aa3c
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 17 deletions.
29 changes: 25 additions & 4 deletions scripts/deforum_helpers/deforum_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,14 @@ def find_animatediff_script(p):
raise Exception("AnimateDiff script not found.")
return animatediff_script

def seed_animatediff(p, animatediff_args, args, anim_args, frame_idx):
if find_animatediff() is None or not is_animatediff_enabled(animatediff_args):
def get_animatediff_temp_dir(args):
return os.path.join(args.outdir, 'animatediff_temp')

def need_animatediff(animatediff_args):
return find_animatediff() is not None and is_animatediff_enabled(animatediff_args):

def seed_animatediff(p, animatediff_args, args, anim_args, root, frame_idx):
if not need_animatediff(animatediff_args):
return

keys = AnimateDiffKeys(animatediff_args, anim_args) # if not parseq_adapter.use_parseq else parseq_adapter.cn_keys
Expand All @@ -198,14 +204,14 @@ def seed_animatediff(p, animatediff_args, args, anim_args, frame_idx):

# Managing the frames to be fed into AD:
# Create a temporal directory
animatediff_temp_dir = os.path.join(args.outdir, 'animatediff_temp')
animatediff_temp_dir = get_animatediff_temp_dir(args)
if os.path.exists(animatediff_temp_dir):
shutil.rmtree(animatediff_temp_dir)
os.makedirs(animatediff_temp_dir)
# Copy the frames (except for the one which is being CN-made) into that dir
for offset in range(video_length - 1):
filename = f"{root.timestring}_{frame_idx - offset - 1:09}.png"
Image.open(os.path.join(args.outdir, filename)).save(os.path.join(f"{offset:09}.png"), "PNG")
Image.open(os.path.join(args.outdir, filename)).save(os.path.join(animatediff_temp_dir, f"{offset:09}.png"), "PNG")

animatediff_script = find_animatediff_script(p)
# let's put it before ControlNet to cause less problems
Expand Down Expand Up @@ -237,3 +243,18 @@ def seed_animatediff(p, animatediff_args, args, anim_args, frame_idx):
args = list(args_dict.values())

p.script_args_value = args + p.script_args_value

def reap_animatediff(images, args, root, frame_idx):
if not need_animatediff(animatediff_args):
return

animatediff_temp_dir = get_animatediff_temp_dir(args)
assert os.path.exists(animatediff_temp_dir)

for offset in range(len(images)):
frame = images[-offset-1]
cur_frame_idx = frame_idx - offset

# overwrite the results
filename = f"{root.timestring}_{cur_frame_idx:09}.png"
frame.save(os.path.join(args.outdir, filename), "PNG")
2 changes: 1 addition & 1 deletion scripts/deforum_helpers/deforum_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def read_cn_data(cn_idx):
p.script_args_value = [None] * controlnet_script.args_to

# Basically, launch AD on a number of previous frames once it hits the seed time
seed_animatediff(p, animatediff_args, args, anim_args, frame_idx)
seed_animatediff(p, animatediff_args, args, anim_args, root, frame_idx)

def create_cnu_dict(cn_args, prefix, img_np, mask_np, frame_idx, CnSchKeys):

Expand Down
27 changes: 15 additions & 12 deletions scripts/deforum_helpers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from types import SimpleNamespace

from .general_utils import debug_print
from .deforum_animatediff import reap_animatediff

def load_mask_latent(mask_input, shape):
# mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object
Expand Down Expand Up @@ -70,14 +71,14 @@ def pairwise_repl(iterable):
next(b, None)
return zip(a, b)

def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
def generate(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame=0, sampler_name=None):
if state.interrupted:
return None

if args.reroll_blank_frames == 'ignore':
return generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
return generate_inner(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name)

image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name)

if caught_vae_exception or not image.getbbox():
patience = args.reroll_patience
Expand All @@ -86,7 +87,7 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada
while caught_vae_exception or not image.getbbox():
print("Rerolling with +1 seed...")
args.seed += 1
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name)
patience -= 1
if patience == 0:
print("Rerolling with +1 seed failed for 10 iterations! Try setting webui's precision to 'full' and if it fails, please report this to the devs! Interrupting...")
Expand All @@ -100,12 +101,12 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada
return None
return image

def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame=0, sampler_name=None):
if cmd_opts.disable_nan_check:
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name)
else:
try:
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name)
except Exception as e:
if "A tensor with all NaNs was produced in VAE." in repr(e):
print(e)
Expand All @@ -114,7 +115,7 @@ def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args,
raise e
return image, False

def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
def generate_inner(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame=0, sampler_name=None):
# Setup the pipeline
p = get_webui_sd_pipeline(args, root)
p.prompt, p.negative_prompt = split_weighted_subprompts(args.prompt, frame, anim_args.max_frames)
Expand Down Expand Up @@ -235,7 +236,7 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
print_combined_table(args, anim_args, p_txt, keys, frame) # print dynamic table to cli

if is_controlnet_enabled(controlnet_args):
process_with_controlnet(p_txt, args, anim_args, controlnet_args, root, parseq_adapter, is_img2img=False, frame_idx=frame)
process_with_controlnet(p_txt, args, anim_args, controlnet_args, animatediff_args, root, parseq_adapter, is_img2img=False, frame_idx=frame)

with A1111OptionsOverrider({"control_net_detectedmap_dir" : os.path.join(args.outdir, "controlnet_detected_map")}):
processed = processing.process_images(p_txt)
Expand Down Expand Up @@ -277,7 +278,7 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
processed = mock_process_images(args, p, init_image)
else:
if is_controlnet_enabled(controlnet_args):
process_with_controlnet(p, args, anim_args, controlnet_args, root, parseq_adapter, is_img2img=True, frame_idx=frame)
process_with_controlnet(p, args, anim_args, controlnet_args, animatediff_args, root, parseq_adapter, is_img2img=True, frame_idx=frame)

with A1111OptionsOverrider({"control_net_detectedmap_dir" : os.path.join(args.outdir, "controlnet_detected_map")}):
processed = processing.process_images(p)
Expand All @@ -287,9 +288,11 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
root.initial_info = processed.info

if root.first_frame is None:
root.first_frame = processed.images[0]
root.first_frame = processed.images[-1]

results = processed.images[0]
results = processed.images[-1] # AD uses ascending order, so we need to get the last frame

reap_animatediff(processed.images, args, root, frame)

return results

Expand Down

0 comments on commit 719aa3c

Please sign in to comment.