mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Partial use Resnet34 as a DownSample replacement
This commit is contained in:
@@ -2,23 +2,32 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torchvision import models
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(UNet, self).__init__()
|
||||
self.down_samples = self.create_down_samples()
|
||||
self.resnet = models.resnet34(pretrained = True)
|
||||
self.down_samples = self.create_down_samples(self)
|
||||
self.up_samples = self.create_up_samples()
|
||||
|
||||
@staticmethod
|
||||
def create_down_samples() -> nn.ModuleList:
|
||||
def create_down_samples(self) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
DownSample(3, 32),
|
||||
nn.Sequential(
|
||||
self.resnet.conv1,
|
||||
self.resnet.bn1,
|
||||
self.resnet.relu,
|
||||
nn.Conv2d(64, 32, kernel_size = 1, bias = False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.LeakyReLU(0.1, inplace = True)
|
||||
),
|
||||
DownSample(32, 64),
|
||||
DownSample(64, 128),
|
||||
DownSample(128, 256),
|
||||
DownSample(256, 512),
|
||||
self.resnet.layer2,
|
||||
self.resnet.layer3,
|
||||
self.resnet.layer4,
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 1024)
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user