diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index f9ee804be..15398ebb7 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -74,8 +74,8 @@ def _extract_generation_config_params( ) -> dict[str, Any]: """Extract valid parameters from a GenerationConfig object. - This function extracts parameters from the internal _raw_generation_config - protobuf and returns them as a dict that can be passed to GenerationConfig(). + Prefer ``GenerationConfig.to_dict()`` so fields only available via + ``from_dict`` (e.g. ``thinking_config``) and zero-valued numerics are kept. Args: config: A GenerationConfig object @@ -86,12 +86,22 @@ def _extract_generation_config_params( """ from vertexai.generative_models import GenerationConfig + if config is None: + return {} + + if hasattr(config, "to_dict"): + params = dict(config.to_dict()) + if exclude_schema: + for key in _GENERATION_CONFIG_SCHEMA_PARAMS: + params.pop(key, None) + return params + if not hasattr(config, "_raw_generation_config"): return {} raw = config._raw_generation_config - # Get valid params from GenerationConfig signature + # Fallback for test doubles without to_dict(). sig = inspect.signature(GenerationConfig.__init__) valid_params = { name @@ -100,14 +110,24 @@ def _extract_generation_config_params( and (not exclude_schema or name not in _GENERATION_CONFIG_SCHEMA_PARAMS) } - preserved = {} + preserved: dict[str, Any] = {} for param in valid_params: val = getattr(raw, param, None) - if val: # Only include non-empty values - # Convert repeated fields (like stop_sequences) to lists - if hasattr(val, "__iter__") and not isinstance(val, (str, bytes, dict)): - val = list(val) - preserved[param] = val + if val is None: + continue + if isinstance(val, (str, bytes)) and not val: + continue + if hasattr(val, "__iter__") and not isinstance(val, (str, bytes, dict)): + val = list(val) + if not val: + continue + preserved[param] = val + + thinking_config = getattr(raw, "thinking_config", None) + if thinking_config is not None and str(thinking_config).strip(): + preserved["thinking_config"] = { + "thinking_budget": thinking_config.thinking_budget + } return preserved @@ -548,7 +568,10 @@ def _get_call_params_v2( # Apply kwargs (they override constructor values but preserve schema) params.update(kwargs) - options["generation_config"] = GenerationConfig(**params) + if "thinking_config" in params: + options["generation_config"] = GenerationConfig.from_dict(params) + else: + options["generation_config"] = GenerationConfig(**params) options["contents"] = contents return options diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 9ce1ad541..90dbf8513 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -26,12 +26,16 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.types import ToolCallResponse -from neo4j_graphrag.llm.vertexai_llm import VertexAILLM +from neo4j_graphrag.llm.vertexai_llm import ( + VertexAILLM, + _extract_generation_config_params, +) from neo4j_graphrag.tool import Tool from neo4j_graphrag.types import LLMMessage from neo4j_graphrag.utils.rate_limit import NoOpRateLimitHandler from pydantic import BaseModel, ConfigDict +from vertexai.generative_models import GenerationConfig @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None) @@ -647,3 +651,86 @@ async def test_vertexai_ainvoke_v2_rate_limit_handler_called( assert response.content == "Hi there!" spy_handler.handle_async.assert_called_once() + + +def test_extract_generation_config_params_preserves_thinking_config() -> None: + """thinking_config and zero-valued temperature survive extraction.""" + config = GenerationConfig.from_dict( + {"temperature": 0.0, "thinking_config": {"thinking_budget": 0}} + ) + + params = _extract_generation_config_params(config) + + assert params["temperature"] == 0.0 + assert params["thinking_config"] == {"thinking_budget": 0} + + +def test_extract_generation_config_params_excludes_schema_when_requested() -> None: + config = GenerationConfig.from_dict( + { + "response_mime_type": "application/json", + "response_schema": {"type": "object"}, + "thinking_config": {"thinking_budget": 512}, + } + ) + + params = _extract_generation_config_params(config, exclude_schema=True) + + assert "response_schema" not in params + assert "response_mime_type" not in params + assert params["thinking_config"] == {"thinking_budget": 512} + + +@pytest.mark.parametrize("thinking_budget", [0, 512]) +@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +def test_vertexai_invoke_v2_preserves_thinking_config_with_response_format( + GenerativeModelMock: MagicMock, + thinking_budget: int, +) -> None: + """Structured-output v2 invoke must keep thinking_config from constructor config.""" + messages: List[LLMMessage] = [{"role": "user", "content": "Extract person info"}] + mock_response = Mock() + mock_response.text = '{"name": "John", "age": 30}' + mock_response.usage_metadata = None + mock_model = GenerativeModelMock.return_value + mock_model.generate_content.return_value = mock_response + + generation_config = GenerationConfig.from_dict( + { + "temperature": 0.0, + "thinking_config": {"thinking_budget": thinking_budget}, + } + ) + llm = VertexAILLM( + model_name="gemini-2.5-flash", + generation_config=generation_config, + ) + llm.invoke(messages, response_format=_TestModelForVertexAI) + + call_args = mock_model.generate_content.call_args.kwargs + rebuilt = call_args["generation_config"] + assert rebuilt.to_dict()["thinking_config"] == {"thinking_budget": thinking_budget} + assert rebuilt.to_dict()["temperature"] == 0.0 + + +@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +def test_vertexai_invoke_v2_preserves_temperature_zero_without_thinking_config( + GenerativeModelMock: MagicMock, +) -> None: + """temperature=0.0 must survive structured-output rebuild without thinking_config.""" + messages: List[LLMMessage] = [{"role": "user", "content": "Test"}] + mock_response = Mock() + mock_response.text = '{"result": "success"}' + mock_response.usage_metadata = None + mock_model = GenerativeModelMock.return_value + mock_model.generate_content.return_value = mock_response + + llm = VertexAILLM( + model_name="gemini-2.5-flash", + generation_config=GenerationConfig(temperature=0.0), + ) + llm.invoke(messages, response_format=_TEST_JSON_SCHEMA) + + rebuilt = mock_model.generate_content.call_args.kwargs["generation_config"] + assert rebuilt.to_dict()["temperature"] == 0.0 + assert "thinking_config" not in rebuilt.to_dict()