mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
chore: fix formatting and generate client
This commit is contained in:
parent
129a6d700c
commit
e23cce444f
8 changed files with 120 additions and 92 deletions
|
|
@ -53,9 +53,9 @@ router = APIRouter(prefix="/ws")
|
|||
class NonRelayFilterPolicy(Enum):
|
||||
"""What to filter from non-relay ICE candidates. Relay candidates always pass."""
|
||||
|
||||
NONE = "none" # filter nothing — pass all candidates
|
||||
NONE = "none" # filter nothing — pass all candidates
|
||||
PRIVATE = "private" # filter non-relay candidates with private/CGNAT IPs
|
||||
ALL = "all" # filter all non-relay candidates (relay-only mode)
|
||||
ALL = "all" # filter all non-relay candidates (relay-only mode)
|
||||
|
||||
|
||||
def is_local_or_cgnat_ip(ip_str: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -214,9 +214,7 @@ async def _load_text_session_or_404(
|
|||
user: UserModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
if user.selected_organization_id is None:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Organization context is required"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="Organization context is required")
|
||||
text_session = await db_client.get_workflow_run_text_session(
|
||||
run_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Awaitable, Callable, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Literal, Optional, Union
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import (
|
||||
|
|
@ -683,7 +682,11 @@ class PipecatEngine:
|
|||
)
|
||||
return "greeting"
|
||||
|
||||
if generate_if_no_greeting and self.llm is not None and self.context is not None:
|
||||
if (
|
||||
generate_if_no_greeting
|
||||
and self.llm is not None
|
||||
and self.context is not None
|
||||
):
|
||||
logger.debug("Queueing initial LLM generation for node opening")
|
||||
# Queue after the voicemail detector in the live pipeline so the
|
||||
# detector can gate initial generations when needed.
|
||||
|
|
|
|||
|
|
@ -12,14 +12,13 @@ from pipecat.frames.frames import (
|
|||
EndFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMContextFrame,
|
||||
LLMAssistantPushAggregationFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSTextFrame,
|
||||
TTSStoppedFrame,
|
||||
TextFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -174,9 +173,7 @@ class _TextChatCaptureProcessor(FrameProcessor):
|
|||
if isinstance(frame, TTSSpeakFrame):
|
||||
text_frame = TextFrame(frame.text)
|
||||
text_frame.append_to_context = (
|
||||
frame.append_to_context
|
||||
if frame.append_to_context is not None
|
||||
else True
|
||||
frame.append_to_context if frame.append_to_context is not None else True
|
||||
)
|
||||
await self.push_frame(text_frame, direction)
|
||||
await self.push_frame(LLMAssistantPushAggregationFrame(), direction)
|
||||
|
|
@ -196,7 +193,10 @@ class _TextChatCaptureProcessor(FrameProcessor):
|
|||
):
|
||||
self._response_window.note_llm_start()
|
||||
|
||||
if isinstance(frame, LLMFullResponseEndFrame) and direction is FrameDirection.DOWNSTREAM:
|
||||
if (
|
||||
isinstance(frame, LLMFullResponseEndFrame)
|
||||
and direction is FrameDirection.DOWNSTREAM
|
||||
):
|
||||
self._response_window.note_llm_end()
|
||||
await self.push_frame(frame, direction)
|
||||
# Text chat has no TTS/output transport, so mixed text+tool responses
|
||||
|
|
@ -254,9 +254,7 @@ def _merge_usage_info(
|
|||
+ int(value.get("completion_tokens") or 0),
|
||||
"total_tokens": int(current.get("total_tokens") or 0)
|
||||
+ int(value.get("total_tokens") or 0),
|
||||
"cache_read_input_tokens": int(
|
||||
current.get("cache_read_input_tokens") or 0
|
||||
)
|
||||
"cache_read_input_tokens": int(current.get("cache_read_input_tokens") or 0)
|
||||
+ int(value.get("cache_read_input_tokens") or 0),
|
||||
"cache_creation_input_tokens": int(
|
||||
current.get("cache_creation_input_tokens") or 0
|
||||
|
|
@ -271,9 +269,9 @@ def _merge_usage_info(
|
|||
merged_section[key] = float(merged_section.get(key) or 0) + float(value)
|
||||
merged[section] = merged_section
|
||||
|
||||
merged["call_duration_seconds"] = int(merged.get("call_duration_seconds") or 0) + int(
|
||||
delta.get("call_duration_seconds") or 0
|
||||
)
|
||||
merged["call_duration_seconds"] = int(
|
||||
merged.get("call_duration_seconds") or 0
|
||||
) + int(delta.get("call_duration_seconds") or 0)
|
||||
|
||||
return merged
|
||||
|
||||
|
|
@ -331,9 +329,11 @@ async def _wait_for_quiescence(
|
|||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
if response_window.frontier_is_idle and (
|
||||
time.monotonic() - capture_processor.last_activity_at
|
||||
) >= TEXT_CHAT_IDLE_SETTLE_SECONDS:
|
||||
if (
|
||||
response_window.frontier_is_idle
|
||||
and (time.monotonic() - capture_processor.last_activity_at)
|
||||
>= TEXT_CHAT_IDLE_SETTLE_SECONDS
|
||||
):
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
|
@ -514,12 +514,9 @@ async def execute_text_chat_pending_turn(
|
|||
await engine.set_node(target_node_id)
|
||||
|
||||
opening_marker = capture_processor.activity_count
|
||||
opening_expects_llm = (
|
||||
pending_user_message is None
|
||||
and (
|
||||
current_node_id == target_node_id
|
||||
or engine.get_node_greeting(target_node_id) is None
|
||||
)
|
||||
opening_expects_llm = pending_user_message is None and (
|
||||
current_node_id == target_node_id
|
||||
or engine.get_node_greeting(target_node_id) is None
|
||||
)
|
||||
if opening_expects_llm:
|
||||
response_window.note_direct_context_request()
|
||||
|
|
|
|||
|
|
@ -167,9 +167,7 @@ class TestIsLocalOrCgnatIp:
|
|||
|
||||
class TestKeepCandidate:
|
||||
def test_private_relay_candidate_survives_private_policy(self):
|
||||
candidate = (
|
||||
"candidate:111 1 udp 41885439 192.168.1.50 50000 typ relay raddr 0.0.0.0 rport 0"
|
||||
)
|
||||
candidate = "candidate:111 1 udp 41885439 192.168.1.50 50000 typ relay raddr 0.0.0.0 rport 0"
|
||||
assert _keep_candidate(candidate, NonRelayFilterPolicy.PRIVATE) is True
|
||||
|
||||
def test_private_host_candidate_drops_under_private_policy(self):
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.tests import MockLLMService
|
||||
|
||||
from api.db.models import OrganizationModel, UserModel
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION
|
||||
from pipecat.tests import MockLLMService
|
||||
|
||||
|
||||
async def _create_user_and_workflow(
|
||||
|
|
@ -92,17 +92,22 @@ async def test_text_chat_session_creation_executes_initial_assistant_turn(
|
|||
)
|
||||
|
||||
llm = MockLLMService(
|
||||
mock_steps=[MockLLMService.create_text_chunks("Hello from the workflow tester.")],
|
||||
mock_steps=[
|
||||
MockLLMService.create_text_chunks("Hello from the workflow tester.")
|
||||
],
|
||||
chunk_delay=0.001,
|
||||
)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
with patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
return_value=llm,
|
||||
), patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
with (
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
return_value=llm,
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
create_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
|
||||
|
|
@ -179,18 +184,23 @@ async def test_text_chat_message_executes_assistant_turn(
|
|||
llm_responses = [
|
||||
MockLLMService(mock_steps=[], chunk_delay=0.001),
|
||||
MockLLMService(
|
||||
mock_steps=[MockLLMService.create_text_chunks("Hello from the workflow tester.")],
|
||||
mock_steps=[
|
||||
MockLLMService.create_text_chunks("Hello from the workflow tester.")
|
||||
],
|
||||
chunk_delay=0.001,
|
||||
),
|
||||
]
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
with patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
side_effect=llm_responses,
|
||||
), patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
with (
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
side_effect=llm_responses,
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
create_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
|
||||
|
|
@ -295,12 +305,15 @@ async def test_text_chat_executes_deferred_tool_calls_after_text_response(
|
|||
]
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
with patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
side_effect=llm_responses,
|
||||
), patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
with (
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
side_effect=llm_responses,
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
create_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
|
||||
|
|
@ -428,12 +441,15 @@ async def test_text_chat_chains_multiple_follow_up_completions_in_one_turn(
|
|||
]
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
with patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
side_effect=llm_responses,
|
||||
), patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
with (
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
side_effect=llm_responses,
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
create_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
|
||||
|
|
@ -458,11 +474,14 @@ async def test_text_chat_chains_multiple_follow_up_completions_in_one_turn(
|
|||
assert "Moving to agent two." in assistant_text
|
||||
assert "Agent two here." in assistant_text
|
||||
assert payload["checkpoint"]["current_node_id"] == "agent2"
|
||||
assert sum(
|
||||
1
|
||||
for event in payload["session_data"]["turns"][1]["events"]
|
||||
if event["type"] == "tool_call_started"
|
||||
) == 2
|
||||
assert (
|
||||
sum(
|
||||
1
|
||||
for event in payload["session_data"]["turns"][1]["events"]
|
||||
if event["type"] == "tool_call_started"
|
||||
)
|
||||
== 2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -530,12 +549,15 @@ async def test_text_chat_greeting_only_plays_on_fresh_node_entry(
|
|||
]
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
with patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
side_effect=llm_responses,
|
||||
), patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
with (
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
side_effect=llm_responses,
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
create_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
|
||||
|
|
@ -543,7 +565,9 @@ async def test_text_chat_greeting_only_plays_on_fresh_node_entry(
|
|||
)
|
||||
assert create_response.status_code == 200
|
||||
session = create_response.json()
|
||||
opening_text = session["session_data"]["turns"][0]["assistant_message"]["text"]
|
||||
opening_text = session["session_data"]["turns"][0]["assistant_message"][
|
||||
"text"
|
||||
]
|
||||
|
||||
first_message = await client.post(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{session['workflow_run_id']}/messages",
|
||||
|
|
@ -565,9 +589,9 @@ async def test_text_chat_greeting_only_plays_on_fresh_node_entry(
|
|||
assert second_message.status_code == 200
|
||||
|
||||
first_text = first_payload["session_data"]["turns"][1]["assistant_message"]["text"]
|
||||
second_text = second_message.json()["session_data"]["turns"][2]["assistant_message"][
|
||||
"text"
|
||||
]
|
||||
second_text = second_message.json()["session_data"]["turns"][2][
|
||||
"assistant_message"
|
||||
]["text"]
|
||||
|
||||
assert opening_text == "Welcome to the workflow tester."
|
||||
assert "Welcome to the workflow tester." not in first_text
|
||||
|
|
@ -699,12 +723,15 @@ async def test_text_chat_rewind_reuses_checkpoint_snapshot(
|
|||
]
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
with patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
side_effect=llm_responses,
|
||||
), patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
with (
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
side_effect=llm_responses,
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
create_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
|
||||
|
|
@ -825,15 +852,20 @@ async def test_text_chat_session_is_not_accessible_from_another_org(
|
|||
|
||||
async with test_client_factory(owner_user) as owner_client:
|
||||
llm = MockLLMService(
|
||||
mock_steps=[MockLLMService.create_text_chunks("Hello from the workflow tester.")],
|
||||
mock_steps=[
|
||||
MockLLMService.create_text_chunks("Hello from the workflow tester.")
|
||||
],
|
||||
chunk_delay=0.001,
|
||||
)
|
||||
with patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
return_value=llm,
|
||||
), patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
with (
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
return_value=llm,
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
create_response = await owner_client.post(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue