Files
disrupting-deepfakes/cyclegan/models/content_gan_model.lua
T
Nataniel Ruiz a3fe19383a cyclegan
2019-12-25 17:21:03 -04:00

202 lines
7.2 KiB
Lua

local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
content = paths.dofile('../util/content_loss.lua')
ContentGANModel = class('ContentGANModel', 'BaseModel')
function ContentGANModel:__init(conf)
BaseModel.__init(self, conf)
conf = conf or {}
end
function ContentGANModel:model_name()
return 'ContentGANModel'
end
function ContentGANModel:InitializeStates()
local optimState = {learningRate=opt.lr, beta1=opt.beta1,}
return optimState
end
-- Defines models and networks
function ContentGANModel:Initialize(opt)
if opt.test == 0 then
self.fakePool = ImagePool(opt.pool_size)
end
-- define tensors
self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
self.real_B = self.fake_B:clone() --torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
-- load/define models
self.criterionGAN = nn.MSECriterion()
self.criterionContent = nn.AbsCriterion()
self.contentFunc = content.defineContent(opt.content_loss, opt.layer_name)
self.netG, self.netD = nil, nil
if opt.continue_train == 1 then
if opt.which_epoch then -- which_epoch option exists in test mode
self.netG = util.load_test_model('G_A', opt)
self.netD = util.load_test_model('D_A', opt)
else
self.netG = util.load_model('G_A', opt)
self.netD = util.load_model('D_A', opt)
end
else
self.netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG)
print('netG...', self.netG)
self.netD = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false)
print('netD...', self.netD)
end
-- define real/fake labels
netD_output_size = self.netD:forward(self.real_A):size()
self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing
self.optimStateD = self:InitializeStates()
self.optimStateG = self:InitializeStates()
self:RefreshParameters()
print('---------- # Learnable Parameters --------------')
print(('G = %d'):format(self.parametersG:size(1)))
print(('D = %d'):format(self.parametersD:size(1)))
print('------------------------------------------------')
-- os.exit()
end
-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function ContentGANModel:Forward(input, opt)
if opt.which_direction == 'BtoA' then
local temp = input.real_A
input.real_A = input.real_B
input.real_B = temp
end
self.real_A:copy(input.real_A)
self.real_B:copy(input.real_B)
self.fake_B = self.netG:forward(self.real_A):clone()
-- output = {self.fake_B}
output = {}
-- if opt.test == 1 then
-- end
return output
end
-- create closure to evaluate f(X) and df/dX of discriminator
function ContentGANModel:fDx_basic(x, gradParams, netD, netG,
real_target, fake_target, opt)
util.BiasZero(netD)
util.BiasZero(netG)
gradParams:zero()
local errD_real, errD_rec, errD_fake, errD = 0, 0, 0, 0
-- Real log(D_A(B))
local output = netD:forward(real_target)
errD_real = self.criterionGAN:forward(output, self.real_label)
df_do = self.criterionGAN:backward(output, self.real_label)
netD:backward(real_target, df_do)
-- Fake + log(1 - D_A(G_A(A)))
output = netD:forward(fake_target)
errD_fake = self.criterionGAN:forward(output, self.fake_label)
df_do = self.criterionGAN:backward(output, self.fake_label)
netD:backward(fake_target, df_do)
errD = (errD_real + errD_fake) / 2.0
-- print('errD', errD
return errD, gradParams
end
function ContentGANModel:fDx(x, opt)
fake_B = self.fakePool:Query(self.fake_B)
self.errD, gradParams = self:fDx_basic(x, self.gradparametersD, self.netD, self.netG,
self.real_B, fake_B, opt)
return self.errD, gradParams
end
function ContentGANModel:fGx_basic(x, netG_source, netD_source, real_source, real_target, fake_target,
gradParametersG_source, opt)
util.BiasZero(netD_source)
util.BiasZero(netG_source)
gradParametersG_source:zero()
-- GAN loss
-- local df_d_GAN = torch.zeros(fake_target:size())
-- local errGAN = 0
-- local errRec = 0
--- Domain GAN loss: D_A(G_A(A))
local output = netD_source.output -- [hack] forward was already executed in fDx, so save computation netD_source:forward(fake_B) ---
local errGAN = self.criterionGAN:forward(output, self.real_label)
local df_do = self.criterionGAN:backward(output, self.real_label)
local df_d_GAN = netD_source:updateGradInput(fake_target, df_do) ---:narrow(2,fake_AB:size(2)-output_nc+1, output_nc)
-- content loss
-- print('content_loss', opt.content_loss)
-- function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight)
local errContent, df_d_content = content.lossUpdate(self.criterionContent, real_source, fake_target, self.contentFunc, opt.content_loss, opt.lambda_A)
netG_source:forward(real_source)
netG_source:backward(real_source, df_d_GAN + df_d_content)
-- print('errD', errGAN)
return gradParametersG_source, errGAN, errContent
end
function ContentGANModel:fGx(x, opt)
self.gradparametersG, self.errG, self.errCont =
self:fGx_basic(x, self.netG, self.netD,
self.real_A, self.real_B, self.fake_B,
self.gradparametersG, opt)
return self.errG, self.gradparametersG
end
function ContentGANModel:OptimizeParameters(opt)
local fDx = function(x) return self:fDx(x, opt) end
local fGx = function(x) return self:fGx(x, opt) end
optim.adam(fDx, self.parametersD, self.optimStateD)
optim.adam(fGx, self.parametersG, self.optimStateG)
end
function ContentGANModel:RefreshParameters()
self.parametersD, self.gradparametersD = nil, nil -- nil them to avoid spiking memory
self.parametersG, self.gradparametersG = nil, nil
-- define parameters of optimization
self.parametersG, self.gradparametersG = self.netG:getParameters()
self.parametersD, self.gradparametersD = self.netD:getParameters()
end
function ContentGANModel:Save(prefix, opt)
util.save_model(self.netG, prefix .. '_net_G_A.t7', 1.0)
util.save_model(self.netD, prefix .. '_net_D_A.t7', 1.0)
end
function ContentGANModel:GetCurrentErrorDescription()
description = ('G: %.4f D: %.4f Content: %.4f'):format(self.errG and self.errG or -1,
self.errD and self.errD or -1,
self.errCont and self.errCont or -1)
return description
end
function ContentGANModel:GetCurrentErrors()
local errors = {errG=self.errG and self.errG or -1, errD=self.errD and self.errD or -1,
errCont=self.errCont and self.errCont or -1}
return errors
end
-- returns a string that describes the display plot configuration
function ContentGANModel:DisplayPlot(opt)
return 'errG,errD,errCont'
end
function ContentGANModel:GetCurrentVisuals(opt, size)
if not size then
size = opt.display_winsize
end
local visuals = {}
table.insert(visuals, {img=self.real_A, label='real_A'})
table.insert(visuals, {img=self.fake_B, label='fake_B'})
table.insert(visuals, {img=self.real_B, label='real_B'})
return visuals
end