fixes and optimizations
This commit is contained in:
@@ -58,13 +58,12 @@ class AVATARModel(ModelBase):
|
||||
self.D = modelify(AVATARModel.Discriminator() ) (Input(df_bgr_shape))
|
||||
self.C = modelify(AVATARModel.ResNet (9, n_blocks=6, ngf=128, use_dropout=False))( Input(res_bgr_t_shape))
|
||||
|
||||
if self.is_first_run():
|
||||
conv_weights_list = []
|
||||
self.CA_conv_weights_list = []
|
||||
if self.is_first_run():
|
||||
for model, _ in self.get_model_filename_list():
|
||||
for layer in model.layers:
|
||||
if type(layer) == keras.layers.Conv2D:
|
||||
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||
CAInitializerMP ( conv_weights_list )
|
||||
self.CA_conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||
|
||||
if not self.is_first_run():
|
||||
self.load_weights_safe( self.get_model_filename_list() )
|
||||
@@ -247,7 +246,14 @@ class AVATARModel(ModelBase):
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( self.get_model_filename_list() )
|
||||
|
||||
|
||||
#override
|
||||
def on_success_train_one_iter(self):
|
||||
if len(self.CA_conv_weights_list) != 0:
|
||||
exec(nnlib.import_all(), locals(), globals())
|
||||
CAInitializerMP ( self.CA_conv_weights_list )
|
||||
self.CA_conv_weights_list = []
|
||||
|
||||
#override
|
||||
def onTrainOneIter(self, generators_samples, generators_list):
|
||||
warped_src64, src64, src64m = generators_samples[0]
|
||||
|
||||
Reference in New Issue
Block a user