Wandb and accelerate #453

Closed
shreygan21 wants to merge 1 commits from wandB-accelerate into main
3 changed files with 58 additions and 23 deletions

View File

@@ -63,7 +63,7 @@ class SwappingDataset(data.Dataset):
def preprocess(self):
"""Preprocess the Swapping dataset."""
print("processing Swapping dataset images...")
#print("processing Swapping dataset images...")
temp_path = os.path.join(self.image_dir,'*/')
pathes = glob.glob(temp_path)

View File

@@ -15,14 +15,15 @@ import time
import random
import argparse
import numpy as np
from accelerate import Accelerator
import torch
import torch.nn.functional as F
from torch.backends import cudnn
import torch.utils.tensorboard as tensorboard
import wandb
from util import util
from util.plot import plot_batch
from util.plot import plot_batch, plot_batch_wandb
from models.projected_model import fsModel
from data.data_loader_Swapping import GetLoader
@@ -64,7 +65,7 @@ class TrainOptions:
self.parser.add_argument('--lambda_id', type=float, default=30.0, help='weight for id loss')
self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss')
self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT")
self.parser.add_argument("--Arc_path", type=str, default='/home/disk_3/shreyansh/SimSwap/arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT")
self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step')
self.parser.add_argument("--log_frep", type=int, default=200, help='frequence for printing log information')
self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling')
@@ -106,6 +107,7 @@ if __name__ == '__main__':
opt = TrainOptions().parse()
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
accelerator = Accelerator(log_with='wandb')
sample_path = os.path.join(opt.checkpoints_dir, opt.name, 'samples')
@@ -122,29 +124,31 @@ if __name__ == '__main__':
start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
except:
start_epoch, epoch_iter = 1, 0
print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
accelerator.print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
else:
start_epoch, epoch_iter = 1, 0
os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_ids)
print("GPU used : ", str(opt.gpu_ids))
accelerator.print("GPU used : ", str(opt.gpu_ids))
cudnn.benchmark = True
accelerator.print(" i m rechine line 134")
model = fsModel()
accelerator.print("Model is created")
model.initialize(opt)
accelerator.print("model is initialised")
#####################################################
if opt.use_tensorboard:
tensorboard_writer = tensorboard.SummaryWriter(log_path)
logger = tensorboard_writer
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
#wandb.init()
accelerator.init_trackers(project_name='SimSwap')
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
accelerator.print("I am reaching here")
with open(log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
@@ -167,18 +171,19 @@ if __name__ == '__main__':
start = int(opt.which_epoch)
total_step = opt.total_step
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
accelerator.print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from util.logo_class import logo_class
logo_class.print_start_training()
model.netD.feature_network.requires_grad_(False)
feat_net = model.netD.feature_network
model.netG, model.netD, model.netArc, optimizer_D, optimizer_G, train_loader = accelerator.prepare( model.netG, model.netD, model.netArc, optimizer_D, optimizer_G, train_loader)
# Training Cycle
for step in range(start, total_step):
model.netG.train()
for interval in range(2):
src_image1, src_image2 = train_loader.next() ### to improve the volatility of the gpu
random.shuffle(randindex)
src_image1, src_image2 = train_loader.next()
if step%2 == 0:
img_id = src_image2
@@ -199,7 +204,8 @@ if __name__ == '__main__':
loss_D = loss_Dgen + loss_Dreal
optimizer_D.zero_grad()
loss_D.backward()
# loss_D.backward()
accelerator.backward(loss_D)
optimizer_D.step()
else:
@@ -213,7 +219,8 @@ if __name__ == '__main__':
latent_fake = model.netArc(img_fake_down)
latent_fake = F.normalize(latent_fake, p=2, dim=1)
loss_G_ID = (1 - model.cosin_metric(latent_fake, latent_id)).mean()
real_feat = model.netD.get_feature(src_image1)
# real_feat = model.netD.get_feature(src_image1)
real_feat = feat_net(src_image1, get_features=True)
feat_match_loss = model.criterionFeat(feat["3"],real_feat["3"])
loss_G = loss_Gmain + loss_G_ID * opt.lambda_id + feat_match_loss * opt.lambda_feat
@@ -224,7 +231,8 @@ if __name__ == '__main__':
loss_G += loss_G_Rec
optimizer_G.zero_grad()
loss_G.backward()
# loss_G.backward()
accelerator.backward(loss_G)
optimizer_G.step()
@@ -245,11 +253,14 @@ if __name__ == '__main__':
if opt.use_tensorboard:
for tag, value in errors.items():
logger.add_scalar(tag, value, step)
for tag, value in errors.items():
accelerator.log({tag: value})
message = '( step: %d, ) ' % (step)
for k, v in errors.items():
message += '%s: %.3f ' % (k, v)
print(message)
accelerator.print(message)
with open(log_name, "a") as log_file:
log_file.write('%s\n' % message)
@@ -259,6 +270,11 @@ if __name__ == '__main__':
with torch.no_grad():
imgs = list()
zero_img = (torch.zeros_like(src_image1[0,...]))
# print(src_image1.shape)
# print(src_image1[:opt.batchSize//2,...].shape)
# src_image1 = src_image1[:opt.batchSize//2,...]
# src_image2 = src_image2[:opt.batchSize//2,...]
imgs.append(zero_img.cpu().numpy())
save_img = ((src_image1.cpu())* imagenet_std + imagenet_mean).numpy()
for r in range(opt.batchSize):
@@ -278,13 +294,16 @@ if __name__ == '__main__':
img_fake = img_fake.numpy()
for j in range(opt.batchSize):
imgs.append(img_fake[j,...])
print("Save test data")
accelerator.print("Save test data")
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
imgs_wandb = wandb.Image(plot_batch_wandb(imgs), caption="Source and Target Images")
accelerator.log({"examples":(imgs_wandb)})
plot_batch(imgs, os.path.join(sample_path, 'step_'+str(step+1)+'.jpg'))
accelerator.print("data successfully saved, ")
### save latest model
if (step+1) % opt.model_freq==0:
print('saving the latest model (steps %d)' % (step+1))
accelerator.print('saving the latest model (steps %d)' % (step+1))
model.save(step+1)
np.savetxt(iter_path, (step+1, total_step), delimiter=',', fmt='%d')
wandb.finish()
accelerator.end_training()

View File

@@ -34,4 +34,20 @@ def plot_batch(X, out_path):
rows = cols = math.ceil(rc)
canvas = tile(X, rows, cols)
canvas = np.squeeze(canvas)
PIL.Image.fromarray(canvas).save(out_path)
PIL.Image.fromarray(canvas).save(out_path)
def plot_batch_wandb(X):
"""Save batch of images tiled."""
n_channels = X.shape[3]
if n_channels > 3:
X = X[:,:,:,np.random.choice(n_channels, size = 3)]
X = postprocess(X)
rc = math.sqrt(X.shape[0])
rows = cols = math.ceil(rc)
canvas = tile(X, rows, cols)
canvas = np.squeeze(canvas)
print(canvas.shape)
print(rows, cols)
canvas = canvas[:2016, :2016,:]
# canvas = cv2.resize(canvas, None, fx=0.5, fy=0.5)
return canvas