Wandb and accelerate #453
@@ -63,7 +63,7 @@ class SwappingDataset(data.Dataset):
|
||||
|
||||
def preprocess(self):
|
||||
"""Preprocess the Swapping dataset."""
|
||||
print("processing Swapping dataset images...")
|
||||
#print("processing Swapping dataset images...")
|
||||
|
||||
temp_path = os.path.join(self.image_dir,'*/')
|
||||
pathes = glob.glob(temp_path)
|
||||
|
||||
61
train.py
61
train.py
@@ -15,14 +15,15 @@ import time
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from accelerate import Accelerator
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.backends import cudnn
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
import wandb
|
||||
|
||||
from util import util
|
||||
from util.plot import plot_batch
|
||||
from util.plot import plot_batch, plot_batch_wandb
|
||||
|
||||
from models.projected_model import fsModel
|
||||
from data.data_loader_Swapping import GetLoader
|
||||
@@ -64,7 +65,7 @@ class TrainOptions:
|
||||
self.parser.add_argument('--lambda_id', type=float, default=30.0, help='weight for id loss')
|
||||
self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss')
|
||||
|
||||
self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT")
|
||||
self.parser.add_argument("--Arc_path", type=str, default='/home/disk_3/shreyansh/SimSwap/arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT")
|
||||
self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step')
|
||||
self.parser.add_argument("--log_frep", type=int, default=200, help='frequence for printing log information')
|
||||
self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling')
|
||||
@@ -106,6 +107,7 @@ if __name__ == '__main__':
|
||||
|
||||
opt = TrainOptions().parse()
|
||||
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
|
||||
accelerator = Accelerator(log_with='wandb')
|
||||
|
||||
sample_path = os.path.join(opt.checkpoints_dir, opt.name, 'samples')
|
||||
|
||||
@@ -122,29 +124,31 @@ if __name__ == '__main__':
|
||||
start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
|
||||
except:
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
|
||||
accelerator.print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
|
||||
else:
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_ids)
|
||||
print("GPU used : ", str(opt.gpu_ids))
|
||||
accelerator.print("GPU used : ", str(opt.gpu_ids))
|
||||
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
accelerator.print(" i m rechine line 134")
|
||||
|
||||
|
||||
|
||||
model = fsModel()
|
||||
|
||||
accelerator.print("Model is created")
|
||||
model.initialize(opt)
|
||||
|
||||
accelerator.print("model is initialised")
|
||||
#####################################################
|
||||
if opt.use_tensorboard:
|
||||
tensorboard_writer = tensorboard.SummaryWriter(log_path)
|
||||
logger = tensorboard_writer
|
||||
|
||||
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
||||
#wandb.init()
|
||||
|
||||
accelerator.init_trackers(project_name='SimSwap')
|
||||
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
||||
accelerator.print("I am reaching here")
|
||||
with open(log_name, "a") as log_file:
|
||||
now = time.strftime("%c")
|
||||
log_file.write('================ Training Loss (%s) ================\n' % now)
|
||||
@@ -167,18 +171,19 @@ if __name__ == '__main__':
|
||||
start = int(opt.which_epoch)
|
||||
total_step = opt.total_step
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
accelerator.print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from util.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
model.netD.feature_network.requires_grad_(False)
|
||||
|
||||
feat_net = model.netD.feature_network
|
||||
model.netG, model.netD, model.netArc, optimizer_D, optimizer_G, train_loader = accelerator.prepare( model.netG, model.netD, model.netArc, optimizer_D, optimizer_G, train_loader)
|
||||
# Training Cycle
|
||||
for step in range(start, total_step):
|
||||
model.netG.train()
|
||||
for interval in range(2):
|
||||
src_image1, src_image2 = train_loader.next() ### to improve the volatility of the gpu
|
||||
random.shuffle(randindex)
|
||||
src_image1, src_image2 = train_loader.next()
|
||||
|
||||
if step%2 == 0:
|
||||
img_id = src_image2
|
||||
@@ -199,7 +204,8 @@ if __name__ == '__main__':
|
||||
|
||||
loss_D = loss_Dgen + loss_Dreal
|
||||
optimizer_D.zero_grad()
|
||||
loss_D.backward()
|
||||
# loss_D.backward()
|
||||
accelerator.backward(loss_D)
|
||||
optimizer_D.step()
|
||||
else:
|
||||
|
||||
@@ -213,7 +219,8 @@ if __name__ == '__main__':
|
||||
latent_fake = model.netArc(img_fake_down)
|
||||
latent_fake = F.normalize(latent_fake, p=2, dim=1)
|
||||
loss_G_ID = (1 - model.cosin_metric(latent_fake, latent_id)).mean()
|
||||
real_feat = model.netD.get_feature(src_image1)
|
||||
# real_feat = model.netD.get_feature(src_image1)
|
||||
real_feat = feat_net(src_image1, get_features=True)
|
||||
feat_match_loss = model.criterionFeat(feat["3"],real_feat["3"])
|
||||
loss_G = loss_Gmain + loss_G_ID * opt.lambda_id + feat_match_loss * opt.lambda_feat
|
||||
|
||||
@@ -224,7 +231,8 @@ if __name__ == '__main__':
|
||||
loss_G += loss_G_Rec
|
||||
|
||||
optimizer_G.zero_grad()
|
||||
loss_G.backward()
|
||||
# loss_G.backward()
|
||||
accelerator.backward(loss_G)
|
||||
optimizer_G.step()
|
||||
|
||||
|
||||
@@ -245,11 +253,14 @@ if __name__ == '__main__':
|
||||
if opt.use_tensorboard:
|
||||
for tag, value in errors.items():
|
||||
logger.add_scalar(tag, value, step)
|
||||
for tag, value in errors.items():
|
||||
|
||||
accelerator.log({tag: value})
|
||||
message = '( step: %d, ) ' % (step)
|
||||
for k, v in errors.items():
|
||||
message += '%s: %.3f ' % (k, v)
|
||||
|
||||
print(message)
|
||||
accelerator.print(message)
|
||||
with open(log_name, "a") as log_file:
|
||||
log_file.write('%s\n' % message)
|
||||
|
||||
@@ -259,6 +270,11 @@ if __name__ == '__main__':
|
||||
with torch.no_grad():
|
||||
imgs = list()
|
||||
zero_img = (torch.zeros_like(src_image1[0,...]))
|
||||
# print(src_image1.shape)
|
||||
# print(src_image1[:opt.batchSize//2,...].shape)
|
||||
# src_image1 = src_image1[:opt.batchSize//2,...]
|
||||
# src_image2 = src_image2[:opt.batchSize//2,...]
|
||||
|
||||
imgs.append(zero_img.cpu().numpy())
|
||||
save_img = ((src_image1.cpu())* imagenet_std + imagenet_mean).numpy()
|
||||
for r in range(opt.batchSize):
|
||||
@@ -278,13 +294,16 @@ if __name__ == '__main__':
|
||||
img_fake = img_fake.numpy()
|
||||
for j in range(opt.batchSize):
|
||||
imgs.append(img_fake[j,...])
|
||||
print("Save test data")
|
||||
accelerator.print("Save test data")
|
||||
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
|
||||
imgs_wandb = wandb.Image(plot_batch_wandb(imgs), caption="Source and Target Images")
|
||||
accelerator.log({"examples":(imgs_wandb)})
|
||||
plot_batch(imgs, os.path.join(sample_path, 'step_'+str(step+1)+'.jpg'))
|
||||
accelerator.print("data successfully saved, ")
|
||||
|
||||
### save latest model
|
||||
if (step+1) % opt.model_freq==0:
|
||||
print('saving the latest model (steps %d)' % (step+1))
|
||||
accelerator.print('saving the latest model (steps %d)' % (step+1))
|
||||
model.save(step+1)
|
||||
np.savetxt(iter_path, (step+1, total_step), delimiter=',', fmt='%d')
|
||||
wandb.finish()
|
||||
accelerator.end_training()
|
||||
|
||||
18
util/plot.py
18
util/plot.py
@@ -34,4 +34,20 @@ def plot_batch(X, out_path):
|
||||
rows = cols = math.ceil(rc)
|
||||
canvas = tile(X, rows, cols)
|
||||
canvas = np.squeeze(canvas)
|
||||
PIL.Image.fromarray(canvas).save(out_path)
|
||||
PIL.Image.fromarray(canvas).save(out_path)
|
||||
|
||||
def plot_batch_wandb(X):
|
||||
"""Save batch of images tiled."""
|
||||
n_channels = X.shape[3]
|
||||
if n_channels > 3:
|
||||
X = X[:,:,:,np.random.choice(n_channels, size = 3)]
|
||||
X = postprocess(X)
|
||||
rc = math.sqrt(X.shape[0])
|
||||
rows = cols = math.ceil(rc)
|
||||
canvas = tile(X, rows, cols)
|
||||
canvas = np.squeeze(canvas)
|
||||
print(canvas.shape)
|
||||
print(rows, cols)
|
||||
canvas = canvas[:2016, :2016,:]
|
||||
# canvas = cv2.resize(canvas, None, fx=0.5, fy=0.5)
|
||||
return canvas
|
||||
Reference in New Issue
Block a user