diff --git a/data/data_loader_Swapping.py b/data/data_loader_Swapping.py index 8788b95..274c982 100644 --- a/data/data_loader_Swapping.py +++ b/data/data_loader_Swapping.py @@ -63,7 +63,7 @@ class SwappingDataset(data.Dataset): def preprocess(self): """Preprocess the Swapping dataset.""" - print("processing Swapping dataset images...") + #print("processing Swapping dataset images...") temp_path = os.path.join(self.image_dir,'*/') pathes = glob.glob(temp_path) diff --git a/train.py b/train.py index ebe5fde..24fe6db 100644 --- a/train.py +++ b/train.py @@ -15,14 +15,15 @@ import time import random import argparse import numpy as np - +from accelerate import Accelerator import torch import torch.nn.functional as F from torch.backends import cudnn import torch.utils.tensorboard as tensorboard +import wandb from util import util -from util.plot import plot_batch +from util.plot import plot_batch, plot_batch_wandb from models.projected_model import fsModel from data.data_loader_Swapping import GetLoader @@ -64,7 +65,7 @@ class TrainOptions: self.parser.add_argument('--lambda_id', type=float, default=30.0, help='weight for id loss') self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss') - self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT") + self.parser.add_argument("--Arc_path", type=str, default='/home/disk_3/shreyansh/SimSwap/arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT") self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step') self.parser.add_argument("--log_frep", type=int, default=200, help='frequence for printing log information') self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling') @@ -106,6 +107,7 @@ if __name__ == '__main__': opt = TrainOptions().parse() iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') + accelerator = Accelerator(log_with='wandb') sample_path = os.path.join(opt.checkpoints_dir, opt.name, 'samples') @@ -122,29 +124,31 @@ if __name__ == '__main__': start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) except: start_epoch, epoch_iter = 1, 0 - print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) + accelerator.print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) else: start_epoch, epoch_iter = 1, 0 os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_ids) - print("GPU used : ", str(opt.gpu_ids)) + accelerator.print("GPU used : ", str(opt.gpu_ids)) cudnn.benchmark = True - + accelerator.print(" i m rechine line 134") + - model = fsModel() - + accelerator.print("Model is created") model.initialize(opt) - + accelerator.print("model is initialised") ##################################################### if opt.use_tensorboard: tensorboard_writer = tensorboard.SummaryWriter(log_path) logger = tensorboard_writer - - log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + #wandb.init() + accelerator.init_trackers(project_name='SimSwap') + log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + accelerator.print("I am reaching here") with open(log_name, "a") as log_file: now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) @@ -167,18 +171,19 @@ if __name__ == '__main__': start = int(opt.which_epoch) total_step = opt.total_step import datetime - print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + accelerator.print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) from util.logo_class import logo_class logo_class.print_start_training() model.netD.feature_network.requires_grad_(False) - + feat_net = model.netD.feature_network + model.netG, model.netD, model.netArc, optimizer_D, optimizer_G, train_loader = accelerator.prepare( model.netG, model.netD, model.netArc, optimizer_D, optimizer_G, train_loader) # Training Cycle for step in range(start, total_step): model.netG.train() for interval in range(2): + src_image1, src_image2 = train_loader.next() ### to improve the volatility of the gpu random.shuffle(randindex) - src_image1, src_image2 = train_loader.next() if step%2 == 0: img_id = src_image2 @@ -199,7 +204,8 @@ if __name__ == '__main__': loss_D = loss_Dgen + loss_Dreal optimizer_D.zero_grad() - loss_D.backward() + # loss_D.backward() + accelerator.backward(loss_D) optimizer_D.step() else: @@ -213,7 +219,8 @@ if __name__ == '__main__': latent_fake = model.netArc(img_fake_down) latent_fake = F.normalize(latent_fake, p=2, dim=1) loss_G_ID = (1 - model.cosin_metric(latent_fake, latent_id)).mean() - real_feat = model.netD.get_feature(src_image1) + # real_feat = model.netD.get_feature(src_image1) + real_feat = feat_net(src_image1, get_features=True) feat_match_loss = model.criterionFeat(feat["3"],real_feat["3"]) loss_G = loss_Gmain + loss_G_ID * opt.lambda_id + feat_match_loss * opt.lambda_feat @@ -224,7 +231,8 @@ if __name__ == '__main__': loss_G += loss_G_Rec optimizer_G.zero_grad() - loss_G.backward() + # loss_G.backward() + accelerator.backward(loss_G) optimizer_G.step() @@ -245,11 +253,14 @@ if __name__ == '__main__': if opt.use_tensorboard: for tag, value in errors.items(): logger.add_scalar(tag, value, step) + for tag, value in errors.items(): + + accelerator.log({tag: value}) message = '( step: %d, ) ' % (step) for k, v in errors.items(): message += '%s: %.3f ' % (k, v) - print(message) + accelerator.print(message) with open(log_name, "a") as log_file: log_file.write('%s\n' % message) @@ -259,6 +270,11 @@ if __name__ == '__main__': with torch.no_grad(): imgs = list() zero_img = (torch.zeros_like(src_image1[0,...])) + # print(src_image1.shape) + # print(src_image1[:opt.batchSize//2,...].shape) + # src_image1 = src_image1[:opt.batchSize//2,...] + # src_image2 = src_image2[:opt.batchSize//2,...] + imgs.append(zero_img.cpu().numpy()) save_img = ((src_image1.cpu())* imagenet_std + imagenet_mean).numpy() for r in range(opt.batchSize): @@ -278,13 +294,16 @@ if __name__ == '__main__': img_fake = img_fake.numpy() for j in range(opt.batchSize): imgs.append(img_fake[j,...]) - print("Save test data") + accelerator.print("Save test data") imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + imgs_wandb = wandb.Image(plot_batch_wandb(imgs), caption="Source and Target Images") + accelerator.log({"examples":(imgs_wandb)}) plot_batch(imgs, os.path.join(sample_path, 'step_'+str(step+1)+'.jpg')) + accelerator.print("data successfully saved, ") ### save latest model if (step+1) % opt.model_freq==0: - print('saving the latest model (steps %d)' % (step+1)) + accelerator.print('saving the latest model (steps %d)' % (step+1)) model.save(step+1) np.savetxt(iter_path, (step+1, total_step), delimiter=',', fmt='%d') - wandb.finish() \ No newline at end of file + accelerator.end_training() diff --git a/util/plot.py b/util/plot.py index 0da1c75..1097c9b 100644 --- a/util/plot.py +++ b/util/plot.py @@ -34,4 +34,20 @@ def plot_batch(X, out_path): rows = cols = math.ceil(rc) canvas = tile(X, rows, cols) canvas = np.squeeze(canvas) - PIL.Image.fromarray(canvas).save(out_path) \ No newline at end of file + PIL.Image.fromarray(canvas).save(out_path) + +def plot_batch_wandb(X): + """Save batch of images tiled.""" + n_channels = X.shape[3] + if n_channels > 3: + X = X[:,:,:,np.random.choice(n_channels, size = 3)] + X = postprocess(X) + rc = math.sqrt(X.shape[0]) + rows = cols = math.ceil(rc) + canvas = tile(X, rows, cols) + canvas = np.squeeze(canvas) + print(canvas.shape) + print(rows, cols) + canvas = canvas[:2016, :2016,:] + # canvas = cv2.resize(canvas, None, fx=0.5, fy=0.5) + return canvas \ No newline at end of file