diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index b52057a..4198673 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -3,7 +3,7 @@ import os from typing import Tuple import torch -from lightning import Trainer, LightningModule +from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import Tensor, nn @@ -12,7 +12,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import StaticDataset from .models.embedding_converter import EmbeddingConverter -from .types import Batch, Embedding, OptimizerSet, Config +from .types import Batch, Config, Embedding, OptimizerSet CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 56f8632..a99f89f 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -5,7 +5,7 @@ from typing import Tuple, cast import torch import torchvision -from lightning import Trainer, LightningModule +from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import Tensor, nn