Update utils.py
This commit is contained in:
@@ -4,8 +4,11 @@ from typing import Iterable
|
||||
import shutil
|
||||
import subprocess
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
import torch
|
||||
|
||||
import server
|
||||
from .logger import logger
|
||||
|
||||
BIGMIN = -(2**53-1)
|
||||
BIGMAX = (2**53-1)
|
||||
@@ -45,7 +48,7 @@ else:
|
||||
except:
|
||||
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
|
||||
raise
|
||||
print("Failed to import imageio_ffmpeg")
|
||||
logger.warn("Failed to import imageio_ffmpeg")
|
||||
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
|
||||
ffmpeg_path = imageio_ffmpeg_path
|
||||
else:
|
||||
@@ -57,7 +60,7 @@ else:
|
||||
if os.path.isfile("ffmpeg.exe"):
|
||||
ffmpeg_paths.append(os.path.abspath("ffmpeg.exe"))
|
||||
if len(ffmpeg_paths) == 0:
|
||||
print("No valid ffmpeg found.")
|
||||
logger.error("No valid ffmpeg found.")
|
||||
ffmpeg_path = None
|
||||
elif len(ffmpeg_paths) == 1:
|
||||
#Evaluation of suitability isn't required, can take sole option
|
||||
@@ -153,32 +156,51 @@ def requeue_workflow(requeue_required=(-1,True)):
|
||||
requeue_workflow_unchecked()
|
||||
|
||||
def get_audio(file, start_time=0, duration=0):
|
||||
args = [ffmpeg_path, "-v", "error", "-i", file]
|
||||
args = [ffmpeg_path, "-i", file]
|
||||
if start_time > 0:
|
||||
args += ["-ss", str(start_time)]
|
||||
if duration > 0:
|
||||
args += ["-t", str(duration)]
|
||||
try:
|
||||
res = subprocess.run(args + ["-f", "wav", "-"],
|
||||
stdout=subprocess.PIPE, check=True).stdout
|
||||
#TODO: scan for sample rate and maintain
|
||||
res = subprocess.run(args + ["-f", "f32le", "-"],
|
||||
capture_output=True, check=True)
|
||||
audio = torch.frombuffer(bytearray(res.stdout), dtype=torch.float32)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Failed to extract audio from: {file}")
|
||||
return False
|
||||
return res
|
||||
|
||||
|
||||
def lazy_eval(func):
|
||||
class Cache:
|
||||
def __init__(self, func):
|
||||
self.res = None
|
||||
self.func = func
|
||||
def get(self):
|
||||
if self.res is None:
|
||||
self.res = self.func()
|
||||
return self.res
|
||||
cache = Cache(func)
|
||||
return lambda : cache.get()
|
||||
logger.warning(f"Failed to extract audio from: {file}")
|
||||
audio = torch.zeros(1,2)
|
||||
match = re.search(', (\\d+) Hz, (\\w+), ',res.stderr.decode('utf-8'))
|
||||
if match:
|
||||
ar = int(match.group(1))
|
||||
#NOTE: Just throwing an error for other channel types right now
|
||||
#Will deal with issues if they come
|
||||
ac = {"mono": 1, "stereo": 2}[match.group(2)]
|
||||
else:
|
||||
ar = 44100
|
||||
ac = 2
|
||||
audio = audio.reshape((-1,ac)).transpose(0,1).unsqueeze(0)
|
||||
return {'waveform': audio, 'sample_rate': ar}
|
||||
|
||||
class LazyAudioMap(Mapping):
|
||||
def __init__(self, file, start_time, duration):
|
||||
self.file = file
|
||||
self.start_time=start_time
|
||||
self.duration=duration
|
||||
self._dict=None
|
||||
def __getitem__(self, key):
|
||||
if self._dict is None:
|
||||
self._dict = get_audio(self.file, self.start_time, self.duration)
|
||||
return self._dict[key]
|
||||
def __iter__(self):
|
||||
if self._dict is None:
|
||||
self._dict = get_audio(self.file, self.start_time, self.duration)
|
||||
return iter(self._dict)
|
||||
def __len__(self):
|
||||
if self._dict is None:
|
||||
self._dict = get_audio(self.file, self.start_time, self.duration)
|
||||
return len(self._dict)
|
||||
def lazy_get_audio(file, start_time=0, duration=0):
|
||||
return LazyAudioMap(file, start_time, duration)
|
||||
|
||||
def is_url(url):
|
||||
return url.split("://")[0] in ["http", "https"]
|
||||
|
||||
Reference in New Issue
Block a user