Update fs_model.py
This commit is contained in:
@@ -48,7 +48,13 @@ class fsModel(BaseModel):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self.isTrain = opt.isTrain
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
if not torch.backends.mps.is_available():
|
||||
if not torch.backends.mps.is_built():
|
||||
device = torch.device("cuda:0")
|
||||
else:
|
||||
print("ERROR")
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
|
||||
if opt.crop_size == 224:
|
||||
from .fs_networks import Generator_Adain_Upsample, Discriminator
|
||||
|
||||
Reference in New Issue
Block a user