53 lines
1.9 KiB
Python
53 lines
1.9 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################
|
|
# File: filter.py
|
|
# Created Date: Wednesday April 13th 2022
|
|
# Author: Chen Xuanhong
|
|
# Email: chenxuanhongzju@outlook.com
|
|
# Last Modified: Wednesday, 13th April 2022 3:49:23 pm
|
|
# Modified By: Chen Xuanhong
|
|
# Copyright (c) 2022 Shanghai Jiao Tong University
|
|
#############################################################
|
|
|
|
import cv2
|
|
import torch.nn as nn
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
|
|
class HighPass(nn.Module):
|
|
def __init__(self, w_hpf, device):
|
|
super(HighPass, self).__init__()
|
|
self.filter = torch.tensor([[-1, -1, -1],
|
|
[-1, 8., -1],
|
|
[-1, -1, -1]]).to(device) / w_hpf
|
|
|
|
def forward(self, x):
|
|
filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1)
|
|
return F.conv2d(x, filter, padding=1, groups=x.size(1))
|
|
|
|
if __name__ == "__main__":
|
|
transformer_Arcface = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
|
imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
|
|
|
img = "G:/swap_data/ID/2.jpg"
|
|
attr = cv2.imread(img)
|
|
attr = Image.fromarray(cv2.cvtColor(attr,cv2.COLOR_BGR2RGB))
|
|
attr = transformer_Arcface(attr).unsqueeze(0)
|
|
results = HighPass(0.5,torch.device("cpu"))(attr)
|
|
|
|
results = results * imagenet_std + imagenet_mean
|
|
results = results.cpu().permute(0,2,3,1)[0,...]
|
|
results = results.numpy()
|
|
results = np.clip(results,0.0,1.0) * 255
|
|
results = cv2.cvtColor(results,cv2.COLOR_RGB2BGR)
|
|
cv2.imwrite("filter_results2.png",results)
|