diff --git a/.github/workflows/tests_ci.yml b/.github/workflows/tests_ci.yml index 8641eb7..b40825b 100644 --- a/.github/workflows/tests_ci.yml +++ b/.github/workflows/tests_ci.yml @@ -42,7 +42,7 @@ jobs: ANTHROPIC_API_KEY: ${{ secrets.INVARIANT_TESTING_ANTHROPIC_KEY }} GEMINI_API_KEY: ${{ secrets.INVARIANT_TESTING_GEMINI_KEY }} INVARIANT_API_KEY: ${{ secrets.INVARIANT_TESTING_GUARDRAILS_KEY }} - run: ./run.sh integration-tests -s -vv + run: ./run.sh integration-tests gemini/test_generate_content_without_tool_calls.py::test_generate_content_with_image -s -vv continue-on-error: true - name: Check test results diff --git a/tests/integration/gemini/test_generate_content_without_tool_calls.py b/tests/integration/gemini/test_generate_content_without_tool_calls.py index cfb4242..0095b64 100644 --- a/tests/integration/gemini/test_generate_content_without_tool_calls.py +++ b/tests/integration/gemini/test_generate_content_without_tool_calls.py @@ -16,6 +16,7 @@ import pytest import PIL.Image import requests from google import genai +from google.genai import types # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -103,59 +104,65 @@ async def test_generate_content( @pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set") @pytest.mark.parametrize("push_to_explorer", [True, False]) -@pytest.mark.skip(reason="Skipping this test: 500 error from Gemini API") async def test_generate_content_with_image( explorer_api_url, gateway_url, push_to_explorer ): """Test that generate content gateway calls work with image.""" dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" - client = get_gemini_client(gateway_url, push_to_explorer, dataset_name) - - - image_path = Path(__file__).parent.parent / "resources" / "images" / "two-cats.png" - image = PIL.Image.open(image_path) - - chat_response = client.models.generate_content( - model="gemini-2.0-flash", - contents=["How many cats are there in this image?", image], - config={"maxOutputTokens": 100}, + client = get_gemini_client( + gateway_url, push_to_explorer, dataset_name, api_version="v1" ) - assert ( - "TWO" in chat_response.candidates[0].content.parts[0].text.upper() - or "2" in chat_response.candidates[0].content.parts[0].text - ) - - if push_to_explorer: - # Wait for the trace to be saved - # This is needed because the trace is saved asynchronously - time.sleep(2) - # Fetch the trace ids for the dataset - traces_response = requests.get( - f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", - timeout=5, + image_path = Path(__file__).parent.parent / "resources" / "images" / "two-cats.jpg" + with open(image_path.absolute(), "rb") as f: + image_bytes = f.read() + chat_response = client.models.generate_content( + model="gemini-2.0-flash", + contents=[ + "How many cats are there in this image?", + types.Part.from_bytes( + data=image_bytes, + mime_type="image/jpg", + ), + ], + config={"maxOutputTokens": 100}, ) - traces = traces_response.json() - assert len(traces) == 1 - trace_id = traces[0]["id"] - # Fetch the trace - trace_response = requests.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5 + assert ( + "TWO" in chat_response.candidates[0].content.parts[0].text.upper() + or "2" in chat_response.candidates[0].content.parts[0].text ) - trace = trace_response.json() - # Verify the trace messages - assert len(trace["messages"]) == 2 - assert trace["messages"][0]["role"] == "user" - assert trace["messages"][0]["content"][0] == { - "type": "text", - "text": "How many cats are there in this image?", - } - assert trace["messages"][0]["content"][1]["type"] == "image_url" - assert trace["messages"][1] == { - "role": "assistant", - "content": chat_response.candidates[0].content.parts[0].text, - } + + if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5 + ) + trace = trace_response.json() + # Verify the trace messages + assert len(trace["messages"]) == 2 + assert trace["messages"][0]["role"] == "user" + assert trace["messages"][0]["content"][0] == { + "type": "text", + "text": "How many cats are there in this image?", + } + assert trace["messages"][0]["content"][1]["type"] == "image_url" + assert trace["messages"][1] == { + "role": "assistant", + "content": chat_response.candidates[0].content.parts[0].text, + } @pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set") @@ -213,7 +220,9 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header( assert trace["messages"] == [ { "role": "user", - "content": [{"text": "What is the capital of Denmark?", "type": "text"}], + "content": [ + {"text": "What is the capital of Denmark?", "type": "text"} + ], }, { "role": "assistant", diff --git a/tests/integration/resources/images/two-cats.jpg b/tests/integration/resources/images/two-cats.jpg new file mode 100644 index 0000000..efe65ea Binary files /dev/null and b/tests/integration/resources/images/two-cats.jpg differ diff --git a/tests/integration/utils.py b/tests/integration/utils.py index 3c6d3bd..246d272 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -43,7 +43,10 @@ def get_anthropic_client( def get_gemini_client( - gateway_url: str, push_to_explorer: bool, dataset_name: str + gateway_url: str, + push_to_explorer: bool, + dataset_name: str, + api_version: str = "v1beta", ) -> genai.Client: """Create a Gemini client for integration tests.""" return genai.Client( @@ -55,6 +58,7 @@ def get_gemini_client( "headers": { "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, + "api_version": api_version, }, )