Files
disrupting-deepfakes/cyclegan/models/cycle_gan_model.lua
T
Nataniel Ruiz a3fe19383a cyclegan
2019-12-25 17:21:03 -04:00

325 lines
13 KiB
Lua

local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
CycleGANModel = class('CycleGANModel', 'BaseModel')
function CycleGANModel:__init(conf)
BaseModel.__init(self, conf)
conf = conf or {}
end
function CycleGANModel:model_name()
return 'CycleGANModel'
end
function CycleGANModel:InitializeStates(use_wgan)
optimState = {learningRate=opt.lr, beta1=opt.beta1,}
return optimState
end
-- Defines models and networks
function CycleGANModel:Initialize(opt)
if opt.test == 0 then
self.fakeAPool = ImagePool(opt.pool_size)
self.fakeBPool = ImagePool(opt.pool_size)
end
-- define tensors
if opt.test == 0 then -- allocate tensors for training
self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
self.real_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
self.fake_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
self.rec_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
self.rec_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
end
-- load/define models
local use_lsgan = ((opt.use_lsgan ~= nil) and (opt.use_lsgan == 1))
if not use_lsgan then
self.criterionGAN = nn.BCECriterion()
else
self.criterionGAN = nn.MSECriterion()
end
self.criterionRec = nn.AbsCriterion()
local netG_A, netD_A, netG_B, netD_B = nil, nil, nil, nil
if opt.continue_train == 1 then
if opt.test == 1 then -- test mode
netG_A = util.load_test_model('G_A', opt)
netG_B = util.load_test_model('G_B', opt)
--setup optnet to save a little bit of memory
if opt.use_optnet == 1 then
local sample_input = torch.randn(1, opt.input_nc, 2, 2)
local optnet = require 'optnet'
optnet.optimizeMemory(netG_A, sample_input, {inplace=true, reuseBuffers=true})
optnet.optimizeMemory(netG_B, sample_input, {inplace=true, reuseBuffers=true})
end
else
netG_A = util.load_model('G_A', opt)
netG_B = util.load_model('G_B', opt)
netD_A = util.load_model('D_A', opt)
netD_B = util.load_model('D_B', opt)
end
else
local use_sigmoid = (not use_lsgan)
-- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch)
-- os.exit()
netG_A = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch)
print('netG_A...', netG_A)
netD_A = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid) -- no sigmoid layer
print('netD_A...', netD_A)
netG_B = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch)
print('netG_B...', netG_B)
netD_B = defineD(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid) -- no sigmoid layer
print('netD_B', netD_B)
end
self.netD_A = netD_A
self.netG_A = netG_A
self.netG_B = netG_B
self.netD_B = netD_B
-- define real/fake labels
if opt.test == 0 then
local D_A_size = self.netD_A:forward(self.real_B):size() -- hack: assume D_size_A = D_size_B
self.fake_label_A = torch.Tensor(D_A_size):fill(0.0)
self.real_label_A = torch.Tensor(D_A_size):fill(1.0) -- no soft smoothing
local D_B_size = self.netD_B:forward(self.real_A):size() -- hack: assume D_size_A = D_size_B
self.fake_label_B = torch.Tensor(D_B_size):fill(0.0)
self.real_label_B = torch.Tensor(D_B_size):fill(1.0) -- no soft smoothing
self.optimStateD_A = self:InitializeStates()
self.optimStateG_A = self:InitializeStates()
self.optimStateD_B = self:InitializeStates()
self.optimStateG_B = self:InitializeStates()
self:RefreshParameters()
print('---------- # Learnable Parameters --------------')
print(('G_A = %d'):format(self.parametersG_A:size(1)))
print(('D_A = %d'):format(self.parametersD_A:size(1)))
print(('G_B = %d'):format(self.parametersG_B:size(1)))
print(('D_B = %d'):format(self.parametersD_B:size(1)))
print('------------------------------------------------')
end
end
-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function CycleGANModel:Forward(input, opt)
if opt.which_direction == 'BtoA' then
local temp = input.real_A:clone()
input.real_A = input.real_B:clone()
input.real_B = temp
end
if opt.test == 0 then
self.real_A:copy(input.real_A)
self.real_B:copy(input.real_B)
end
if opt.test == 1 then -- forward for test
if opt.gpu > 0 then
self.real_A = input.real_A:cuda()
self.real_B = input.real_B:cuda()
else
self.real_A = input.real_A:clone()
self.real_B = input.real_B:clone()
end
self.fake_B = self.netG_A:forward(self.real_A):clone()
self.fake_A = self.netG_B:forward(self.real_B):clone()
self.rec_A = self.netG_B:forward(self.fake_B):clone()
self.rec_B = self.netG_A:forward(self.fake_A):clone()
end
end
-- create closure to evaluate f(X) and df/dX of discriminator
function CycleGANModel:fDx_basic(x, gradParams, netD, netG, real, fake, real_label, fake_label, opt)
util.BiasZero(netD)
util.BiasZero(netG)
gradParams:zero()
-- Real log(D_A(B))
local output = netD:forward(real)
local errD_real = self.criterionGAN:forward(output, real_label)
local df_do = self.criterionGAN:backward(output, real_label)
netD:backward(real, df_do)
-- Fake + log(1 - D_A(G_A(A)))
output = netD:forward(fake)
local errD_fake = self.criterionGAN:forward(output, fake_label)
local df_do2 = self.criterionGAN:backward(output, fake_label)
netD:backward(fake, df_do2)
-- Compute loss
local errD = (errD_real + errD_fake) / 2.0
return errD, gradParams
end
function CycleGANModel:fDAx(x, opt)
-- use image pool that stores the old fake images
fake_B = self.fakeBPool:Query(self.fake_B)
self.errD_A, gradParams = self:fDx_basic(x, self.gradparametersD_A, self.netD_A, self.netG_A,
self.real_B, fake_B, self.real_label_A, self.fake_label_A, opt)
return self.errD_A, gradParams
end
function CycleGANModel:fDBx(x, opt)
-- use image pool that stores the old fake images
fake_A = self.fakeAPool:Query(self.fake_A)
self.errD_B, gradParams = self:fDx_basic(x, self.gradparametersD_B, self.netD_B, self.netG_B,
self.real_A, fake_A, self.real_label_B, self.fake_label_B, opt)
return self.errD_B, gradParams
end
function CycleGANModel:fGx_basic(x, gradParams, netG, netD, netE, real, real2, real_label, lambda1, lambda2, opt)
util.BiasZero(netD)
util.BiasZero(netG)
util.BiasZero(netE) -- inverse mapping
gradParams:zero()
-- G should be identity if real2 is fed.
local errI = nil
local identity = nil
if opt.lambda_identity > 0 then
identity = netG:forward(real2):clone()
errI = self.criterionRec:forward(identity, real2) * lambda2 * opt.lambda_identity
local didentity_loss_do = self.criterionRec:backward(identity, real2):mul(lambda2):mul(opt.lambda_identity)
netG:backward(real2, didentity_loss_do)
end
--- GAN loss: D_A(G_A(A))
local fake = netG:forward(real):clone()
local output = netD:forward(fake)
local errG = self.criterionGAN:forward(output, real_label)
local df_do1 = self.criterionGAN:backward(output, real_label)
local df_d_GAN = netD:updateGradInput(fake, df_do1) --
-- forward cycle loss
local rec = netE:forward(fake):clone()
local errRec = self.criterionRec:forward(rec, real) * lambda1
local df_do2 = self.criterionRec:backward(rec, real):mul(lambda1)
local df_do_rec = netE:updateGradInput(fake, df_do2)
netG:backward(real, df_d_GAN + df_do_rec)
-- backward cycle loss
local fake2 = netE:forward(real2)--:clone()
local rec2 = netG:forward(fake2)--:clone()
local errAdapt = self.criterionRec:forward(rec2, real2) * lambda2
local df_do_coadapt = self.criterionRec:backward(rec2, real2):mul(lambda2)
netG:backward(fake2, df_do_coadapt)
return gradParams, errG, errRec, errI, fake, rec, identity
end
function CycleGANModel:fGAx(x, opt)
self.gradparametersG_A, self.errG_A, self.errRec_A, self.errI_A, self.fake_B, self.rec_A, self.identity_B =
self:fGx_basic(x, self.gradparametersG_A, self.netG_A, self.netD_A, self.netG_B, self.real_A, self.real_B,
self.real_label_A, opt.lambda_A, opt.lambda_B, opt)
return self.errG_A, self.gradparametersG_A
end
function CycleGANModel:fGBx(x, opt)
self.gradparametersG_B, self.errG_B, self.errRec_B, self.errI_B, self.fake_A, self.rec_B, self.identity_A =
self:fGx_basic(x, self.gradparametersG_B, self.netG_B, self.netD_B, self.netG_A, self.real_B, self.real_A,
self.real_label_B, opt.lambda_B, opt.lambda_A, opt)
return self.errG_B, self.gradparametersG_B
end
function CycleGANModel:OptimizeParameters(opt)
local fDA = function(x) return self:fDAx(x, opt) end
local fGA = function(x) return self:fGAx(x, opt) end
local fDB = function(x) return self:fDBx(x, opt) end
local fGB = function(x) return self:fGBx(x, opt) end
optim.adam(fGA, self.parametersG_A, self.optimStateG_A)
optim.adam(fDA, self.parametersD_A, self.optimStateD_A)
optim.adam(fGB, self.parametersG_B, self.optimStateG_B)
optim.adam(fDB, self.parametersD_B, self.optimStateD_B)
end
function CycleGANModel:RefreshParameters()
self.parametersD_A, self.gradparametersD_A = nil, nil -- nil them to avoid spiking memory
self.parametersG_A, self.gradparametersG_A = nil, nil
self.parametersG_B, self.gradparametersG_B = nil, nil
self.parametersD_B, self.gradparametersD_B = nil, nil
-- define parameters of optimization
self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters()
self.parametersD_A, self.gradparametersD_A = self.netD_A:getParameters()
self.parametersG_B, self.gradparametersG_B = self.netG_B:getParameters()
self.parametersD_B, self.gradparametersD_B = self.netD_B:getParameters()
end
function CycleGANModel:Save(prefix, opt)
util.save_model(self.netG_A, prefix .. '_net_G_A.t7', 1)
util.save_model(self.netD_A, prefix .. '_net_D_A.t7', 1)
util.save_model(self.netG_B, prefix .. '_net_G_B.t7', 1)
util.save_model(self.netD_B, prefix .. '_net_D_B.t7', 1)
end
function CycleGANModel:GetCurrentErrorDescription()
description = ('[A] G: %.4f D: %.4f Rec: %.4f I: %.4f || [B] G: %.4f D: %.4f Rec: %.4f I:%.4f'):format(
self.errG_A and self.errG_A or -1,
self.errD_A and self.errD_A or -1,
self.errRec_A and self.errRec_A or -1,
self.errI_A and self.errI_A or -1,
self.errG_B and self.errG_B or -1,
self.errD_B and self.errD_B or -1,
self.errRec_B and self.errRec_B or -1,
self.errI_B and self.errI_B or -1)
return description
end
function CycleGANModel:GetCurrentErrors()
local errors = {errG_A=self.errG_A, errD_A=self.errD_A, errRec_A=self.errRec_A, errI_A=self.errI_A,
errG_B=self.errG_B, errD_B=self.errD_B, errRec_B=self.errRec_B, errI_B=self.errI_B}
return errors
end
-- returns a string that describes the display plot configuration
function CycleGANModel:DisplayPlot(opt)
if opt.lambda_identity > 0 then
return 'errG_A,errD_A,errRec_A,errI_A,errG_B,errD_B,errRec_B,errI_B'
else
return 'errG_A,errD_A,errRec_A,errG_B,errD_B,errRec_B'
end
end
function CycleGANModel:UpdateLearningRate(opt)
local lrd = opt.lr / opt.niter_decay
local old_lr = self.optimStateD_A['learningRate']
local lr = old_lr - lrd
self.optimStateD_A['learningRate'] = lr
self.optimStateD_B['learningRate'] = lr
self.optimStateG_A['learningRate'] = lr
self.optimStateG_B['learningRate'] = lr
print(('update learning rate: %f -> %f'):format(old_lr, lr))
end
local function MakeIm3(im)
if im:size(2) == 1 then
local im3 = torch.repeatTensor(im, 1,3,1,1)
return im3
else
return im
end
end
function CycleGANModel:GetCurrentVisuals(opt, size)
local visuals = {}
table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
table.insert(visuals, {img=MakeIm3(self.rec_A), label='rec_A'})
if opt.test == 0 and opt.lambda_identity > 0 then
table.insert(visuals, {img=MakeIm3(self.identity_A), label='identity_A'})
end
table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'})
table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'})
table.insert(visuals, {img=MakeIm3(self.rec_B), label='rec_B'})
if opt.test == 0 and opt.lambda_identity > 0 then
table.insert(visuals, {img=MakeIm3(self.identity_B), label='identity_B'})
end
return visuals
end