diff --git a/backend/app/api/docs/llm/speech_to_speech.md b/backend/app/api/docs/llm/speech_to_speech.md new file mode 100644 index 000000000..e4ad03e6f --- /dev/null +++ b/backend/app/api/docs/llm/speech_to_speech.md @@ -0,0 +1,228 @@ +# Speech-to-Speech (STS) with RAG + +Execute a complete speech-to-speech workflow with knowledge base retrieval. + +## Endpoint + +``` +POST /llm/sts +``` + +## Flow + +``` +Voice Input → STT (auto language) → RAG (Knowledge Base) → TTS → Voice Output +``` + +## Input + +- **Voice note**: WhatsApp-compatible audio format (required) +- **Knowledge base IDs**: One or more knowledge bases for RAG (required) +- **Languages**: Input and output languages (optional, defaults to Hindi) +- **Models**: STT, LLM, and TTS model selection (optional, defaults to Sarvam) + +## Output + +You will receive **3 callbacks** to your webhook URL: + +1. **STT Callback** (Intermediate): Transcribed text from audio +2. **LLM Callback** (Intermediate): RAG-enhanced response text +3. **TTS Callback** (Final): Audio output + response text + +Each callback includes: +- Output from that step +- Token usage +- Latency information (check timestamps) + +## Supported Languages + +### Primary Indian Languages +- English, Hindi, Hinglish (code-switching) +- Bengali, Kannada, Malayalam, Marathi +- Odia, Punjabi, Tamil, Telugu, Gujarati + +### Additional Languages (Sarvam Saaras V3) +- Assamese, Urdu, Nepali +- Konkani, Kashmiri, Sindhi +- Sanskrit, Santali, Manipuri +- Bodo, Maithili, Dogri + +**Total: 25 languages** with automatic language detection + +## Available Models + +### STT (Speech-to-Text) +- `saaras:v3` - Sarvam Saaras V3 (**default**, fast, auto language detection, optimized for Indian languages) +- `gemini-2.5-pro` - Google Gemini 2.5 Pro + +**Note:** Sarvam STT uses automatic language detection. No need to specify input language. + +### LLM (RAG) +- `gpt-4o` - OpenAI GPT-4o (**default**, best quality) +- `gpt-4o-mini` - OpenAI GPT-4o Mini (faster, lower cost) + +### TTS (Text-to-Speech) +- `bulbul:v3` - Sarvam Bulbul V3 (**default**, natural Indian voices, MP3 output) +- `gemini-2.5-pro-preview-tts` - Google Gemini 2.5 Pro (OGG OPUS output) + +## Edge Cases & Error Handling + +### Empty STT Output +If speech-to-text returns empty/blank: +- Chain fails immediately +- Error message: "STT returned no transcription" +- No subsequent blocks are executed + +### Audio Size Limit +WhatsApp limit: 16MB +- TTS providers may fail if output exceeds limit +- Error is caught and reported in callback +- Consider using shorter responses or compression + +### Invalid Audio Format +If input audio format is unsupported: +- STT provider fails with format error +- Error reported in callback +- Supported: MP3, WAV, OGG, OPUS, M4A + +### Provider Failures +Each block has independent error handling: +- STT fails → Chain stops, STT error reported +- LLM fails → Chain stops, RAG error reported +- TTS fails → Chain stops, TTS error reported + +## Example Request + +```bash +curl -X POST https://api.kaapi.ai/llm/sts \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d @- < tuple[str, str]: + """Pick effective input/output language codes. + + If input is "auto" and output isn't pinned, output also becomes "auto" + so the TTS mapper falls back to provider auto-detection. + """ + input_lang = request.input_language or "auto" + if request.output_language: + return input_lang, request.output_language + return input_lang, ("auto" if input_lang == "auto" else input_lang) + + +def _merge_stt(user: STTLLMParams | None, input_lang: str) -> STTLLMParams: + base = STTLLMParams(model=DEFAULT_STT_MODEL, input_language=input_lang) + if user is None: + return base + overrides = user.model_dump(exclude_unset=True) + overrides["input_language"] = input_lang # route owns this + return base.model_copy(update=overrides) + + +def _merge_rag( + user: TextLLMParams | None, knowledge_base_ids: list[str] +) -> TextLLMParams: + base = TextLLMParams( + model=DEFAULT_RAG_MODEL, + temperature=DEFAULT_RAG_TEMPERATURE, + instructions=DEFAULT_RAG_INSTRUCTIONS, + ) + merged = ( + base + if user is None + else base.model_copy(update=user.model_dump(exclude_unset=True)) + ) + return merged.model_copy(update={"knowledge_base_ids": knowledge_base_ids}) + + +def _merge_tts(user: TTSLLMParams | None, output_lang: str) -> TTSLLMParams: + base = TTSLLMParams( + model=DEFAULT_TTS_MODEL, + voice=DEFAULT_TTS_VOICE, + language=output_lang, + response_format=DEFAULT_TTS_FORMAT, + ) + if user is None: + return base + overrides = user.model_dump(exclude_unset=True) + overrides["language"] = output_lang # route owns this + return base.model_copy(update=overrides) + + +def _inline_call_config( + type_: BlockType, + params: STTLLMParams | TextLLMParams | TTSLLMParams, + provider: str | None, +) -> LLMCallConfig: + return LLMCallConfig( + blob=ConfigBlob( + completion=KaapiCompletionConfig( + provider=provider, + type=type_, + params=params.model_dump(exclude_none=True), + ) + ) + ) + + +def _stored_call_config(config_id: UUID, config_version: int) -> LLMCallConfig: + return LLMCallConfig(id=config_id, version=config_version) + + +# ---------- Per-block resolution ---------- + + +def _resolve_stt_block( + spec: STTBlockSpec | None, input_lang: str, provider: str | None +) -> ChainBlock: + if spec and spec.is_stored_ref: + config = _stored_call_config(spec.config_id, spec.config_version) + else: + merged = _merge_stt(spec.params if spec else None, input_lang) + config = _inline_call_config("stt", merged, provider) + return ChainBlock(config=config, intermediate_callback=True) + + +def _resolve_rag_block( + spec: RAGBlockSpec | None, + knowledge_base_ids: list[str], + provider: str | None, +) -> ChainBlock: + if spec and spec.is_stored_ref: + config = _stored_call_config(spec.config_id, spec.config_version) + else: + merged = _merge_rag(spec.params if spec else None, knowledge_base_ids) + config = _inline_call_config("text", merged, provider or "openai") + return ChainBlock(config=config, intermediate_callback=True) + + +def _resolve_tts_block( + spec: TTSBlockSpec | None, output_lang: str, provider: str | None +) -> ChainBlock: + if spec and spec.is_stored_ref: + config = _stored_call_config(spec.config_id, spec.config_version) + else: + merged = _merge_tts(spec.params if spec else None, output_lang) + config = _inline_call_config("tts", merged, provider) + return ChainBlock(config=config, intermediate_callback=False) + + +# ---------- Metadata ---------- + + +def _model_for_metadata( + spec: STTBlockSpec | RAGBlockSpec | TTSBlockSpec | None, default: str +) -> str: + """Best-effort model label for logs/metadata.""" + if spec is None: + return default + if spec.is_stored_ref: + return f"stored:{spec.config_id}@v{spec.config_version}" + if spec.params and getattr(spec.params, "model", None): + return spec.params.model + return default + + +def _build_metadata( + request: SpeechToSpeechRequest, input_lang: str, output_lang: str +) -> dict[str, Any]: + metadata = dict(request.request_metadata or {}) + metadata.update( + { + "speech_to_speech": True, + "input_language": input_lang, + "output_language": output_lang, + "stt_model": _model_for_metadata(request.stt, DEFAULT_STT_MODEL), + "llm_model": _model_for_metadata(request.rag, DEFAULT_RAG_MODEL), + "tts_model": _model_for_metadata(request.tts, DEFAULT_TTS_MODEL), + } + ) + return metadata + + +# ---------- Endpoint ---------- + + +@router.post( + "/llm/chain/sts", + description=load_description("llm/speech_to_speech.md"), + response_model=APIResponse[Message], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def speech_to_speech( + _current_user: AuthContextDep, + session: SessionDep, + request: SpeechToSpeechRequest, +) -> APIResponse[Message]: + """Run the STT → RAG → TTS chain for a single voice input.""" + project_id = _current_user.project_.id + organization_id = _current_user.organization_.id + + if request.callback_url: + validate_callback_url(str(request.callback_url)) + + if ( + request.input_language + and request.input_language not in SUPPORTED_LANGUAGE_CODES + ): + raise HTTPException( + status_code=422, + detail=f"Unsupported input language code: {request.input_language}. Supported: {', '.join(SUPPORTED_LANGUAGE_CODES)}", + ) + + if request.output_language and ( + request.output_language not in SUPPORTED_LANGUAGE_CODES + or request.output_language in ("auto", "unknown") + ): + tts_supported = SUPPORTED_LANGUAGE_CODES - {"auto", "unknown"} + raise HTTPException( + status_code=422, + detail=f"Unsupported output language code: {request.output_language}. Supported: {', '.join(tts_supported)}", + ) + + input_lang, output_lang = _resolve_languages(request) + + blocks = [ + _resolve_stt_block(request.stt, input_lang, request.stt_provider), + _resolve_rag_block( + request.rag, request.knowledge_base_ids, request.rag_provider + ), + _resolve_tts_block(request.tts, output_lang, request.tts_provider), + ] + + logger.info( + f"[speech_to_speech] Starting STS chain | " + f"project_id={project_id}, " + f"input_lang={input_lang}, output_lang={output_lang}" + ) + + chain_request = LLMChainRequest( + query=QueryParams(input=request.query), + blocks=blocks, + callback_url=request.callback_url, + request_metadata=_build_metadata(request, input_lang, output_lang), + ) + + start_chain_job( + db=session, + request=chain_request, + project_id=project_id, + organization_id=organization_id, + ) + + return APIResponse.success_response( + data=Message( + message=( + "Speech-to-speech processing initiated. " + "You will receive intermediate callbacks for STT and LLM outputs, " + "followed by the final callback with audio and text." + ) + ) + ) diff --git a/backend/app/models/llm/constants.py b/backend/app/models/llm/constants.py index 1838da79d..54ac8e543 100644 --- a/backend/app/models/llm/constants.py +++ b/backend/app/models/llm/constants.py @@ -1,6 +1,7 @@ DEFAULT_STT_MODEL = "gemini-2.5-pro" DEFAULT_TTS_MODEL = "gemini-2.5-flash-preview-tts" DEFAULT_TTS_VOICE = "Kore" +DEFAULT_RAG_MODEL = "gpt-4o" # BCP-47 to language tag -> Gemini ISO 639-1 code (Indic + English) BCP47_LOCALE_TO_GEMINI_LANG: dict[str, str] = { diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index da0c18120..ec8df7415 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -97,6 +97,9 @@ class TTSLLMParams(SQLModel): class TextContent(SQLModel): format: Literal["text"] = "text" value: str = Field(..., description="Text content") + language_code: str | None = Field( + None, description="Optional detected language code in STT 'auto' mode" + ) class AudioContent(SQLModel): @@ -804,3 +807,153 @@ class LlmChain(SQLModel, table=True): "comment": "Timestamp when the chain record was last updated" }, ) + + +class _BlockSpecBase(SQLModel): + """Common xor logic for per-block specs: provide *either* (config_id + + config_version) for a stored config, *or* inline `params`. Not both. + """ + + config_id: UUID | None = Field( + default=None, + description="ID of a stored LLM config to use for this block.", + ) + config_version: int | None = Field( + default=None, + ge=1, + description="Version of the stored config (required when config_id is set).", + ) + + @model_validator(mode="after") + def _validate_xor(self): + has_ref = self.config_id is not None or self.config_version is not None + has_params = getattr(self, "params", None) is not None + + if has_ref and has_params: + raise ValueError( + "Provide either (config_id + config_version) OR inline 'params', not both." + ) + if has_ref and (self.config_id is None or self.config_version is None): + raise ValueError( + "Both 'config_id' and 'config_version' must be set together." + ) + return self + + @property + def is_stored_ref(self) -> bool: + return self.config_id is not None and self.config_version is not None + + +class STTBlockSpec(_BlockSpecBase): + params: STTLLMParams | None = Field( + default=None, + description="Inline STT parameters. Omit to use endpoint defaults.", + ) + + +class RAGBlockSpec(_BlockSpecBase): + params: TextLLMParams | None = Field( + default=None, + description="Inline RAG (text) parameters. Omit to use endpoint defaults.", + ) + + +class TTSBlockSpec(_BlockSpecBase): + params: TTSLLMParams | None = Field( + default=None, + description="Inline TTS parameters. Omit to use endpoint defaults.", + ) + + +class SpeechToSpeechRequest(SQLModel): + """ + API request for speech-to-speech (STS) with RAG. + + Convenience endpoint that orchestrates a 3-block chain: + STT → RAG → TTS + + Input: Audio + Output: Audio + Text (via callback) + """ + + query: AudioInput = Field( + ..., description="Voice note input (WhatsApp compatible format)" + ) + knowledge_base_ids: list[str] = Field( + ..., min_length=1, description="Knowledge base IDs for RAG retrieval" + ) + + # Optional language config (BCP-47 codes) + input_language: str | None = Field( + "auto", + description=( + "BCP-47 language code for STT input (auto-detect by default). " + "Supported codes: 'auto', 'en-IN', 'hi-IN', 'bn-IN', 'kn-IN', 'ml-IN', 'mr-IN', 'od-IN', " + "'pa-IN', 'ta-IN', 'te-IN', 'gu-IN', 'as-IN', 'ur-IN', 'ne-IN', 'kok-IN', 'ks-IN', " + "'sd-IN', 'sa-IN', 'sat-IN', 'mni-IN', 'brx-IN', 'mai-IN', 'doi-IN'" + ), + ) + output_language: str | None = Field( + None, + description=( + "BCP-47 language code for TTS output (defaults to input_language if not specified). " + "Supported codes: same as input_language (except 'auto')." + ), + ) + + # Per-block specs. Each spec accepts EITHER (config_id + config_version) + # to reference a stored config, OR inline `params` to override the + # endpoint defaults. Omit entirely to use defaults only. + stt: STTBlockSpec | None = Field( + None, + description=( + "STT block spec. Use 'params' for inline overrides or " + "'config_id' + 'config_version' to reference a stored config." + ), + ) + rag: RAGBlockSpec | None = Field( + None, + description=( + "RAG block spec. Use 'params' for inline overrides or " + "'config_id' + 'config_version' to reference a stored config." + ), + ) + tts: TTSBlockSpec | None = Field( + None, + description=( + "TTS block spec. Use 'params' for inline overrides or " + "'config_id' + 'config_version' to reference a stored config." + ), + ) + + # Provider hints. Optional — KaapiCompletionConfig auto-defaults to + # "google" for stt/tts when omitted. + stt_provider: Literal["google", "sarvamai", "elevenlabs"] | None = None + tts_provider: Literal["google", "sarvamai", "elevenlabs"] | None = None + rag_provider: Literal["openai"] | None = None + + # Callback and metadata + callback_url: HttpUrl | None = Field( + None, description="Webhook URL for async response delivery" + ) + request_metadata: dict[str, Any] | None = Field( + None, description="Client-provided metadata" + ) + + @model_validator(mode="after") + def validate_languages(self): + """Normalize BCP-47 language codes to standard format (e.g., 'hi-in' -> 'hi-IN').""" + # Normalize input_language + if self.input_language and self.input_language != "auto": + # Normalize BCP-47: lowercase language, uppercase region (e.g., "hi-IN") + parts = self.input_language.split("-") + if len(parts) == 2: + self.input_language = f"{parts[0].lower()}-{parts[1].upper()}" + + # Normalize output_language + if self.output_language: + parts = self.output_language.split("-") + if len(parts) == 2: + self.output_language = f"{parts[0].lower()}-{parts[1].upper()}" + + return self diff --git a/backend/app/services/llm/chain/chain.py b/backend/app/services/llm/chain/chain.py index 2cd80708b..ec8f8fb19 100644 --- a/backend/app/services/llm/chain/chain.py +++ b/backend/app/services/llm/chain/chain.py @@ -36,6 +36,9 @@ class ChainContext: langfuse_credentials: dict[str, Any] | None = None request_metadata: dict | None = None intermediate_callback_flags: list[bool] = field(default_factory=list) + detected_language: str | None = ( + None # Stores language detected by STT for use by TTS + ) aggregated_usage: Usage = field( default_factory=lambda: Usage( input_tokens=0, @@ -45,17 +48,37 @@ class ChainContext: ) -def result_to_query(result: BlockResult) -> QueryParams: +def result_to_query( + result: BlockResult, context: ChainContext | None = None +) -> QueryParams: """Convert a block's output into the next block's QueryParams. Text output → TextInput query Audio output → AudioInput query + + Also preserves language_code from STT output for use by downstream TTS blocks. """ output = result.response.response.output if isinstance(output, TextOutput): + # Preserve language_code if present (from STT auto-detection) + language_code = ( + output.content.language_code + if hasattr(output.content, "language_code") + else None + ) + + # Store detected language in context for TTS to use + if context and language_code: + context.detected_language = language_code + logger.info(f"[result_to_query] Detected language: {language_code}") + return QueryParams( - input=TextInput(content=TextContent(value=output.content.value)) + input=TextInput( + content=TextContent( + value=output.content.value, language_code=language_code + ) + ) ) elif isinstance(output, AudioOutput): return QueryParams(input=AudioInput(content=output.content)) @@ -96,6 +119,7 @@ def execute(self, query: QueryParams) -> BlockResult: langfuse_credentials=self._context.langfuse_credentials, include_provider_raw_response=self._include_provider_raw_response, chain_id=self._context.chain_id, + detected_language=self._context.detected_language, ) @@ -132,6 +156,6 @@ def execute( return result if block is not self._blocks[-1]: - current_query = result_to_query(result) + current_query = result_to_query(result, self._context) return result diff --git a/backend/app/services/llm/chain/utils.py b/backend/app/services/llm/chain/utils.py new file mode 100644 index 000000000..2955d58c2 --- /dev/null +++ b/backend/app/services/llm/chain/utils.py @@ -0,0 +1,37 @@ +"""Utility functions for LLM chain operations, including speech-to-speech helpers.""" + +# BCP-47 language codes accepted by the speech-to-speech endpoint. +SUPPORTED_LANGUAGE_CODES = { + "auto", + "unknown", + # Primary Indian languages + "en-IN", + "hi-IN", + "bn-IN", + "kn-IN", + "ml-IN", + "mr-IN", + "od-IN", + "pa-IN", + "ta-IN", + "te-IN", + "gu-IN", + # Additional languages + "as-IN", + "ur-IN", + "ne-IN", + "kok-IN", + "ks-IN", + "sd-IN", + "sa-IN", + "sat-IN", + "mni-IN", + "brx-IN", + "mai-IN", + "doi-IN", +} + +DEFAULT_RAG_INSTRUCTIONS = ( + "Answer the user's question using the provided knowledge base. " + "Be concise and accurate." +) diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 550d7ff41..66913aa40 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -45,6 +45,7 @@ ImageInput, KaapiCompletionConfig, LLMCallConfig, + NativeCompletionConfig, PDFInput, QueryParams, TextContent, @@ -504,10 +505,14 @@ def execute_llm_call( langfuse_credentials: dict | None, include_provider_raw_response: bool = False, chain_id: UUID | None = None, + detected_language: str | None = None, ) -> BlockResult: """Execute a single LLM call. Shared by /llm/call and /llm/chain. Returns BlockResult with response + usage on success, or error on failure. + + Args: + detected_language: Language code detected by STT (used to replace {{detected}} marker in TTS) """ config_blob: ConfigBlob | None = None @@ -627,6 +632,26 @@ def execute_llm_call( request_metadata = {} request_metadata.setdefault("warnings", []).extend(warnings) + # Replace {{detected}} marker in TTS configs with actual detected language + if ( + isinstance(completion_config, NativeCompletionConfig) + and completion_config.type == "tts" + ): + params = completion_config.params + # Replace {{detected}} marker in any language-related params + for key in ["target_language_code", "language_code"]: + if key in params and params[key] == "{{detected}}": + if detected_language: + params[key] = detected_language + logger.info( + f"[execute_llm_call] Using detected language for TTS: {detected_language} | job_id={job_id}" + ) + else: + # Fallback to English if no language was detected + params[key] = "en-IN" + logger.warning( + f"[execute_llm_call] No language detected, falling back to en-IN for TTS | job_id={job_id}" + ) model_name = str(completion_config.params.get("model") or "") resolved_config_blob = ConfigBlob( diff --git a/backend/app/services/llm/providers/sai.py b/backend/app/services/llm/providers/sai.py index f4a6cc5e7..5866bc384 100644 --- a/backend/app/services/llm/providers/sai.py +++ b/backend/app/services/llm/providers/sai.py @@ -116,7 +116,10 @@ def _execute_stt( provider=provider_name, model=model, output=TextOutput( - content=TextContent(value=sarvam_response.transcript) + content=TextContent( + value=sarvam_response.transcript, + language_code=sarvam_response.language_code, + ) ), ), usage=Usage( diff --git a/backend/app/tests/services/llm/providers/test_sai.py b/backend/app/tests/services/llm/providers/test_sai.py index 78fdd30a8..c7550c123 100644 --- a/backend/app/tests/services/llm/providers/test_sai.py +++ b/backend/app/tests/services/llm/providers/test_sai.py @@ -18,14 +18,17 @@ def mock_sarvam_stt_response( transcript: str = "नमस्ते", request_id: str = "req_stt_123", + language_code: str | None = "hi-IN", ) -> SimpleNamespace: """Create a mock SarvamAI STT response object.""" response = SimpleNamespace( transcript=transcript, request_id=request_id, + language_code=language_code, model_dump=lambda: { "transcript": transcript, "request_id": request_id, + "language_code": language_code, }, ) return response diff --git a/backend/app/tests/services/llm/test_sts.py b/backend/app/tests/services/llm/test_sts.py new file mode 100644 index 000000000..f1141b3ef --- /dev/null +++ b/backend/app/tests/services/llm/test_sts.py @@ -0,0 +1,880 @@ +"""Tests for the /llm/chain/sts endpoint. + +Covers: +- Default block construction (models, voices, formats) +- Language resolution (auto, pinned, cross-language, BCP-47 normalisation) +- Provider combos (google/sarvamai/elevenlabs for STT/TTS, openai for RAG) +- Inline param overrides on each block +- Stored config references on each block (and mixed) +- Intermediate callback flags +- Metadata construction (defaults, inline overrides, stored-ref labels) +- Error paths: bad language codes, "unknown"/"auto" as output, XOR violations, + missing required fields +""" + +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from fastapi.testclient import TestClient +from pydantic import ValidationError + +from app.models.llm.constants import ( + DEFAULT_RAG_MODEL, + DEFAULT_STT_MODEL, + DEFAULT_TTS_MODEL, + DEFAULT_TTS_VOICE, +) +from app.models.llm.request import ( + AudioContent, + AudioInput, + RAGBlockSpec, + SpeechToSpeechRequest, + STTBlockSpec, + STTLLMParams, + TextLLMParams, + TTSBlockSpec, + TTSLLMParams, +) + +URL = "/api/v1/llm/chain/sts" + + +# ---------- Fixtures ---------- + + +@pytest.fixture +def audio_input() -> AudioInput: + return AudioInput( + content=AudioContent( + format="base64", + value="SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4Lg==", + mime_type="audio/ogg", + ) + ) + + +@pytest.fixture +def audio_url_input() -> AudioInput: + return AudioInput( + content=AudioContent( + format="url", + value="https://example.com/audio.ogg", + mime_type="audio/ogg", + ) + ) + + +@pytest.fixture +def kb_ids() -> list[str]: + return ["kb-faq", "kb-product"] + + +def _post(client: TestClient, headers: dict, payload: SpeechToSpeechRequest): + return client.post(URL, json=payload.model_dump(mode="json"), headers=headers) + + +def _chain_request(mock_start): + return mock_start.call_args.kwargs["request"] + + +# ---------- Defaults ---------- + + +class TestDefaults: + def test_three_blocks_always_created( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + response = _post( + client, + user_api_key_header, + SpeechToSpeechRequest(query=audio_input, knowledge_base_ids=kb_ids), + ) + assert response.status_code == 200 + assert len(_chain_request(mock).blocks) == 3 + + def test_default_stt_model(self, client, user_api_key_header, audio_input, kb_ids): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest(query=audio_input, knowledge_base_ids=kb_ids), + ) + stt_params = _chain_request(mock).blocks[0].config.blob.completion.params + assert stt_params["model"] == DEFAULT_STT_MODEL + + def test_default_rag_model_and_temperature( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest(query=audio_input, knowledge_base_ids=kb_ids), + ) + rag_params = _chain_request(mock).blocks[1].config.blob.completion.params + assert rag_params["model"] == DEFAULT_RAG_MODEL + assert rag_params["temperature"] == 0.1 + + def test_default_tts_model_voice_format( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest(query=audio_input, knowledge_base_ids=kb_ids), + ) + tts_params = _chain_request(mock).blocks[2].config.blob.completion.params + assert tts_params["model"] == DEFAULT_TTS_MODEL + assert tts_params["voice"] == DEFAULT_TTS_VOICE + assert tts_params["response_format"] == "ogg" + + def test_rag_block_always_has_knowledge_base_ids( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest(query=audio_input, knowledge_base_ids=kb_ids), + ) + rag_params = _chain_request(mock).blocks[1].config.blob.completion.params + assert rag_params["knowledge_base_ids"] == kb_ids + + def test_stt_and_rag_are_intermediate_tts_is_not( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest(query=audio_input, knowledge_base_ids=kb_ids), + ) + blocks = _chain_request(mock).blocks + assert blocks[0].intermediate_callback is True + assert blocks[1].intermediate_callback is True + assert blocks[2].intermediate_callback is False + + def test_default_stt_input_language_is_auto( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest(query=audio_input, knowledge_base_ids=kb_ids), + ) + stt_params = _chain_request(mock).blocks[0].config.blob.completion.params + assert stt_params["input_language"] == "auto" + + +# ---------- Language resolution ---------- + + +class TestLanguageResolution: + def test_pinned_input_propagates_to_tts_when_output_not_set( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + input_language="hi-IN", + ), + ) + tts_params = _chain_request(mock).blocks[2].config.blob.completion.params + assert tts_params["language"] == "hi-IN" + + def test_explicit_output_language_overrides_input( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + input_language="hi-IN", + output_language="ta-IN", + ), + ) + tts_params = _chain_request(mock).blocks[2].config.blob.completion.params + assert tts_params["language"] == "ta-IN" + + def test_auto_input_with_pinned_output( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + output_language="kn-IN", + ), + ) + stt_params = _chain_request(mock).blocks[0].config.blob.completion.params + tts_params = _chain_request(mock).blocks[2].config.blob.completion.params + assert stt_params["input_language"] == "auto" + assert tts_params["language"] == "kn-IN" + + def test_bcp47_normalisation_lowercase_region( + self, client, user_api_key_header, audio_input, kb_ids + ): + """'hi-in' should be normalised to 'hi-IN' before validation.""" + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + response = _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + input_language="hi-in", + ), + ) + assert response.status_code == 200 + tts_params = _chain_request(mock).blocks[2].config.blob.completion.params + assert tts_params["language"] == "hi-IN" + + def test_bcp47_normalisation_uppercase_language( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + response = _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + input_language="HI-IN", + ), + ) + assert response.status_code == 200 + tts_params = _chain_request(mock).blocks[2].config.blob.completion.params + assert tts_params["language"] == "hi-IN" + + def test_route_always_owns_stt_input_language( + self, client, user_api_key_header, audio_input, kb_ids + ): + """User passing input_language via STTBlockSpec params should be overridden by the route.""" + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + input_language="bn-IN", + stt=STTBlockSpec( + params=STTLLMParams(model="saaras:v3", input_language="hi-IN") + ), + ), + ) + stt_params = _chain_request(mock).blocks[0].config.blob.completion.params + assert stt_params["input_language"] == "bn-IN" + + def test_route_always_owns_tts_language( + self, client, user_api_key_header, audio_input, kb_ids + ): + """User passing language via TTSBlockSpec params should be overridden by the route.""" + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + input_language="te-IN", + tts=TTSBlockSpec( + params=TTSLLMParams(model=DEFAULT_TTS_MODEL, language="hi-IN") + ), + ), + ) + tts_params = _chain_request(mock).blocks[2].config.blob.completion.params + assert tts_params["language"] == "te-IN" + + +# ---------- Provider combos ---------- + + +class TestProviders: + def test_stt_sarvamai_provider( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + stt_provider="sarvamai", + ), + ) + stt_completion = _chain_request(mock).blocks[0].config.blob.completion + assert stt_completion.provider == "sarvamai" + + def test_tts_sarvamai_provider( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + tts_provider="sarvamai", + ), + ) + tts_completion = _chain_request(mock).blocks[2].config.blob.completion + assert tts_completion.provider == "sarvamai" + + def test_tts_elevenlabs_provider( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + tts_provider="elevenlabs", + ), + ) + tts_completion = _chain_request(mock).blocks[2].config.blob.completion + assert tts_completion.provider == "elevenlabs" + + def test_rag_provider_defaults_to_openai( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest(query=audio_input, knowledge_base_ids=kb_ids), + ) + rag_completion = _chain_request(mock).blocks[1].config.blob.completion + assert rag_completion.provider == "openai" + + def test_sarvamai_stt_with_saaras_model( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + stt_provider="sarvamai", + stt=STTBlockSpec(params=STTLLMParams(model="saaras:v3")), + ), + ) + stt_completion = _chain_request(mock).blocks[0].config.blob.completion + assert stt_completion.provider == "sarvamai" + assert stt_completion.params["model"] == "saaras:v3" + + def test_google_stt_with_gemini_model( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + stt_provider="google", + stt=STTBlockSpec(params=STTLLMParams(model="gemini-2.5-pro")), + ), + ) + stt_completion = _chain_request(mock).blocks[0].config.blob.completion + assert stt_completion.provider == "google" + assert stt_completion.params["model"] == "gemini-2.5-pro" + + def test_all_three_providers_set_independently( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + stt_provider="sarvamai", + rag_provider="openai", + tts_provider="elevenlabs", + ), + ) + blocks = _chain_request(mock).blocks + assert blocks[0].config.blob.completion.provider == "sarvamai" + assert blocks[1].config.blob.completion.provider == "openai" + assert blocks[2].config.blob.completion.provider == "elevenlabs" + + +# ---------- Inline param overrides ---------- + + +class TestInlineOverrides: + def test_stt_model_override(self, client, user_api_key_header, audio_input, kb_ids): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + stt=STTBlockSpec(params=STTLLMParams(model="saaras:v3")), + ), + ) + stt_params = _chain_request(mock).blocks[0].config.blob.completion.params + assert stt_params["model"] == "saaras:v3" + + def test_rag_model_and_instructions_override( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + rag=RAGBlockSpec( + params=TextLLMParams( + model="gpt-4o-mini", + instructions="Be brief.", + temperature=0.5, + ) + ), + ), + ) + rag_params = _chain_request(mock).blocks[1].config.blob.completion.params + assert rag_params["model"] == "gpt-4o-mini" + assert rag_params["instructions"] == "Be brief." + assert rag_params["temperature"] == 0.5 + + def test_rag_inline_still_injects_kb_ids( + self, client, user_api_key_header, audio_input, kb_ids + ): + """knowledge_base_ids must be injected even when user provides partial RAG params.""" + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + rag=RAGBlockSpec(params=TextLLMParams(model="gpt-4o-mini")), + ), + ) + rag_params = _chain_request(mock).blocks[1].config.blob.completion.params + assert rag_params["knowledge_base_ids"] == kb_ids + + def test_tts_voice_override(self, client, user_api_key_header, audio_input, kb_ids): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + tts=TTSBlockSpec(params=TTSLLMParams(voice="Orus")), + ), + ) + tts_params = _chain_request(mock).blocks[2].config.blob.completion.params + assert tts_params["voice"] == "Orus" + + def test_tts_format_override( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + tts=TTSBlockSpec(params=TTSLLMParams(response_format="mp3")), + ), + ) + tts_params = _chain_request(mock).blocks[2].config.blob.completion.params + assert tts_params["response_format"] == "mp3" + + +# ---------- Stored config references ---------- + + +class TestStoredConfigRefs: + def test_stored_stt_block(self, client, user_api_key_header, audio_input, kb_ids): + config_id = uuid4() + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + stt=STTBlockSpec(config_id=config_id, config_version=2), + ), + ) + stt_config = _chain_request(mock).blocks[0].config + assert stt_config.id == config_id + assert stt_config.version == 2 + assert stt_config.blob is None + + def test_stored_rag_block(self, client, user_api_key_header, audio_input, kb_ids): + config_id = uuid4() + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + rag=RAGBlockSpec(config_id=config_id, config_version=1), + ), + ) + rag_config = _chain_request(mock).blocks[1].config + assert rag_config.id == config_id + assert rag_config.version == 1 + assert rag_config.blob is None + + def test_stored_tts_block(self, client, user_api_key_header, audio_input, kb_ids): + config_id = uuid4() + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + tts=TTSBlockSpec(config_id=config_id, config_version=3), + ), + ) + tts_config = _chain_request(mock).blocks[2].config + assert tts_config.id == config_id + assert tts_config.version == 3 + assert tts_config.blob is None + + def test_all_blocks_stored(self, client, user_api_key_header, audio_input, kb_ids): + stt_id, rag_id, tts_id = uuid4(), uuid4(), uuid4() + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + stt=STTBlockSpec(config_id=stt_id, config_version=1), + rag=RAGBlockSpec(config_id=rag_id, config_version=1), + tts=TTSBlockSpec(config_id=tts_id, config_version=1), + ), + ) + blocks = _chain_request(mock).blocks + assert blocks[0].config.id == stt_id + assert blocks[1].config.id == rag_id + assert blocks[2].config.id == tts_id + + def test_mixed_stored_and_inline( + self, client, user_api_key_header, audio_input, kb_ids + ): + """STT stored, RAG inline, TTS stored.""" + stt_id, tts_id = uuid4(), uuid4() + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + stt=STTBlockSpec(config_id=stt_id, config_version=1), + rag=RAGBlockSpec(params=TextLLMParams(model="gpt-4o-mini")), + tts=TTSBlockSpec(config_id=tts_id, config_version=2), + ), + ) + blocks = _chain_request(mock).blocks + assert blocks[0].config.id == stt_id + assert blocks[1].config.blob is not None + assert blocks[2].config.id == tts_id + + +# ---------- Metadata ---------- + + +class TestMetadata: + def test_metadata_has_speech_to_speech_flag( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest(query=audio_input, knowledge_base_ids=kb_ids), + ) + meta = _chain_request(mock).request_metadata + assert meta["speech_to_speech"] is True + + def test_metadata_reflects_resolved_languages( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + input_language="mr-IN", + output_language="ta-IN", + ), + ) + meta = _chain_request(mock).request_metadata + assert meta["input_language"] == "mr-IN" + assert meta["output_language"] == "ta-IN" + + def test_metadata_default_model_labels( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest(query=audio_input, knowledge_base_ids=kb_ids), + ) + meta = _chain_request(mock).request_metadata + assert meta["stt_model"] == DEFAULT_STT_MODEL + assert meta["llm_model"] == DEFAULT_RAG_MODEL + assert meta["tts_model"] == DEFAULT_TTS_MODEL + + def test_metadata_inline_model_labels( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + stt=STTBlockSpec(params=STTLLMParams(model="saaras:v3")), + rag=RAGBlockSpec(params=TextLLMParams(model="gpt-4o-mini")), + tts=TTSBlockSpec( + params=TTSLLMParams(model="gemini-2.5-flash-preview-tts") + ), + ), + ) + meta = _chain_request(mock).request_metadata + assert meta["stt_model"] == "saaras:v3" + assert meta["llm_model"] == "gpt-4o-mini" + assert meta["tts_model"] == "gemini-2.5-flash-preview-tts" + + def test_metadata_stored_ref_label_format( + self, client, user_api_key_header, audio_input, kb_ids + ): + config_id = uuid4() + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + stt=STTBlockSpec(config_id=config_id, config_version=4), + ), + ) + meta = _chain_request(mock).request_metadata + assert meta["stt_model"] == f"stored:{config_id}@v4" + + def test_caller_metadata_is_preserved( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + request_metadata={"session_id": "abc123", "user": "test"}, + ), + ) + meta = _chain_request(mock).request_metadata + assert meta["session_id"] == "abc123" + assert meta["user"] == "test" + assert meta["speech_to_speech"] is True + + def test_caller_metadata_cannot_override_sts_keys( + self, client, user_api_key_header, audio_input, kb_ids + ): + """STS-injected keys must win over any caller-provided values.""" + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + input_language="hi-IN", + request_metadata={ + "input_language": "caller-value", + "output_language": "caller-value", + "speech_to_speech": False, + "stt_model": "caller-value", + "llm_model": "caller-value", + "tts_model": "caller-value", + }, + ), + ) + meta = _chain_request(mock).request_metadata + assert meta["input_language"] == "hi-IN" + assert meta["output_language"] == "hi-IN" + assert meta["speech_to_speech"] is True + assert meta["stt_model"] == DEFAULT_STT_MODEL + assert meta["llm_model"] == DEFAULT_RAG_MODEL + assert meta["tts_model"] == DEFAULT_TTS_MODEL + + +# ---------- Error paths ---------- + + +class TestErrorPaths: + def test_invalid_input_language_returns_422( + self, client, user_api_key_header, audio_input, kb_ids + ): + response = _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + input_language="hindi", + ), + ) + assert response.status_code == 422 + assert "input language" in response.json()["error"].lower() + + def test_invalid_output_language_returns_422( + self, client, user_api_key_header, audio_input, kb_ids + ): + response = _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + output_language="english", + ), + ) + assert response.status_code == 422 + assert "output language" in response.json()["error"].lower() + + def test_unknown_as_output_language_returns_422( + self, client, user_api_key_header, audio_input, kb_ids + ): + """'unknown' is a detection sentinel, not a valid TTS target.""" + response = _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + output_language="unknown", + ), + ) + assert response.status_code == 422 + + def test_auto_as_output_language_returns_422( + self, client, user_api_key_header, audio_input, kb_ids + ): + """'auto' cannot be pinned as TTS output — TTS needs a concrete language.""" + response = _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + output_language="auto", + ), + ) + assert response.status_code == 422 + + def test_auto_as_input_language_is_valid( + self, client, user_api_key_header, audio_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job"): + response = _post( + client, + user_api_key_header, + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + input_language="auto", + ), + ) + assert response.status_code == 200 + + def test_explicit_null_input_language_defaults_to_auto( + self, client, user_api_key_header, audio_input, kb_ids + ): + """Sending input_language=null in JSON should still result in STT getting 'auto'.""" + with patch("app.api.routes.llm_sts.start_chain_job") as mock: + response = client.post( + URL, + json={ + "query": audio_input.model_dump(mode="json"), + "knowledge_base_ids": kb_ids, + "input_language": None, + }, + headers=user_api_key_header, + ) + assert response.status_code == 200 + stt_params = _chain_request(mock).blocks[0].config.blob.completion.params + assert stt_params["input_language"] == "auto" + + def test_empty_knowledge_base_ids_rejected(self, audio_input): + with pytest.raises(ValidationError): + SpeechToSpeechRequest(query=audio_input, knowledge_base_ids=[]) + + def test_block_spec_rejects_params_and_config_id_together( + self, audio_input, kb_ids + ): + with pytest.raises(ValidationError): + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + stt=STTBlockSpec( + config_id=uuid4(), + config_version=1, + params=STTLLMParams(model="saaras:v3"), + ), + ) + + def test_block_spec_rejects_config_id_without_version(self, audio_input, kb_ids): + with pytest.raises(ValidationError): + SpeechToSpeechRequest( + query=audio_input, + knowledge_base_ids=kb_ids, + rag=RAGBlockSpec(config_id=uuid4()), + ) + + def test_url_audio_input_accepted( + self, client, user_api_key_header, audio_url_input, kb_ids + ): + with patch("app.api.routes.llm_sts.start_chain_job"): + response = _post( + client, + user_api_key_header, + SpeechToSpeechRequest(query=audio_url_input, knowledge_base_ids=kb_ids), + ) + assert response.status_code == 200