diff --git a/options/test_options.py b/options/test_options.py index 0a4b4be..a3c768b 100644 --- a/options/test_options.py +++ b/options/test_options.py @@ -15,8 +15,10 @@ class TestOptions(BaseOptions): self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") self.parser.add_argument("--Arc_path", type=str, default='models/BEST_checkpoint.tar', help="run ONNX model via TRT") - self.parser.add_argument("--pic_a_path", type=str, default='crop_224/gdg.jpg', help="people a") - self.parser.add_argument("--pic_b_path", type=str, default='crop_224/zrf.jpg', help="people b") - self.parser.add_argument("--output_path", type=str, default='output/', help="people b") + self.parser.add_argument("--pic_a_path", type=str, default='./crop_224/gdg.jpg', help="People who provide identity information") + self.parser.add_argument("--pic_b_path", type=str, default='./crop_224/zrf.jpg', help="People who provide information other than their identity") + self.parser.add_argument("--video_path", type=str, default='./demo_file/mutil_people_1080p.mp4', help="path for the video to swap") + self.parser.add_argument("--temp_path", type=str, default='./temp_results', help="path to save temporarily images") + self.parser.add_argument("--output_path", type=str, default='./output/', help="results path") self.isTrain = False diff --git a/simswaplogo/simswaplogo.png b/simswaplogo/simswaplogo.png new file mode 100644 index 0000000..806d3ac Binary files /dev/null and b/simswaplogo/simswaplogo.png differ diff --git a/test_one_image.py b/test_one_image.py index ea930dd..01c17fe 100644 --- a/test_one_image.py +++ b/test_one_image.py @@ -26,60 +26,60 @@ detransformer = transforms.Compose([ transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]), transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1]) ]) +if __name__ == '__main__': + opt = TestOptions().parse() -opt = TestOptions().parse() + start_epoch, epoch_iter = 1, 0 -start_epoch, epoch_iter = 1, 0 - -torch.nn.Module.dump_patches = True -model = create_model(opt) -model.eval() + torch.nn.Module.dump_patches = True + model = create_model(opt) + model.eval() -pic_a = opt.pic_a_path -img_a = Image.open(pic_a).convert('RGB') -img_a = transformer_Arcface(img_a) -img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) + pic_a = opt.pic_a_path + img_a = Image.open(pic_a).convert('RGB') + img_a = transformer_Arcface(img_a) + img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) -pic_b = opt.pic_b_path + pic_b = opt.pic_b_path -img_b = Image.open(pic_b).convert('RGB') -img_b = transformer(img_b) -img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2]) + img_b = Image.open(pic_b).convert('RGB') + img_b = transformer(img_b) + img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2]) -# convert numpy to tensor -img_id = img_id.cuda() -img_att = img_att.cuda() + # convert numpy to tensor + img_id = img_id.cuda() + img_att = img_att.cuda() -#create latent id -img_id_downsample = F.interpolate(img_id, scale_factor=0.5) -latend_id = model.netArc(img_id_downsample) -latend_id = latend_id.detach().to('cpu') -latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True) -latend_id = latend_id.to('cuda') + #create latent id + img_id_downsample = F.interpolate(img_id, scale_factor=0.5) + latend_id = model.netArc(img_id_downsample) + latend_id = latend_id.detach().to('cpu') + latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True) + latend_id = latend_id.to('cuda') -############## Forward Pass ###################### -img_fake = model(img_id, img_att, latend_id, latend_id, True) + ############## Forward Pass ###################### + img_fake = model(img_id, img_att, latend_id, latend_id, True) -for i in range(img_id.shape[0]): - if i == 0: - row1 = img_id[i] - row2 = img_att[i] - row3 = img_fake[i] - else: - row1 = torch.cat([row1, img_id[i]], dim=2) - row2 = torch.cat([row2, img_att[i]], dim=2) - row3 = torch.cat([row3, img_fake[i]], dim=2) + for i in range(img_id.shape[0]): + if i == 0: + row1 = img_id[i] + row2 = img_att[i] + row3 = img_fake[i] + else: + row1 = torch.cat([row1, img_id[i]], dim=2) + row2 = torch.cat([row2, img_att[i]], dim=2) + row3 = torch.cat([row3, img_fake[i]], dim=2) -#full = torch.cat([row1, row2, row3], dim=1).detach() -full = row3.detach() -full = full.permute(1, 2, 0) -output = full.to('cpu') -output = np.array(output) -output = output[..., ::-1] + #full = torch.cat([row1, row2, row3], dim=1).detach() + full = row3.detach() + full = full.permute(1, 2, 0) + output = full.to('cpu') + output = np.array(output) + output = output[..., ::-1] -output = output*255 + output = output*255 -cv2.imwrite(opt.output_path + 'result.jpg',output) \ No newline at end of file + cv2.imwrite(opt.output_path + 'result.jpg',output) \ No newline at end of file diff --git a/util/add_watermark.py b/util/add_watermark.py new file mode 100644 index 0000000..cc62eaa --- /dev/null +++ b/util/add_watermark.py @@ -0,0 +1,131 @@ +import cv2 +import numpy as np +from PIL import Image +import math +import numpy as np +# import torch +# from torchvision import transforms + +def rotate_image(image, angle, center = None, scale = 1.0): + (h, w) = image.shape[:2] + + if center is None: + center = (w / 2, h / 2) + + # Perform the rotation + M = cv2.getRotationMatrix2D(center, angle, scale) + rotated = cv2.warpAffine(image, M, (w, h)) + + return rotated + +class watermark_image: + def __init__(self, logo_path, size=0.3, oritation="DR", margin=(5,20,20,100), angle=15, rgb_weight=(0,1,1.5), input_frame_shape=None) -> None: + logo_image = cv2.imread(logo_path, cv2.IMREAD_UNCHANGED) + h,w,c = logo_image.shape + + if angle%360 != 0: + new_h = w*math.sin(angle/180*math.pi) + h*math.cos(angle/180*math.pi) + pad_h = int((new_h-h)//2) + + padding = np.zeros((pad_h, w, c), dtype=np.uint8) + logo_image = cv2.vconcat([logo_image, padding]) + logo_image = cv2.vconcat([padding, logo_image]) + + logo_image = rotate_image(logo_image, angle) + print(logo_image.shape) + self.logo_image = logo_image + + if self.logo_image.shape[2] < 4: + print("No alpha channel found!") + self.logo_image = self.__addAlpha__(self.logo_image) #add alpha channel + self.size = size + self.oritation = oritation + self.margin = margin + self.ori_shape = self.logo_image.shape + self.resized = False + self.rgb_weight = rgb_weight + + self.logo_image[:, :, 2] = self.logo_image[:, :, 2]*self.rgb_weight[0] + self.logo_image[:, :, 1] = self.logo_image[:, :, 1]*self.rgb_weight[1] + self.logo_image[:, :, 0] = self.logo_image[:, :, 0]*self.rgb_weight[2] + + if input_frame_shape is not None: + if input_frame_shape[0] > input_frame_shape[1]: + logo_h = input_frame_shape[0] * self.size + ratio = logo_h / self.ori_shape[0] + logo_w = int(ratio * self.ori_shape[1]) + logo_h = int(logo_h) + else: + logo_w = input_frame_shape[1] * self.size + ratio = logo_w / self.ori_shape[1] + logo_h = int(ratio * self.ori_shape[0]) + logo_w = int(logo_w) + + size = (logo_w, logo_h) + self.logo_image = cv2.resize(self.logo_image, size, interpolation = cv2.INTER_CUBIC) + self.resized = True + if oritation == "UL": + self.coor_h = self.margin[1] + self.coor_w = self.margin[0] + elif oritation == "UR": + self.coor_h = self.margin[1] + self.coor_w = input_frame_shape[1] - (logo_w + self.margin[2]) + elif oritation == "DL": + self.coor_h = input_frame_shape[0] - (logo_h + self.margin[1]) + self.coor_w = self.margin[0] + else: + self.coor_h = input_frame_shape[0] - (logo_h + self.margin[1]) + self.coor_w = input_frame_shape[1] - (logo_w + self.margin[2]) + self.logo_w = logo_w + self.logo_h = logo_h + self.mask = self.logo_image[:,:,3] + self.mask = cv2.bitwise_not(self.mask//255) + + + + + def apply_frames(self, frame): + if not self.resized: + shape = frame.shape + if shape[0] > shape[1]: + logo_h = shape[0] * self.size + ratio = logo_h / self.ori_shape[0] + logo_w = int(ratio * self.ori_shape[1]) + logo_h = int(logo_h) + else: + logo_w = shape[1] * self.size + ratio = logo_w / self.ori_shape[1] + logo_h = int(ratio * self.ori_shape[0]) + logo_w = int(logo_w) + + size = (logo_w, logo_h) + self.logo_image = cv2.resize(self.logo_image, size, interpolation = cv2.INTER_CUBIC) + self.resized = True + if self.oritation == "UL": + self.coor_h = self.margin[1] + self.coor_w = self.margin[0] + elif self.oritation == "UR": + self.coor_h = self.margin[1] + self.coor_w = shape[1] - (logo_w + self.margin[2]) + elif self.oritation == "DL": + self.coor_h = shape[0] - (logo_h + self.margin[1]) + self.coor_w = self.margin[0] + else: + self.coor_h = shape[0] - (logo_h + self.margin[1]) + self.coor_w = shape[1] - (logo_w + self.margin[2]) + self.logo_w = logo_w + self.logo_h = logo_h + self.mask = self.logo_image[:,:,3] + self.mask = cv2.bitwise_not(self.mask//255) + + + original_frame = frame[self.coor_h:(self.coor_h+self.logo_h), self.coor_w:(self.coor_w+self.logo_w),:] + blending_logo = cv2.add(self.logo_image[:,:,0:3],original_frame,mask = self.mask) + frame[self.coor_h:(self.coor_h+self.logo_h), self.coor_w:(self.coor_w+self.logo_w),:] = blending_logo + return frame + + def __addAlpha__(self, image): + shape = image.shape + alpha_channel = np.ones((shape[0],shape[1],1),np.uint8)*255 + return np.concatenate((image,alpha_channel),2) +