diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 4873913..56f8632 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -3,10 +3,9 @@ import os import warnings from typing import Tuple, cast -import lightning import torch import torchvision -from lightning import Trainer +from lightning import Trainer, LightningModule from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import Tensor, nn @@ -26,7 +25,7 @@ CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -class FaceSwapperTrainer(lightning.LightningModule): +class FaceSwapperTrainer(LightningModule): def __init__(self) -> None: super().__init__() embedder_path = CONFIG.get('training.model', 'embedder_path')