Files
disrupting-deepfakes/cyclegan/models/bigan_model.lua
T
Nataniel Ruiz a3fe19383a cyclegan
2019-12-25 17:21:03 -04:00

255 lines
8.8 KiB
Lua

local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
content = paths.dofile('../util/content_loss.lua')
BiGANModel = class('BiGANModel', 'BaseModel')
function BiGANModel:__init(conf)
BaseModel.__init(self, conf)
conf = conf or {}
end
function BiGANModel:model_name()
return 'BiGANModel'
end
function BiGANModel:InitializeStates(use_wgan)
optimState = {learningRate=opt.lr, beta1=opt.beta1,}
return optimState
end
-- Defines models and networks
function BiGANModel:Initialize(opt)
if opt.test == 0 then
self.realABPool = ImagePool(opt.pool_size)
self.fakeABPool = ImagePool(opt.pool_size)
end
-- define tensors
local d_input_nc = opt.input_nc + opt.output_nc
self.real_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
self.fake_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
-- load/define models
self.criterionGAN = nn.MSECriterion()
local netG, netE, netD = nil, nil, nil
if opt.continue_train == 1 then
if opt.test == 1 then -- which_epoch option exists in test mode
netG = util.load_test_model('G', opt)
netE = util.load_test_model('E', opt)
netD = util.load_test_model('D', opt)
else
netG = util.load_model('G', opt)
netE = util.load_model('E', opt)
netD = util.load_model('D', opt)
end
else
-- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch)
-- os.exit()
netD = defineD(d_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false) -- no sigmoid layer
print('netD...', netD)
netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch)
print('netG...', netG)
netE = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch)
print('netE...', netE)
end
self.netD = netD
self.netG = netG
self.netE = netE
-- define real/fake labels
netD_output_size = self.netD:forward(self.real_AB):size()
self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing
self.optimStateD = self:InitializeStates()
self.optimStateG = self:InitializeStates()
self.optimStateE = self:InitializeStates()
self.A_idx = {{}, {1, opt.input_nc}, {}, {}}
self.B_idx = {{}, {opt.input_nc+1, opt.input_nc+opt.output_nc}, {}, {}}
self:RefreshParameters()
print('---------- # Learnable Parameters --------------')
print(('G = %d'):format(self.parametersG:size(1)))
print(('E = %d'):format(self.parametersE:size(1)))
print(('D = %d'):format(self.parametersD:size(1)))
print('------------------------------------------------')
-- os.exit()
end
-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function BiGANModel:Forward(input, opt)
if opt.which_direction == 'BtoA' then
local temp = input.real_A
input.real_A = input.real_B
input.real_B = temp
end
self.real_AB[self.A_idx]:copy(input.real_A)
self.fake_AB[self.B_idx]:copy(input.real_B)
self.real_A = self.real_AB[self.A_idx]
self.real_B = self.fake_AB[self.B_idx]
self.fake_B = self.netG:forward(self.real_A):clone()
self.fake_A = self.netE:forward(self.real_B):clone()
self.real_AB[self.B_idx]:copy(self.fake_B) -- real_AB: real_A, fake_B -> real_label
self.fake_AB[self.A_idx]:copy(self.fake_A) -- fake_AB: fake_A, real_B -> fake_label
-- if opt.test == 0 then
-- self.real_AB = self.realABPool:Query(self.real_AB) -- batch history
-- self.fake_AB = self.fakeABPool:Query(self.fake_AB) -- batch history
-- end
end
-- create closure to evaluate f(X) and df/dX of discriminator
function BiGANModel:fDx_basic(x, gradParams, netD, real_AB, fake_AB, opt)
util.BiasZero(netD)
gradParams:zero()
-- Real log(D_A(B))
local output = netD:forward(real_AB):clone()
local errD_real = self.criterionGAN:forward(output, self.real_label)
local df_do = self.criterionGAN:backward(output, self.real_label)
netD:backward(real_AB, df_do)
-- Fake + log(1 - D_A(G(A)))
output = netD:forward(fake_AB):clone()
local errD_fake = self.criterionGAN:forward(output, self.fake_label)
local df_do2 = self.criterionGAN:backward(output, self.fake_label)
netD:backward(fake_AB, df_do2)
-- Compute loss
local errD = (errD_real + errD_fake) / 2.0
return errD, gradParams
end
function BiGANModel:fDx(x, opt)
-- use image pool that stores the old fake images
real_AB = self.realABPool:Query(self.real_AB)
fake_AB = self.fakeABPool:Query(self.fake_AB)
self.errD, gradParams = self:fDx_basic(x, self.gradParametersD, self.netD, real_AB, fake_AB, opt)
return self.errD, gradParams
end
function BiGANModel:fGx_basic(x, netG, netD, gradParametersG, opt)
util.BiasZero(netG)
util.BiasZero(netD)
gradParametersG:zero()
-- First. G(A) should fake the discriminator
local output = netD:forward(self.real_AB):clone()
local errG = self.criterionGAN:forward(output, self.fake_label)
local dgan_loss_dd = self.criterionGAN:backward(output, self.fake_label)
local dgan_loss_do = netD:updateGradInput(self.real_AB, dgan_loss_dd)
netG:backward(self.real_A, dgan_loss_do[self.B_idx]) -- real_AB: real_A, fake_B -> real_label
return gradParametersG, errG
end
function BiGANModel:fGx(x, opt)
self.gradParametersG, self.errG = self:fGx_basic(x, self.netG, self.netD,
self.gradParametersG, opt)
return self.errG, self.gradParametersG
end
function BiGANModel:fEx_basic(x, netE, netD, gradParametersE, opt)
util.BiasZero(netE)
util.BiasZero(netD)
gradParametersE:zero()
-- First. G(A) should fake the discriminator
local output = netD:forward(self.fake_AB):clone()
local errE= self.criterionGAN:forward(output, self.real_label)
local dgan_loss_dd = self.criterionGAN:backward(output, self.real_label)
local dgan_loss_do = netD:updateGradInput(self.fake_AB, dgan_loss_dd)
netE:backward(self.real_B, dgan_loss_do[self.A_idx])-- fake_AB: fake_A, real_B -> fake_label
return gradParametersE, errE
end
function BiGANModel:fEx(x, opt)
self.gradParametersE, self.errE = self:fEx_basic(x, self.netE, self.netD,
self.gradParametersE, opt)
return self.errE, self.gradParametersE
end
function BiGANModel:OptimizeParameters(opt)
local fG = function(x) return self:fGx(x, opt) end
local fE = function(x) return self:fEx(x, opt) end
local fD = function(x) return self:fDx(x, opt) end
optim.adam(fD, self.parametersD, self.optimStateD)
optim.adam(fG, self.parametersG, self.optimStateG)
optim.adam(fE, self.parametersE, self.optimStateE)
end
function BiGANModel:RefreshParameters()
self.parametersD, self.gradParametersD = nil, nil -- nil them to avoid spiking memory
self.parametersG, self.gradParametersG = nil, nil
self.parametersE, self.gradParametersE = nil, nil
-- define parameters of optimization
self.parametersD, self.gradParametersD = self.netD:getParameters()
self.parametersG, self.gradParametersG = self.netG:getParameters()
self.parametersE, self.gradParametersE = self.netE:getParameters()
end
function BiGANModel:Save(prefix, opt)
util.save_model(self.netG, prefix .. '_net_G.t7', 1)
util.save_model(self.netE, prefix .. '_net_E.t7', 1)
util.save_model(self.netD, prefix .. '_net_D.t7', 1)
end
function BiGANModel:GetCurrentErrorDescription()
description = ('D: %.4f G: %.4f E: %.4f'):format(
self.errD and self.errD or -1,
self.errG and self.errG or -1,
self.errE and self.errE or -1)
return description
end
function BiGANModel:GetCurrentErrors()
local errors = {errD=self.errD, errG=self.errG, errE=self.errE}
return errors
end
-- returns a string that describes the display plot configuration
function BiGANModel:DisplayPlot(opt)
return 'errD,errG,errE'
end
function BiGANModel:UpdateLearningRate(opt)
local lrd = opt.lr / opt.niter_decay
local old_lr = self.optimStateD['learningRate']
local lr = old_lr - lrd
self.optimStateD['learningRate'] = lr
self.optimStateG['learningRate'] = lr
self.optimStateE['learningRate'] = lr
print(('update learning rate: %f -> %f'):format(old_lr, lr))
end
local function MakeIm3(im)
-- print('before im_size', im:size())
local im3 = nil
if im:size(2) == 1 then
im3 = torch.repeatTensor(im, 1,3,1,1)
else
im3 = im
end
-- print('after im_size', im:size())
-- print('after im3_size', im3:size())
return im3
end
function BiGANModel:GetCurrentVisuals(opt, size)
if not size then
size = opt.display_winsize
end
local visuals = {}
table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'})
table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'})
return visuals
end