From 2d95749eec3a06f32ad4dcb4fcda824251d672c1 Mon Sep 17 00:00:00 2001 From: Sam Khoze <68170403+SamKhoze@users.noreply.github.com> Date: Sun, 16 Jun 2024 16:24:22 +0530 Subject: [PATCH] Update tts_generation.py --- tts_generation.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tts_generation.py b/tts_generation.py index 299ef8b..7367ba9 100644 --- a/tts_generation.py +++ b/tts_generation.py @@ -20,11 +20,14 @@ def main(): parser.add_argument('--device', type=str, choices=["cpu", "mps","cuda"], default="cpu" if torch.cuda.is_available() else "cpu", help="The device to run the model on.") args = parser.parse_args() + + if args.device == "cuda": + device="cuda" + else: + device="cpu" # Init TTS - tts = TTS() - tts.load_tts_model_by_path(model_path=args.model,config_path=os.path.join(args.model,"config.json")) - # tts.to(args.device) + tts = TTS(model_path=args.model,config_path=os.path.join(args.model,"config.json")).to(device) # Run TTS and save to file tts.tts_to_file(text=args.text, speaker_wav=args.speaker_wav, language=args.language, file_path=args.output_file)