mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Some code review
This commit is contained in:
@@ -16,23 +16,23 @@ class MaskNet(nn.Module):
|
||||
self.conv = nn.Conv2d(self.config_base_channels, self.config_output_channels, kernel_size = 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def create_down_samples(self, input_channels : int, base_channels: int) -> nn.ModuleList:
|
||||
down_samples = nn.ModuleList(
|
||||
@staticmethod
|
||||
def create_down_samples(input_channels : int, base_channels : int) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
DownSample(input_channels, base_channels),
|
||||
DownSample(base_channels, base_channels * 2),
|
||||
DownSample(base_channels * 2, base_channels * 4)
|
||||
])
|
||||
return down_samples
|
||||
|
||||
def create_up_samples(self, base_channels : int) -> nn.ModuleList:
|
||||
down_samples = nn.ModuleList(
|
||||
@staticmethod
|
||||
def create_up_samples(base_channels : int) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
UpSample(base_channels * 4, base_channels * 2),
|
||||
UpSample(base_channels * 2, base_channels),
|
||||
UpSample(base_channels, base_channels)
|
||||
])
|
||||
return down_samples
|
||||
|
||||
def forward(self, target_tensor : Tensor, target_attribute : Tensor) -> Tensor:
|
||||
output_tensor = torch.cat([ target_tensor, target_attribute ], dim = 1)
|
||||
@@ -51,14 +51,14 @@ class MaskNet(nn.Module):
|
||||
|
||||
|
||||
class BottleNeck(nn.Module):
|
||||
def __init__(self, channels : int):
|
||||
def __init__(self, base_channels : int):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.Conv2d(base_channels, base_channels, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(base_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(channels, channels, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.Conv2d(base_channels, base_channels, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(base_channels),
|
||||
nn.ReLU(inplace = True)
|
||||
)
|
||||
self.relu = nn.ReLU(inplace = True)
|
||||
|
||||
Reference in New Issue
Block a user