255 lines
8.8 KiB
Lua
255 lines
8.8 KiB
Lua
local class = require 'class'
|
|
require 'models.base_model'
|
|
require 'models.architectures'
|
|
require 'util.image_pool'
|
|
util = paths.dofile('../util/util.lua')
|
|
content = paths.dofile('../util/content_loss.lua')
|
|
|
|
BiGANModel = class('BiGANModel', 'BaseModel')
|
|
|
|
function BiGANModel:__init(conf)
|
|
BaseModel.__init(self, conf)
|
|
conf = conf or {}
|
|
end
|
|
|
|
function BiGANModel:model_name()
|
|
return 'BiGANModel'
|
|
end
|
|
|
|
function BiGANModel:InitializeStates(use_wgan)
|
|
optimState = {learningRate=opt.lr, beta1=opt.beta1,}
|
|
return optimState
|
|
end
|
|
-- Defines models and networks
|
|
function BiGANModel:Initialize(opt)
|
|
if opt.test == 0 then
|
|
self.realABPool = ImagePool(opt.pool_size)
|
|
self.fakeABPool = ImagePool(opt.pool_size)
|
|
end
|
|
-- define tensors
|
|
local d_input_nc = opt.input_nc + opt.output_nc
|
|
self.real_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
|
|
self.fake_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
|
|
-- load/define models
|
|
self.criterionGAN = nn.MSECriterion()
|
|
|
|
local netG, netE, netD = nil, nil, nil
|
|
if opt.continue_train == 1 then
|
|
if opt.test == 1 then -- which_epoch option exists in test mode
|
|
netG = util.load_test_model('G', opt)
|
|
netE = util.load_test_model('E', opt)
|
|
netD = util.load_test_model('D', opt)
|
|
else
|
|
netG = util.load_model('G', opt)
|
|
netE = util.load_model('E', opt)
|
|
netD = util.load_model('D', opt)
|
|
end
|
|
else
|
|
-- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch)
|
|
-- os.exit()
|
|
netD = defineD(d_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false) -- no sigmoid layer
|
|
print('netD...', netD)
|
|
netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch)
|
|
print('netG...', netG)
|
|
netE = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch)
|
|
print('netE...', netE)
|
|
|
|
end
|
|
|
|
self.netD = netD
|
|
self.netG = netG
|
|
self.netE = netE
|
|
|
|
-- define real/fake labels
|
|
netD_output_size = self.netD:forward(self.real_AB):size()
|
|
self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
|
|
self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing
|
|
|
|
self.optimStateD = self:InitializeStates()
|
|
self.optimStateG = self:InitializeStates()
|
|
self.optimStateE = self:InitializeStates()
|
|
self.A_idx = {{}, {1, opt.input_nc}, {}, {}}
|
|
self.B_idx = {{}, {opt.input_nc+1, opt.input_nc+opt.output_nc}, {}, {}}
|
|
self:RefreshParameters()
|
|
|
|
print('---------- # Learnable Parameters --------------')
|
|
print(('G = %d'):format(self.parametersG:size(1)))
|
|
print(('E = %d'):format(self.parametersE:size(1)))
|
|
print(('D = %d'):format(self.parametersD:size(1)))
|
|
print('------------------------------------------------')
|
|
-- os.exit()
|
|
end
|
|
|
|
-- Runs the forward pass of the network and
|
|
-- saves the result to member variables of the class
|
|
function BiGANModel:Forward(input, opt)
|
|
if opt.which_direction == 'BtoA' then
|
|
local temp = input.real_A
|
|
input.real_A = input.real_B
|
|
input.real_B = temp
|
|
end
|
|
self.real_AB[self.A_idx]:copy(input.real_A)
|
|
self.fake_AB[self.B_idx]:copy(input.real_B)
|
|
self.real_A = self.real_AB[self.A_idx]
|
|
self.real_B = self.fake_AB[self.B_idx]
|
|
self.fake_B = self.netG:forward(self.real_A):clone()
|
|
self.fake_A = self.netE:forward(self.real_B):clone()
|
|
self.real_AB[self.B_idx]:copy(self.fake_B) -- real_AB: real_A, fake_B -> real_label
|
|
self.fake_AB[self.A_idx]:copy(self.fake_A) -- fake_AB: fake_A, real_B -> fake_label
|
|
-- if opt.test == 0 then
|
|
-- self.real_AB = self.realABPool:Query(self.real_AB) -- batch history
|
|
-- self.fake_AB = self.fakeABPool:Query(self.fake_AB) -- batch history
|
|
-- end
|
|
end
|
|
|
|
-- create closure to evaluate f(X) and df/dX of discriminator
|
|
function BiGANModel:fDx_basic(x, gradParams, netD, real_AB, fake_AB, opt)
|
|
util.BiasZero(netD)
|
|
gradParams:zero()
|
|
-- Real log(D_A(B))
|
|
local output = netD:forward(real_AB):clone()
|
|
local errD_real = self.criterionGAN:forward(output, self.real_label)
|
|
local df_do = self.criterionGAN:backward(output, self.real_label)
|
|
netD:backward(real_AB, df_do)
|
|
-- Fake + log(1 - D_A(G(A)))
|
|
output = netD:forward(fake_AB):clone()
|
|
local errD_fake = self.criterionGAN:forward(output, self.fake_label)
|
|
local df_do2 = self.criterionGAN:backward(output, self.fake_label)
|
|
netD:backward(fake_AB, df_do2)
|
|
-- Compute loss
|
|
local errD = (errD_real + errD_fake) / 2.0
|
|
return errD, gradParams
|
|
end
|
|
|
|
|
|
function BiGANModel:fDx(x, opt)
|
|
-- use image pool that stores the old fake images
|
|
real_AB = self.realABPool:Query(self.real_AB)
|
|
fake_AB = self.fakeABPool:Query(self.fake_AB)
|
|
self.errD, gradParams = self:fDx_basic(x, self.gradParametersD, self.netD, real_AB, fake_AB, opt)
|
|
return self.errD, gradParams
|
|
end
|
|
|
|
|
|
|
|
function BiGANModel:fGx_basic(x, netG, netD, gradParametersG, opt)
|
|
util.BiasZero(netG)
|
|
util.BiasZero(netD)
|
|
gradParametersG:zero()
|
|
|
|
-- First. G(A) should fake the discriminator
|
|
local output = netD:forward(self.real_AB):clone()
|
|
local errG = self.criterionGAN:forward(output, self.fake_label)
|
|
local dgan_loss_dd = self.criterionGAN:backward(output, self.fake_label)
|
|
local dgan_loss_do = netD:updateGradInput(self.real_AB, dgan_loss_dd)
|
|
netG:backward(self.real_A, dgan_loss_do[self.B_idx]) -- real_AB: real_A, fake_B -> real_label
|
|
return gradParametersG, errG
|
|
end
|
|
|
|
|
|
function BiGANModel:fGx(x, opt)
|
|
self.gradParametersG, self.errG = self:fGx_basic(x, self.netG, self.netD,
|
|
self.gradParametersG, opt)
|
|
return self.errG, self.gradParametersG
|
|
end
|
|
|
|
|
|
function BiGANModel:fEx_basic(x, netE, netD, gradParametersE, opt)
|
|
util.BiasZero(netE)
|
|
util.BiasZero(netD)
|
|
gradParametersE:zero()
|
|
|
|
-- First. G(A) should fake the discriminator
|
|
local output = netD:forward(self.fake_AB):clone()
|
|
local errE= self.criterionGAN:forward(output, self.real_label)
|
|
local dgan_loss_dd = self.criterionGAN:backward(output, self.real_label)
|
|
local dgan_loss_do = netD:updateGradInput(self.fake_AB, dgan_loss_dd)
|
|
netE:backward(self.real_B, dgan_loss_do[self.A_idx])-- fake_AB: fake_A, real_B -> fake_label
|
|
return gradParametersE, errE
|
|
end
|
|
|
|
|
|
function BiGANModel:fEx(x, opt)
|
|
self.gradParametersE, self.errE = self:fEx_basic(x, self.netE, self.netD,
|
|
self.gradParametersE, opt)
|
|
return self.errE, self.gradParametersE
|
|
end
|
|
|
|
|
|
function BiGANModel:OptimizeParameters(opt)
|
|
local fG = function(x) return self:fGx(x, opt) end
|
|
local fE = function(x) return self:fEx(x, opt) end
|
|
local fD = function(x) return self:fDx(x, opt) end
|
|
optim.adam(fD, self.parametersD, self.optimStateD)
|
|
optim.adam(fG, self.parametersG, self.optimStateG)
|
|
optim.adam(fE, self.parametersE, self.optimStateE)
|
|
end
|
|
|
|
function BiGANModel:RefreshParameters()
|
|
self.parametersD, self.gradParametersD = nil, nil -- nil them to avoid spiking memory
|
|
self.parametersG, self.gradParametersG = nil, nil
|
|
self.parametersE, self.gradParametersE = nil, nil
|
|
-- define parameters of optimization
|
|
self.parametersD, self.gradParametersD = self.netD:getParameters()
|
|
self.parametersG, self.gradParametersG = self.netG:getParameters()
|
|
self.parametersE, self.gradParametersE = self.netE:getParameters()
|
|
end
|
|
|
|
function BiGANModel:Save(prefix, opt)
|
|
util.save_model(self.netG, prefix .. '_net_G.t7', 1)
|
|
util.save_model(self.netE, prefix .. '_net_E.t7', 1)
|
|
util.save_model(self.netD, prefix .. '_net_D.t7', 1)
|
|
end
|
|
|
|
function BiGANModel:GetCurrentErrorDescription()
|
|
description = ('D: %.4f G: %.4f E: %.4f'):format(
|
|
self.errD and self.errD or -1,
|
|
self.errG and self.errG or -1,
|
|
self.errE and self.errE or -1)
|
|
return description
|
|
end
|
|
|
|
function BiGANModel:GetCurrentErrors()
|
|
local errors = {errD=self.errD, errG=self.errG, errE=self.errE}
|
|
return errors
|
|
end
|
|
|
|
-- returns a string that describes the display plot configuration
|
|
function BiGANModel:DisplayPlot(opt)
|
|
return 'errD,errG,errE'
|
|
end
|
|
function BiGANModel:UpdateLearningRate(opt)
|
|
local lrd = opt.lr / opt.niter_decay
|
|
local old_lr = self.optimStateD['learningRate']
|
|
local lr = old_lr - lrd
|
|
self.optimStateD['learningRate'] = lr
|
|
self.optimStateG['learningRate'] = lr
|
|
self.optimStateE['learningRate'] = lr
|
|
print(('update learning rate: %f -> %f'):format(old_lr, lr))
|
|
end
|
|
|
|
local function MakeIm3(im)
|
|
-- print('before im_size', im:size())
|
|
local im3 = nil
|
|
if im:size(2) == 1 then
|
|
im3 = torch.repeatTensor(im, 1,3,1,1)
|
|
else
|
|
im3 = im
|
|
end
|
|
-- print('after im_size', im:size())
|
|
-- print('after im3_size', im3:size())
|
|
return im3
|
|
end
|
|
function BiGANModel:GetCurrentVisuals(opt, size)
|
|
if not size then
|
|
size = opt.display_winsize
|
|
end
|
|
|
|
local visuals = {}
|
|
table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
|
|
table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
|
|
table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'})
|
|
table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'})
|
|
return visuals
|
|
end
|