diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index be9c236..6f757ec 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -13,7 +13,7 @@ CONFIG.read('config.ini') class Generator(nn.Module): def __init__(self) -> None: super(Generator, self).__init__() - encoder_type = CONFIG.getint('training.model.generator', 'encoder_type') + encoder_type = CONFIG.get('training.model.generator', 'encoder_type') id_channels = CONFIG.getint('training.model.generator', 'id_channels') num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 127fdef..5f9cb05 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -59,7 +59,7 @@ class UNet(nn.Module): class UNetPro(UNet): def __init__(self) -> None: - super(UNetPro, self).__init__() + super(UNet, self).__init__() self.resnet = models.resnet34(pretrained = True) self.down_samples = self.create_down_samples(self) self.up_samples = self.create_up_samples()