fix DFLJPG,
SAE: added "rare sample booster" SAE: pixel loss replaced to smooth transition from DSSIM to PixelLoss in 15k epochs by default
This commit is contained in:
+8
-9
@@ -199,7 +199,7 @@ class ModelBase(object):
|
||||
pass
|
||||
|
||||
#overridable
|
||||
def onTrainOneEpoch(self, sample):
|
||||
def onTrainOneEpoch(self, sample, generator_list):
|
||||
#train your keras models here
|
||||
|
||||
#return array of losses
|
||||
@@ -293,7 +293,8 @@ class ModelBase(object):
|
||||
images = []
|
||||
for generator in self.generator_list:
|
||||
for i,batch in enumerate(next(generator)):
|
||||
images.append( batch[0] )
|
||||
if len(batch.shape) == 4:
|
||||
images.append( batch[0] )
|
||||
|
||||
return image_utils.equalize_and_stack_square (images)
|
||||
|
||||
@@ -305,14 +306,12 @@ class ModelBase(object):
|
||||
supressor = std_utils.suppress_stdout_stderr()
|
||||
supressor.__enter__()
|
||||
|
||||
self.last_sample = self.generate_next_sample()
|
||||
|
||||
epoch_time = time.time()
|
||||
|
||||
losses = self.onTrainOneEpoch(self.last_sample)
|
||||
|
||||
sample = self.generate_next_sample()
|
||||
epoch_time = time.time()
|
||||
losses = self.onTrainOneEpoch(sample, self.generator_list)
|
||||
epoch_time = time.time() - epoch_time
|
||||
|
||||
self.last_sample = sample
|
||||
|
||||
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
||||
|
||||
if self.supress_std_once:
|
||||
|
||||
Reference in New Issue
Block a user