This commit is contained in:
Sam Khoze
2024-07-10 21:57:07 +05:30
committed by GitHub
parent d8d3c158c9
commit 900130d38e
+6 -103
View File
@@ -26,7 +26,6 @@ from scipy.io.wavfile import write
import folder_paths
from .utils import ffmpeg_path, get_audio, hash_path, validate_path, requeue_workflow, gifski_path, calculate_file_hash, strip_path
from comfy.utils import ProgressBar
from .utils import BIGMAX, DIMMAX, calculate_file_hash, get_sorted_dir_files_from_directory, get_audio, lazy_eval, hash_path, validate_path, strip_path
from .llm_node import LLM_node
from .audio_playback import PlayBackAudio
from .audio_playback import SaveAudio
@@ -153,102 +152,6 @@ def cv_frame_generator(video, force_rate, frame_load_cap, skip_first_frames,
if prev_frame is not None:
yield prev_frame
def load_video_cv(video: str, force_rate: int, force_size: str,
custom_width: int,custom_height: int, frame_load_cap: int,
skip_first_frames: int, select_every_nth: int,
meta_batch=None, unique_id=None, memory_limit_mb=None):
print(meta_batch)
if meta_batch is None or unique_id not in meta_batch.inputs:
gen = cv_frame_generator(video, force_rate, frame_load_cap, skip_first_frames,
select_every_nth, meta_batch, unique_id)
(width, height, fps, duration, total_frames, target_frame_time) = next(gen)
if meta_batch is not None:
meta_batch.inputs[unique_id] = (gen, width, height, fps, duration, total_frames, target_frame_time)
else:
(gen, width, height, fps, duration, total_frames, target_frame_time) = meta_batch.inputs[unique_id]
if memory_limit_mb is not None:
memory_limit *= 2 ** 20
else:
#TODO: verify if garbage collection should be performed here.
#leaves ~128 MB unreserved for safety
memory_limit = (psutil.virtual_memory().available + psutil.swap_memory().free) - 2 ** 27
#space required to load as f32, exist as latent with wiggle room, decode to f32
max_loadable_frames = int(memory_limit//(width*height*3*(4+4+1/10)))
if meta_batch is not None:
if meta_batch.frames_per_batch > max_loadable_frames:
raise RuntimeError(f"Meta Batch set to {meta_batch.frames_per_batch} frames but only {max_loadable_frames} can fit in memory")
gen = itertools.islice(gen, meta_batch.frames_per_batch)
else:
original_gen = gen
gen = itertools.islice(gen, max_loadable_frames)
#Some minor wizardry to eliminate a copy and reduce max memory by a factor of ~2
images = torch.from_numpy(np.fromiter(gen, np.dtype((np.float32, (height, width, 3)))))
if meta_batch is None:
try:
next(original_gen)
raise RuntimeError(f"Memory limit hit after loading {len(images)} frames. Stopping execution.")
except StopIteration:
pass
if len(images) == 0:
raise RuntimeError("No frames generated")
if force_size != "Disabled":
new_size = target_size(width, height, force_size, custom_width, custom_height)
if new_size[0] != width or new_size[1] != height:
s = images.movedim(-1,1)
s = common_upscale(s, new_size[0], new_size[1], "lanczos", "center")
images = s.movedim(1,-1)
#Setup lambda for lazy audio capture
audio = lambda : get_audio(video, skip_first_frames * target_frame_time,
frame_load_cap*target_frame_time*select_every_nth)
#Adjust target_frame_time for select_every_nth
target_frame_time *= select_every_nth
video_info = {
"source_fps": fps,
"source_frame_count": total_frames,
"source_duration": duration,
"source_width": width,
"source_height": height,
"loaded_fps": 1/target_frame_time,
"loaded_frame_count": len(images),
"loaded_duration": len(images) * target_frame_time,
"loaded_width": images.shape[2],
"loaded_height": images.shape[1],
}
print("images", type(images))
return (images, len(images), lazy_eval(audio), video_info)
class AudioData:
def __init__(self, audio_file) -> None:
# Extract the sample rate
sample_rate = audio_file.frame_rate
# Get the number of audio channels
num_channels = audio_file.channels
# Extract the audio data as a NumPy array
audio_data = np.array(audio_file.get_array_of_samples())
self.audio_data = audio_data
self.sample_rate = sample_rate
self.num_channels = num_channels
def get_channel_audio_data(self, channel: int):
if channel < 0 or channel >= self.num_channels:
raise IndexError(f"Channel '{channel}' out of range. total channels is '{self.num_channels}'.")
return self.audio_data[channel::self.num_channels]
def get_channel_fft(self, channel: int):
audio_data = self.get_channel_audio_data(channel)
return fft(audio_data)
def gen_format_widgets(video_format):
for k in video_format:
if k.endswith("_pass"):
@@ -582,7 +485,7 @@ def load_video_cv(video: str, force_rate: int, force_size: str,
images = s.movedim(1,-1)
#Setup lambda for lazy audio capture
audio = lambda : get_audio(video, skip_first_frames * target_frame_time,
audio = get_audio(video, skip_first_frames * target_frame_time,
frame_load_cap*target_frame_time*select_every_nth)
#Adjust target_frame_time for select_every_nth
target_frame_time *= select_every_nth
@@ -599,7 +502,7 @@ def load_video_cv(video: str, force_rate: int, force_size: str,
"loaded_height": images.shape[1],
}
print("images", type(images))
return (images, len(images), lazy_eval(audio), video_info)
return (images, len(images), audio, video_info)
@@ -1018,7 +921,7 @@ class DeepFuzeFaceSwap:
print(result.stderr)
audio_file = os.path.join(audio_dir,str(time.time()).replace(".","")+".wav")
torchaudio.save(audio_file,audio["waveform"],audio["sample_rate"])
torchaudio.save(audio_file,audio["waveform"][0],audio["sample_rate"])
subprocess.run(f"ffmpeg -i {faceswap_filename} -i {audio_file} -c copy {faceswap_filename.replace('.mp4','_.mp4')} -y".split())
return load_video_cv(faceswap_filename.replace('.mp4','_.mp4'),0,'Disabled',512,512,0,0,1)
@@ -1336,7 +1239,7 @@ class DeepFuzeAdavance:
output_files.append(file_path)
audio_file = os.path.join(audio_dir,str(time.time()).replace(".","")+".wav")
torchaudio.save(audio_file,audio["waveform"],audio["sample_rate"])
torchaudio.save(audio_file,audio["waveform"][0],audio["sample_rate"])
print(audio_file)
filename = os.path.join(result_dir,f"{str(time.time()).replace('.','')}.mp4")
enhanced_filename = os.path.join(result_dir,f"enhanced_{str(time.time()).replace('.','')}.mp4")
@@ -1513,7 +1416,7 @@ class TTS_generation:
language = supported_language.split("(")[1][:-1]
file_path = os.path.join(audio_path,str(time.time()).replace(".","")+".wav")
torchaudio.save(audio_file,audio["waveform"],audio["sample_rate"])
torchaudio.save(audio_file,audio["waveform"][0],audio["sample_rate"])
command = [
'python', 'tts_generation.py',
'--model', checkpoint_path_voice,
@@ -1531,7 +1434,7 @@ class TTS_generation:
print("stdout:", result.stdout)
print("stderr:", result.stderr)
audio = get_audio(file_path)
return (audio)
return (audio,)