mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
chore: bump pipecat version and fix tests (#263)
* chore: bump pipecat version and fix tests * chore: add github workflow to run tests * fix: install reqirements.dev.txt in test script * fix: fix api-test action * feat: add integration test * test: add integration tests * test: add test for function call mute strategy
This commit is contained in:
parent
d256c6005c
commit
0e12c41fc7
76 changed files with 1776 additions and 670 deletions
|
|
@ -14,18 +14,10 @@ result in the context when generating the next response.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import (
|
||||
AGENT_SYSTEM_PROMPT,
|
||||
END_CALL_SYSTEM_PROMPT,
|
||||
START_CALL_SYSTEM_PROMPT,
|
||||
)
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -35,75 +27,21 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMAssistantAggregatorParams,
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
|
||||
class ContextCapturingMockLLM(MockLLMService):
|
||||
"""A MockLLMService that captures the context state at each generation.
|
||||
|
||||
This allows us to verify that tool call results are present in the context
|
||||
when the next LLM generation is triggered.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.captured_contexts: List[Dict[str, Any]] = []
|
||||
|
||||
async def _stream_chat_completions_universal_context(self, context):
|
||||
"""Override to capture context state before streaming chunks."""
|
||||
# Deep copy the messages to avoid mutation issues
|
||||
messages_snapshot = []
|
||||
for msg in context.messages:
|
||||
msg_copy = dict(msg)
|
||||
# Copy content to avoid reference issues
|
||||
if "content" in msg_copy:
|
||||
msg_copy["content"] = (
|
||||
str(msg_copy["content"]) if msg_copy["content"] else None
|
||||
)
|
||||
messages_snapshot.append(msg_copy)
|
||||
|
||||
self.captured_contexts.append(
|
||||
{
|
||||
"step": self._current_step,
|
||||
"messages": messages_snapshot,
|
||||
"system_prompt": self._settings.system_instruction,
|
||||
}
|
||||
)
|
||||
|
||||
# Call parent implementation to stream the mock chunks
|
||||
return await super()._stream_chat_completions_universal_context(context)
|
||||
|
||||
def get_context_at_step(self, step: int) -> Dict[str, Any]:
|
||||
"""Get the captured context at a specific step (0-indexed)."""
|
||||
for ctx in self.captured_contexts:
|
||||
if ctx["step"] == step:
|
||||
return ctx
|
||||
return None
|
||||
|
||||
def has_tool_call_result_at_step(self, step: int, function_name: str) -> bool:
|
||||
"""Check if a tool call result for the given function exists in context at step."""
|
||||
ctx = self.get_context_at_step(step)
|
||||
if not ctx:
|
||||
return False
|
||||
|
||||
for msg in ctx["messages"]:
|
||||
# Check for tool/function role messages
|
||||
if msg.get("role") == "tool" and msg.get("name") == function_name:
|
||||
return True
|
||||
# Also check for tool_call_id which indicates a tool response
|
||||
if msg.get("tool_call_id") and function_name in str(msg.get("name", "")):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_system_prompt_at_step(self, step: int) -> str:
|
||||
"""Get the system prompt from settings at a specific step."""
|
||||
ctx = self.get_context_at_step(step)
|
||||
if ctx:
|
||||
return ctx.get("system_prompt") or ""
|
||||
return ""
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import (
|
||||
AGENT_SYSTEM_PROMPT,
|
||||
END_CALL_SYSTEM_PROMPT,
|
||||
START_CALL_SYSTEM_PROMPT,
|
||||
)
|
||||
from pipecat.tests import (
|
||||
ContextCapturingMockLLM,
|
||||
MockLLMService,
|
||||
MockTTSService,
|
||||
)
|
||||
|
||||
|
||||
async def run_pipeline_and_capture_context(
|
||||
|
|
@ -142,7 +80,7 @@ async def run_pipeline_and_capture_context(
|
|||
context = LLMContext()
|
||||
|
||||
# Add assistant context aggregator
|
||||
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
|
||||
assistant_params = LLMAssistantAggregatorParams()
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params
|
||||
)
|
||||
|
|
@ -184,7 +122,7 @@ async def run_pipeline_and_capture_context(
|
|||
|
||||
# Patch DB calls
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||
"api.db:db_client.get_organization_id_by_workflow_run_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue