Skip to content

Commit

Permalink
Allow controlnet to load input images from URL. Unify extension valid…
Browse files Browse the repository at this point in the history
…ation logic.
  • Loading branch information
rewbs committed Feb 7, 2024
1 parent df6a63d commit ee18d78
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 37 deletions.
5 changes: 2 additions & 3 deletions scripts/deforum_helpers/deforum_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,10 @@ def process_controlnet_input_frames(args, anim_args, controlnet_args, video_path
frame_path = os.path.join(args.outdir, f'controlnet_{id}_{outdir_suffix}')
os.makedirs(frame_path, exist_ok=True)

accepted_image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
if video_path and video_path.lower().endswith(accepted_image_extensions):
if video_path:
convert_image(video_path, os.path.join(frame_path, '000000000.jpg'))
print(f"Copied CN Model {id}'s single input image to inputframes folder!")
elif mask_path and mask_path.lower().endswith(accepted_image_extensions):
elif mask_path:
convert_image(mask_path, os.path.join(frame_path, '000000000.jpg'))
print(f"Copied CN Model {id}'s single input image to inputframes *mask* folder!")
else:
Expand Down
87 changes: 53 additions & 34 deletions scripts/deforum_helpers/video_audio_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tempfile
import re
import glob
import numpy as np
import concurrent.futures
from pathlib import Path
from pkg_resources import resource_filename
Expand All @@ -35,19 +36,29 @@
from threading import Thread

def convert_image(input_path, output_path):
extension = get_extension_if_valid(input_path, ["png", "jpg", "jpeg", "bmp"])

if not extension:
return

# Read the input image
img = cv2.imread(input_path)
# Get the file extension of the output path
out_ext = os.path.splitext(output_path)[1].lower()
if input_path.startswith('http://') or input_path.startswith('https://'):
resp = requests.get(input_path, allow_redirects=True)
arr = np.asarray(bytearray(resp.content), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
else:
img = cv2.imread(input_path)

# Convert the image to the specified output format
if out_ext == ".png":
if extension == "png":
cv2.imwrite(output_path, img, [cv2.IMWRITE_PNG_COMPRESSION, 9])
elif out_ext == ".jpg" or out_ext == ".jpeg":
elif extension == "jpg" or extension == "jpeg":
cv2.imwrite(output_path, img, [cv2.IMWRITE_JPEG_QUALITY, 99])
elif out_ext == ".bmp":
elif extension == "bmp":
cv2.imwrite(output_path, img)
else:
print(f"Unsupported output format: {out_ext}")
raise ValueError(f"Unrecognized image extension: {extension}")


def get_ffmpeg_params(): # get ffmpeg params from webui's settings -> deforum tab. actual opts are set in deforum.py
f_location = opts.data.get("deforum_ffmpeg_location", find_ffmpeg_binary())
Expand Down Expand Up @@ -86,7 +97,7 @@ def vid2frames(video_path, video_in_frame_path, n=1, overwrite=True, extract_fro

video_path = clean_gradio_path_strings(video_path)
# check vid path using a function and only enter if we get True
if is_vid_path_valid(video_path):
if get_extension_if_valid(video_path, ["mov", "mpeg", "mp4", "m4v", "avi", "mpg", "webm"]):

name = get_frame_name(video_path)

Expand Down Expand Up @@ -149,36 +160,44 @@ def vid2frames(video_path, video_in_frame_path, n=1, overwrite=True, extract_fro
vidcap.release()
return video_fps

# make sure the video_path provided is an existing local file or a web URL with a supported file extension
def is_vid_path_valid(video_path):
# make sure file format is supported!
file_formats = ["mov", "mpeg", "mp4", "m4v", "avi", "mpg", "webm"]
extension = video_path.rsplit('.', 1)[-1].lower()
# vid path is actually a URL, check it
if video_path.startswith('http://') or video_path.startswith('https://'):
response = requests.head(video_path, allow_redirects=True)
extension = extension.rsplit('?', 1)[0] # remove query string before checking file format extension.
# Make sure the video_path provided is an existing local file or a web URL with a supported file extension
# If so, return the extension. If not, raise an error.
def get_extension_if_valid(path_to_check, acceptable_extensions: list[str] ) -> str:
if path_to_check.startswith('http://') or path_to_check.startswith('https://'):
# Path is actually a URL. Make sure it resolves and has a valid file extension.
response = requests.head(path_to_check, allow_redirects=True)
if response.status_code != 200:
raise ConnectionError(f"URL {path_to_check} is not valid. Response status code: {response.status_code}")

extension = path_to_check.rsplit('?', 1)[0].rsplit('.', 1)[-1] # remove query string before checking file format extension.
if extension in acceptable_extensions:
return extension

content_disposition_extension = None
content_disposition = response.headers.get('Content-Disposition')
if content_disposition and extension not in file_formats:
# Filename doesn't look valid, but perhaps the content disposition will say otherwise?
if content_disposition:
match = re.search(r'filename="?(?P<filename>[^"]+)"?', content_disposition)
if match:
extension = match.group('filename').rsplit('.', 1)[-1].lower()
if response.status_code == 404:
raise ConnectionError(f"Video URL {video_path} is not valid. Response status code: {response.status_code}")
elif response.status_code == 302:
response = requests.head(response.headers['location'], allow_redirects=True)
if response.status_code != 200:
raise ConnectionError(f"Video URL {video_path} is not valid. Response status code: {response.status_code}")
if extension not in file_formats:
raise ValueError(f"Video file {video_path} has format '{extension}', which is not supported. Supported formats are: {file_formats}")
content_disposition_extension = match.group('filename').rsplit('.', 1)[-1].lower()

if content_disposition_extension in acceptable_extensions:
return content_disposition_extension


raise ValueError(f"File {path_to_check} has format '{extension}' (from URL) or '{content_disposition_extension}' (from content disposition), which are not supported. Supported formats are: {acceptable_extensions}")

else:
video_path = os.path.realpath(video_path)
if not os.path.exists(video_path):
raise RuntimeError(f"Video path does not exist: {video_path}")
if extension not in file_formats:
raise ValueError(f"Video file {video_path} has format '{extension}', which is not supported. Supported formats are: {file_formats}")
return True
path_to_check = os.path.realpath(path_to_check)
extension = path_to_check.rsplit('.', 1)[-1].lower()

if not os.path.exists(path_to_check):
raise RuntimeError(f"Path does not exist: {path_to_check}")
if extension in acceptable_extensions:
return extension

raise ValueError(f"File {path_to_check} has format '{extension}', which is not supported. Supported formats are: {acceptable_extensions}")



# quick-retreive frame count, FPS and H/W dimensions of a video (local or URL-based)
def get_quick_vid_info(vid_path):
Expand Down

0 comments on commit ee18d78

Please sign in to comment.