support multi-gpu

This commit is contained in:
chenxuanhong
2022-02-08 16:37:30 +08:00
parent 94534e2e30
commit 8dc2cec3dc
53 changed files with 8778 additions and 34 deletions
+11 -8
View File
@@ -8,13 +8,15 @@ from torch.utils import data
from torchvision import transforms as T
# from StyleResize import StyleResize
class data_prefetcher():
def __init__(self, loader):
def __init__(self, loader, cur_gpu):
self.loader = loader
self.dataiter = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1)
self.stream = torch.cuda.Stream(device=cur_gpu)
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda(device=cur_gpu).view(1,3,1,1)
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda(device=cur_gpu).view(1,3,1,1)
self.cur_gpu = cur_gpu
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
@@ -30,9 +32,9 @@ class data_prefetcher():
self.src_image1, self.src_image2 = next(self.dataiter)
with torch.cuda.stream(self.stream):
self.src_image1 = self.src_image1.cuda(non_blocking=True)
self.src_image1 = self.src_image1.cuda(device= self.cur_gpu, non_blocking=True)
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
self.src_image2 = self.src_image2.cuda(non_blocking=True)
self.src_image2 = self.src_image2.cuda(device= self.cur_gpu, non_blocking=True)
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
@@ -41,7 +43,7 @@ class data_prefetcher():
# self.next_input = self.next_input.float()
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
torch.cuda.current_stream(device= self.cur_gpu,).wait_stream(self.stream)
src_image1 = self.src_image1
src_image2 = self.src_image2
self.preload()
@@ -102,6 +104,7 @@ class VGGFace2HQDataset(data.Dataset):
return self.num_images
def GetLoader( dataset_roots,
cur_gpu,
batch_size=16,
**kwargs
):
@@ -123,7 +126,7 @@ def GetLoader( dataset_roots,
random_seed)
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
prefetcher = data_prefetcher(content_data_loader)
prefetcher = data_prefetcher(content_data_loader,cur_gpu)
return prefetcher
def denorm(x):