diff --git a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py index 9395fe1c69..8940365930 100644 --- a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py +++ b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py @@ -10,10 +10,9 @@ import tempfile import threading from collections.abc import AsyncIterable, AsyncIterator, Generator, Sequence -from contextlib import suppress +from contextlib import AbstractAsyncContextManager, AsyncExitStack, suppress from dataclasses import asdict, is_dataclass from pathlib import Path -from contextlib import AbstractAsyncContextManager, AsyncExitStack, suppress from typing import Protocol, cast from agent_framework import ( @@ -73,6 +72,7 @@ MessageContentOutputTextContent, MessageContentReasoningTextContent, MessageContentRefusalContent, + MessageRole, OAuthConsentRequestOutputItem, OutputItem, OutputItemApplyPatchToolCall, @@ -117,6 +117,8 @@ logger = logging.getLogger(__name__) +_AZURE_RESPONSES_MESSAGE_ROLE_TYPE = f"{MessageRole.__module__}:{MessageRole.__qualname__}" + # region Approval Storage class ApprovalStorage(Protocol): @@ -250,7 +252,12 @@ def _checkpoint_storage_for_context(root: str, context_id: str) -> FileCheckpoin storage_path = (root_path / context_id).resolve() if not storage_path.is_relative_to(root_path): raise RuntimeError(f"Invalid checkpoint context id: {context_id!r}") - return FileCheckpointStorage(storage_path) + return FileCheckpointStorage( + storage_path, + # Keep this provider-specific allowlist narrow. Hosted workflow + # checkpoints can persist Azure's role enum inside Message objects. + allowed_checkpoint_types=[_AZURE_RESPONSES_MESSAGE_ROLE_TYPE], + ) # endregion Approval Storage diff --git a/python/packages/foundry_hosting/tests/test_responses.py b/python/packages/foundry_hosting/tests/test_responses.py index d5e25b99f9..9358549a86 100644 --- a/python/packages/foundry_hosting/tests/test_responses.py +++ b/python/packages/foundry_hosting/tests/test_responses.py @@ -13,7 +13,7 @@ import json from collections.abc import AsyncIterator, Callable from dataclasses import dataclass -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -26,6 +26,9 @@ Message, RawAgent, ResponseStream, + WorkflowCheckpoint, + WorkflowCheckpointException, + WorkflowMessage, ) from azure.ai.agentserver.responses import InMemoryResponseProvider from mcp import McpError @@ -34,6 +37,7 @@ from agent_framework_foundry_hosting import ResponsesHostServer from agent_framework_foundry_hosting._responses import ( + _AZURE_RESPONSES_MESSAGE_ROLE_TYPE, # pyright: ignore[reportPrivateUsage] CONSENT_ERROR_CODE, FileBasedFunctionApprovalStorage, # pyright: ignore[reportPrivateUsage] InMemoryFunctionApprovalStorage, # pyright: ignore[reportPrivateUsage] @@ -2712,6 +2716,23 @@ def _helper() -> Callable[[str, str], FileCheckpointStorage]: return _checkpoint_storage_for_context + @staticmethod + def _checkpoint_with_azure_message_role() -> WorkflowCheckpoint: + from azure.ai.agentserver.responses.models import MessageRole + + return WorkflowCheckpoint( + workflow_name="wf", + graph_signature_hash="hash", + messages={ + "executor": [ + WorkflowMessage( + data=Message(role=MessageRole.USER, contents=[Content.from_text("hello")]), + source_id="source", + ) + ] + }, + ) + def test_valid_segment_creates_storage_under_root(self, tmp_path: Any) -> None: helper = self._helper() root = tmp_path / "root" @@ -2720,6 +2741,124 @@ def test_valid_segment_creates_storage_under_root(self, tmp_path: Any) -> None: assert storage.storage_path.is_dir() assert storage.storage_path.parent == root.resolve() + def test_azure_message_role_allowlist_type_matches_generated_sdk_path(self) -> None: + assert ( + _AZURE_RESPONSES_MESSAGE_ROLE_TYPE + == "azure.ai.agentserver.responses.models._generated.sdk.models.models._enums:MessageRole" + ) + + async def test_storage_allows_azure_message_role_checkpoint_restore(self, tmp_path: Any) -> None: + from azure.ai.agentserver.responses.models import MessageRole + + helper = self._helper() + root = tmp_path / "root" + root.mkdir() + storage = helper(str(root), "resp_abc123") + checkpoint = self._checkpoint_with_azure_message_role() + + await storage.save(checkpoint) + loaded = await storage.load(checkpoint.checkpoint_id) + + loaded_message = loaded.messages["executor"][0].data + assert isinstance(loaded_message, Message) + assert type(loaded_message.role) is MessageRole + assert loaded_message.role == MessageRole.USER + assert loaded_message.text == "hello" + + async def test_plain_storage_blocks_azure_message_role_checkpoint_restore(self, tmp_path: Any) -> None: + storage = FileCheckpointStorage(tmp_path / "plain") + checkpoint = self._checkpoint_with_azure_message_role() + + await storage.save(checkpoint) + with pytest.raises(WorkflowCheckpointException, match="MessageRole"): + await storage.load(checkpoint.checkpoint_id) + + async def test_get_latest_restores_azure_message_role(self, tmp_path: Any) -> None: + from azure.ai.agentserver.responses.models import MessageRole + + helper = self._helper() + root = tmp_path / "root" + root.mkdir() + storage = helper(str(root), "resp_abc123") + checkpoint = self._checkpoint_with_azure_message_role() + + await storage.save(checkpoint) + latest = await storage.get_latest(workflow_name="wf") + + assert latest is not None + assert latest.checkpoint_id == checkpoint.checkpoint_id + latest_message = latest.messages["executor"][0].data + assert isinstance(latest_message, Message) + assert type(latest_message.role) is MessageRole + + async def test_get_latest_silently_skips_without_allowlist( + self, tmp_path: Any, caplog: pytest.LogCaptureFixture + ) -> None: + import logging + + storage = FileCheckpointStorage(tmp_path / "plain") + checkpoint = self._checkpoint_with_azure_message_role() + + await storage.save(checkpoint) + with caplog.at_level(logging.WARNING, logger="agent_framework"): + latest = await storage.get_latest(workflow_name="wf") + + assert latest is None + assert any("MessageRole" in message for message in caplog.messages) + + async def test_handle_inner_workflow_restores_message_role_checkpoint_from_previous_response( + self, tmp_path: Any + ) -> None: + from agent_framework import WorkflowAgent + from azure.ai.agentserver.responses import ResponseContext + from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage + + previous_response_id = "resp_previous" + response_id = "resp_current" + root = tmp_path / "root" + root.mkdir() + checkpoint_storage = self._helper()(str(root), previous_response_id) + checkpoint = self._checkpoint_with_azure_message_role() + await checkpoint_storage.save(checkpoint) + + agent = MagicMock(spec=WorkflowAgent) + agent.id = "wf-agent" + agent.name = "wf" + agent.description = "" + agent.context_providers = [] + agent.workflow = MagicMock() + agent.workflow.name = "wf" + agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False) + agent.run = AsyncMock( + side_effect=[ + AgentResponse(messages=[]), + AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]), + ] + ) + server = ResponsesHostServer(agent, store=InMemoryResponseProvider()) + server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage] + + request = CreateResponse(model="m", input="hi", previous_response_id=previous_response_id) + context = ResponseContext( + response_id=response_id, previous_response_id=previous_response_id, mode_flags=MagicMock() + ) + input_item = ItemMessage({"type": "message", "role": "user", "content": "next turn"}) + + with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])): + async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage] + pass + + assert agent.run.call_count == 2 + restore_call = agent.run.call_args_list[0] + assert restore_call.kwargs["checkpoint_id"] == checkpoint.checkpoint_id + assert restore_call.kwargs["checkpoint_storage"].storage_path == (root / previous_response_id).resolve() + + new_turn_call = agent.run.call_args_list[1] + new_turn_messages = new_turn_call.args[0] + assert len(new_turn_messages) == 1 + assert new_turn_messages[0].text == "next turn" + assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve() + @pytest.mark.parametrize( "bad_id", [ @@ -2923,6 +3062,8 @@ async def test_malicious_context_id_rejected_e2e(self, tmp_path: Any, context_fi f"before={before} after={after}" ) assert list(root.iterdir()) == [], f"Checkpoint directory created inside root for {context_field}={bad_id!r}" + + # region Agent lifecycle (lazy entry & OAuth consent surfacing)