385 lines
18 KiB
Lua
385 lines
18 KiB
Lua
require 'nngraph'
|
|
|
|
|
|
----------------------------------------------------------------------------
|
|
local function weights_init(m)
|
|
local name = torch.type(m)
|
|
if name:find('Convolution') then
|
|
m.weight:normal(0.0, 0.02)
|
|
m.bias:fill(0)
|
|
elseif name:find('Normalization') then
|
|
if m.weight then m.weight:normal(1.0, 0.02) end
|
|
if m.bias then m.bias:fill(0) end
|
|
end
|
|
end
|
|
|
|
|
|
normalization = nil
|
|
|
|
function set_normalization(norm)
|
|
if norm == 'instance' then
|
|
require 'util.InstanceNormalization'
|
|
print('use InstanceNormalization')
|
|
normalization = nn.InstanceNormalization
|
|
elseif norm == 'batch' then
|
|
print('use SpatialBatchNormalization')
|
|
normalization = nn.SpatialBatchNormalization
|
|
end
|
|
end
|
|
|
|
function defineG(input_nc, output_nc, ngf, which_model_netG, nz, arch)
|
|
local netG = nil
|
|
if which_model_netG == "encoder_decoder" then netG = defineG_encoder_decoder(input_nc, output_nc, ngf)
|
|
elseif which_model_netG == "unet128" then netG = defineG_unet128(input_nc, output_nc, ngf)
|
|
elseif which_model_netG == "unet256" then netG = defineG_unet256(input_nc, output_nc, ngf)
|
|
elseif which_model_netG == "resnet_6blocks" then netG = defineG_resnet_6blocks(input_nc, output_nc, ngf)
|
|
elseif which_model_netG == "resnet_9blocks" then netG = defineG_resnet_9blocks(input_nc, output_nc, ngf)
|
|
else error("unsupported netG model")
|
|
end
|
|
netG:apply(weights_init)
|
|
|
|
return netG
|
|
end
|
|
|
|
function defineD(input_nc, ndf, which_model_netD, n_layers_D, use_sigmoid)
|
|
local netD = nil
|
|
if which_model_netD == "basic" then netD = defineD_basic(input_nc, ndf, use_sigmoid)
|
|
elseif which_model_netD == "imageGAN" then netD = defineD_imageGAN(input_nc, ndf, use_sigmoid)
|
|
elseif which_model_netD == "n_layers" then netD = defineD_n_layers(input_nc, ndf, n_layers_D, use_sigmoid)
|
|
else error("unsupported netD model")
|
|
end
|
|
netD:apply(weights_init)
|
|
|
|
return netD
|
|
end
|
|
|
|
function defineG_encoder_decoder(input_nc, output_nc, ngf)
|
|
-- input is (nc) x 256 x 256
|
|
local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
|
|
-- input is (ngf) x 128 x 128
|
|
local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
|
|
-- input is (ngf * 2) x 64 x 64
|
|
local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
|
|
-- input is (ngf * 4) x 32 x 32
|
|
local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 16 x 16
|
|
local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 8 x 8
|
|
local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 4 x 4
|
|
local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 2 x 2
|
|
local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 1 x 1
|
|
|
|
local d1 = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
|
|
-- input is (ngf * 8) x 2 x 2
|
|
local d2 = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
|
|
-- input is (ngf * 8) x 4 x 4
|
|
local d3 = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
|
|
-- input is (ngf * 8) x 8 x 8
|
|
local d4 = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 16 x 16
|
|
local d5 = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
|
|
-- input is (ngf * 4) x 32 x 32
|
|
local d6 = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
|
|
-- input is (ngf * 2) x 64 x 64
|
|
local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
|
|
-- input is (ngf) x128 x 128
|
|
local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf, output_nc, 4, 4, 2, 2, 1, 1)
|
|
-- input is (nc) x 256 x 256
|
|
local o1 = d8 - nn.Tanh()
|
|
|
|
local netG = nn.gModule({e1},{o1})
|
|
return netG
|
|
end
|
|
|
|
|
|
function defineG_unet128(input_nc, output_nc, ngf)
|
|
local netG = nil
|
|
-- input is (nc) x 128 x 128
|
|
local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
|
|
-- input is (ngf) x 64 x 64
|
|
local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
|
|
-- input is (ngf * 2) x 32 x 32
|
|
local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
|
|
-- input is (ngf * 4) x 16 x 16
|
|
local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 8 x 8
|
|
local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 4 x 4
|
|
local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 2 x 2
|
|
local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 1 x 1
|
|
|
|
local d1_ = e7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
|
|
-- input is (ngf * 8) x 2 x 2
|
|
local d1 = {d1_,e6} - nn.JoinTable(2)
|
|
local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
|
|
-- input is (ngf * 8) x 4 x 4
|
|
local d2 = {d2_,e5} - nn.JoinTable(2)
|
|
local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
|
|
-- input is (ngf * 8) x 8 x 8
|
|
local d3 = {d3_,e4} - nn.JoinTable(2)
|
|
local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
|
|
-- input is (ngf * 8) x 16 x 16
|
|
local d4 = {d4_,e3} - nn.JoinTable(2)
|
|
local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
|
|
-- input is (ngf * 4) x 32 x 32
|
|
local d5 = {d5_,e2} - nn.JoinTable(2)
|
|
local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
|
|
-- input is (ngf * 2) x 64 x 64
|
|
local d6 = {d6_,e1} - nn.JoinTable(2)
|
|
|
|
local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
|
|
-- input is (nc) x 128 x 128
|
|
|
|
local o1 = d7 - nn.Tanh()
|
|
local netG = nn.gModule({e1},{o1})
|
|
return netG
|
|
end
|
|
|
|
|
|
function defineG_unet256(input_nc, output_nc, ngf)
|
|
local netG = nil
|
|
-- input is (nc) x 256 x 256
|
|
local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
|
|
-- input is (ngf) x 128 x 128
|
|
local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
|
|
-- input is (ngf * 2) x 64 x 64
|
|
local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
|
|
-- input is (ngf * 4) x 32 x 32
|
|
local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 16 x 16
|
|
local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 8 x 8
|
|
local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 4 x 4
|
|
local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 2 x 2
|
|
local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 1 x 1
|
|
|
|
local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
|
|
-- input is (ngf * 8) x 2 x 2
|
|
local d1 = {d1_,e7} - nn.JoinTable(2)
|
|
local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
|
|
-- input is (ngf * 8) x 4 x 4
|
|
local d2 = {d2_,e6} - nn.JoinTable(2)
|
|
local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
|
|
-- input is (ngf * 8) x 8 x 8
|
|
local d3 = {d3_,e5} - nn.JoinTable(2)
|
|
local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
|
|
-- input is (ngf * 8) x 16 x 16
|
|
local d4 = {d4_,e4} - nn.JoinTable(2)
|
|
local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
|
|
-- input is (ngf * 4) x 32 x 32
|
|
local d5 = {d5_,e3} - nn.JoinTable(2)
|
|
local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
|
|
-- input is (ngf * 2) x 64 x 64
|
|
local d6 = {d6_,e2} - nn.JoinTable(2)
|
|
local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
|
|
-- input is (ngf) x128 x 128
|
|
local d7 = {d7_,e1} - nn.JoinTable(2)
|
|
local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
|
|
-- input is (nc) x 256 x 256
|
|
|
|
local o1 = d8 - nn.Tanh()
|
|
local netG = nn.gModule({e1},{o1})
|
|
return netG
|
|
end
|
|
|
|
--------------------------------------------------------------------------------
|
|
-- Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/
|
|
--------------------------------------------------------------------------------
|
|
|
|
local function build_conv_block(dim, padding_type)
|
|
local conv_block = nn.Sequential()
|
|
local p = 0
|
|
if padding_type == 'reflect' then
|
|
conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1))
|
|
elseif padding_type == 'replicate' then
|
|
conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1))
|
|
elseif padding_type == 'zero' then
|
|
p = 1
|
|
end
|
|
conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p))
|
|
conv_block:add(normalization(dim))
|
|
conv_block:add(nn.ReLU(true))
|
|
if padding_type == 'reflect' then
|
|
conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1))
|
|
elseif padding_type == 'replicate' then
|
|
conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1))
|
|
end
|
|
conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p))
|
|
conv_block:add(normalization(dim))
|
|
return conv_block
|
|
end
|
|
|
|
|
|
local function build_res_block(dim, padding_type)
|
|
local conv_block = build_conv_block(dim, padding_type)
|
|
local res_block = nn.Sequential()
|
|
local concat = nn.ConcatTable()
|
|
concat:add(conv_block)
|
|
concat:add(nn.Identity())
|
|
|
|
res_block:add(concat):add(nn.CAddTable())
|
|
return res_block
|
|
end
|
|
|
|
function defineG_resnet_6blocks(input_nc, output_nc, ngf)
|
|
padding_type = 'reflect'
|
|
local ks = 3
|
|
local netG = nil
|
|
local f = 7
|
|
local p = (f - 1) / 2
|
|
local data = -nn.Identity()
|
|
local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
|
|
local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
|
|
local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
|
|
local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
|
|
- build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
|
|
local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true)
|
|
local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true)
|
|
local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh()
|
|
netG = nn.gModule({data},{d4})
|
|
return netG
|
|
end
|
|
|
|
function defineG_resnet_9blocks(input_nc, output_nc, ngf)
|
|
padding_type = 'reflect'
|
|
local ks = 3
|
|
local netG = nil
|
|
local f = 7
|
|
local p = (f - 1) / 2
|
|
local data = -nn.Identity()
|
|
local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
|
|
local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
|
|
local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
|
|
local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
|
|
- build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
|
|
- build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
|
|
local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true)
|
|
local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true)
|
|
local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh()
|
|
netG = nn.gModule({data},{d4})
|
|
return netG
|
|
end
|
|
|
|
function defineD_imageGAN(input_nc, ndf, use_sigmoid)
|
|
local netD = nn.Sequential()
|
|
|
|
-- input is (nc) x 256 x 256
|
|
netD:add(nn.SpatialConvolution(input_nc, ndf, 4, 4, 2, 2, 1, 1))
|
|
netD:add(nn.LeakyReLU(0.2, true))
|
|
-- state size: (ndf) x 128 x 128
|
|
netD:add(nn.SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1))
|
|
netD:add(nn.SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
|
|
-- state size: (ndf*2) x 64 x 64
|
|
netD:add(nn.SpatialConvolution(ndf * 2, ndf*4, 4, 4, 2, 2, 1, 1))
|
|
netD:add(nn.SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true))
|
|
-- state size: (ndf*4) x 32 x 32
|
|
netD:add(nn.SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1))
|
|
netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
|
|
-- state size: (ndf*8) x 16 x 16
|
|
netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1))
|
|
netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
|
|
-- state size: (ndf*8) x 8 x 8
|
|
netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1))
|
|
netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
|
|
-- state size: (ndf*8) x 4 x 4
|
|
netD:add(nn.SpatialConvolution(ndf * 8, 1, 4, 4, 2, 2, 1, 1))
|
|
-- state size: 1 x 1 x 1
|
|
if use_sigmoid then
|
|
netD:add(nn.Sigmoid())
|
|
end
|
|
|
|
return netD
|
|
end
|
|
|
|
|
|
|
|
function defineD_basic(input_nc, ndf, use_sigmoid)
|
|
n_layers = 3
|
|
return defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid)
|
|
end
|
|
|
|
-- rf=1
|
|
function defineD_pixelGAN(input_nc, ndf, use_sigmoid)
|
|
|
|
local netD = nn.Sequential()
|
|
|
|
-- input is (nc) x 256 x 256
|
|
netD:add(nn.SpatialConvolution(input_nc, ndf, 1, 1, 1, 1, 0, 0))
|
|
netD:add(nn.LeakyReLU(0.2, true))
|
|
-- state size: (ndf) x 256 x 256
|
|
netD:add(nn.SpatialConvolution(ndf, ndf * 2, 1, 1, 1, 1, 0, 0))
|
|
netD:add(normalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
|
|
-- state size: (ndf*2) x 256 x 256
|
|
netD:add(nn.SpatialConvolution(ndf * 2, 1, 1, 1, 1, 1, 0, 0))
|
|
-- state size: 1 x 256 x 256
|
|
if use_sigmoid then
|
|
netD:add(nn.Sigmoid())
|
|
-- state size: 1 x 30 x 30
|
|
end
|
|
|
|
return netD
|
|
end
|
|
|
|
-- if n=0, then use pixelGAN (rf=1)
|
|
-- else rf is 16 if n=1
|
|
-- 34 if n=2
|
|
-- 70 if n=3
|
|
-- 142 if n=4
|
|
-- 286 if n=5
|
|
-- 574 if n=6
|
|
function defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid, kw, dropout_ratio)
|
|
|
|
if dropout_ratio == nil then
|
|
dropout_ratio = 0.0
|
|
end
|
|
|
|
if kw == nil then
|
|
kw = 4
|
|
end
|
|
padw = math.ceil((kw-1)/2)
|
|
|
|
if n_layers==0 then
|
|
return defineD_pixelGAN(input_nc, ndf, use_sigmoid)
|
|
else
|
|
|
|
local netD = nn.Sequential()
|
|
|
|
-- input is (nc) x 256 x 256
|
|
-- print('input_nc', input_nc)
|
|
netD:add(nn.SpatialConvolution(input_nc, ndf, kw, kw, 2, 2, padw, padw))
|
|
netD:add(nn.LeakyReLU(0.2, true))
|
|
|
|
local nf_mult = 1
|
|
local nf_mult_prev = 1
|
|
for n = 1, n_layers-1 do
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = math.min(2^n,8)
|
|
netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 2, 2, padw,padw))
|
|
netD:add(normalization(ndf * nf_mult)):add(nn.Dropout(dropout_ratio))
|
|
netD:add(nn.LeakyReLU(0.2, true))
|
|
end
|
|
|
|
-- state size: (ndf*M) x N x N
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = math.min(2^n_layers,8)
|
|
netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 1, 1, padw, padw))
|
|
netD:add(normalization(ndf * nf_mult)):add(nn.LeakyReLU(0.2, true))
|
|
-- state size: (ndf*M*2) x (N-1) x (N-1)
|
|
netD:add(nn.SpatialConvolution(ndf * nf_mult, 1, kw, kw, 1, 1, padw,padw))
|
|
-- state size: 1 x (N-2) x (N-2)
|
|
if use_sigmoid then
|
|
netD:add(nn.Sigmoid())
|
|
end
|
|
-- state size: 1 x (N-2) x (N-2)
|
|
return netD
|
|
end
|
|
end
|