mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
feat: add end_call tool (#118)
* feat: add end_call tool * chore: remove run_llm=True from properties
This commit is contained in:
parent
e7712474c1
commit
a172db8022
26 changed files with 1274 additions and 716 deletions
|
|
@ -0,0 +1,50 @@
|
|||
"""add end_call tool category
|
||||
|
||||
Revision ID: 493ca2bb001f
|
||||
Revises: b79f19f68157
|
||||
Create Date: 2026-01-14 15:04:48.899778
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from alembic_postgresql_enum import TableReference
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "493ca2bb001f"
|
||||
down_revision: Union[str, None] = "b79f19f68157"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.sync_enum_values(
|
||||
enum_schema="public",
|
||||
enum_name="tool_category",
|
||||
new_values=["http_api", "end_call", "native", "integration"],
|
||||
affected_columns=[
|
||||
TableReference(
|
||||
table_schema="public", table_name="tools", column_name="category"
|
||||
)
|
||||
],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.sync_enum_values(
|
||||
enum_schema="public",
|
||||
enum_name="tool_category",
|
||||
new_values=["http_api", "native", "integration"],
|
||||
affected_columns=[
|
||||
TableReference(
|
||||
table_schema="public", table_name="tools", column_name="category"
|
||||
)
|
||||
],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
"""remove unique tool name constraint
|
||||
|
||||
Revision ID: dcb0a27d98c6
|
||||
Revises: 493ca2bb001f
|
||||
Create Date: 2026-01-14 16:07:01.940879
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "dcb0a27d98c6"
|
||||
down_revision: Union[str, None] = "493ca2bb001f"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(op.f("unique_org_tool_name"), "tools", type_="unique")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_unique_constraint(
|
||||
op.f("unique_org_tool_name"),
|
||||
"tools",
|
||||
["organization_id", "name"],
|
||||
postgresql_nulls_not_distinct=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -889,5 +889,4 @@ class ToolModel(Base):
|
|||
Index("ix_tools_uuid", "tool_uuid"),
|
||||
Index("ix_tools_status", "status"),
|
||||
Index("ix_tools_category", "category"),
|
||||
UniqueConstraint("organization_id", "name", name="unique_org_tool_name"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -86,7 +86,12 @@ class ToolClient(BaseDBClient):
|
|||
)
|
||||
|
||||
if status:
|
||||
query = query.where(ToolModel.status == status)
|
||||
# Support comma-separated status values (e.g., "active,archived")
|
||||
status_list = [s.strip() for s in status.split(",")]
|
||||
if len(status_list) > 1:
|
||||
query = query.where(ToolModel.status.in_(status_list))
|
||||
else:
|
||||
query = query.where(ToolModel.status == status)
|
||||
else:
|
||||
# By default, exclude archived tools
|
||||
query = query.where(ToolModel.status != ToolStatus.ARCHIVED.value)
|
||||
|
|
@ -233,6 +238,44 @@ class ToolClient(BaseDBClient):
|
|||
return True
|
||||
return False
|
||||
|
||||
async def unarchive_tool(
|
||||
self, tool_uuid: str, organization_id: int
|
||||
) -> Optional[ToolModel]:
|
||||
"""Restore an archived tool by setting its status to active.
|
||||
|
||||
Args:
|
||||
tool_uuid: The unique tool UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
The unarchived ToolModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
update(ToolModel)
|
||||
.where(
|
||||
ToolModel.tool_uuid == tool_uuid,
|
||||
ToolModel.organization_id == organization_id,
|
||||
ToolModel.status == ToolStatus.ARCHIVED.value,
|
||||
)
|
||||
.values(
|
||||
status=ToolStatus.ACTIVE.value,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.info(
|
||||
f"Unarchived tool {tool_uuid} for organization {organization_id}"
|
||||
)
|
||||
# Fetch and return the updated tool
|
||||
result = await session.execute(
|
||||
select(ToolModel).where(ToolModel.tool_uuid == tool_uuid)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
return None
|
||||
|
||||
async def validate_tool_uuid(self, tool_uuid: str, organization_id: int) -> bool:
|
||||
"""Check if a tool UUID exists and belongs to the organization.
|
||||
|
||||
|
|
|
|||
|
|
@ -121,9 +121,8 @@ class ToolCategory(Enum):
|
|||
"""Tool category types"""
|
||||
|
||||
HTTP_API = "http_api" # Custom HTTP API calls (implemented)
|
||||
NATIVE = (
|
||||
"native" # Built-in integrations (future: call_transfer, dtmf_input, end_call)
|
||||
)
|
||||
END_CALL = "end_call" # End call tool
|
||||
NATIVE = "native" # Built-in integrations (future: call_transfer, dtmf_input)
|
||||
INTEGRATION = "integration" # Third-party integrations (future: Google Calendar, Salesforce, etc.)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""API routes for managing tools."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -45,14 +45,38 @@ class HttpApiConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
"""Tool definition schema."""
|
||||
class EndCallConfig(BaseModel):
|
||||
"""Configuration for End Call tools."""
|
||||
|
||||
schema_version: int = Field(
|
||||
default=1, description="Schema version for compatibility"
|
||||
messageType: Literal["none", "custom"] = Field(
|
||||
default="none", description="Type of goodbye message"
|
||||
)
|
||||
type: str = Field(description="Tool type (http_api)")
|
||||
config: HttpApiConfig = Field(description="Tool configuration")
|
||||
customMessage: Optional[str] = Field(
|
||||
default=None, description="Custom message to play before ending the call"
|
||||
)
|
||||
|
||||
|
||||
class HttpApiToolDefinition(BaseModel):
|
||||
"""Tool definition for HTTP API tools."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version")
|
||||
type: Literal["http_api"] = Field(description="Tool type")
|
||||
config: HttpApiConfig = Field(description="HTTP API configuration")
|
||||
|
||||
|
||||
class EndCallToolDefinition(BaseModel):
|
||||
"""Tool definition for End Call tools."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version")
|
||||
type: Literal["end_call"] = Field(description="Tool type")
|
||||
config: EndCallConfig = Field(description="End Call configuration")
|
||||
|
||||
|
||||
# Union type for tool definitions - Pydantic will discriminate based on 'type' field
|
||||
ToolDefinition = Annotated[
|
||||
Union[HttpApiToolDefinition, EndCallToolDefinition],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class CreateToolRequest(BaseModel):
|
||||
|
|
@ -140,13 +164,15 @@ def validate_category(category: str) -> None:
|
|||
|
||||
|
||||
def validate_status(status: str) -> None:
|
||||
"""Validate that the status is valid."""
|
||||
"""Validate that the status is valid. Supports comma-separated values."""
|
||||
valid_statuses = [s.value for s in ToolStatus]
|
||||
if status not in valid_statuses:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status '{status}'. Must be one of: {', '.join(valid_statuses)}",
|
||||
)
|
||||
status_list = [s.strip() for s in status.split(",")]
|
||||
for s in status_list:
|
||||
if s not in valid_statuses:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status '{s}'. Must be one of: {', '.join(valid_statuses)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
|
|
@ -205,27 +231,18 @@ async def create_tool(
|
|||
|
||||
validate_category(request.category)
|
||||
|
||||
try:
|
||||
tool = await db_client.create_tool(
|
||||
organization_id=user.selected_organization_id,
|
||||
user_id=user.id,
|
||||
name=request.name,
|
||||
definition=request.definition.model_dump(),
|
||||
category=request.category,
|
||||
description=request.description,
|
||||
icon=request.icon,
|
||||
icon_color=request.icon_color,
|
||||
)
|
||||
tool = await db_client.create_tool(
|
||||
organization_id=user.selected_organization_id,
|
||||
user_id=user.id,
|
||||
name=request.name,
|
||||
definition=request.definition.model_dump(),
|
||||
category=request.category,
|
||||
description=request.description,
|
||||
icon=request.icon,
|
||||
icon_color=request.icon_color,
|
||||
)
|
||||
|
||||
return build_tool_response(tool)
|
||||
|
||||
except Exception as e:
|
||||
if "unique_org_tool_name" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A tool with the name '{request.name}' already exists",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
return build_tool_response(tool)
|
||||
|
||||
|
||||
@router.get("/{tool_uuid}")
|
||||
|
|
@ -281,32 +298,21 @@ async def update_tool(
|
|||
if request.status:
|
||||
validate_status(request.status)
|
||||
|
||||
try:
|
||||
tool = await db_client.update_tool(
|
||||
tool_uuid=tool_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
definition=request.definition.model_dump() if request.definition else None,
|
||||
icon=request.icon,
|
||||
icon_color=request.icon_color,
|
||||
status=request.status,
|
||||
)
|
||||
tool = await db_client.update_tool(
|
||||
tool_uuid=tool_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
definition=request.definition.model_dump() if request.definition else None,
|
||||
icon=request.icon,
|
||||
icon_color=request.icon_color,
|
||||
status=request.status,
|
||||
)
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
|
||||
return build_tool_response(tool, include_created_by=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if "unique_org_tool_name" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A tool with the name '{request.name}' already exists",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
return build_tool_response(tool, include_created_by=True)
|
||||
|
||||
|
||||
@router.delete("/{tool_uuid}")
|
||||
|
|
@ -334,3 +340,30 @@ async def delete_tool(
|
|||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
|
||||
return {"status": "archived", "tool_uuid": tool_uuid}
|
||||
|
||||
|
||||
@router.post("/{tool_uuid}/unarchive")
|
||||
async def unarchive_tool(
|
||||
tool_uuid: str,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> ToolResponse:
|
||||
"""
|
||||
Unarchive a tool (restore from archived state).
|
||||
|
||||
Args:
|
||||
tool_uuid: The UUID of the tool to unarchive
|
||||
|
||||
Returns:
|
||||
The unarchived tool
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
tool = await db_client.unarchive_tool(tool_uuid, user.selected_organization_id)
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
|
||||
return build_tool_response(tool)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import ToolCategory
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
|
|
@ -20,13 +21,17 @@ from api.services.workflow.tools.custom_tool import (
|
|||
tool_to_function_schema,
|
||||
)
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.frames.frames import FunctionCallResultProperties
|
||||
from pipecat.frames.frames import FunctionCallResultProperties, TTSSpeakFrame
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
# End task reason for end call tool
|
||||
END_CALL_TOOL_REASON = "end_call_tool"
|
||||
|
||||
|
||||
class CustomToolManager:
|
||||
"""Manager for custom tool registration and execution.
|
||||
|
||||
|
|
@ -34,14 +39,12 @@ class CustomToolManager:
|
|||
1. Fetching tools from the database based on tool UUIDs
|
||||
2. Converting tools to LLM function schemas
|
||||
3. Registering tool execution handlers with the LLM
|
||||
4. Executing HTTP API tools when invoked by the LLM
|
||||
4. Executing tools when invoked by the LLM
|
||||
"""
|
||||
|
||||
def __init__(self, engine: "PipecatEngine") -> None:
|
||||
self._engine = engine
|
||||
self._organization_id: Optional[int] = None
|
||||
# Cache: maps function_name -> (tool, schema)
|
||||
self._tools_cache: dict[str, tuple[Any, dict]] = {}
|
||||
|
||||
async def get_organization_id(self) -> Optional[int]:
|
||||
"""Get and cache the organization ID from workflow run."""
|
||||
|
|
@ -73,9 +76,6 @@ class CustomToolManager:
|
|||
raw_schema = tool_to_function_schema(tool)
|
||||
function_name = raw_schema["function"]["name"]
|
||||
|
||||
# Cache the tool for later execution
|
||||
self._tools_cache[function_name] = (tool, raw_schema)
|
||||
|
||||
# Convert to FunctionSchema object for compatibility with update_llm_context
|
||||
func_schema = get_function_schema(
|
||||
function_name,
|
||||
|
|
@ -117,9 +117,6 @@ class CustomToolManager:
|
|||
schema = tool_to_function_schema(tool)
|
||||
function_name = schema["function"]["name"]
|
||||
|
||||
# Cache the tool for potential later use
|
||||
self._tools_cache[function_name] = (tool, schema)
|
||||
|
||||
# Create and register the handler
|
||||
handler = self._create_handler(tool, function_name)
|
||||
self._engine.llm.register_function(function_name, handler)
|
||||
|
|
@ -133,7 +130,7 @@ class CustomToolManager:
|
|||
logger.error(f"Failed to register custom tool handlers: {e}")
|
||||
|
||||
def _create_handler(self, tool: Any, function_name: str):
|
||||
"""Create a handler function for a custom tool.
|
||||
"""Create a handler function for a tool based on its category.
|
||||
|
||||
Args:
|
||||
tool: The ToolModel instance
|
||||
|
|
@ -142,17 +139,29 @@ class CustomToolManager:
|
|||
Returns:
|
||||
Async handler function for the tool
|
||||
"""
|
||||
# Run LLM after tool execution to continue conversation
|
||||
properties = FunctionCallResultProperties(run_llm=True)
|
||||
if tool.category == ToolCategory.END_CALL.value:
|
||||
return self._create_end_call_handler(tool, function_name)
|
||||
|
||||
async def custom_tool_handler(
|
||||
return self._create_http_tool_handler(tool, function_name)
|
||||
|
||||
def _create_http_tool_handler(self, tool: Any, function_name: str):
|
||||
"""Create a handler function for an HTTP API tool.
|
||||
|
||||
Args:
|
||||
tool: The ToolModel instance
|
||||
function_name: The function name used by the LLM
|
||||
|
||||
Returns:
|
||||
Async handler function for the HTTP API tool
|
||||
"""
|
||||
|
||||
async def http_tool_handler(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: {function_name}")
|
||||
logger.info(f"HTTP Tool EXECUTED: {function_name}")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
|
||||
try:
|
||||
# Execute the HTTP API tool
|
||||
result = await execute_http_tool(
|
||||
tool=tool,
|
||||
arguments=function_call_params.arguments,
|
||||
|
|
@ -160,30 +169,66 @@ class CustomToolManager:
|
|||
organization_id=self._organization_id,
|
||||
)
|
||||
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Custom tool '{function_name}' execution failed: {e}")
|
||||
logger.error(f"HTTP tool '{function_name}' execution failed: {e}")
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": str(e)},
|
||||
properties=properties,
|
||||
{"status": "error", "error": str(e)}
|
||||
)
|
||||
|
||||
return custom_tool_handler
|
||||
return http_tool_handler
|
||||
|
||||
def get_cached_tool(self, function_name: str) -> Optional[tuple[Any, dict]]:
|
||||
"""Get a cached tool by its function name.
|
||||
def _create_end_call_handler(self, tool: Any, function_name: str):
|
||||
"""Create a handler function for an end call tool.
|
||||
|
||||
Args:
|
||||
tool: The ToolModel instance
|
||||
function_name: The function name used by the LLM
|
||||
|
||||
Returns:
|
||||
Tuple of (tool, schema) if found, None otherwise
|
||||
Async handler function for the end call tool
|
||||
"""
|
||||
return self._tools_cache.get(function_name)
|
||||
# Don't run LLM after end call - we're terminating
|
||||
properties = FunctionCallResultProperties(run_llm=False)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the tools cache."""
|
||||
self._tools_cache.clear()
|
||||
async def end_call_handler(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
logger.info(f"End Call Tool EXECUTED: {function_name}")
|
||||
|
||||
try:
|
||||
# Get the end call configuration
|
||||
config = tool.definition.get("config", {})
|
||||
message_type = config.get("messageType", "none")
|
||||
custom_message = config.get("customMessage", "")
|
||||
|
||||
# Send result callback first
|
||||
await function_call_params.result_callback(
|
||||
{"status": "success", "action": "ending_call"},
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
if message_type == "custom" and custom_message:
|
||||
# Queue the custom message to be spoken
|
||||
logger.info(f"Playing custom goodbye message: {custom_message}")
|
||||
await self._engine.task.queue_frame(TTSSpeakFrame(custom_message))
|
||||
# End the call after the message (not immediately)
|
||||
await self._engine.send_end_task_frame(
|
||||
END_CALL_TOOL_REASON, abort_immediately=False
|
||||
)
|
||||
else:
|
||||
# No message - end call immediately
|
||||
logger.info("Ending call immediately (no goodbye message)")
|
||||
await self._engine.send_end_task_frame(
|
||||
END_CALL_TOOL_REASON, abort_immediately=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"End call tool '{function_name}' execution failed: {e}")
|
||||
# Still try to end the call even if there's an error
|
||||
await self._engine.send_end_task_frame(
|
||||
END_CALL_TOOL_REASON, abort_immediately=True
|
||||
)
|
||||
|
||||
return end_call_handler
|
||||
|
|
|
|||
|
|
@ -113,7 +113,6 @@ class WorkflowGraph:
|
|||
# )
|
||||
|
||||
errors.extend(self._assert_start_node())
|
||||
errors.extend(self._assert_end_node())
|
||||
errors.extend(self._assert_connection_counts())
|
||||
errors.extend(self._assert_global_node())
|
||||
errors.extend(self._assert_node_configs())
|
||||
|
|
@ -174,30 +173,6 @@ class WorkflowGraph:
|
|||
)
|
||||
return errors
|
||||
|
||||
def _assert_end_node(self):
|
||||
errors: list[WorkflowError] = []
|
||||
end_node = [n for n in self.nodes.values() if n.data.is_end]
|
||||
if not end_node:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message="Workflow must have exactly one end node",
|
||||
)
|
||||
)
|
||||
|
||||
elif len(end_node) > 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message="Workflow must have exactly one end node",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _assert_connection_counts(self):
|
||||
errors: list[WorkflowError] = []
|
||||
|
||||
|
|
@ -235,13 +210,13 @@ class WorkflowGraph:
|
|||
)
|
||||
)
|
||||
case NodeType.agentNode:
|
||||
if in_d < 1 or out_d < 1:
|
||||
if in_d < 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.node,
|
||||
id=n.id,
|
||||
field=None,
|
||||
message=f"Worker must have at least 1 incoming and 1 outgoing edge",
|
||||
message=f"Worker must have at least 1 incoming edge",
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ class MockToolModel:
|
|||
tool_uuid: str
|
||||
name: str
|
||||
description: str
|
||||
category: str
|
||||
definition: Dict[str, Any]
|
||||
|
||||
|
||||
|
|
@ -55,6 +56,7 @@ class TestToolToFunctionSchema:
|
|||
tool_uuid="test-uuid-1",
|
||||
name="Get Weather",
|
||||
description="Get current weather for a location",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -97,6 +99,7 @@ class TestToolToFunctionSchema:
|
|||
tool_uuid="test-uuid-2",
|
||||
name="Book Appointment",
|
||||
description="Book an appointment with the service",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -145,6 +148,7 @@ class TestToolToFunctionSchema:
|
|||
tool_uuid="test-uuid-3",
|
||||
name="Get User's Account Info!!!",
|
||||
description="Get account information",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -167,6 +171,7 @@ class TestToolToFunctionSchema:
|
|||
tool_uuid="test-uuid-4",
|
||||
name="Ping Server",
|
||||
description="Check if server is alive",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -188,6 +193,7 @@ class TestToolToFunctionSchema:
|
|||
tool_uuid="test-uuid-5",
|
||||
name="My Tool",
|
||||
description=None,
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -213,6 +219,7 @@ class TestExecuteHttpTool:
|
|||
tool_uuid="test-uuid",
|
||||
name="Create User",
|
||||
description="Create a new user",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -257,6 +264,7 @@ class TestExecuteHttpTool:
|
|||
tool_uuid="test-uuid",
|
||||
name="Search Users",
|
||||
description="Search for users",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -297,6 +305,7 @@ class TestExecuteHttpTool:
|
|||
tool_uuid="test-uuid",
|
||||
name="Delete User",
|
||||
description="Delete a user",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -336,6 +345,7 @@ class TestExecuteHttpTool:
|
|||
tool_uuid="test-uuid",
|
||||
name="Slow API",
|
||||
description="A slow API call",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -368,6 +378,7 @@ class TestExecuteHttpTool:
|
|||
tool_uuid="test-uuid",
|
||||
name="API with Headers",
|
||||
description="API that requires headers",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -406,6 +417,7 @@ class TestExecuteHttpTool:
|
|||
tool_uuid="test-uuid",
|
||||
name="Authenticated API",
|
||||
description="API that requires authentication",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -457,6 +469,7 @@ class TestExecuteHttpTool:
|
|||
tool_uuid="test-uuid",
|
||||
name="API with Credential",
|
||||
description="API with credential configured",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -724,6 +737,7 @@ class TestCustomToolManagerUnit:
|
|||
tool_uuid="uuid-1",
|
||||
name="Test Tool",
|
||||
description="A test tool",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -791,6 +805,7 @@ class TestCustomToolManagerUnit:
|
|||
tool_uuid="uuid-1",
|
||||
name="API Call",
|
||||
description="Make an API call",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
|
|
@ -846,53 +861,6 @@ class TestCustomToolManagerUnit:
|
|||
# Verify result was returned
|
||||
assert result_received["status"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_cache_prevents_duplicate_fetches(self):
|
||||
"""Test that tools are cached after first fetch."""
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
|
||||
mock_engine = Mock()
|
||||
mock_engine._workflow_run_id = 1
|
||||
mock_engine._call_context_vars = {}
|
||||
mock_engine.llm = Mock()
|
||||
mock_engine.llm.register_function = Mock()
|
||||
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
mock_tool = MockToolModel(
|
||||
tool_uuid="uuid-1",
|
||||
name="Cached Tool",
|
||||
description="A tool that should be cached",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {"method": "GET", "url": "https://api.example.com"},
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=[mock_tool])
|
||||
|
||||
# First call should fetch from DB
|
||||
await manager.get_tool_schemas(["uuid-1"])
|
||||
|
||||
# Verify tool is now in cache
|
||||
cached = manager.get_cached_tool("cached_tool")
|
||||
assert cached is not None
|
||||
assert cached[0].tool_uuid == "uuid-1"
|
||||
|
||||
# Clear cache and verify it's empty
|
||||
manager.clear_cache()
|
||||
cached = manager.get_cached_tool("cached_tool")
|
||||
assert cached is None
|
||||
|
||||
|
||||
class TestUpdateLLMContext:
|
||||
"""Tests for update_llm_context function."""
|
||||
|
|
|
|||
|
|
@ -206,31 +206,6 @@ class TestCustomToolManagerContextIntegration:
|
|||
assert "get_current_time" in tool_names
|
||||
assert "get_weather" in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_cached_after_first_fetch(self, mock_engine, sample_tools):
|
||||
"""Test that CustomToolManager caches tools after first fetch."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]])
|
||||
|
||||
# First fetch
|
||||
await manager.get_tool_schemas(["weather-uuid-123"])
|
||||
|
||||
# Verify tool is cached (cache stores raw schema dict, not FunctionSchema)
|
||||
cached = manager.get_cached_tool("get_weather")
|
||||
assert cached is not None
|
||||
tool, raw_schema = cached
|
||||
assert tool.tool_uuid == "weather-uuid-123"
|
||||
assert raw_schema["function"]["name"] == "get_weather"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_preserves_function_call_history(
|
||||
self, mock_engine, sample_tools
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import pytest
|
|||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT, MockTransportProcessor
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
|
|
@ -123,6 +124,7 @@ async def run_pipeline_with_tool_calls(
|
|||
# Small delay to let runner start
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
# Run both concurrently
|
||||
await asyncio.gather(run_pipeline(), initialize_engine())
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import pytest
|
|||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import MockTransportProcessor
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
|
|
@ -128,6 +129,7 @@ async def run_pipeline_with_user_idle(
|
|||
# Small delay to let runner start
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
# Calculate total wait time:
|
||||
# - Initial bot speech
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue