diff --git a/models/fs_model.py b/models/fs_model.py index 548dd92..c54a28d 100644 --- a/models/fs_model.py +++ b/models/fs_model.py @@ -48,7 +48,13 @@ class fsModel(BaseModel): torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain - device = torch.device("cuda:0") + if not torch.backends.mps.is_available(): + if not torch.backends.mps.is_built(): + device = torch.device("cuda:0") + else: + print("ERROR") + else: + device = torch.device("mps") if opt.crop_size == 224: from .fs_networks import Generator_Adain_Upsample, Discriminator