Files
disrupting-deepfakes/cyclegan/train.lua
T
Nataniel Ruiz a3fe19383a cyclegan
2019-12-25 17:21:03 -04:00

179 lines
5.7 KiB
Lua

-- usage example: DATA_ROOT=/path/to/data/ which_direction=BtoA name=expt1 th train.lua
-- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix
require 'torch'
require 'nn'
require 'optim'
util = paths.dofile('util/util.lua')
content = paths.dofile('util/content_loss.lua')
require 'image'
require 'models.architectures'
-- load configuration file
options = require 'options'
opt = options.parse_options('train')
-- setup visualization
visualizer = require 'util/visualizer'
-- initialize torch GPU/CPU mode
if opt.gpu > 0 then
require 'cutorch'
require 'cunn'
cutorch.setDevice(opt.gpu)
print ("GPU Mode")
torch.setdefaulttensortype('torch.CudaTensor')
else
torch.setdefaulttensortype('torch.FloatTensor')
print ("CPU Mode")
end
-- load data
local data_loader = nil
if opt.align_data > 0 then
require 'data.aligned_data_loader'
data_loader = AlignedDataLoader()
else
require 'data.unaligned_data_loader'
data_loader = UnalignedDataLoader()
end
print( "DataLoader " .. data_loader:name() .. " was created.")
data_loader:Initialize(opt)
-- set batch/instance normalization
set_normalization(opt.norm)
--- timer
local epoch_tm = torch.Timer()
local tm = torch.Timer()
-- define model
local model = nil
local display_plot = nil
if opt.model == 'cycle_gan' then
assert(data_loader:name() == 'UnalignedDataLoader')
require 'models.cycle_gan_model'
model = CycleGANModel()
elseif opt.model == 'pix2pix' then
require 'models.pix2pix_model'
assert(data_loader:name() == 'AlignedDataLoader')
model = Pix2PixModel()
elseif opt.model == 'bigan' then
assert(data_loader:name() == 'UnalignedDataLoader')
require 'models.bigan_model'
model = BiGANModel()
elseif opt.model == 'content_gan' then
require 'models.content_gan_model'
assert(data_loader:name() == 'UnalignedDataLoader')
model = ContentGANModel()
else
error('Please specify a correct model')
end
-- print the model name
print('Model ' .. model:model_name() .. ' was specified.')
model:Initialize(opt)
-- set up the loss plot
require 'util/plot_util'
plotUtil = PlotUtil()
display_plot = model:DisplayPlot(opt)
plotUtil:Initialize(display_plot, opt.display_id, opt.name)
--------------------------------------------------------------------------------
-- Helper Functions
--------------------------------------------------------------------------------
function visualize_current_results()
local visuals = model:GetCurrentVisuals(opt)
for i,visual in ipairs(visuals) do
visualizer.disp_image(visual.img, opt.display_winsize,
opt.display_id+i, opt.name .. ' ' .. visual.label)
end
end
function save_current_results(epoch, counter)
local visuals = model:GetCurrentVisuals(opt)
for i,visual in ipairs(visuals) do
output_path = paths.concat(opt.visual_dir, 'train_epoch' .. epoch .. '_iter' .. counter .. '_' .. visual.label .. '.jpg')
visualizer.save_results(visual.img, output_path)
end
end
function print_current_errors(epoch, counter_in_epoch)
print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f '
.. '%s'):
format(epoch, ((counter_in_epoch-1) / opt.batchSize),
math.floor(math.min(data_loader:size(), opt.ntrain) / opt.batchSize),
tm:time().real / opt.batchSize,
data_loader:time_elapsed_to_fetch_data() / opt.batchSize,
model:GetCurrentErrorDescription()
))
end
function plot_current_errors(epoch, counter_ratio, opt)
local errs = model:GetCurrentErrors(opt)
local plot_vals = { epoch + counter_ratio}
plotUtil:Display(plot_vals, errs)
end
--------------------------------------------------------------------------------
-- Main Training Loop
--------------------------------------------------------------------------------
local counter = 0
local num_batches = math.floor(math.min(data_loader:size(), opt.ntrain) / opt.batchSize)
print('#training iterations: ' .. opt.niter+opt.niter_decay )
for epoch = 1, opt.niter+opt.niter_decay do
epoch_tm:reset()
for counter_in_epoch = 1, math.min(data_loader:size(), opt.ntrain), opt.batchSize do
tm:reset()
-- load a batch and run G on that batch
local real_dataA, real_dataB, _, _ = data_loader:GetNextBatch()
model:Forward({real_A=real_dataA, real_B=real_dataB}, opt)
-- run forward pass
opt.counter = counter
-- run backward pass
model:OptimizeParameters(opt)
-- display on the web server
if counter % opt.display_freq == 0 and opt.display_id > 0 then
visualize_current_results()
end
-- logging
if counter % opt.print_freq == 0 then
print_current_errors(epoch, counter_in_epoch)
plot_current_errors(epoch, counter_in_epoch/num_batches, opt)
end
-- save latest model
if counter % opt.save_latest_freq == 0 and counter > 0 then
print(('saving the latest model (epoch %d, iters %d)'):format(epoch, counter))
model:Save('latest', opt)
end
-- save latest results
if counter % opt.save_display_freq == 0 then
save_current_results(epoch, counter)
end
counter = counter + 1
end
-- save model at the end of epoch
if epoch % opt.save_epoch_freq == 0 then
print(('saving the model (epoch %d, iters %d)'):format(epoch, counter))
model:Save('latest', opt)
model:Save(epoch, opt)
end
-- print the timing information after each epoch
print(('End of epoch %d / %d \t Time Taken: %.3f'):
format(epoch, opt.niter+opt.niter_decay, epoch_tm:time().real))
-- update learning rate
if epoch > opt.niter then
model:UpdateLearningRate(opt)
end
-- refresh parameters
model:RefreshParameters(opt)
end