diff --git a/tests/conftest.py b/tests/conftest.py index ea4f1503f..1802c9844 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,6 @@ from src.text_generation.common.constants import Constants from src.text_generation.services.guardrails.generated_text_guardrail_service import GeneratedTextGuardrailService from src.text_generation.services.guardrails.reflexion_security_guidelines_service import ReflexionSecurityGuardrailsService from src.text_generation.services.guidelines.chain_of_thought_security_guidelines_service import ChainOfThoughtSecurityGuidelinesService -from src.text_generation.services.guidelines.guidelines_factory import GuidelinesFactory from src.text_generation.services.guidelines.rag_context_security_guidelines_configuration_builder import RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder from src.text_generation.services.guidelines.rag_context_security_guidelines_service import RagContextSecurityGuidelinesService from src.text_generation.services.guidelines.rag_plus_cot_security_guidelines_service import RagPlusCotSecurityGuidelinesService @@ -97,6 +96,10 @@ def setup_test_environment(): def constants(): return Constants() +@pytest.fixture(scope="session") +def foundation_model(): + return TextGenerationFoundationModel() + @pytest.fixture(scope="session") def embedding_model(): return EmbeddingModel() @@ -128,6 +131,48 @@ def rag_config_builder( def llm_configuration_introspection_service(): return LLMConfigurationIntrospectionService() +@pytest.fixture(scope="session") +def rag_context_guidelines( + foundation_model, + response_processing_service, + prompt_template_service, + llm_configuration_introspection_service, + rag_config_builder): + return RagContextSecurityGuidelinesService( + foundation_model=foundation_model, + response_processing_service=response_processing_service, + prompt_template_service=prompt_template_service, + llm_configuration_introspection_service=llm_configuration_introspection_service, + config_builder=rag_config_builder + ) + +@pytest.fixture(scope="session") +def chain_of_thought_guidelines( + foundation_model, + response_processing_service, + llm_configuration_introspection_service, + prompt_template_service): + return ChainOfThoughtSecurityGuidelinesService( + foundation_model=foundation_model, + response_processing_service=response_processing_service, + llm_configuration_introspection_service=llm_configuration_introspection_service, + prompt_template_service=prompt_template_service + ) + +@pytest.fixture(scope="session") +def rag_plus_cot_guidelines( + foundation_model, + response_processing_service, + prompt_template_service, + llm_configuration_introspection_service, + rag_config_builder): + return RagPlusCotSecurityGuidelinesService( + foundation_model=foundation_model, + response_processing_service=response_processing_service, + prompt_template_service=prompt_template_service, + llm_configuration_introspection_service=llm_configuration_introspection_service, + config_builder=rag_config_builder + ) @pytest.fixture(scope="session") def prompt_injection_example_service(prompt_injection_example_repository): @@ -154,36 +199,30 @@ def reflexion_guardrails( def response_processing_service(): return ResponseProcessingService() - @pytest.fixture(scope="session") -def guidelines_factory(): - return GuidelinesFactory() - -@pytest.fixture(scope="session") -def guidelines_config_builder( - embedding_model, - prompt_template_service, - prompt_injection_example_repository): - return RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder( - embedding_model=embedding_model, - prompt_template_service=prompt_template_service, - prompt_injection_example_repository=prompt_injection_example_repository - ) +def llm_configuration_introspection_service(): + return LLMConfigurationIntrospectionService() @pytest.fixture(scope="session") def text_generation_completion_service( + foundation_model, response_processing_service, prompt_template_service, - guidelines_factory, - guidelines_config_builder, + chain_of_thought_guidelines, + rag_context_guidelines, + rag_plus_cot_guidelines, + reflexion_guardrails, semantic_similarity_service, prompt_injection_example_service, llm_configuration_introspection_service): return TextGenerationCompletionService( + foundation_model=foundation_model, response_processing_service=response_processing_service, prompt_template_service=prompt_template_service, - guidelines_factory=guidelines_factory, - guidelines_config_builder=guidelines_config_builder, + chain_of_thought_guidelines=chain_of_thought_guidelines, + rag_context_guidelines=rag_context_guidelines, + rag_plus_cot_guidelines=rag_plus_cot_guidelines, + reflexion_guardrails=reflexion_guardrails, semantic_similarity_service=semantic_similarity_service, prompt_injection_example_service=prompt_injection_example_service, llm_configuration_introspection_service=llm_configuration_introspection_service