fix
This commit is contained in:
@@ -644,14 +644,14 @@ class SAEHDModel(ModelBase):
|
||||
self.D_train = D_train
|
||||
|
||||
if gan_power != 0:
|
||||
def D_src_dst_train(warped_src, target_src, target_srcm, \
|
||||
warped_dst, target_dst, target_dstm):
|
||||
def D_src_dst_train(warped_src, target_src, target_srcm_all, \
|
||||
warped_dst, target_dst, target_dstm_all:
|
||||
nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src,
|
||||
self.target_src :target_src,
|
||||
self.target_srcm:target_srcm,
|
||||
self.target_srcm_all:target_srcm_all,
|
||||
self.warped_dst :warped_dst,
|
||||
self.target_dst :target_dst,
|
||||
self.target_dstm:target_dstm})
|
||||
self.target_dstm_all:target_dstm_all})
|
||||
self.D_src_dst_train = D_src_dst_train
|
||||
|
||||
if learn_mask:
|
||||
|
||||
Reference in New Issue
Block a user