Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions src/neo4j_graphrag/llm/vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def _extract_generation_config_params(
) -> dict[str, Any]:
"""Extract valid parameters from a GenerationConfig object.

This function extracts parameters from the internal _raw_generation_config
protobuf and returns them as a dict that can be passed to GenerationConfig().
Prefer ``GenerationConfig.to_dict()`` so fields only available via
``from_dict`` (e.g. ``thinking_config``) and zero-valued numerics are kept.

Args:
config: A GenerationConfig object
Expand All @@ -86,12 +86,22 @@ def _extract_generation_config_params(
"""
from vertexai.generative_models import GenerationConfig

if config is None:
return {}

if hasattr(config, "to_dict"):
params = dict(config.to_dict())
if exclude_schema:
for key in _GENERATION_CONFIG_SCHEMA_PARAMS:
params.pop(key, None)
return params

if not hasattr(config, "_raw_generation_config"):
return {}

raw = config._raw_generation_config

# Get valid params from GenerationConfig signature
# Fallback for test doubles without to_dict().
sig = inspect.signature(GenerationConfig.__init__)
valid_params = {
name
Expand All @@ -100,14 +110,24 @@ def _extract_generation_config_params(
and (not exclude_schema or name not in _GENERATION_CONFIG_SCHEMA_PARAMS)
}

preserved = {}
preserved: dict[str, Any] = {}
for param in valid_params:
val = getattr(raw, param, None)
if val: # Only include non-empty values
# Convert repeated fields (like stop_sequences) to lists
if hasattr(val, "__iter__") and not isinstance(val, (str, bytes, dict)):
val = list(val)
preserved[param] = val
if val is None:
continue
if isinstance(val, (str, bytes)) and not val:
continue
if hasattr(val, "__iter__") and not isinstance(val, (str, bytes, dict)):
val = list(val)
if not val:
continue
preserved[param] = val

thinking_config = getattr(raw, "thinking_config", None)
if thinking_config is not None and str(thinking_config).strip():
preserved["thinking_config"] = {
"thinking_budget": thinking_config.thinking_budget
}

return preserved

Expand Down Expand Up @@ -548,7 +568,10 @@ def _get_call_params_v2(
# Apply kwargs (they override constructor values but preserve schema)
params.update(kwargs)

options["generation_config"] = GenerationConfig(**params)
if "thinking_config" in params:
options["generation_config"] = GenerationConfig.from_dict(params)
else:
options["generation_config"] = GenerationConfig(**params)
options["contents"] = contents
return options

Expand Down
89 changes: 88 additions & 1 deletion tests/unit/llm/test_vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.types import ToolCallResponse
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM
from neo4j_graphrag.llm.vertexai_llm import (
VertexAILLM,
_extract_generation_config_params,
)
from neo4j_graphrag.tool import Tool
from neo4j_graphrag.types import LLMMessage
from neo4j_graphrag.utils.rate_limit import NoOpRateLimitHandler

from pydantic import BaseModel, ConfigDict
from vertexai.generative_models import GenerationConfig


@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None)
Expand Down Expand Up @@ -647,3 +651,86 @@ async def test_vertexai_ainvoke_v2_rate_limit_handler_called(

assert response.content == "Hi there!"
spy_handler.handle_async.assert_called_once()


def test_extract_generation_config_params_preserves_thinking_config() -> None:
"""thinking_config and zero-valued temperature survive extraction."""
config = GenerationConfig.from_dict(
{"temperature": 0.0, "thinking_config": {"thinking_budget": 0}}
)

params = _extract_generation_config_params(config)

assert params["temperature"] == 0.0
assert params["thinking_config"] == {"thinking_budget": 0}


def test_extract_generation_config_params_excludes_schema_when_requested() -> None:
config = GenerationConfig.from_dict(
{
"response_mime_type": "application/json",
"response_schema": {"type": "object"},
"thinking_config": {"thinking_budget": 512},
}
)

params = _extract_generation_config_params(config, exclude_schema=True)

assert "response_schema" not in params
assert "response_mime_type" not in params
assert params["thinking_config"] == {"thinking_budget": 512}


@pytest.mark.parametrize("thinking_budget", [0, 512])
@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
def test_vertexai_invoke_v2_preserves_thinking_config_with_response_format(
GenerativeModelMock: MagicMock,
thinking_budget: int,
) -> None:
"""Structured-output v2 invoke must keep thinking_config from constructor config."""
messages: List[LLMMessage] = [{"role": "user", "content": "Extract person info"}]
mock_response = Mock()
mock_response.text = '{"name": "John", "age": 30}'
mock_response.usage_metadata = None
mock_model = GenerativeModelMock.return_value
mock_model.generate_content.return_value = mock_response

generation_config = GenerationConfig.from_dict(
{
"temperature": 0.0,
"thinking_config": {"thinking_budget": thinking_budget},
}
)
llm = VertexAILLM(
model_name="gemini-2.5-flash",
generation_config=generation_config,
)
llm.invoke(messages, response_format=_TestModelForVertexAI)

call_args = mock_model.generate_content.call_args.kwargs
rebuilt = call_args["generation_config"]
assert rebuilt.to_dict()["thinking_config"] == {"thinking_budget": thinking_budget}
assert rebuilt.to_dict()["temperature"] == 0.0


@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
def test_vertexai_invoke_v2_preserves_temperature_zero_without_thinking_config(
GenerativeModelMock: MagicMock,
) -> None:
"""temperature=0.0 must survive structured-output rebuild without thinking_config."""
messages: List[LLMMessage] = [{"role": "user", "content": "Test"}]
mock_response = Mock()
mock_response.text = '{"result": "success"}'
mock_response.usage_metadata = None
mock_model = GenerativeModelMock.return_value
mock_model.generate_content.return_value = mock_response

llm = VertexAILLM(
model_name="gemini-2.5-flash",
generation_config=GenerationConfig(temperature=0.0),
)
llm.invoke(messages, response_format=_TEST_JSON_SCHEMA)

rebuilt = mock_model.generate_content.call_args.kwargs["generation_config"]
assert rebuilt.to_dict()["temperature"] == 0.0
assert "thinking_config" not in rebuilt.to_dict()
Loading