diff --git a/tts_generation.py b/tts_generation.py index 299ef8b..7367ba9 100644 --- a/tts_generation.py +++ b/tts_generation.py @@ -20,11 +20,14 @@ def main(): parser.add_argument('--device', type=str, choices=["cpu", "mps","cuda"], default="cpu" if torch.cuda.is_available() else "cpu", help="The device to run the model on.") args = parser.parse_args() + + if args.device == "cuda": + device="cuda" + else: + device="cpu" # Init TTS - tts = TTS() - tts.load_tts_model_by_path(model_path=args.model,config_path=os.path.join(args.model,"config.json")) - # tts.to(args.device) + tts = TTS(model_path=args.model,config_path=os.path.join(args.model,"config.json")).to(device) # Run TTS and save to file tts.tts_to_file(text=args.text, speaker_wav=args.speaker_wav, language=args.language, file_path=args.output_file)