fix for plaidml
This commit is contained in:
@@ -483,7 +483,8 @@ class SAEv2Model(ModelBase):
|
||||
dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) )
|
||||
dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
|
||||
|
||||
opt_D_loss = []
|
||||
G_loss = src_loss+dst_loss
|
||||
|
||||
if self.true_face_training:
|
||||
def DLoss(labels,logits):
|
||||
return K.mean(K.binary_crossentropy(labels,logits))
|
||||
@@ -493,7 +494,7 @@ class SAEv2Model(ModelBase):
|
||||
src_code_d_zeros = K.zeros_like(src_code_d)
|
||||
dst_code_d = self.dis( self.model.dst_code )
|
||||
dst_code_d_ones = K.ones_like(dst_code_d)
|
||||
opt_D_loss = [ 0.01*DLoss(src_code_d_ones, src_code_d) ]
|
||||
G_loss += 0.01*DLoss(src_code_d_ones, src_code_d)
|
||||
|
||||
loss_D = (DLoss(dst_code_d_ones , dst_code_d) + \
|
||||
DLoss(src_code_d_zeros, src_code_d) ) * 0.5
|
||||
@@ -502,7 +503,7 @@ class SAEv2Model(ModelBase):
|
||||
|
||||
self.src_dst_train = K.function ([self.model.warped_src, self.model.warped_dst, self.model.target_src, self.model.target_srcm, self.model.target_dst, self.model.target_dstm],
|
||||
[src_loss,dst_loss],
|
||||
self.src_dst_opt.get_updates( [src_loss+dst_loss]+opt_D_loss, self.model.src_dst_trainable_weights)
|
||||
self.src_dst_opt.get_updates( G_loss, self.model.src_dst_trainable_weights)
|
||||
)
|
||||
|
||||
if self.options['learn_mask']:
|
||||
|
||||
+3
-2
@@ -232,6 +232,7 @@ class device:
|
||||
|
||||
plaidML_build = os.environ.get("DFL_PLAIDML_BUILD", "0") == "1"
|
||||
plaidML_devices = None
|
||||
plaidML_devices_count = 0
|
||||
cuda_devices = None
|
||||
|
||||
if plaidML_build:
|
||||
@@ -253,8 +254,8 @@ if plaidML_build:
|
||||
ctx.shutdown()
|
||||
except:
|
||||
pass
|
||||
|
||||
if len(plaidML_devices) != 0:
|
||||
plaidML_devices_count = len(plaidML_devices)
|
||||
if plaidML_devices_count != 0:
|
||||
device.backend = "plaidML"
|
||||
else:
|
||||
if cuda_devices is None:
|
||||
|
||||
Reference in New Issue
Block a user