diff --git a/tests/conftest.py b/tests/conftest.py index 80907e8f2..fc69ab2ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ from tenacity import retry, stop_after_delay from src.text_generation import config from src.text_generation.adapters.llm.embedding_model import EmbeddingModel from src.text_generation.adapters.llm.language_model import LanguageModel +from src.text_generation.adapters.llm.language_model_with_rag import LanguageModelWithRag from src.text_generation.services.language_models.text_generation_response_service import TextGenerationResponseService from src.text_generation.services.similarity_scoring.generated_text_guardrail_service import GeneratedTextGuardrailService @@ -50,6 +51,10 @@ def language_model(): def embedding_model(): return EmbeddingModel() +@pytest.fixture(scope="session") +def language_model_with_rag(embedding_model): + return LanguageModelWithRag(embeddings=embedding_model) + @pytest.fixture(scope="session") def text_generation_response_service(language_model): return TextGenerationResponseService(language_model)