Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
CLIENT_ORIGINS=https://fake-origin.example.com
CLIENT_ORIGINS_REGEX="^http://fake-localhost:.*"
SESSION_COOKIE_DOMAIN=.example.com
ENV=development

##### AZURE #####
Expand Down
260 changes: 26 additions & 234 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ langgraph-checkpoint-postgres = "^2.0.23"
azure-ai-inference = "^1.0.0b9"
azure-identity = "^1.25.0"
psycopg = {extras = ["binary"], version = "^3.2.10"}
welearn-database = "^1.4.5"
bs4 = "^0.0.2"
Comment thread
jmsevin marked this conversation as resolved.
urllib3 = "^2.6.3"
refinedoc = "^1.0.1"
Expand All @@ -50,7 +51,6 @@ langchain-mistralai = "^1.1.2"
langchain-azure-ai = "^1.2.3"
langgraph = "^1.1.10"
mistralai = "^2.4.3"
welearn-database = "^1.4.5"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
Expand Down
90 changes: 82 additions & 8 deletions src/app/api/api_v1/endpoints/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@
from langchain_core.messages import ToolMessage
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from openai import RateLimitError
from psycopg.rows import dict_row
from psycopg.rows import AsyncRowFactory, DictRow, dict_row
from pydantic import BaseModel

from src.app.api.api_v1.endpoints.chat_utils import (
_resolve_thread_id,
_sse_wrap,
_stream_agent_response,
)
from src.app.models import chat as models
from src.app.search.services.search import SearchService, get_search_service
from src.app.services.data_collection import get_data_collection_service
Expand Down Expand Up @@ -42,6 +47,16 @@
database=settings.PG_DATABASE,
)

# psycopg exposes dict_row with a BaseCursor annotation, while AsyncConnection.connect
# expects an async row factory type. Runtime is valid; cast keeps static typing happy.
ASYNC_DICT_ROW_FACTORY = cast(AsyncRowFactory[DictRow], dict_row)

SSE_HEADERS = {
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}


def get_params(body: models.Context) -> models.ContextOut:
body.sources = body.sources[:7]
Expand All @@ -55,6 +70,7 @@ def get_params(body: models.Context) -> models.ContextOut:
history=body.history or [],
query=body.query,
subject=body.subject,
conversation_id=None,
Comment thread
jmsevin marked this conversation as resolved.
)


Expand Down Expand Up @@ -221,8 +237,9 @@ async def q_and_a_rephrase_stream(
)

return StreamingResponse(
content=content,
content=_sse_wrap(content),
media_type="text/event-stream",
headers=SSE_HEADERS,
)


Expand Down Expand Up @@ -316,8 +333,9 @@ async def q_and_a_stream(
)

return StreamingResponse(
content=content,
content=_sse_wrap(content),
media_type="text/event-stream",
headers=SSE_HEADERS,
)
except LanguageNotSupportedError as e:
bad_request(message=e.message, msg_code=e.msg_code)
Expand All @@ -338,8 +356,11 @@ async def get_chat_history(
chatfactory=Depends(get_chat_service),
) -> list[Dict[str, str | list[Dict[str, str]] | None]]:
if thread_id:
async with await psycopg.AsyncConnection.connect(
DB_URI, autocommit=True, prepare_threshold=0, row_factory=dict_row
async with await psycopg.AsyncConnection[DictRow].connect(
DB_URI,
autocommit=True,
prepare_threshold=0,
row_factory=ASYNC_DICT_ROW_FACTORY,
) as conn:
await conn.execute("SET SEARCH_PATH to agent_related")
await conn.commit()
Expand All @@ -351,6 +372,56 @@ async def get_chat_history(
return res


@router.post(
"/chat/agent_stream",
summary="Agent Response Stream",
description="This endpoint streams an agent response to the user's message and ends with the full response payload.",
response_class=StreamingResponse,
)
Comment thread
jmsevin marked this conversation as resolved.
@backoff.on_exception(
wait_gen=backoff.expo,
exception=RateLimitError,
logger=logger,
max_tries=5,
max_time=180,
jitter=backoff.random_jitter,
factor=2,
)
async def agent_stream_response(
request: Request,
background_tasks: BackgroundTasks,
body: models.AgentContext = Depends(get_agent_params),
chatfactory=Depends(get_chat_service),
sp: SearchService = Depends(get_search_service),
data_collection=Depends(get_data_collection_service),
) -> StreamingResponse:
try:
session_id = extract_session_cookie(request)
thread_id = _resolve_thread_id(body.thread_id)

if body.query is None:
raise EmptyQueryError()

return StreamingResponse(
content=_stream_agent_response(
db_uri=DB_URI,
async_dict_row_factory=ASYNC_DICT_ROW_FACTORY,
body=body,
chatfactory=chatfactory,
sp=sp,
background_tasks=background_tasks,
data_collection=data_collection,
session_id=session_id,
thread_id=thread_id,
),
media_type="text/event-stream",
headers=SSE_HEADERS,
)
except LanguageNotSupportedError as e:
bad_request(message=e.message, msg_code=e.msg_code)
raise


@router.post(
"/chat/agent",
summary="Agent Response",
Expand Down Expand Up @@ -388,8 +459,11 @@ async def agent_response(
raise EmptyQueryError()

if thread_id:
async with await psycopg.AsyncConnection.connect(
DB_URI, autocommit=True, prepare_threshold=0, row_factory=dict_row
async with await psycopg.AsyncConnection[DictRow].connect(
DB_URI,
autocommit=True,
prepare_threshold=0,
row_factory=ASYNC_DICT_ROW_FACTORY,
) as conn:
await conn.execute("SET SEARCH_PATH to agent_related")
await conn.commit()
Expand Down Expand Up @@ -425,7 +499,7 @@ async def agent_response(
}

try:
conversation_id, message_id = await data_collection.register_chat_data(
_, message_id = await data_collection.register_chat_data(
session_id=session_id,
user_query=body.query,
conversation_id=thread_id,
Expand Down
203 changes: 203 additions & 0 deletions src/app/api/api_v1/endpoints/chat_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import json
import uuid
from typing import Any, AsyncGenerator, cast
from uuid import UUID

import psycopg
from fastapi import BackgroundTasks
from fastapi.encoders import jsonable_encoder
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg.rows import AsyncRowFactory, DictRow

from src.app.models import chat as models
from src.app.search.services.search import SearchService
from src.app.utils.logger import logger as utils_logger

logger = utils_logger(__name__)


def _format_sse_event(data: str) -> str:
lines = data.splitlines()
return "".join(f"data: {line}\n" for line in lines) + "\n"


async def _sse_wrap(stream: Any) -> AsyncGenerator[str, None]:
async for chunk in stream:
if isinstance(chunk, str):
data = chunk
elif isinstance(chunk, bytes):
data = chunk.decode("utf-8", errors="replace")
else:
data = json.dumps(jsonable_encoder(chunk))
yield _format_sse_event(data)


def _resolve_thread_id(thread_id: UUID | None) -> UUID:
if thread_id:
return thread_id

logger.info("No thread_id provided. Generating new thread_id.")
return uuid.uuid4()


def _update_agent_stream_state(
chunk: dict[str, Any],
current_final_content: str,
current_docs: Any,
) -> tuple[str, Any]:
status = chunk.get("status")
docs = current_docs
final_content = current_final_content

if status == "processing" and chunk.get("docs"):
docs = chunk["docs"]
elif status == "streaming":
final_content += cast(str, chunk.get("content", ""))
elif status == "stop":
stop_content = cast(str, chunk.get("content", ""))
if stop_content:
final_content = stop_content

return final_content, docs


def _serialize_agent_stream_chunk(chunk: dict[str, Any]) -> str:
payload = {
"content": chunk.get("content"),
"status": chunk.get("status"),
"step": chunk.get("step"),
"label": chunk.get("label"),
"docs": chunk.get("docs"),
}

return json.dumps(jsonable_encoder(payload))


async def _stream_agent_with_memory(
*,
db_uri: str,
async_dict_row_factory: AsyncRowFactory[DictRow],
chatfactory: Any,
body: models.AgentContext,
sp: SearchService,
background_tasks: BackgroundTasks,
thread_id: UUID,
) -> AsyncGenerator[dict[str, Any], None]:
async with await psycopg.AsyncConnection[DictRow].connect(
db_uri,
autocommit=True,
prepare_threshold=0,
row_factory=async_dict_row_factory,
) as conn:
await conn.execute("SET SEARCH_PATH to agent_related")
await conn.commit()

memory = AsyncPostgresSaver(conn)
stream = await chatfactory.agent_message(
query=body.query,
memory=memory,
thread_id=thread_id,
corpora=body.corpora,
sdg_filter=body.sdg_filter,
sp=sp,
background_tasks=background_tasks,
streamed_ans=True,
)

async for chunk in stream:
yield chunk


def _build_final_stream_payload(
*,
final_content: str,
docs: Any,
thread_id: UUID,
) -> dict[str, Any]:
return {
"content": final_content,
"status": "stop",
"docs": docs,
"thread_id": thread_id,
}


async def _register_stream_chat_data(
*,
data_collection: Any,
session_id: UUID | None,
user_query: str,
conversation_id: UUID,
answer_content: str,
sources: Any,
) -> Any:
_, message_id = await data_collection.register_chat_data(
session_id=session_id,
user_query=user_query,
conversation_id=conversation_id,
answer_content=answer_content,
sources=sources,
)
return message_id


async def _stream_agent_response(
*,
db_uri: str,
async_dict_row_factory: AsyncRowFactory[DictRow],
body: models.AgentContext,
chatfactory: Any,
sp: SearchService,
background_tasks: BackgroundTasks,
data_collection: Any,
session_id: UUID | None,
thread_id: UUID,
) -> AsyncGenerator[str, None]:
final_content = ""
docs = []
has_streamed_content = False

stream = _stream_agent_with_memory(
db_uri=db_uri,
async_dict_row_factory=async_dict_row_factory,
chatfactory=chatfactory,
body=body,
sp=sp,
background_tasks=background_tasks,
thread_id=thread_id,
)

async for chunk in stream:
final_content, docs = _update_agent_stream_state(chunk, final_content, docs)
if chunk.get("status") == "streaming" and chunk.get("content"):
has_streamed_content = True
if chunk.get("status") == "stop":
continue
try:
yield _format_sse_event(_serialize_agent_stream_chunk(chunk))
except Exception as e:
logger.error("Error while yielding chunk: %s", e)

final_payload = _build_final_stream_payload(
final_content=final_content,
docs=docs,
thread_id=thread_id,
)

if has_streamed_content:
final_payload = {**final_payload, "content": ""}

try:
message_id = await _register_stream_chat_data(
data_collection=data_collection,
session_id=session_id,
user_query=cast(str, body.query),
conversation_id=thread_id,
answer_content=final_content,
sources=docs,
)
final_payload = {**final_payload, "message_id": message_id}
except Exception as e:
logger.error("Error while registering chat data: %s", e)

yield _format_sse_event(json.dumps(jsonable_encoder(final_payload)))
2 changes: 2 additions & 0 deletions src/app/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class AgentContext(SDGFilter):

class AgentResponse(BaseModel):
content: str | None = None
status: str | None = None
step: str | None = None
docs: list[ScoredPoint] | None = None
thread_id: uuid.UUID | None = None

Expand Down
Loading
Loading