From d8d3c158c9259b1c22a0b4ff8fc98d2303bc0ff9 Mon Sep 17 00:00:00 2001 From: Sam Khoze <68170403+SamKhoze@users.noreply.github.com> Date: Wed, 10 Jul 2024 20:17:01 +0530 Subject: [PATCH] Update utils.py --- utils.py | 64 +++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/utils.py b/utils.py index 490a1af..a309332 100644 --- a/utils.py +++ b/utils.py @@ -4,8 +4,11 @@ from typing import Iterable import shutil import subprocess import re +from collections.abc import Mapping +import torch import server +from .logger import logger BIGMIN = -(2**53-1) BIGMAX = (2**53-1) @@ -45,7 +48,7 @@ else: except: if "VHS_USE_IMAGEIO_FFMPEG" in os.environ: raise - print("Failed to import imageio_ffmpeg") + logger.warn("Failed to import imageio_ffmpeg") if "VHS_USE_IMAGEIO_FFMPEG" in os.environ: ffmpeg_path = imageio_ffmpeg_path else: @@ -57,7 +60,7 @@ else: if os.path.isfile("ffmpeg.exe"): ffmpeg_paths.append(os.path.abspath("ffmpeg.exe")) if len(ffmpeg_paths) == 0: - print("No valid ffmpeg found.") + logger.error("No valid ffmpeg found.") ffmpeg_path = None elif len(ffmpeg_paths) == 1: #Evaluation of suitability isn't required, can take sole option @@ -153,32 +156,51 @@ def requeue_workflow(requeue_required=(-1,True)): requeue_workflow_unchecked() def get_audio(file, start_time=0, duration=0): - args = [ffmpeg_path, "-v", "error", "-i", file] + args = [ffmpeg_path, "-i", file] if start_time > 0: args += ["-ss", str(start_time)] if duration > 0: args += ["-t", str(duration)] try: - res = subprocess.run(args + ["-f", "wav", "-"], - stdout=subprocess.PIPE, check=True).stdout + #TODO: scan for sample rate and maintain + res = subprocess.run(args + ["-f", "f32le", "-"], + capture_output=True, check=True) + audio = torch.frombuffer(bytearray(res.stdout), dtype=torch.float32) except subprocess.CalledProcessError as e: - print(f"Failed to extract audio from: {file}") - return False - return res - - -def lazy_eval(func): - class Cache: - def __init__(self, func): - self.res = None - self.func = func - def get(self): - if self.res is None: - self.res = self.func() - return self.res - cache = Cache(func) - return lambda : cache.get() + logger.warning(f"Failed to extract audio from: {file}") + audio = torch.zeros(1,2) + match = re.search(', (\\d+) Hz, (\\w+), ',res.stderr.decode('utf-8')) + if match: + ar = int(match.group(1)) + #NOTE: Just throwing an error for other channel types right now + #Will deal with issues if they come + ac = {"mono": 1, "stereo": 2}[match.group(2)] + else: + ar = 44100 + ac = 2 + audio = audio.reshape((-1,ac)).transpose(0,1).unsqueeze(0) + return {'waveform': audio, 'sample_rate': ar} +class LazyAudioMap(Mapping): + def __init__(self, file, start_time, duration): + self.file = file + self.start_time=start_time + self.duration=duration + self._dict=None + def __getitem__(self, key): + if self._dict is None: + self._dict = get_audio(self.file, self.start_time, self.duration) + return self._dict[key] + def __iter__(self): + if self._dict is None: + self._dict = get_audio(self.file, self.start_time, self.duration) + return iter(self._dict) + def __len__(self): + if self._dict is None: + self._dict = get_audio(self.file, self.start_time, self.duration) + return len(self._dict) +def lazy_get_audio(file, start_time=0, duration=0): + return LazyAudioMap(file, start_time, duration) def is_url(url): return url.split("://")[0] in ["http", "https"]