mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
141 lines
5.2 KiB
Python
141 lines
5.2 KiB
Python
|
|
import json
|
|||
|
|
import os
|
|||
|
|
from unittest.mock import AsyncMock, patch
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
from pipecat.services.openai.llm import OpenAILLMContext
|
|||
|
|
|
|||
|
|
from api.services.workflow.dto import ExtractionVariableDTO, VariableType
|
|||
|
|
from api.services.workflow.pipecat_engine_variable_extractor import (
|
|||
|
|
VariableExtractionManager,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DummyLLM:
|
|||
|
|
"""A minimal stub that mimics the parts of an LLM service used by the extractor."""
|
|||
|
|
|
|||
|
|
def __init__(self, streamed_response: str | None = None):
|
|||
|
|
# Optionally provide a pre-defined streaming response for _perform_extraction tests
|
|||
|
|
self._streamed_response = streamed_response or "{}"
|
|||
|
|
self.registered_functions: dict[str, AsyncMock] = {}
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# API used by VariableExtractionManager
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
def register_function(self, name: str, func, cancel_on_interruption=True): # noqa: D401 – simple delegate
|
|||
|
|
self.registered_functions[name] = func
|
|||
|
|
|
|||
|
|
async def get_chat_completions(self, _context, _messages):
|
|||
|
|
"""Return an async generator that yields a single chunk with the full response."""
|
|||
|
|
|
|||
|
|
class _Delta: # noqa: D401 – tiny helper classes for stub response
|
|||
|
|
def __init__(self, content):
|
|||
|
|
self.content = content
|
|||
|
|
|
|||
|
|
class _Choice:
|
|||
|
|
def __init__(self, delta):
|
|||
|
|
self.delta = delta
|
|||
|
|
|
|||
|
|
class _Chunk:
|
|||
|
|
def __init__(self, content):
|
|||
|
|
self.choices = [_Choice(_Delta(content))]
|
|||
|
|
|
|||
|
|
async def _stream():
|
|||
|
|
yield _Chunk(self._streamed_response)
|
|||
|
|
|
|||
|
|
return _stream()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DummyEngine:
|
|||
|
|
"""A bare-bones Engine stub exposing only what the extractor relies on."""
|
|||
|
|
|
|||
|
|
def __init__(self, llm):
|
|||
|
|
self.llm = llm
|
|||
|
|
self.context = OpenAILLMContext()
|
|||
|
|
self._pending_function_calls = 0
|
|||
|
|
# VariableExtractionManager currently updates this private attribute
|
|||
|
|
self._gathered_context: dict = {}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Tests
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_perform_extraction_parses_json_correctly():
|
|||
|
|
"""_perform_extraction should return the parsed JSON from the LLM stream."""
|
|||
|
|
# Set dummy OpenAI API key to prevent initialization errors
|
|||
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|||
|
|
expected_payload = {"name": "Alice", "age": 30}
|
|||
|
|
llm = DummyLLM(json.dumps(expected_payload))
|
|||
|
|
engine = DummyEngine(llm)
|
|||
|
|
manager = VariableExtractionManager(engine)
|
|||
|
|
|
|||
|
|
# Mock the AsyncOpenAI client and its response
|
|||
|
|
mock_response = AsyncMock()
|
|||
|
|
mock_response.choices = [AsyncMock()]
|
|||
|
|
mock_response.choices[0].message = AsyncMock()
|
|||
|
|
mock_response.choices[0].message.content = json.dumps(expected_payload)
|
|||
|
|
|
|||
|
|
mock_client = AsyncMock()
|
|||
|
|
mock_client.chat.completions.create.return_value = mock_response
|
|||
|
|
|
|||
|
|
with patch(
|
|||
|
|
"api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI",
|
|||
|
|
return_value=mock_client,
|
|||
|
|
):
|
|||
|
|
# Minimal set of variables to extract – the prompts themselves are irrelevant here
|
|||
|
|
extraction_variables = [
|
|||
|
|
ExtractionVariableDTO(
|
|||
|
|
name="name", type=VariableType.string, prompt="user name"
|
|||
|
|
),
|
|||
|
|
ExtractionVariableDTO(
|
|||
|
|
name="age", type=VariableType.number, prompt="user age"
|
|||
|
|
),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
result = await manager._perform_extraction(
|
|||
|
|
extraction_variables, parent_ctx=None, extraction_prompt=""
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert result == expected_payload
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_perform_extraction_with_custom_system_prompt():
|
|||
|
|
"""_perform_extraction should use the provided extraction_prompt as system prompt."""
|
|||
|
|
# Set dummy OpenAI API key to prevent initialization errors
|
|||
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|||
|
|
expected_payload = {"color": "blue"}
|
|||
|
|
llm = DummyLLM(json.dumps(expected_payload))
|
|||
|
|
engine = DummyEngine(llm)
|
|||
|
|
manager = VariableExtractionManager(engine)
|
|||
|
|
|
|||
|
|
# Mock the AsyncOpenAI client and its response
|
|||
|
|
mock_response = AsyncMock()
|
|||
|
|
mock_response.choices = [AsyncMock()]
|
|||
|
|
mock_response.choices[0].message = AsyncMock()
|
|||
|
|
mock_response.choices[0].message.content = json.dumps(expected_payload)
|
|||
|
|
|
|||
|
|
mock_client = AsyncMock()
|
|||
|
|
mock_client.chat.completions.create.return_value = mock_response
|
|||
|
|
|
|||
|
|
with patch(
|
|||
|
|
"api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI",
|
|||
|
|
return_value=mock_client,
|
|||
|
|
):
|
|||
|
|
extraction_variables = [
|
|||
|
|
ExtractionVariableDTO(
|
|||
|
|
name="color", type=VariableType.string, prompt="favourite color"
|
|||
|
|
)
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# Call with a custom extraction prompt
|
|||
|
|
custom_prompt = "You are a color extraction specialist."
|
|||
|
|
result = await manager._perform_extraction(
|
|||
|
|
extraction_variables, parent_ctx=None, extraction_prompt=custom_prompt
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert result == expected_payload
|