upd liae loss
This commit is contained in:
@@ -390,23 +390,23 @@ class SAEv2Model(ModelBase):
|
||||
self.target_srcm, self.target_dstm = Input(mask_shape), Input(mask_shape)
|
||||
|
||||
warped_src_code = self.encoder (self.warped_src)
|
||||
self.src_code = warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
||||
src_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
||||
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
||||
self.src_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
||||
|
||||
warped_dst_code = self.encoder (self.warped_dst)
|
||||
self.dst_code = warped_dst_inter_B_code = self.inter_B (warped_dst_code)
|
||||
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
|
||||
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
|
||||
dst_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
|
||||
self.dst_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
|
||||
|
||||
src_dst_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
|
||||
|
||||
self.pred_src_src = self.decoder(src_code)
|
||||
self.pred_dst_dst = self.decoder(dst_code)
|
||||
self.pred_src_src = self.decoder(self.src_code)
|
||||
self.pred_dst_dst = self.decoder(self.dst_code)
|
||||
self.pred_src_dst = self.decoder(src_dst_code)
|
||||
|
||||
if learn_mask:
|
||||
self.pred_src_srcm = self.decoderm(src_code)
|
||||
self.pred_dst_dstm = self.decoderm(dst_code)
|
||||
self.pred_src_srcm = self.decoderm(self.src_code)
|
||||
self.pred_dst_dstm = self.decoderm(self.dst_code)
|
||||
self.pred_src_dstm = self.decoderm(src_dst_code)
|
||||
|
||||
def get_model_filename_list(self, exclude_for_pretrain=False):
|
||||
|
||||
Reference in New Issue
Block a user