support multi-gpu
This commit is contained in:
@@ -8,13 +8,15 @@ from torch.utils import data
|
||||
from torchvision import transforms as T
|
||||
# from StyleResize import StyleResize
|
||||
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader):
|
||||
def __init__(self, loader, cur_gpu):
|
||||
self.loader = loader
|
||||
self.dataiter = iter(loader)
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1)
|
||||
self.stream = torch.cuda.Stream(device=cur_gpu)
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda(device=cur_gpu).view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda(device=cur_gpu).view(1,3,1,1)
|
||||
self.cur_gpu = cur_gpu
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
@@ -30,9 +32,9 @@ class data_prefetcher():
|
||||
self.src_image1, self.src_image2 = next(self.dataiter)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.src_image1 = self.src_image1.cuda(non_blocking=True)
|
||||
self.src_image1 = self.src_image1.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
|
||||
self.src_image2 = self.src_image2.cuda(non_blocking=True)
|
||||
self.src_image2 = self.src_image2.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
@@ -41,7 +43,7 @@ class data_prefetcher():
|
||||
# self.next_input = self.next_input.float()
|
||||
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
torch.cuda.current_stream(device= self.cur_gpu,).wait_stream(self.stream)
|
||||
src_image1 = self.src_image1
|
||||
src_image2 = self.src_image2
|
||||
self.preload()
|
||||
@@ -102,6 +104,7 @@ class VGGFace2HQDataset(data.Dataset):
|
||||
return self.num_images
|
||||
|
||||
def GetLoader( dataset_roots,
|
||||
cur_gpu,
|
||||
batch_size=16,
|
||||
**kwargs
|
||||
):
|
||||
@@ -123,7 +126,7 @@ def GetLoader( dataset_roots,
|
||||
random_seed)
|
||||
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
|
||||
prefetcher = data_prefetcher(content_data_loader)
|
||||
prefetcher = data_prefetcher(content_data_loader,cur_gpu)
|
||||
return prefetcher
|
||||
|
||||
def denorm(x):
|
||||
|
||||
Reference in New Issue
Block a user