From c01d582b3fa2cb0c8496ddd8fa54e0cddb4f1a04 Mon Sep 17 00:00:00 2001 From: Smiril Date: Tue, 12 Jul 2022 12:31:05 +0200 Subject: [PATCH] Update fs_model.py --- models/fs_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/models/fs_model.py b/models/fs_model.py index 548dd92..c54a28d 100644 --- a/models/fs_model.py +++ b/models/fs_model.py @@ -48,7 +48,13 @@ class fsModel(BaseModel): torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain - device = torch.device("cuda:0") + if not torch.backends.mps.is_available(): + if not torch.backends.mps.is_built(): + device = torch.device("cuda:0") + else: + print("ERROR") + else: + device = torch.device("mps") if opt.crop_size == 224: from .fs_networks import Generator_Adain_Upsample, Discriminator