mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
reduce layer
This commit is contained in:
@@ -12,7 +12,7 @@ class MaskNet(nn.Module):
|
||||
self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters')
|
||||
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters)
|
||||
self.up_samples = self.create_up_samples(self.config_num_filters)
|
||||
self.bottleneck = BottleNeck(self.config_num_filters * 4)
|
||||
self.bottleneck = BottleNeck(self.config_num_filters * 2)
|
||||
self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
@@ -22,14 +22,12 @@ class MaskNet(nn.Module):
|
||||
[
|
||||
DownSample(input_channels, num_filters),
|
||||
DownSample(num_filters, num_filters * 2),
|
||||
DownSample(num_filters * 2, num_filters * 4)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def create_up_samples(num_filters : int) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
UpSample(num_filters * 4, num_filters * 2),
|
||||
UpSample(num_filters * 2, num_filters),
|
||||
UpSample(num_filters, num_filters)
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user