Add M1 GPU Support #321

Open
Smiril wants to merge 1 commits from Smiril/patch-1 into main

View File

@@ -80,8 +80,16 @@ class BaseOptions():
self.opt.gpu_ids.append(id)
# set gpu ids
if len(self.opt.gpu_ids) > 0:
torch.cuda.set_device(self.opt.gpu_ids[0])
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
if len(self.opt.gpu_ids) > 0:
torch.cuda.set_device(self.opt.gpu_ids[0])
else:
print("ERROR")
else:
print("ERROR")
else:
torch.device("mps")
args = vars(self.opt)