Files
SimSwapPlus/filter.py
T
chenxuanhong 29d8914c0a update
2022-04-24 15:44:47 +08:00

53 lines
1.9 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: filter.py
# Created Date: Wednesday April 13th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 13th April 2022 3:49:23 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import cv2
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision import transforms
class HighPass(nn.Module):
def __init__(self, w_hpf, device):
super(HighPass, self).__init__()
self.filter = torch.tensor([[-1, -1, -1],
[-1, 8., -1],
[-1, -1, -1]]).to(device) / w_hpf
def forward(self, x):
filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1)
return F.conv2d(x, filter, padding=1, groups=x.size(1))
if __name__ == "__main__":
transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
img = "G:/swap_data/ID/2.jpg"
attr = cv2.imread(img)
attr = Image.fromarray(cv2.cvtColor(attr,cv2.COLOR_BGR2RGB))
attr = transformer_Arcface(attr).unsqueeze(0)
results = HighPass(0.5,torch.device("cpu"))(attr)
results = results * imagenet_std + imagenet_mean
results = results.cpu().permute(0,2,3,1)[0,...]
results = results.numpy()
results = np.clip(results,0.0,1.0) * 255
results = cv2.cvtColor(results,cv2.COLOR_RGB2BGR)
cv2.imwrite("filter_results2.png",results)