mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
Merge remote-tracking branch 'origin/main' into feat/text-chat
This commit is contained in:
commit
129a6d700c
160 changed files with 9287 additions and 3935 deletions
|
|
@ -6,35 +6,40 @@ FastAPI backend for the Dograh voice AI platform.
|
|||
|
||||
```
|
||||
api/
|
||||
├── app.py # Application entry point, FastAPI setup
|
||||
├── routes/ # API endpoint handlers
|
||||
├── services/ # Business logic and integrations
|
||||
├── services/ # Domain logic, runtime systems, and extension seams
|
||||
├── db/ # Database models and data access
|
||||
├── schemas/ # Pydantic request/response schemas
|
||||
├── tasks/ # Background jobs (ARQ)
|
||||
├── utils/ # Utility functions
|
||||
├── tasks/ # Background jobs and post-call work
|
||||
├── mcp_server/ # MCP surface exposed by the backend
|
||||
├── utils/ # Shared utilities
|
||||
├── alembic/ # Database migrations
|
||||
├── constants.py # Environment variables and constants
|
||||
└── tests/ # Test suite
|
||||
```
|
||||
|
||||
## Where to Find Things
|
||||
|
||||
| Looking for... | Go to... |
|
||||
| ---------------------- | ------------------------------------------------------------------------ |
|
||||
| API endpoints | `routes/` - each file is a router module, aggregated in `routes/main.py` |
|
||||
| Business logic | `services/` - organized by domain (telephony, workflow, campaign, etc.) |
|
||||
| Database models | `db/models.py` |
|
||||
| Database queries | `db/*_client.py` files (repository pattern) |
|
||||
| Request/response types | `schemas/` |
|
||||
| Background tasks | `tasks/` - uses ARQ for async job processing |
|
||||
| Environment config | `constants.py` |
|
||||
| Looking for... | Go to... |
|
||||
| ---------------------------- | ----------------------------------------------------------------------------- |
|
||||
| API endpoints | `routes/` - domain routers mounted under `/api/v1` |
|
||||
| Workflow graph and node data | `services/workflow/` |
|
||||
| Live pipeline runtime | `services/pipecat/` |
|
||||
| Telephony providers/call flow| `services/telephony/` |
|
||||
| Third-party integrations | `services/integrations/` |
|
||||
| Campaign and other domains | `services/` |
|
||||
| Database access | `db/` |
|
||||
| Request/response types | `schemas/` |
|
||||
| Background jobs | `tasks/` |
|
||||
| MCP backend surface | `mcp_server/` |
|
||||
| Tests | `tests/` |
|
||||
|
||||
## API Structure
|
||||
|
||||
- All routes are mounted at `/api/v1` prefix
|
||||
- Routes are organized by domain (workflow, telephony, campaign, user, etc.)
|
||||
- `routes/main.py` aggregates all routers
|
||||
- Routes are organized by domain under `routes/`
|
||||
- Workflow execution spans `services/workflow/`, `services/pipecat/`, and `tasks/`
|
||||
- Telephony is a full subsystem under `services/telephony/`, with provider-specific packages under `services/telephony/providers/`
|
||||
- Integrations extend through `services/integrations/`; package-specific rules should live in that subtree's own `AGENTS.md`
|
||||
|
||||
## Database Migrations
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ RUN pip install --user --no-cache-dir -r requirements.txt && \
|
|||
|
||||
# Copy and install pipecat from local submodule
|
||||
COPY pipecat /tmp/pipecat
|
||||
RUN pip install --user --no-cache-dir '/tmp/pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb]' && \
|
||||
RUN pip install --user --no-cache-dir '/tmp/pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb,mcp]' && \
|
||||
# Swap opencv-python (pulled by pipecat[webrtc]) for opencv-python-headless
|
||||
# to drop X11/Qt dependencies that otherwise require libxcb etc. in runner.
|
||||
pip uninstall -y opencv-python && \
|
||||
|
|
|
|||
64
api/alembic/versions/0a1b2c3d4e5f_add_mcp_in_toolcategory.py
Normal file
64
api/alembic/versions/0a1b2c3d4e5f_add_mcp_in_toolcategory.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""add mcp in ToolCategory
|
||||
|
||||
Revision ID: 0a1b2c3d4e5f
|
||||
Revises: 4c1f1e3e8ef2
|
||||
Create Date: 2026-05-16 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from alembic_postgresql_enum import TableReference
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0a1b2c3d4e5f"
|
||||
down_revision: Union[str, None] = "4c1f1e3e8ef2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.sync_enum_values(
|
||||
enum_schema="public",
|
||||
enum_name="tool_category",
|
||||
new_values=[
|
||||
"http_api",
|
||||
"end_call",
|
||||
"transfer_call",
|
||||
"calculator",
|
||||
"native",
|
||||
"integration",
|
||||
"mcp",
|
||||
],
|
||||
affected_columns=[
|
||||
TableReference(
|
||||
table_schema="public", table_name="tools", column_name="category"
|
||||
)
|
||||
],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.sync_enum_values(
|
||||
enum_schema="public",
|
||||
enum_name="tool_category",
|
||||
new_values=[
|
||||
"http_api",
|
||||
"end_call",
|
||||
"transfer_call",
|
||||
"calculator",
|
||||
"native",
|
||||
"integration",
|
||||
],
|
||||
affected_columns=[
|
||||
TableReference(
|
||||
table_schema="public", table_name="tools", column_name="category"
|
||||
)
|
||||
],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
"""rename integrations organisation_id to organization_id
|
||||
|
||||
Revision ID: 19d2a4b6c8ef
|
||||
Revises: 0a1b2c3d4e5f
|
||||
Create Date: 2026-05-19 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "19d2a4b6c8ef"
|
||||
down_revision: Union[str, None] = "0a1b2c3d4e5f"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"integrations",
|
||||
"organisation_id",
|
||||
new_column_name="organization_id",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"integrations",
|
||||
"organization_id",
|
||||
new_column_name="organisation_id",
|
||||
)
|
||||
|
|
@ -143,11 +143,4 @@ FORCE_TURN_RELAY = os.getenv("FORCE_TURN_RELAY", "false").lower() == "true"
|
|||
OSS_JWT_SECRET = os.getenv("OSS_JWT_SECRET", "change-me-in-production")
|
||||
OSS_JWT_EXPIRY_HOURS = int(os.getenv("OSS_JWT_EXPIRY_HOURS", "720")) # 30 days
|
||||
|
||||
# REMOVE-AFTER 2026-05-15: transitional flag. When True, Telnyx webhook
|
||||
# signature verification is skipped for configs that have no
|
||||
# webhook_public_key set (existing configs predating the field). Set in prod
|
||||
# through 2026-05-15 to give users time to add their key; once removed,
|
||||
# configs without a key will fail signature verification.
|
||||
TELNYX_WEBHOOK_VERIFICATION_OPTIONAL = (
|
||||
os.getenv("TELNYX_WEBHOOK_VERIFICATION_OPTIONAL", "false").lower() == "true"
|
||||
)
|
||||
TUNER_BASE_URL = os.getenv("TUNER_BASE_URL", "https://api.usetuner.ai")
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class IntegrationClient(BaseDBClient):
|
|||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(IntegrationModel).where(
|
||||
IntegrationModel.organisation_id == organization_id
|
||||
IntegrationModel.organization_id == organization_id
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
|
@ -23,7 +23,7 @@ class IntegrationClient(BaseDBClient):
|
|||
self,
|
||||
integration_id: str,
|
||||
provider: str,
|
||||
organisation_id: int,
|
||||
organization_id: int,
|
||||
connection_details: dict,
|
||||
created_by: int = None,
|
||||
is_active: bool = True,
|
||||
|
|
@ -32,7 +32,7 @@ class IntegrationClient(BaseDBClient):
|
|||
async with self.async_session() as session:
|
||||
new_integration = IntegrationModel(
|
||||
integration_id=integration_id,
|
||||
organisation_id=organisation_id,
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
is_active=is_active,
|
||||
provider=provider,
|
||||
|
|
@ -96,7 +96,7 @@ class IntegrationClient(BaseDBClient):
|
|||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(IntegrationModel).where(
|
||||
IntegrationModel.organisation_id == organization_id,
|
||||
IntegrationModel.organization_id == organization_id,
|
||||
IntegrationModel.is_active == True,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -292,8 +292,10 @@ class IntegrationModel(Base):
|
|||
__tablename__ = "integrations"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
integration_id = Column(String, nullable=False, index=True) # Nango Connection ID
|
||||
organisation_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
|
||||
integration_id = Column(
|
||||
String, nullable=False, index=True
|
||||
) # External connection ID
|
||||
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
|
||||
provider = Column(String, nullable=False)
|
||||
created_by = Column(Integer, ForeignKey("users.id"))
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
|
@ -598,8 +600,8 @@ class CampaignModel(Base):
|
|||
)
|
||||
|
||||
# Source configuration
|
||||
source_type = Column(String, nullable=False, default="google-sheet")
|
||||
source_id = Column(String, nullable=False) # Sheet URL
|
||||
source_type = Column(String, nullable=False, default="csv")
|
||||
source_id = Column(String, nullable=False) # CSV file key
|
||||
|
||||
# State management
|
||||
state = Column(
|
||||
|
|
|
|||
|
|
@ -134,6 +134,7 @@ class ToolCategory(Enum):
|
|||
CALCULATOR = "calculator" # Built-in calculator tool
|
||||
NATIVE = "native" # Built-in integrations (future: dtmf_input)
|
||||
INTEGRATION = "integration" # Third-party integrations (future: Google Calendar, Salesforce, etc.)
|
||||
MCP = "mcp" # Customer-provided MCP server exposing a tool catalog
|
||||
|
||||
|
||||
class ToolStatus(Enum):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ async def authenticate_mcp_request() -> UserModel:
|
|||
the `langfuse.user.id` / `langfuse.session.id` attributes make the
|
||||
span filterable in the Langfuse UI.
|
||||
"""
|
||||
headers = get_http_headers()
|
||||
# FastMCP strips Authorization by default unless explicitly included.
|
||||
# Preserve it here so Bearer API keys work for MCP tool invocations.
|
||||
headers = get_http_headers(include={"authorization"})
|
||||
api_key = headers.get("x-api-key")
|
||||
if not api_key:
|
||||
auth = headers.get("authorization", "")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,14 @@ handling, hard constraints). Design-level per-field guidance belongs in
|
|||
each `PropertySpec.llm_hint`; it flows out through `get_node_type` and
|
||||
doesn't need to be repeated here.
|
||||
|
||||
Tool names, parameters, and per-tool `error_code` values are NOT
|
||||
authoritative here — they reach the model dynamically via `tools/list`
|
||||
from each tool's own signature and docstring. Reference tools by bare
|
||||
name and describe orchestration; do not restate signatures (they drift)
|
||||
or re-enumerate error codes (document those on the tool itself).
|
||||
`test_mcp_instructions_drift.py` fails if this guide names a tool that
|
||||
is not registered, or if a tool's error codes aren't in its docstring.
|
||||
|
||||
Extend based on real LLM failures — every bullet below ideally maps to a
|
||||
mistake the system has seen at least once.
|
||||
"""
|
||||
|
|
@ -16,18 +24,23 @@ You build and edit Dograh voice-AI workflows by emitting TypeScript that uses th
|
|||
|
||||
## Call order
|
||||
|
||||
### Reading documentation
|
||||
1. `search_docs` — use first for keyword or acronym lookup when the user is asking how Dograh works or how to configure something.
|
||||
2. `read_doc` — fetch the full page once one result looks likely. Prefer this over reasoning from search summaries alone.
|
||||
3. `list_docs` — use when the user wants to browse a topic area or when search terms are too vague. Call it with no arguments for the top-level sections; returned section paths feed back into `list_docs`, returned page paths feed into `read_doc`.
|
||||
|
||||
### Editing an existing workflow
|
||||
1. `list_workflows` — locate the target workflow.
|
||||
2. `get_workflow_code(workflow_id)` — fetch the current source.
|
||||
3. (optional) `list_node_types` / `get_node_type(name)` — consult before adding or editing a node type whose fields aren't already visible in the current code.
|
||||
2. `get_workflow_code` — fetch the current source for that workflow.
|
||||
3. (optional) `list_node_types` / `get_node_type` — consult before adding or editing a node type whose fields aren't already visible in the current code.
|
||||
4. Mutate the code in place. Preserve existing nodes, edges, and variable names unless the task requires removing or renaming them.
|
||||
5. `save_workflow(workflow_id, code)` — persist as a new draft. The published version is untouched.
|
||||
5. `save_workflow` — persist as a new draft. The published version is untouched.
|
||||
|
||||
### Creating a new workflow
|
||||
1. Create a simple 1-node workflow with only `startCall`. The user can iteratively add complexity by editing it.
|
||||
2. `list_node_types` / `get_node_type(name)` — consult to learn the fields available on the node types you intend to use.
|
||||
2. `list_node_types` / `get_node_type` — consult to learn the fields available on the node types you intend to use.
|
||||
3. Author SDK TypeScript from scratch. The `new Workflow({ name: "..." })` call is required — `name` becomes the workflow's display name.
|
||||
4. `create_workflow(code)` — persists a new workflow as version 1 (published). Returns the new `workflow_id`. For subsequent edits use `save_workflow(workflow_id, code)` (which writes a draft).
|
||||
4. `create_workflow` — persists a new workflow as version 1 (published). Returns the new `workflow_id`. For subsequent edits use `save_workflow` (which writes a draft).
|
||||
|
||||
## Allowed source shape
|
||||
|
||||
|
|
@ -68,14 +81,7 @@ Example:
|
|||
|
||||
## Iterating on errors
|
||||
|
||||
`save_workflow` and `create_workflow` return one of:
|
||||
- `parse_error` — Disallowed construct (see grammar above) or malformed TypeScript.
|
||||
- `validation_error` — Node data failed spec validation (unknown field, missing required, wrong type, bad `options` value).
|
||||
- `graph_validation` — Structural rule broken (missing startCall, unreachable node, edge to/from wrong node type).
|
||||
- `missing_name` — (`create_workflow` only) `new Workflow({ name })` is absent or empty.
|
||||
- `bridge_error` — Internal; retry once, then surface to the user.
|
||||
|
||||
Every error carries `line` and `column`. Fix at that location and resubmit the **complete source** — this tool does not accept patches.
|
||||
A failed `save_workflow` / `create_workflow` returns a result with `saved`/`created` set to false, a machine-readable `error_code`, and a human-readable `error` message — carrying `line` and `column` when the problem is locatable in your source. The full set of `error_code` values and their meanings is documented on each tool (visible in its description). Read the `error` message, fix at the reported location, and resubmit the **complete source** — these tools do not accept patches. If a failure looks internal or transient rather than a problem with your code, retry once before surfacing it to the user.
|
||||
|
||||
## Field conventions
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from fastmcp import FastMCP
|
||||
from mcp.types import ToolAnnotations
|
||||
|
||||
from api.mcp_server.instructions import DOGRAH_MCP_INSTRUCTIONS
|
||||
from api.mcp_server.tools.catalog import (
|
||||
|
|
@ -8,6 +9,7 @@ from api.mcp_server.tools.catalog import (
|
|||
list_tools,
|
||||
)
|
||||
from api.mcp_server.tools.create_workflow import create_workflow
|
||||
from api.mcp_server.tools.docs_search import list_docs, read_doc, search_docs
|
||||
from api.mcp_server.tools.get_workflow_code import get_workflow_code
|
||||
from api.mcp_server.tools.node_types import get_node_type, list_node_types
|
||||
from api.mcp_server.tools.save_workflow import save_workflow
|
||||
|
|
@ -29,3 +31,13 @@ for _tool in (
|
|||
save_workflow,
|
||||
):
|
||||
mcp.tool(_tool)
|
||||
|
||||
_DOCS_TOOL_ANNOTATIONS = ToolAnnotations(
|
||||
readOnlyHint=True,
|
||||
idempotentHint=True,
|
||||
destructiveHint=False,
|
||||
openWorldHint=False,
|
||||
)
|
||||
|
||||
for _tool in (list_docs, read_doc, search_docs):
|
||||
mcp.tool(_tool, annotations=_DOCS_TOOL_ANNOTATIONS)
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ Execution flow mirrors `save_workflow`:
|
|||
4. Persist via `db_client.create_workflow` — workflow row + v1
|
||||
published definition in a single transaction.
|
||||
|
||||
Error codes surfaced to the LLM match `save_workflow`. An additional
|
||||
`missing_name` error is returned when the source omits
|
||||
`new Workflow({ name: "..." })` — the name is required and there is no
|
||||
prior workflow to fall back to.
|
||||
Each failure path returns an `error_code` via `_error_result`. Those
|
||||
codes and their meanings are documented in the `create_workflow`
|
||||
docstring (the description shipped to the LLM via `tools/list`); keep the
|
||||
two in sync — `test_mcp_instructions_drift.py` enforces it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -86,6 +86,22 @@ async def create_workflow(code: str) -> dict[str, Any]:
|
|||
On success the new workflow is published as version 1. Use
|
||||
`save_workflow(workflow_id, code)` for subsequent edits — those go to
|
||||
a draft.
|
||||
|
||||
On failure the result has `created: false`, a machine-readable
|
||||
`error_code`, and a human-readable `error` (with file:line:column
|
||||
where the problem is locatable). Resubmit the full corrected source —
|
||||
patches are not accepted. Possible `error_code` values:
|
||||
- `parse_error` — disallowed construct or malformed TypeScript.
|
||||
- `validation_error` — node data failed spec validation (unknown
|
||||
field, missing required, wrong type, option out of range).
|
||||
- `schema_validation` — wire-format (DTO) rejection; rare.
|
||||
- `graph_validation` — structural rule broken (e.g. no start node,
|
||||
unreachable node, edge to/from the wrong node type).
|
||||
- `missing_name` — `new Workflow({ name })` is absent or empty; the
|
||||
name is required and there is no prior workflow to fall back to.
|
||||
- `trigger_path_conflict` — a trigger node's path is already used by
|
||||
another workflow in this organization; rename it and resubmit.
|
||||
- `bridge_error` — internal/transient; retry once, then surface it.
|
||||
"""
|
||||
user = await authenticate_mcp_request()
|
||||
|
||||
|
|
|
|||
704
api/mcp_server/tools/docs_search.py
Normal file
704
api/mcp_server/tools/docs_search.py
Normal file
|
|
@ -0,0 +1,704 @@
|
|||
"""MCP docs discovery tools over the Mintlify docs tree.
|
||||
|
||||
The docs surface is intentionally split into three steps:
|
||||
|
||||
- ``list_docs`` for lightweight navigation over the published hierarchy
|
||||
- ``search_docs`` for keyword lookup across the visible docs catalog
|
||||
- ``read_doc`` for the full content of one chosen page (or one section)
|
||||
|
||||
The runtime index is derived from ``docs/docs.json`` plus the referenced
|
||||
``.mdx``/``.md`` files. That keeps navigation, ordering, and visibility in
|
||||
sync with the published docs rather than indexing every file under ``docs/``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, replace
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
from api.mcp_server.tracing import traced_tool
|
||||
|
||||
DOCS_SEARCH_MAX_LIMIT = 25
|
||||
DOCS_LIST_MAX_DEPTH = 3
|
||||
_ROOT_SECTION_PATH = "__root__"
|
||||
|
||||
_TOKEN_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||
_FRONTMATTER_RE = re.compile(r"\A---\s*\n(.*?)\n---\s*\n?", re.DOTALL)
|
||||
_HEADING_RE = re.compile(r"^(#{1,6})\s+(.*?)\s*$", re.MULTILINE)
|
||||
_STOPWORDS = {
|
||||
"a",
|
||||
"an",
|
||||
"and",
|
||||
"are",
|
||||
"at",
|
||||
"be",
|
||||
"by",
|
||||
"can",
|
||||
"do",
|
||||
"for",
|
||||
"from",
|
||||
"how",
|
||||
"i",
|
||||
"if",
|
||||
"in",
|
||||
"is",
|
||||
"it",
|
||||
"me",
|
||||
"my",
|
||||
"of",
|
||||
"on",
|
||||
"or",
|
||||
"the",
|
||||
"to",
|
||||
"what",
|
||||
"when",
|
||||
"where",
|
||||
"with",
|
||||
"you",
|
||||
"your",
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocSection:
|
||||
title: str
|
||||
slug: str
|
||||
level: int
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocPage:
|
||||
path: str
|
||||
file_path: str
|
||||
title: str
|
||||
description: str
|
||||
llm_hint: str
|
||||
aliases: tuple[str, ...]
|
||||
breadcrumb: tuple[str, ...]
|
||||
content: str
|
||||
sections: tuple[DocSection, ...]
|
||||
order: int
|
||||
|
||||
def breadcrumb_text(self) -> str:
|
||||
return " > ".join(self.breadcrumb)
|
||||
|
||||
def routing_hint(self) -> str:
|
||||
return self.llm_hint or self.description
|
||||
|
||||
def to_catalog_dict(self, section: DocSection | None = None) -> dict:
|
||||
data = {
|
||||
"kind": "page",
|
||||
"path": self.path,
|
||||
"title": self.title,
|
||||
"breadcrumb": self.breadcrumb_text(),
|
||||
"llm_hint": self.routing_hint(),
|
||||
}
|
||||
if section is not None:
|
||||
data["section_title"] = section.title
|
||||
data["section_slug"] = section.slug
|
||||
return _compact_dict(data)
|
||||
|
||||
def to_read_dict(self, section: DocSection | None = None) -> dict:
|
||||
active_section = section
|
||||
content = self.content
|
||||
if active_section is not None:
|
||||
content = active_section.content
|
||||
|
||||
return _compact_dict(
|
||||
{
|
||||
"path": self.path,
|
||||
"title": self.title,
|
||||
"breadcrumb": self.breadcrumb_text(),
|
||||
"llm_hint": self.routing_hint(),
|
||||
"section_title": active_section.title if active_section else None,
|
||||
"section_slug": active_section.slug if active_section else None,
|
||||
"content": content,
|
||||
"sections": [
|
||||
{"title": sec.title, "slug": sec.slug}
|
||||
for sec in self.sections
|
||||
if sec.title and sec.slug
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NavSection:
|
||||
path: str
|
||||
title: str
|
||||
breadcrumb: tuple[str, ...]
|
||||
children: tuple[tuple[str, str], ...]
|
||||
descendant_page_count: int = 0
|
||||
|
||||
def breadcrumb_text(self) -> str:
|
||||
return " > ".join(self.breadcrumb)
|
||||
|
||||
def to_mcp_dict(self) -> dict:
|
||||
hint = None
|
||||
if self.descendant_page_count:
|
||||
hint = f"Browse {self.descendant_page_count} docs in this section."
|
||||
return _compact_dict(
|
||||
{
|
||||
"kind": "section",
|
||||
"path": self.path,
|
||||
"title": self.title,
|
||||
"breadcrumb": self.breadcrumb_text(),
|
||||
"llm_hint": hint,
|
||||
"has_children": bool(self.children),
|
||||
"child_count": len(self.children),
|
||||
"page_count": self.descendant_page_count,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocsIndex:
|
||||
pages_by_path: dict[str, DocPage]
|
||||
sections_by_path: dict[str, NavSection]
|
||||
|
||||
|
||||
def _compact_dict(data: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
key: value for key, value in data.items() if value not in (None, "", [], (), {})
|
||||
}
|
||||
|
||||
|
||||
def _slugify(value: str) -> str:
|
||||
slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-")
|
||||
return slug or "section"
|
||||
|
||||
|
||||
def _coerce_docs_root(candidate: Path) -> Path | None:
|
||||
candidate = candidate.expanduser().resolve()
|
||||
if (candidate / "docs.json").is_file():
|
||||
return candidate
|
||||
nested = candidate / "docs"
|
||||
if (nested / "docs.json").is_file():
|
||||
return nested
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_docs_root() -> Path | None:
|
||||
"""Return the path to the on-disk docs tree, or None if not found."""
|
||||
override = os.environ.get("DOGRAH_DOCS_PATH")
|
||||
if override:
|
||||
resolved = _coerce_docs_root(Path(override))
|
||||
if resolved is not None:
|
||||
return resolved
|
||||
|
||||
docker_default = _coerce_docs_root(Path("/app/docs"))
|
||||
if docker_default is not None:
|
||||
return docker_default
|
||||
|
||||
for parent in Path(__file__).resolve().parents:
|
||||
resolved = _coerce_docs_root(parent / "docs")
|
||||
if resolved is not None:
|
||||
return resolved
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _split_frontmatter(contents: str) -> tuple[dict[str, Any], str]:
|
||||
match = _FRONTMATTER_RE.match(contents)
|
||||
if not match:
|
||||
return {}, contents
|
||||
try:
|
||||
frontmatter = yaml.safe_load(match.group(1)) or {}
|
||||
except yaml.YAMLError:
|
||||
return {}, contents
|
||||
if not isinstance(frontmatter, dict):
|
||||
frontmatter = {}
|
||||
return frontmatter, contents[match.end() :].lstrip("\n")
|
||||
|
||||
|
||||
def _strip_frontmatter(contents: str) -> str:
|
||||
"""Drop the YAML frontmatter block from a docs page body."""
|
||||
return _split_frontmatter(contents)[1]
|
||||
|
||||
|
||||
def _clean_heading_text(raw: str) -> str:
|
||||
text = re.sub(r"\s*\{#.*\}\s*$", "", raw.strip())
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def _extract_page_title(contents: str, fallback: str) -> str:
|
||||
"""Pull a human-readable title for a docs page."""
|
||||
frontmatter, body = _split_frontmatter(contents)
|
||||
title = frontmatter.get("title")
|
||||
if isinstance(title, str) and title.strip():
|
||||
return title.strip()
|
||||
|
||||
match = _HEADING_RE.search(body)
|
||||
if match:
|
||||
return _clean_heading_text(match.group(2))
|
||||
|
||||
return fallback
|
||||
|
||||
|
||||
def _normalize_text(value: Any) -> str:
|
||||
if isinstance(value, str):
|
||||
return " ".join(value.strip().split())
|
||||
return ""
|
||||
|
||||
|
||||
def _normalize_aliases(value: Any) -> tuple[str, ...]:
|
||||
if isinstance(value, str):
|
||||
aliases = [value]
|
||||
elif isinstance(value, list):
|
||||
aliases = [item for item in value if isinstance(item, str)]
|
||||
else:
|
||||
aliases = []
|
||||
return tuple(alias.strip() for alias in aliases if alias.strip())
|
||||
|
||||
|
||||
def _extract_sections(body: str) -> tuple[DocSection, ...]:
|
||||
matches = list(_HEADING_RE.finditer(body))
|
||||
stripped_body = body.strip()
|
||||
if not matches:
|
||||
if not stripped_body:
|
||||
return ()
|
||||
return (
|
||||
DocSection(
|
||||
title="Overview",
|
||||
slug="overview",
|
||||
level=1,
|
||||
content=stripped_body,
|
||||
),
|
||||
)
|
||||
|
||||
sections: list[DocSection] = []
|
||||
preamble = body[: matches[0].start()].strip()
|
||||
if preamble:
|
||||
sections.append(
|
||||
DocSection(
|
||||
title="Overview",
|
||||
slug="overview",
|
||||
level=1,
|
||||
content=preamble,
|
||||
)
|
||||
)
|
||||
|
||||
for index, match in enumerate(matches):
|
||||
start = match.start()
|
||||
end = matches[index + 1].start() if index + 1 < len(matches) else len(body)
|
||||
title = _clean_heading_text(match.group(2))
|
||||
sections.append(
|
||||
DocSection(
|
||||
title=title or "Section",
|
||||
slug=_slugify(title or "section"),
|
||||
level=len(match.group(1)),
|
||||
content=body[start:end].strip(),
|
||||
)
|
||||
)
|
||||
return tuple(sections)
|
||||
|
||||
|
||||
def _tokenize_text(text: str) -> list[str]:
|
||||
return [
|
||||
token
|
||||
for token in _TOKEN_RE.findall(text.lower())
|
||||
if len(token) >= 2 and token not in _STOPWORDS
|
||||
]
|
||||
|
||||
|
||||
def _tokenize_query(query: str) -> list[str]:
|
||||
"""Split a user query into lowercased keyword terms."""
|
||||
seen: set[str] = set()
|
||||
terms: list[str] = []
|
||||
for token in _TOKEN_RE.findall(query.lower()):
|
||||
if len(token) < 2 or token in _STOPWORDS or token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
terms.append(token)
|
||||
return terms
|
||||
|
||||
|
||||
def _resolve_doc_file(root: Path, route_path: str) -> Path | None:
|
||||
candidates = (
|
||||
root / f"{route_path}.mdx",
|
||||
root / f"{route_path}.md",
|
||||
root / route_path / "index.mdx",
|
||||
root / route_path / "index.md",
|
||||
)
|
||||
for candidate in candidates:
|
||||
if candidate.is_file():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def _build_doc_page(
|
||||
root: Path,
|
||||
route_path: str,
|
||||
*,
|
||||
breadcrumb: tuple[str, ...],
|
||||
order: int,
|
||||
) -> DocPage | None:
|
||||
file_path = _resolve_doc_file(root, route_path)
|
||||
if file_path is None:
|
||||
return None
|
||||
try:
|
||||
contents = file_path.read_text(encoding="utf-8")
|
||||
except (OSError, UnicodeDecodeError):
|
||||
return None
|
||||
|
||||
frontmatter, body = _split_frontmatter(contents)
|
||||
fallback = route_path.rsplit("/", 1)[-1].replace("-", " ").title()
|
||||
title = _extract_page_title(contents, fallback=fallback)
|
||||
description = _normalize_text(frontmatter.get("description"))
|
||||
llm_hint = _normalize_text(frontmatter.get("llm_hint"))
|
||||
aliases = _normalize_aliases(frontmatter.get("aliases"))
|
||||
content = body.strip()
|
||||
|
||||
return DocPage(
|
||||
path=route_path,
|
||||
file_path=file_path.relative_to(root).as_posix(),
|
||||
title=title,
|
||||
description=description,
|
||||
llm_hint=llm_hint,
|
||||
aliases=aliases,
|
||||
breadcrumb=breadcrumb,
|
||||
content=content,
|
||||
sections=_extract_sections(content),
|
||||
order=order,
|
||||
)
|
||||
|
||||
|
||||
def _score_counter(counter: Counter[str], term: str, *, weight: int, cap: int) -> int:
|
||||
return min(counter.get(term, 0), cap) * weight
|
||||
|
||||
|
||||
def _normalized_phrase(text: str) -> str:
|
||||
return " ".join(_tokenize_text(text))
|
||||
|
||||
|
||||
def _score_section(section: DocSection, terms: list[str]) -> int:
|
||||
title_counts = Counter(_tokenize_text(section.title))
|
||||
body_counts = Counter(_tokenize_text(section.content))
|
||||
score = 0
|
||||
matched_terms = 0
|
||||
for term in terms:
|
||||
term_score = _score_counter(
|
||||
title_counts, term, weight=7, cap=2
|
||||
) + _score_counter(body_counts, term, weight=1, cap=4)
|
||||
if term_score:
|
||||
matched_terms += 1
|
||||
score += term_score
|
||||
score += matched_terms * 4
|
||||
|
||||
phrase = " ".join(terms)
|
||||
if phrase and phrase in _normalized_phrase(section.content):
|
||||
score += 6
|
||||
return score
|
||||
|
||||
|
||||
def _score_page(page: DocPage, terms: list[str]) -> tuple[int, DocSection | None]:
|
||||
if not terms:
|
||||
return 0, None
|
||||
|
||||
path_counts = Counter(_tokenize_text(page.path))
|
||||
title_counts = Counter(_tokenize_text(page.title))
|
||||
breadcrumb_counts = Counter(_tokenize_text(" ".join(page.breadcrumb)))
|
||||
hint_counts = Counter(_tokenize_text(page.routing_hint()))
|
||||
alias_counts = Counter(_tokenize_text(" ".join(page.aliases)))
|
||||
|
||||
score = 0
|
||||
matched_terms = 0
|
||||
for term in terms:
|
||||
term_score = (
|
||||
_score_counter(path_counts, term, weight=6, cap=3)
|
||||
+ _score_counter(title_counts, term, weight=10, cap=2)
|
||||
+ _score_counter(breadcrumb_counts, term, weight=4, cap=2)
|
||||
+ _score_counter(hint_counts, term, weight=7, cap=3)
|
||||
+ _score_counter(alias_counts, term, weight=7, cap=3)
|
||||
)
|
||||
if term_score:
|
||||
matched_terms += 1
|
||||
score += term_score
|
||||
|
||||
best_section = None
|
||||
best_section_score = 0
|
||||
for section in page.sections:
|
||||
section_score = _score_section(section, terms)
|
||||
if section_score > best_section_score:
|
||||
best_section = section
|
||||
best_section_score = section_score
|
||||
|
||||
if score == 0 and best_section_score == 0:
|
||||
return 0, None
|
||||
|
||||
score += matched_terms * 8 + best_section_score
|
||||
|
||||
phrase = " ".join(terms)
|
||||
if phrase:
|
||||
if phrase in _normalized_phrase(page.title):
|
||||
score += 12
|
||||
elif phrase in _normalized_phrase(page.routing_hint()):
|
||||
score += 8
|
||||
elif phrase in _normalized_phrase(page.path):
|
||||
score += 8
|
||||
elif best_section is not None and phrase in _normalized_phrase(
|
||||
best_section.content
|
||||
):
|
||||
score += 4
|
||||
|
||||
return score, best_section
|
||||
|
||||
|
||||
def _set_descendant_counts(
|
||||
sections_by_path: dict[str, NavSection],
|
||||
section_path: str,
|
||||
) -> int:
|
||||
section = sections_by_path[section_path]
|
||||
page_count = 0
|
||||
for child_kind, child_path in section.children:
|
||||
if child_kind == "page":
|
||||
page_count += 1
|
||||
else:
|
||||
page_count += _set_descendant_counts(sections_by_path, child_path)
|
||||
sections_by_path[section_path] = replace(section, descendant_page_count=page_count)
|
||||
return page_count
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _docs_index() -> DocsIndex:
|
||||
root = _resolve_docs_root()
|
||||
if root is None:
|
||||
return DocsIndex(pages_by_path={}, sections_by_path={})
|
||||
|
||||
try:
|
||||
docs_config = json.loads((root / "docs.json").read_text(encoding="utf-8"))
|
||||
except (OSError, UnicodeDecodeError, json.JSONDecodeError):
|
||||
return DocsIndex(pages_by_path={}, sections_by_path={})
|
||||
|
||||
pages_by_path: dict[str, DocPage] = {}
|
||||
sections_by_path: dict[str, NavSection] = {}
|
||||
page_order = 0
|
||||
|
||||
def ensure_unique_section_path(base_path: str) -> str:
|
||||
if base_path not in sections_by_path:
|
||||
return base_path
|
||||
suffix = 2
|
||||
while f"{base_path}-{suffix}" in sections_by_path:
|
||||
suffix += 1
|
||||
return f"{base_path}-{suffix}"
|
||||
|
||||
def walk_pages(
|
||||
items: list[Any],
|
||||
*,
|
||||
section_path: str,
|
||||
section_title: str,
|
||||
ancestor_breadcrumb: tuple[str, ...],
|
||||
) -> None:
|
||||
nonlocal page_order
|
||||
children: list[tuple[str, str]] = []
|
||||
page_breadcrumb = ancestor_breadcrumb + (section_title,)
|
||||
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
route_path = item.strip("/")
|
||||
if not route_path:
|
||||
continue
|
||||
if route_path not in pages_by_path:
|
||||
page = _build_doc_page(
|
||||
root,
|
||||
route_path,
|
||||
breadcrumb=page_breadcrumb,
|
||||
order=page_order,
|
||||
)
|
||||
if page is not None:
|
||||
pages_by_path[route_path] = page
|
||||
page_order += 1
|
||||
if route_path in pages_by_path:
|
||||
children.append(("page", route_path))
|
||||
continue
|
||||
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
group_title = str(item.get("group", "")).strip()
|
||||
nested_pages = item.get("pages")
|
||||
if not group_title or not isinstance(nested_pages, list):
|
||||
continue
|
||||
|
||||
child_path = ensure_unique_section_path(
|
||||
f"{section_path}/{_slugify(group_title)}"
|
||||
)
|
||||
walk_pages(
|
||||
nested_pages,
|
||||
section_path=child_path,
|
||||
section_title=group_title,
|
||||
ancestor_breadcrumb=page_breadcrumb,
|
||||
)
|
||||
children.append(("section", child_path))
|
||||
|
||||
sections_by_path[section_path] = NavSection(
|
||||
path=section_path,
|
||||
title=section_title,
|
||||
breadcrumb=ancestor_breadcrumb,
|
||||
children=tuple(children),
|
||||
)
|
||||
|
||||
root_children: list[tuple[str, str]] = []
|
||||
tabs = docs_config.get("navigation", {}).get("tabs", [])
|
||||
for tab in tabs:
|
||||
if not isinstance(tab, dict):
|
||||
continue
|
||||
tab_title = str(tab.get("tab", "")).strip() or "Docs"
|
||||
for group in tab.get("groups", []):
|
||||
if not isinstance(group, dict):
|
||||
continue
|
||||
group_title = str(group.get("group", "")).strip()
|
||||
group_pages = group.get("pages")
|
||||
if not group_title or not isinstance(group_pages, list):
|
||||
continue
|
||||
top_level_path = ensure_unique_section_path(
|
||||
f"{_slugify(tab_title)}/{_slugify(group_title)}"
|
||||
)
|
||||
walk_pages(
|
||||
group_pages,
|
||||
section_path=top_level_path,
|
||||
section_title=group_title,
|
||||
ancestor_breadcrumb=(tab_title,),
|
||||
)
|
||||
root_children.append(("section", top_level_path))
|
||||
|
||||
sections_by_path[_ROOT_SECTION_PATH] = NavSection(
|
||||
path=_ROOT_SECTION_PATH,
|
||||
title="Docs",
|
||||
breadcrumb=(),
|
||||
children=tuple(root_children),
|
||||
)
|
||||
_set_descendant_counts(sections_by_path, _ROOT_SECTION_PATH)
|
||||
|
||||
return DocsIndex(pages_by_path=pages_by_path, sections_by_path=sections_by_path)
|
||||
|
||||
|
||||
def _get_page_or_404(path: str) -> DocPage:
|
||||
page = _docs_index().pages_by_path.get(path.strip("/"))
|
||||
if page is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown docs page: {path!r}")
|
||||
return page
|
||||
|
||||
|
||||
def _find_section(page: DocPage, section: str) -> DocSection | None:
|
||||
target = section.strip().lower()
|
||||
for candidate in page.sections:
|
||||
if candidate.slug.lower() == target or candidate.title.lower() == target:
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def _expand_nav_entries(
|
||||
index: DocsIndex,
|
||||
section_path: str,
|
||||
depth: int,
|
||||
) -> list[dict]:
|
||||
section = index.sections_by_path[section_path]
|
||||
results: list[dict] = []
|
||||
for child_kind, child_path in section.children:
|
||||
if child_kind == "section":
|
||||
child_section = index.sections_by_path[child_path]
|
||||
results.append(child_section.to_mcp_dict())
|
||||
if depth > 1:
|
||||
results.extend(_expand_nav_entries(index, child_path, depth - 1))
|
||||
else:
|
||||
results.append(index.pages_by_path[child_path].to_catalog_dict())
|
||||
return results
|
||||
|
||||
|
||||
@traced_tool
|
||||
async def list_docs(path: str | None = None, depth: int = 1) -> list[dict]:
|
||||
"""Browse the Dograh docs hierarchy before reading a page in full.
|
||||
|
||||
``path`` addresses navigation sections exposed by this tool. Page paths
|
||||
returned by ``search_docs`` and ``read_doc`` are the published docs routes
|
||||
instead, for example ``voice-agent/tools/mcp-tool``.
|
||||
"""
|
||||
await authenticate_mcp_request()
|
||||
|
||||
if depth < 1 or depth > DOCS_LIST_MAX_DEPTH:
|
||||
raise ValueError(f"`depth` must be between 1 and {DOCS_LIST_MAX_DEPTH}.")
|
||||
|
||||
index = _docs_index()
|
||||
if not index.sections_by_path:
|
||||
return []
|
||||
|
||||
if path is None:
|
||||
return _expand_nav_entries(index, _ROOT_SECTION_PATH, depth)
|
||||
|
||||
normalized = path.strip("/")
|
||||
if normalized in index.sections_by_path:
|
||||
return _expand_nav_entries(index, normalized, depth)
|
||||
if normalized in index.pages_by_path:
|
||||
return [index.pages_by_path[normalized].to_catalog_dict()]
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"Unknown docs section: {path!r}")
|
||||
|
||||
|
||||
@traced_tool
|
||||
async def read_doc(path: str, section: str | None = None) -> dict:
|
||||
"""Read one docs page after you have narrowed to a likely match."""
|
||||
await authenticate_mcp_request()
|
||||
|
||||
if not isinstance(path, str) or not path.strip():
|
||||
raise ValueError("`path` must be a non-empty string.")
|
||||
|
||||
page = _get_page_or_404(path)
|
||||
active_section = None
|
||||
if section is not None:
|
||||
active_section = _find_section(page, section)
|
||||
if active_section is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Unknown section {section!r} for docs page {path!r}",
|
||||
)
|
||||
return page.to_read_dict(section=active_section)
|
||||
|
||||
|
||||
@traced_tool
|
||||
async def search_docs(query: str, limit: int = 5) -> list[dict]:
|
||||
"""Search the Dograh documentation and return a lean ranked shortlist.
|
||||
|
||||
Use this first for keyword or acronym lookup. Once the right page looks
|
||||
likely, call ``read_doc(path)`` instead of reasoning from summaries alone.
|
||||
"""
|
||||
await authenticate_mcp_request()
|
||||
|
||||
if not isinstance(query, str) or not query.strip():
|
||||
raise ValueError("`query` must be a non-empty string.")
|
||||
if limit < 1:
|
||||
raise ValueError("`limit` must be at least 1.")
|
||||
|
||||
terms = _tokenize_query(query)
|
||||
if not terms:
|
||||
raise ValueError(
|
||||
"`query` must contain at least one non-stopword alphanumeric term."
|
||||
)
|
||||
|
||||
index = _docs_index()
|
||||
if not index.pages_by_path:
|
||||
return []
|
||||
|
||||
capped_limit = min(limit, DOCS_SEARCH_MAX_LIMIT)
|
||||
ranked: list[tuple[int, int, DocPage, DocSection | None]] = []
|
||||
for page in index.pages_by_path.values():
|
||||
score, best_section = _score_page(page, terms)
|
||||
if score <= 0:
|
||||
continue
|
||||
ranked.append((score, page.order, page, best_section))
|
||||
|
||||
ranked.sort(key=lambda item: (-item[0], item[1], item[2].path))
|
||||
return [
|
||||
page.to_catalog_dict(section=best_section)
|
||||
for _, _, page, best_section in ranked[:capped_limit]
|
||||
]
|
||||
|
|
@ -40,15 +40,17 @@ async def list_node_types() -> dict:
|
|||
|
||||
@traced_tool
|
||||
async def get_node_type(name: str) -> dict:
|
||||
"""Fetch the full schema for a node type, including every property's
|
||||
type, default, conditional visibility rules, and LLM-readable
|
||||
description, plus worked examples.
|
||||
"""Fetch the authoring schema for a node type: each property's name,
|
||||
type, default, requiredness, enum options, validation bounds, and
|
||||
LLM-readable description, plus worked examples and graph constraints.
|
||||
|
||||
Use the property `description` and the `examples` list to understand
|
||||
semantics — types alone are not enough.
|
||||
UI-only metadata (display labels, placeholders, conditional visibility
|
||||
rules, renderer hints) is intentionally omitted — set only the fields
|
||||
you need. Use the property `description`/`llm_hint` and the `examples`
|
||||
list to understand semantics; types alone are not enough.
|
||||
"""
|
||||
await authenticate_mcp_request()
|
||||
spec = get_spec(name)
|
||||
if spec is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown node type: {name!r}")
|
||||
return spec.model_dump(mode="json")
|
||||
return spec.to_mcp_dict()
|
||||
|
|
|
|||
|
|
@ -10,16 +10,12 @@ Execution flow:
|
|||
4. Save as a new draft via `db_client.save_workflow_draft` — the
|
||||
published version stays intact, so edits are rollback-safe.
|
||||
|
||||
Error codes surfaced to the LLM:
|
||||
parse_error — TS parse failed or a disallowed construct was used
|
||||
validation_error — node data failed spec validation (unknown field,
|
||||
missing required, wrong type, option out of range)
|
||||
schema_validation — ReactFlowDTO Pydantic rejection (rare; parser bug)
|
||||
graph_validation — semantic graph rule broken (e.g. no start node)
|
||||
bridge_error — Node subprocess failed before returning JSON
|
||||
|
||||
All LLM-facing errors include file:line:column where available so the
|
||||
LLM can correct its code directly.
|
||||
Each failure path returns an `error_code` via `_error_result`. Those
|
||||
codes and their meanings are documented in the `save_workflow` docstring
|
||||
(the description shipped to the LLM via `tools/list`); keep the two in
|
||||
sync — `test_mcp_instructions_drift.py` enforces it. All LLM-facing
|
||||
errors include file:line:column where available so the LLM can correct
|
||||
its code directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -91,6 +87,18 @@ async def save_workflow(workflow_id: int, code: str) -> dict[str, Any]:
|
|||
|
||||
On success the draft version is saved; the published version is
|
||||
untouched.
|
||||
|
||||
On failure the result has `saved: false`, a machine-readable
|
||||
`error_code`, and a human-readable `error` (with file:line:column
|
||||
where the problem is locatable). Resubmit the full corrected source —
|
||||
patches are not accepted. Possible `error_code` values:
|
||||
- `parse_error` — disallowed construct or malformed TypeScript.
|
||||
- `validation_error` — node data failed spec validation (unknown
|
||||
field, missing required, wrong type, option out of range).
|
||||
- `schema_validation` — wire-format (DTO) rejection; rare.
|
||||
- `graph_validation` — structural rule broken (e.g. no start node,
|
||||
unreachable node, edge to/from the wrong node type).
|
||||
- `bridge_error` — internal/transient; retry once, then surface it.
|
||||
"""
|
||||
user = await authenticate_mcp_request()
|
||||
|
||||
|
|
|
|||
|
|
@ -18,4 +18,5 @@ bcrypt==5.0.0
|
|||
email-validator==2.3.0
|
||||
posthog==7.11.1
|
||||
fastmcp==3.2.4
|
||||
tuner-pipecat-sdk==0.2.0
|
||||
PyNaCl==1.6.2
|
||||
|
|
|
|||
|
|
@ -152,8 +152,8 @@ class CircuitBreakerConfigResponse(BaseModel):
|
|||
class CreateCampaignRequest(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
workflow_id: int
|
||||
source_type: str = Field(..., pattern="^(google-sheet|csv)$")
|
||||
source_id: str # Google Sheet URL or CSV file key
|
||||
source_type: str = Field(..., pattern="^csv$")
|
||||
source_id: str # CSV file key
|
||||
# Optional during the legacy → multi-config migration window. Required in
|
||||
# a follow-up. When omitted, the dispatcher falls back to the org's
|
||||
# default config.
|
||||
|
|
@ -929,8 +929,6 @@ async def get_campaign_source_download_url(
|
|||
user: UserModel = Depends(get_user),
|
||||
) -> CampaignSourceDownloadResponse:
|
||||
"""Get presigned download URL for campaign CSV source file
|
||||
|
||||
Only works for CSV source type. For Google Sheets, use the source_id directly.
|
||||
Validates that the campaign belongs to the user's organization for security.
|
||||
"""
|
||||
# Verify campaign exists and belongs to organization
|
||||
|
|
|
|||
|
|
@ -1,266 +0,0 @@
|
|||
"""
|
||||
Route for 3rd party integrations. Currently being backed by nango.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.integrations.nango import nango_service
|
||||
|
||||
router = APIRouter(prefix="/integration")
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntegrationResponse:
|
||||
id: int
|
||||
integration_id: str
|
||||
organisation_id: int
|
||||
created_by: Optional[int]
|
||||
provider: str
|
||||
is_active: bool
|
||||
created_at: str
|
||||
action: str
|
||||
provider_data: dict
|
||||
|
||||
|
||||
class SessionResponse(TypedDict):
|
||||
session_token: str
|
||||
expires_at: str
|
||||
|
||||
|
||||
class WebhookResponse(TypedDict):
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class UpdateIntegrationRequest(BaseModel):
|
||||
selected_files: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class AccessTokenResponse(BaseModel):
|
||||
access_token: Optional[str]
|
||||
refresh_token: Optional[str]
|
||||
expires_at: Optional[str]
|
||||
connection_id: str
|
||||
|
||||
|
||||
def build_integration_response(integration) -> IntegrationResponse:
|
||||
"""Build a standardized integration response with provider-specific data."""
|
||||
provider_data = {}
|
||||
|
||||
if integration.provider == "google-sheet":
|
||||
# For Google Sheets, include selected_files
|
||||
provider_data["selected_files"] = integration.connection_details.get(
|
||||
"selected_files", []
|
||||
)
|
||||
elif integration.provider == "slack":
|
||||
# For Slack, include channel information
|
||||
channel = integration.connection_details.get("connection_config", {}).get(
|
||||
"incoming_webhook.channel"
|
||||
)
|
||||
if channel:
|
||||
provider_data["channel"] = channel
|
||||
|
||||
return IntegrationResponse(
|
||||
id=integration.id,
|
||||
integration_id=integration.integration_id,
|
||||
organisation_id=integration.organisation_id,
|
||||
created_by=integration.created_by,
|
||||
provider=integration.provider,
|
||||
is_active=integration.is_active,
|
||||
created_at=integration.created_at.isoformat(),
|
||||
action=integration.action,
|
||||
provider_data=provider_data,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def get_integrations(
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> list[IntegrationResponse]:
|
||||
"""
|
||||
Get all integrations for the user's selected organization.
|
||||
|
||||
Returns:
|
||||
List of integrations associated with the user's selected organization
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
integrations = await db_client.get_integrations_by_organization_id(
|
||||
user.selected_organization_id
|
||||
)
|
||||
|
||||
return [build_integration_response(integration) for integration in integrations]
|
||||
|
||||
|
||||
@router.post("/session")
|
||||
async def create_session(
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> SessionResponse:
|
||||
"""
|
||||
Create a Nango session for the user's selected organization.
|
||||
|
||||
Returns:
|
||||
Session token and ID for the created session
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
try:
|
||||
session_data = await nango_service.create_session(
|
||||
user_id=str(user.id), organization_id=user.selected_organization_id
|
||||
)
|
||||
|
||||
return {
|
||||
"session_token": session_data["data"]["token"],
|
||||
"expires_at": session_data["data"]["expires_at"],
|
||||
}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create session: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{integration_id}")
|
||||
async def update_integration(
|
||||
integration_id: int,
|
||||
request: UpdateIntegrationRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> IntegrationResponse:
|
||||
"""
|
||||
Update an integration's selected files (for Google Sheets).
|
||||
|
||||
Args:
|
||||
integration_id: The ID of the integration to update
|
||||
request: The update request containing selected files
|
||||
user: The authenticated user
|
||||
|
||||
Returns:
|
||||
Updated integration details
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
# Get the integration first to verify ownership
|
||||
integrations = await db_client.get_integrations_by_organization_id(
|
||||
user.selected_organization_id
|
||||
)
|
||||
|
||||
integration = next((i for i in integrations if i.id == integration_id), None)
|
||||
if not integration:
|
||||
raise HTTPException(status_code=404, detail="Integration not found")
|
||||
|
||||
# Only allow updating selected_files for google-sheet provider
|
||||
if integration.provider != "google-sheet":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="This endpoint only supports updating Google Sheet integrations",
|
||||
)
|
||||
|
||||
# Update the connection_details with the new selected_files
|
||||
updated_connection_details = integration.connection_details.copy()
|
||||
updated_connection_details["selected_files"] = request.selected_files
|
||||
|
||||
# Update the integration
|
||||
updated_integration = await db_client.update_integration_connection_details(
|
||||
integration_id=integration_id, connection_details=updated_connection_details
|
||||
)
|
||||
|
||||
if not updated_integration:
|
||||
raise HTTPException(status_code=500, detail="Failed to update integration")
|
||||
|
||||
return build_integration_response(updated_integration)
|
||||
|
||||
|
||||
@router.get("/{integration_id}/access-token")
|
||||
async def get_integration_access_token(
|
||||
integration_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> AccessTokenResponse:
|
||||
"""
|
||||
Get the latest access token for an integration from Nango.
|
||||
|
||||
Args:
|
||||
integration_id: The ID of the integration
|
||||
user: The authenticated user
|
||||
|
||||
Returns:
|
||||
Dict containing access token and expiration info
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
# Get the integration to verify ownership and get connection details
|
||||
integrations = await db_client.get_integrations_by_organization_id(
|
||||
user.selected_organization_id
|
||||
)
|
||||
|
||||
integration = next((i for i in integrations if i.id == integration_id), None)
|
||||
if not integration:
|
||||
raise HTTPException(status_code=404, detail="Integration not found")
|
||||
|
||||
try:
|
||||
# Fetch the latest access token from Nango
|
||||
token_data = await nango_service.get_access_token(
|
||||
connection_id=integration.integration_id,
|
||||
provider_config_key=integration.provider,
|
||||
)
|
||||
|
||||
# Extract relevant fields
|
||||
return AccessTokenResponse(
|
||||
access_token=token_data.get("credentials", {}).get("access_token"),
|
||||
refresh_token=token_data.get("credentials", {}).get("refresh_token"),
|
||||
expires_at=token_data.get("credentials", {}).get("expires_at"),
|
||||
connection_id=integration.integration_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get access token: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch access token: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/webhook", include_in_schema=False)
|
||||
async def handle_nango_webhook(
|
||||
request: Request,
|
||||
) -> WebhookResponse:
|
||||
"""
|
||||
Handle Nango integration webhook requests.
|
||||
|
||||
Processes webhook events from Nango when integrations are created/updated
|
||||
and stores the integration details in the database.
|
||||
|
||||
Args:
|
||||
request: The raw FastAPI request object
|
||||
|
||||
Returns:
|
||||
WebhookResponse with status and message
|
||||
"""
|
||||
raw_body = await request.body()
|
||||
|
||||
# Get signature from headers (you may need to adjust the header name)
|
||||
signature = request.headers.get("X-Nango-Signature")
|
||||
|
||||
# Use the nango service to process the webhook
|
||||
result = await nango_service.process_webhook(raw_body, signature)
|
||||
|
||||
return result
|
||||
|
|
@ -6,7 +6,6 @@ from api.routes.agent_stream import router as agent_stream_router
|
|||
from api.routes.auth import router as auth_router
|
||||
from api.routes.campaign import router as campaign_router
|
||||
from api.routes.credentials import router as credentials_router
|
||||
from api.routes.integration import router as integration_router
|
||||
from api.routes.knowledge_base import router as knowledge_base_router
|
||||
from api.routes.node_types import router as node_types_router
|
||||
from api.routes.organization import router as organization_router
|
||||
|
|
@ -27,6 +26,7 @@ from api.routes.workflow import router as workflow_router
|
|||
from api.routes.workflow_embed import router as workflow_embed_router
|
||||
from api.routes.workflow_recording import router as workflow_recording_router
|
||||
from api.routes.workflow_text_chat import router as workflow_text_chat_router
|
||||
from api.services.integrations import all_routers
|
||||
|
||||
router = APIRouter(
|
||||
tags=["main"],
|
||||
|
|
@ -41,7 +41,6 @@ router.include_router(user_router)
|
|||
router.include_router(campaign_router)
|
||||
router.include_router(credentials_router)
|
||||
router.include_router(tool_router)
|
||||
router.include_router(integration_router)
|
||||
router.include_router(organization_router)
|
||||
router.include_router(s3_router)
|
||||
router.include_router(service_keys_router)
|
||||
|
|
@ -59,6 +58,9 @@ router.include_router(auth_router)
|
|||
router.include_router(node_types_router)
|
||||
router.include_router(agent_stream_router)
|
||||
|
||||
for _integration_router in all_routers():
|
||||
router.include_router(_integration_router)
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
"""Public download endpoints for workflow recordings and transcripts.
|
||||
|
||||
These endpoints provide secure, token-based public access to workflow artifacts
|
||||
without requiring authentication. Tokens are generated on-demand when webhooks
|
||||
are executed and included in the webhook payload.
|
||||
without requiring authentication. Tokens are generated on-demand during
|
||||
post-call processing for runs that execute integrations, QA, or campaign
|
||||
reporting.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
"""API routes for managing tools."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from api.db import db_client
|
||||
|
|
@ -13,9 +15,23 @@ from api.enums import PostHogEvent, ToolCategory, ToolStatus
|
|||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.workflow.mcp_tool_session import discover_mcp_tools
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpDefinitionError,
|
||||
validate_mcp_definition,
|
||||
)
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpToolConfig as SharedMcpToolConfig,
|
||||
)
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpToolDefinition as SharedMcpToolDefinition,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/tools")
|
||||
|
||||
McpToolConfig = SharedMcpToolConfig
|
||||
McpToolDefinition = SharedMcpToolDefinition
|
||||
|
||||
|
||||
# Request/Response schemas
|
||||
class ToolParameter(BaseModel):
|
||||
|
|
@ -183,6 +199,7 @@ ToolDefinition = Annotated[
|
|||
EndCallToolDefinition,
|
||||
TransferCallToolDefinition,
|
||||
CalculatorToolDefinition,
|
||||
McpToolDefinition,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
@ -248,6 +265,14 @@ class ToolResponse(BaseModel):
|
|||
from_attributes = True
|
||||
|
||||
|
||||
class McpRefreshResponse(BaseModel):
|
||||
"""Result of re-discovering an MCP server's tool catalog."""
|
||||
|
||||
tool_uuid: str
|
||||
discovered_tools: list = Field(default_factory=list)
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def build_tool_response(tool, include_created_by: bool = False) -> ToolResponse:
|
||||
"""Build a response from a tool model."""
|
||||
created_by = None
|
||||
|
|
@ -336,6 +361,52 @@ async def list_tools(
|
|||
return [build_tool_response(tool) for tool in tools]
|
||||
|
||||
|
||||
async def _fetch_credential(credential_uuid: Optional[str], organization_id: int):
|
||||
"""Best-effort credential lookup for MCP auth. A missing/failed credential
|
||||
degrades to ``None`` (unauthenticated) rather than failing the request."""
|
||||
if not credential_uuid:
|
||||
return None
|
||||
try:
|
||||
return await db_client.get_credential_by_uuid(credential_uuid, organization_id)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"MCP: credential fetch failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _populate_discovered_tools(definition: dict, *, organization_id: int) -> dict:
|
||||
"""Best-effort: for an MCP definition, connect to the server, list its
|
||||
tools, and overwrite ``config.discovered_tools``. Never raises and never
|
||||
blocks tool save — a dead server yields ``discovered_tools: []``. Non-MCP
|
||||
definitions pass through untouched."""
|
||||
if not isinstance(definition, dict) or definition.get("type") != "mcp":
|
||||
return definition
|
||||
try:
|
||||
cfg = validate_mcp_definition(definition)
|
||||
except McpDefinitionError:
|
||||
return definition
|
||||
|
||||
credential = await _fetch_credential(cfg.get("credential_uuid"), organization_id)
|
||||
|
||||
# Run discovery in an isolated asyncio task so an anyio cancel-scope
|
||||
# CancelledError doesn't bleed into the parent task and corrupt the
|
||||
# subsequent DB write. _run() never raises (degrades to []).
|
||||
async def _run() -> list:
|
||||
try:
|
||||
return await discover_mcp_tools(
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
except BaseException as e: # noqa: BLE001
|
||||
logger.warning(f"MCP discovery failed; caching empty list: {e}")
|
||||
return []
|
||||
|
||||
discovered = await asyncio.ensure_future(_run())
|
||||
definition["config"]["discovered_tools"] = discovered
|
||||
return definition
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_tool(
|
||||
request: CreateToolRequest,
|
||||
|
|
@ -357,11 +428,16 @@ async def create_tool(
|
|||
|
||||
validate_category(request.category)
|
||||
|
||||
definition = await _populate_discovered_tools(
|
||||
request.definition.model_dump(),
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
tool = await db_client.create_tool(
|
||||
organization_id=user.selected_organization_id,
|
||||
user_id=user.id,
|
||||
name=request.name,
|
||||
definition=request.definition.model_dump(),
|
||||
definition=definition,
|
||||
category=request.category,
|
||||
description=request.description,
|
||||
icon=request.icon,
|
||||
|
|
@ -410,6 +486,67 @@ async def get_tool(
|
|||
return build_tool_response(tool, include_created_by=True)
|
||||
|
||||
|
||||
@router.post("/{tool_uuid}/mcp/refresh")
|
||||
async def refresh_mcp_tools(
|
||||
tool_uuid: str,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> McpRefreshResponse:
|
||||
"""Re-discover an MCP tool's server catalog and overwrite the cached
|
||||
``definition.config.discovered_tools``. Server down → 200 with error
|
||||
(cache not overwritten on transient failure)."""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
tool = await db_client.get_tool_by_uuid(
|
||||
tool_uuid, user.selected_organization_id, include_archived=True
|
||||
)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
if tool.category != ToolCategory.MCP.value:
|
||||
raise HTTPException(status_code=400, detail="Tool is not an MCP tool")
|
||||
|
||||
try:
|
||||
cfg = validate_mcp_definition(tool.definition)
|
||||
except McpDefinitionError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid MCP definition: {e}")
|
||||
|
||||
credential = await _fetch_credential(
|
||||
cfg.get("credential_uuid"), user.selected_organization_id
|
||||
)
|
||||
|
||||
try:
|
||||
discovered = await discover_mcp_tools(
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"MCP refresh discovery failed: {e}")
|
||||
discovered = []
|
||||
|
||||
if not discovered:
|
||||
error = (
|
||||
f"Could not reach the MCP server at {cfg['url']} "
|
||||
f"(or it exposes no tools). Previously cached list retained."
|
||||
)
|
||||
# Do NOT clobber a previously-good cache with [] on a transient outage.
|
||||
return McpRefreshResponse(tool_uuid=tool_uuid, discovered_tools=[], error=error)
|
||||
|
||||
new_def = dict(tool.definition or {})
|
||||
new_def["config"] = {**new_def.get("config", {}), "discovered_tools": discovered}
|
||||
await db_client.update_tool(
|
||||
tool_uuid=tool_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
definition=new_def,
|
||||
)
|
||||
return McpRefreshResponse(
|
||||
tool_uuid=tool_uuid, discovered_tools=discovered, error=None
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{tool_uuid}")
|
||||
async def update_tool(
|
||||
tool_uuid: str,
|
||||
|
|
@ -434,12 +571,21 @@ async def update_tool(
|
|||
if request.status:
|
||||
validate_status(request.status)
|
||||
|
||||
definition = (
|
||||
await _populate_discovered_tools(
|
||||
request.definition.model_dump(),
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
if request.definition
|
||||
else None
|
||||
)
|
||||
|
||||
tool = await db_client.update_tool(
|
||||
tool_uuid=tool_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
definition=request.definition.model_dump() if request.definition else None,
|
||||
definition=definition,
|
||||
icon=request.icon,
|
||||
icon_color=request.icon_color,
|
||||
status=request.status,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import asyncio
|
|||
import ipaddress
|
||||
import os
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from aiortc import RTCIceServer
|
||||
|
|
@ -49,6 +50,63 @@ from api.services.quota_service import check_dograh_quota
|
|||
router = APIRouter(prefix="/ws")
|
||||
|
||||
|
||||
class NonRelayFilterPolicy(Enum):
|
||||
"""What to filter from non-relay ICE candidates. Relay candidates always pass."""
|
||||
|
||||
NONE = "none" # filter nothing — pass all candidates
|
||||
PRIVATE = "private" # filter non-relay candidates with private/CGNAT IPs
|
||||
ALL = "all" # filter all non-relay candidates (relay-only mode)
|
||||
|
||||
|
||||
def is_local_or_cgnat_ip(ip_str: str) -> bool:
|
||||
"""Return True for RFC1918, loopback, link-local, and CGNAT addresses."""
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
is_cgnat = ip.version == 4 and ip in ipaddress.ip_network("100.64.0.0/10")
|
||||
return ip.is_private or ip.is_loopback or ip.is_link_local or is_cgnat
|
||||
|
||||
|
||||
def resolve_ice_filter_policies(
|
||||
environment: str,
|
||||
force_turn_relay: bool,
|
||||
server_ip: str,
|
||||
) -> tuple[NonRelayFilterPolicy, NonRelayFilterPolicy]:
|
||||
"""Resolve outbound and inbound non-relay filtering for this deployment."""
|
||||
|
||||
private_lan_deployment = (
|
||||
environment != Environment.LOCAL.value and is_local_or_cgnat_ip(server_ip)
|
||||
)
|
||||
|
||||
if force_turn_relay:
|
||||
# Relay-only diagnostics stay explicit. On private LAN deployments we
|
||||
# must still accept inbound private candidates for relay<->host pairs.
|
||||
outbound_policy = NonRelayFilterPolicy.ALL
|
||||
inbound_policy = (
|
||||
NonRelayFilterPolicy.NONE
|
||||
if private_lan_deployment
|
||||
else NonRelayFilterPolicy.PRIVATE
|
||||
)
|
||||
return outbound_policy, inbound_policy
|
||||
|
||||
if environment == Environment.LOCAL.value or private_lan_deployment:
|
||||
return NonRelayFilterPolicy.NONE, NonRelayFilterPolicy.NONE
|
||||
|
||||
# Public remote deployment: drop private-IP host candidates to avoid
|
||||
# coturn denied-peer-ip errors against Docker bridge and LAN interfaces.
|
||||
return NonRelayFilterPolicy.PRIVATE, NonRelayFilterPolicy.PRIVATE
|
||||
|
||||
|
||||
ICE_OUTBOUND_POLICY, ICE_INBOUND_POLICY = resolve_ice_filter_policies(
|
||||
ENVIRONMENT,
|
||||
FORCE_TURN_RELAY,
|
||||
os.getenv("SERVER_IP", ""),
|
||||
)
|
||||
|
||||
|
||||
def is_private_ip_candidate(candidate_str: str) -> bool:
|
||||
"""Check if ICE candidate contains a private IP address or CGNAT IP Address.
|
||||
|
||||
|
|
@ -69,61 +127,58 @@ def is_private_ip_candidate(candidate_str: str) -> bool:
|
|||
if "typ" in parts:
|
||||
typ_index = parts.index("typ")
|
||||
ip_str = parts[typ_index - 2]
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
is_cgnat = ip in ipaddress.ip_network("100.64.0.0/10")
|
||||
return ip.is_private or is_cgnat
|
||||
return is_local_or_cgnat_ip(ip_str)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def filter_outbound_sdp(sdp: str) -> str:
|
||||
"""Strip ICE candidates from an outbound answer SDP based on env config.
|
||||
def _keep_candidate(candidate_str: str, policy: NonRelayFilterPolicy) -> bool:
|
||||
"""Return True if this ICE candidate should be kept under the given policy.
|
||||
|
||||
Two filters apply:
|
||||
|
||||
1. In non-LOCAL environments, drop host candidates with private/CGNAT IPs.
|
||||
aiortc gathers host candidates from every interface on the box, including
|
||||
Docker bridges (172.17.0.1, 172.18.0.1). Advertising those to the browser
|
||||
causes coturn "peer IP X denied" errors when the browser asks TURN to
|
||||
permit them.
|
||||
|
||||
2. When FORCE_TURN_RELAY is set, drop every non-relay candidate so the
|
||||
only path the browser can use is via TURN. Lets you verify TURN
|
||||
connectivity end-to-end — if TURN is broken, the call simply fails.
|
||||
Relay candidates always pass — a relay with a private IP (LAN TURN server)
|
||||
must never be dropped regardless of policy.
|
||||
"""
|
||||
if ENVIRONMENT == Environment.LOCAL.value and not FORCE_TURN_RELAY:
|
||||
if " typ relay" in candidate_str:
|
||||
return True
|
||||
if policy == NonRelayFilterPolicy.NONE:
|
||||
return True
|
||||
if policy == NonRelayFilterPolicy.ALL:
|
||||
return False
|
||||
# PRIVATE: drop non-relay candidates with private/CGNAT IPs
|
||||
return not is_private_ip_candidate(candidate_str)
|
||||
|
||||
|
||||
def filter_outbound_sdp(sdp: str) -> str:
|
||||
"""Strip ICE candidates from an outbound answer SDP based on ICE_OUTBOUND_POLICY."""
|
||||
if ICE_OUTBOUND_POLICY == NonRelayFilterPolicy.NONE:
|
||||
return sdp
|
||||
|
||||
lines = sdp.split("\r\n")
|
||||
filtered: List[str] = []
|
||||
dropped_non_relay = 0
|
||||
dropped = 0
|
||||
kept_relay = 0
|
||||
for line in lines:
|
||||
if line.startswith("a=candidate:"):
|
||||
candidate_str = line[2:]
|
||||
if FORCE_TURN_RELAY and " typ relay" not in candidate_str:
|
||||
dropped_non_relay += 1
|
||||
if not _keep_candidate(candidate_str, ICE_OUTBOUND_POLICY):
|
||||
dropped += 1
|
||||
continue
|
||||
if ENVIRONMENT != Environment.LOCAL.value and is_private_ip_candidate(
|
||||
candidate_str
|
||||
):
|
||||
continue
|
||||
if FORCE_TURN_RELAY:
|
||||
if " typ relay" in candidate_str:
|
||||
kept_relay += 1
|
||||
filtered.append(line)
|
||||
|
||||
if FORCE_TURN_RELAY:
|
||||
if ICE_OUTBOUND_POLICY == NonRelayFilterPolicy.ALL:
|
||||
if kept_relay == 0:
|
||||
logger.warning(
|
||||
"FORCE_TURN_RELAY is on but the answer SDP has no relay candidates "
|
||||
f"(dropped {dropped_non_relay} non-relay). TURN may be unreachable; "
|
||||
f"(dropped {dropped} non-relay). TURN may be unreachable; "
|
||||
"the connection will fail."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"FORCE_TURN_RELAY: kept {kept_relay} relay candidates, "
|
||||
f"dropped {dropped_non_relay} non-relay"
|
||||
f"dropped {dropped} non-relay"
|
||||
)
|
||||
|
||||
return "\r\n".join(filtered)
|
||||
|
|
@ -370,9 +425,7 @@ class SignalingManager:
|
|||
Uses SmallWebRTC's native ICE trickling support via add_ice_candidate().
|
||||
Candidates are parsed using aiortc's candidate_from_sdp() for proper formatting,
|
||||
consistent with SmallWebRTCRequestHandler.handle_patch_request().
|
||||
|
||||
In non-local environments, private IP candidates are filtered out to prevent
|
||||
TURN relay errors when coturn blocks private IP ranges (denied-peer-ip).
|
||||
Candidates are filtered according to ICE_INBOUND_POLICY before being added.
|
||||
"""
|
||||
pc_id = payload.get("pc_id")
|
||||
candidate_data = payload.get("candidate")
|
||||
|
|
@ -389,13 +442,9 @@ class SignalingManager:
|
|||
if candidate_data:
|
||||
candidate_str = candidate_data.get("candidate", "")
|
||||
|
||||
# Filter out private IP candidates in non-local environments
|
||||
# This prevents TURN relay errors when coturn blocks private IP ranges
|
||||
if ENVIRONMENT != Environment.LOCAL.value and is_private_ip_candidate(
|
||||
candidate_str
|
||||
):
|
||||
if not _keep_candidate(candidate_str, ICE_INBOUND_POLICY):
|
||||
logger.debug(
|
||||
f"Skipping private IP candidate in {ENVIRONMENT}: {candidate_str[:50]}..."
|
||||
f"Dropping inbound candidate per policy ({ICE_INBOUND_POLICY.value}): {candidate_str[:50]}..."
|
||||
)
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -129,22 +129,6 @@ async def get_user(
|
|||
return user_model
|
||||
|
||||
|
||||
async def get_user_optional(
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
|
||||
) -> UserModel | None:
|
||||
"""
|
||||
Same as get_user but returns None instead of raising 401 if unauthorized.
|
||||
Useful for endpoints that need to work both with and without auth.
|
||||
"""
|
||||
try:
|
||||
return await get_user(authorization, x_api_key)
|
||||
except HTTPException as e:
|
||||
if e.status_code == 401:
|
||||
return None
|
||||
raise
|
||||
|
||||
|
||||
async def _handle_oss_auth(authorization: str | None) -> UserModel:
|
||||
"""
|
||||
Handle authentication for OSS deployment mode.
|
||||
|
|
|
|||
|
|
@ -183,9 +183,7 @@ class CampaignSourceSyncService(ABC):
|
|||
async def get_source_credentials(
|
||||
self, organization_id: int, source_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Gets OAuth tokens or API credentials via Nango"""
|
||||
# This would be implemented to work with Nango service
|
||||
# For now, returning placeholder
|
||||
"""Gets source credentials when a sync service requires them."""
|
||||
logger.info(
|
||||
f"Getting credentials for org {organization_id}, source {source_type}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,15 +1,12 @@
|
|||
from api.services.campaign.source_sync import CampaignSourceSyncService
|
||||
from api.services.campaign.sources.csv import CSVSyncService
|
||||
from api.services.campaign.sources.google_sheets import GoogleSheetsSyncService
|
||||
|
||||
|
||||
def get_sync_service(source_type: str) -> CampaignSourceSyncService:
|
||||
"""Returns appropriate sync service based on source type"""
|
||||
|
||||
services = {
|
||||
"google-sheet": GoogleSheetsSyncService,
|
||||
"csv": CSVSyncService,
|
||||
# Add more as needed: "hubspot": HubSpotSyncService,
|
||||
}
|
||||
|
||||
service_class = services.get(source_type)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
"""Campaign source sync services"""
|
||||
|
||||
from .google_sheets import GoogleSheetsSyncService
|
||||
|
||||
__all__ = ["GoogleSheetsSyncService"]
|
||||
__all__: list[str] = []
|
||||
|
|
|
|||
|
|
@ -1,224 +0,0 @@
|
|||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.campaign.source_sync import (
|
||||
CampaignSourceSyncService,
|
||||
ValidationError,
|
||||
ValidationResult,
|
||||
)
|
||||
from api.services.integrations.nango import NangoService
|
||||
|
||||
|
||||
class GoogleSheetsSyncService(CampaignSourceSyncService):
|
||||
"""Implementation for Google Sheets synchronization"""
|
||||
|
||||
def __init__(self):
|
||||
self.nango_service = NangoService()
|
||||
self.sheets_api_base = "https://sheets.googleapis.com/v4/spreadsheets"
|
||||
|
||||
async def _get_access_token(self, organization_id: int) -> str:
|
||||
"""Get OAuth access token for Google Sheets via Nango."""
|
||||
integrations = await db_client.get_integrations_by_organization_id(
|
||||
organization_id
|
||||
)
|
||||
integration = None
|
||||
for intg in integrations:
|
||||
if intg.provider == "google-sheet" and intg.is_active:
|
||||
integration = intg
|
||||
break
|
||||
|
||||
if not integration:
|
||||
raise ValueError("Google Sheets integration not found or inactive")
|
||||
|
||||
token_data = await self.nango_service.get_access_token(
|
||||
connection_id=integration.integration_id, provider_config_key="google-sheet"
|
||||
)
|
||||
return token_data["credentials"]["access_token"]
|
||||
|
||||
async def _fetch_all_sheet_data(
|
||||
self, sheet_url: str, organization_id: int
|
||||
) -> List[List[str]]:
|
||||
"""Fetch all data from a Google Sheet. Returns all rows including header."""
|
||||
access_token = await self._get_access_token(organization_id)
|
||||
sheet_id = self._extract_sheet_id(sheet_url)
|
||||
|
||||
metadata = await self._get_sheet_metadata(sheet_id, access_token)
|
||||
if not metadata.get("sheets"):
|
||||
raise ValueError("No sheets found in the spreadsheet")
|
||||
|
||||
sheet_name = metadata["sheets"][0]["properties"]["title"]
|
||||
|
||||
return await self._fetch_sheet_data(sheet_id, f"{sheet_name}!A:Z", access_token)
|
||||
|
||||
async def validate_source(
|
||||
self, source_id: str, organization_id: Optional[int] = None
|
||||
) -> ValidationResult:
|
||||
"""Validate a Google Sheet source for campaign creation."""
|
||||
if organization_id is None:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error=ValidationError(
|
||||
message="Organization ID is required for Google Sheets validation"
|
||||
),
|
||||
)
|
||||
|
||||
# Validate URL format first
|
||||
pattern = r"/spreadsheets/d/([a-zA-Z0-9-_]+)"
|
||||
if not re.search(pattern, source_id):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error=ValidationError(
|
||||
message=f"Invalid Google Sheets URL: {source_id}"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
rows = await self._fetch_all_sheet_data(source_id, organization_id)
|
||||
except ValueError as e:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error=ValidationError(message=str(e)),
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error fetching Google Sheet: {e.response.status_code}")
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error=ValidationError(
|
||||
message=f"Failed to fetch Google Sheet data: {e.response.status_code}"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Google Sheet: {e}")
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error=ValidationError(message="Failed to fetch Google Sheet data"),
|
||||
)
|
||||
|
||||
if not rows or len(rows) < 2:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error=ValidationError(
|
||||
message="Google Sheet must have a header row and at least one data row"
|
||||
),
|
||||
)
|
||||
|
||||
headers = rows[0]
|
||||
data_rows = rows[1:]
|
||||
|
||||
return self.validate_source_data(headers, data_rows)
|
||||
|
||||
async def sync_source_data(self, campaign_id: int) -> int:
|
||||
"""
|
||||
Fetches data from Google Sheets and creates queued_runs
|
||||
"""
|
||||
# Get campaign
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
rows = await self._fetch_all_sheet_data(
|
||||
campaign.source_id, campaign.organization_id
|
||||
)
|
||||
|
||||
if not rows or len(rows) < 2:
|
||||
logger.warning(f"No data found in sheet for campaign {campaign_id}")
|
||||
return 0
|
||||
|
||||
headers = self.normalize_headers(rows[0])
|
||||
data_rows = rows[1:]
|
||||
|
||||
sheet_id = self._extract_sheet_id(campaign.source_id)
|
||||
|
||||
queued_runs = []
|
||||
for idx, row_values in enumerate(data_rows, 1):
|
||||
# Pad row to match headers length
|
||||
padded_row = row_values + [""] * (len(headers) - len(row_values))
|
||||
|
||||
# Create context variables dict
|
||||
context_vars = dict(zip(headers, padded_row))
|
||||
|
||||
# Skip if no phone number
|
||||
if not context_vars.get("phone_number"):
|
||||
logger.debug(f"Skipping row {idx}: no phone_number")
|
||||
continue
|
||||
|
||||
# Generate unique source UUID
|
||||
source_uuid = f"sheet_{sheet_id}_row_{idx}"
|
||||
|
||||
queued_runs.append(
|
||||
{
|
||||
"campaign_id": campaign_id,
|
||||
"source_uuid": source_uuid,
|
||||
"context_variables": context_vars,
|
||||
"state": "queued",
|
||||
}
|
||||
)
|
||||
|
||||
# Bulk insert
|
||||
if queued_runs:
|
||||
await db_client.bulk_create_queued_runs(queued_runs)
|
||||
logger.info(
|
||||
f"Created {len(queued_runs)} queued runs for campaign {campaign_id}"
|
||||
)
|
||||
|
||||
# Update campaign total_rows
|
||||
await db_client.update_campaign(
|
||||
campaign_id=campaign_id,
|
||||
total_rows=len(queued_runs),
|
||||
source_sync_status="completed",
|
||||
)
|
||||
|
||||
return len(queued_runs)
|
||||
|
||||
async def _fetch_sheet_data(
|
||||
self, sheet_id: str, range: str, access_token: str
|
||||
) -> List[List[str]]:
|
||||
"""Fetch data from Google Sheets API"""
|
||||
url = f"{self.sheets_api_base}/{sheet_id}/values/{range}"
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return data.get("values", [])
|
||||
|
||||
async def _get_sheet_metadata(
|
||||
self, sheet_id: str, access_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get sheet metadata including sheet names"""
|
||||
url = f"{self.sheets_api_base}/{sheet_id}"
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
logger.debug(f"Fetching sheet metadata from URL: {url}")
|
||||
logger.debug(f"Using sheet_id: {sheet_id}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error {e.response.status_code} for URL: {url}")
|
||||
logger.error(f"Response body: {e.response.text}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching sheet metadata: {e}")
|
||||
raise
|
||||
|
||||
def _extract_sheet_id(self, sheet_url: str) -> str:
|
||||
"""
|
||||
Extract sheet ID from various Google Sheets URL formats:
|
||||
- https://docs.google.com/spreadsheets/d/{id}/edit
|
||||
- https://docs.google.com/spreadsheets/d/{id}/edit#gid=0
|
||||
"""
|
||||
pattern = r"/spreadsheets/d/([a-zA-Z0-9-_]+)"
|
||||
match = re.search(pattern, sheet_url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
raise ValueError(f"Invalid Google Sheets URL: {sheet_url}")
|
||||
|
|
@ -13,6 +13,7 @@ from typing import Any, Dict, Optional
|
|||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.registry import ServiceConfig
|
||||
from api.services.integrations import get_node_secret_fields
|
||||
|
||||
VISIBLE_CHARS = 4 # number of trailing characters to reveal
|
||||
MASK_CHAR = "*"
|
||||
|
|
@ -129,14 +130,22 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow definition helpers – mask / merge QA-node API keys
|
||||
# Workflow definition helpers – mask / merge node API keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_QA_API_KEY_FIELD = "qa_api_key"
|
||||
_NODE_SECRET_FIELDS: dict[str, tuple[str, ...]] = {
|
||||
"qa": ("qa_api_key",),
|
||||
}
|
||||
|
||||
|
||||
def _secret_fields_for_node_type(node_type: str | None) -> tuple[str, ...]:
|
||||
if not node_type:
|
||||
return ()
|
||||
return _NODE_SECRET_FIELDS.get(node_type, ()) or get_node_secret_fields(node_type)
|
||||
|
||||
|
||||
def mask_workflow_definition(workflow_definition: Optional[Dict]) -> Optional[Dict]:
|
||||
"""Return a *shallow copy* of *workflow_definition* with QA-node API keys masked."""
|
||||
"""Return a copy of *workflow_definition* with node secret fields masked."""
|
||||
if not workflow_definition:
|
||||
return workflow_definition
|
||||
|
||||
|
|
@ -144,47 +153,46 @@ def mask_workflow_definition(workflow_definition: Optional[Dict]) -> Optional[Di
|
|||
|
||||
masked = copy.deepcopy(workflow_definition)
|
||||
for node in masked.get("nodes", []):
|
||||
if node.get("type") != "qa":
|
||||
secret_fields = _secret_fields_for_node_type(node.get("type"))
|
||||
if not secret_fields:
|
||||
continue
|
||||
data = node.get("data", {})
|
||||
raw_key = data.get(_QA_API_KEY_FIELD)
|
||||
if raw_key:
|
||||
data[_QA_API_KEY_FIELD] = mask_key(raw_key)
|
||||
for field in secret_fields:
|
||||
raw_key = data.get(field)
|
||||
if raw_key:
|
||||
data[field] = mask_key(raw_key)
|
||||
return masked
|
||||
|
||||
|
||||
def merge_workflow_api_keys(
|
||||
incoming_definition: Optional[Dict], existing_definition: Optional[Dict]
|
||||
) -> Optional[Dict]:
|
||||
"""Preserve real QA-node API keys when the incoming value is a masked placeholder.
|
||||
|
||||
For each QA node in *incoming_definition*, if its ``qa_api_key`` equals
|
||||
the masked form of the corresponding node in *existing_definition*, the
|
||||
real key is restored so it is never lost.
|
||||
"""
|
||||
"""Preserve real node secret fields when the incoming value is masked."""
|
||||
if not incoming_definition or not existing_definition:
|
||||
return incoming_definition
|
||||
|
||||
# Build lookup: node-id → data for existing QA nodes
|
||||
existing_qa: Dict[str, Dict] = {}
|
||||
existing_nodes: Dict[str, Dict] = {}
|
||||
for node in existing_definition.get("nodes", []):
|
||||
if node.get("type") == "qa":
|
||||
existing_qa[node["id"]] = node.get("data", {})
|
||||
if _secret_fields_for_node_type(node.get("type")):
|
||||
existing_nodes[node["id"]] = node.get("data", {})
|
||||
|
||||
for node in incoming_definition.get("nodes", []):
|
||||
if node.get("type") != "qa":
|
||||
secret_fields = _secret_fields_for_node_type(node.get("type"))
|
||||
if not secret_fields:
|
||||
continue
|
||||
data = node.get("data", {})
|
||||
incoming_key = data.get(_QA_API_KEY_FIELD)
|
||||
if not incoming_key:
|
||||
continue
|
||||
|
||||
old_data = existing_qa.get(node["id"])
|
||||
old_data = existing_nodes.get(node["id"])
|
||||
if not old_data:
|
||||
continue
|
||||
|
||||
old_key = old_data.get(_QA_API_KEY_FIELD, "")
|
||||
if old_key and is_mask_of(incoming_key, old_key):
|
||||
data[_QA_API_KEY_FIELD] = old_key
|
||||
for field in secret_fields:
|
||||
incoming_key = data.get(field)
|
||||
if not incoming_key:
|
||||
continue
|
||||
|
||||
old_key = old_data.get(field, "")
|
||||
if old_key and is_mask_of(incoming_key, old_key):
|
||||
data[field] = old_key
|
||||
|
||||
return incoming_definition
|
||||
|
|
|
|||
|
|
@ -975,29 +975,67 @@ class SarvamSTTConfiguration(BaseSTTConfiguration):
|
|||
|
||||
# Speechmatics STT Service
|
||||
SPEECHMATICS_STT_LANGUAGES = [
|
||||
"en",
|
||||
"es",
|
||||
"fr",
|
||||
"de",
|
||||
"it",
|
||||
"pt",
|
||||
"ar",
|
||||
"ar_en",
|
||||
"ba",
|
||||
"eu",
|
||||
"be",
|
||||
"bn",
|
||||
"bg",
|
||||
"yue",
|
||||
"ca",
|
||||
"hr",
|
||||
"cs",
|
||||
"da",
|
||||
"nl",
|
||||
"en",
|
||||
"eo",
|
||||
"et",
|
||||
"fi",
|
||||
"fr",
|
||||
"gl",
|
||||
"de",
|
||||
"el",
|
||||
"he",
|
||||
"hi",
|
||||
"hu",
|
||||
"id",
|
||||
"ia",
|
||||
"ga",
|
||||
"it",
|
||||
"ja",
|
||||
"ko",
|
||||
"zh",
|
||||
"ru",
|
||||
"ar",
|
||||
"hi",
|
||||
"pl",
|
||||
"tr",
|
||||
"vi",
|
||||
"th",
|
||||
"id",
|
||||
"lv",
|
||||
"lt",
|
||||
"ms",
|
||||
"sv",
|
||||
"da",
|
||||
"en_ms",
|
||||
"mt",
|
||||
"cmn",
|
||||
"cmn_en",
|
||||
"cmn_en_ms_ta",
|
||||
"mr",
|
||||
"mn",
|
||||
"no",
|
||||
"fi",
|
||||
"fa",
|
||||
"pl",
|
||||
"pt",
|
||||
"ro",
|
||||
"ru",
|
||||
"sk",
|
||||
"sl",
|
||||
"es",
|
||||
"sw",
|
||||
"sv",
|
||||
"tl",
|
||||
"ta",
|
||||
"en_ta",
|
||||
"th",
|
||||
"tr",
|
||||
"uk",
|
||||
"ur",
|
||||
"ug",
|
||||
"vi",
|
||||
"cy",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
239
api/services/integrations/AGENTS.md
Normal file
239
api/services/integrations/AGENTS.md
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
# Integrations - Plugin Contract
|
||||
|
||||
`api/services/integrations/` is the extension seam for third-party integrations.
|
||||
New integrations should be self-contained here. Do not bleed integration-specific
|
||||
logic into `workflow/dto.py`, `workflow/node_specs/`, `run_pipeline.py`,
|
||||
`event_handlers.py`, or `run_integrations.py` unless you are changing the generic
|
||||
framework itself.
|
||||
|
||||
## Golden Path
|
||||
|
||||
Create a package:
|
||||
|
||||
```text
|
||||
api/services/integrations/<name>/
|
||||
├── __init__.py
|
||||
├── node.py
|
||||
├── runtime.py # optional
|
||||
├── completion.py # optional
|
||||
├── routes.py # optional
|
||||
└── client.py # optional
|
||||
```
|
||||
|
||||
The package self-registers on import via `register_package(...)`. Discovery is
|
||||
automatic: `api/services/integrations/loader.py` imports every submodule under
|
||||
`api.services.integrations` except the reserved internal names `base`, `loader`,
|
||||
and `registry`.
|
||||
|
||||
## Registration Pattern
|
||||
|
||||
`__init__.py` should register one `IntegrationPackageSpec`, following the
|
||||
existing integration packages in this directory.
|
||||
|
||||
Use:
|
||||
|
||||
```python
|
||||
PACKAGE = register_package(
|
||||
IntegrationPackageSpec(
|
||||
name="<package_name>",
|
||||
nodes=(NODE,),
|
||||
create_runtime_sessions=create_runtime_sessions, # optional
|
||||
run_completion=run_completion, # optional
|
||||
routers=(router,), # optional
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
The package name is the registry key. The node `type_name` is the workflow node
|
||||
type string and must stay stable once exposed.
|
||||
|
||||
## Node Model + Spec
|
||||
|
||||
For integration nodes, the Pydantic model is the source of truth. The serialized
|
||||
`NodeSpec` is derived from it.
|
||||
|
||||
Refer to an existing integration node for the overall structure:
|
||||
|
||||
- Define one Pydantic model per node, inheriting
|
||||
`api/services/workflow/node_data.py:BaseNodeData`.
|
||||
- Annotate it with `@node_spec(...)`.
|
||||
- Define fields with `spec_field(...)`.
|
||||
- Generate the external spec with `SPEC = build_spec(ModelClass)`.
|
||||
- Register the node with `IntegrationNodeRegistration(...)`.
|
||||
|
||||
Important rules:
|
||||
|
||||
- Put runtime validation in the model, not in the generated spec.
|
||||
Example: conditional requiredness belongs in `@model_validator(mode="after")`.
|
||||
- Keep `@node_spec(name=...)` and `IntegrationNodeRegistration.type_name`
|
||||
identical. They are the same workflow node type string.
|
||||
- Put wire constraints in the field itself where possible.
|
||||
Example: `gt=0`, `min_length=1`, `pattern=...`.
|
||||
- Put UI/export-only differences in `field_overrides`.
|
||||
Use this for `display_name`, `description`, `required`, `spec_default`,
|
||||
`display_options`, or property ordering.
|
||||
- Use `spec_exclude=True` for internal fields that must exist in persisted data
|
||||
but must not show up in `/api/v1/node-types`.
|
||||
- Set `property_order=(...)` in `@node_spec(...)` when the editor field order
|
||||
must remain stable.
|
||||
|
||||
Typical workflow graph constraints for configuration-only integration nodes:
|
||||
|
||||
```python
|
||||
GraphConstraints(min_incoming=0, max_incoming=0, min_outgoing=0, max_outgoing=0)
|
||||
```
|
||||
|
||||
These constraints control how the node can be connected in the workflow graph.
|
||||
Use them for configuration nodes that are not conversational graph steps.
|
||||
|
||||
## Secret Fields
|
||||
|
||||
If the node stores secrets, register them in
|
||||
`IntegrationNodeRegistration.sensitive_fields`.
|
||||
|
||||
That is enough for generic masking / masked round-trip preservation via
|
||||
`api/services/configuration/masking.py`. Do not add new integration-specific
|
||||
masking branches unless you are changing the shared masking framework.
|
||||
|
||||
## No Central DTO Edits
|
||||
|
||||
Do not add integration node classes to `api/services/workflow/dto.py`.
|
||||
|
||||
Integration nodes are resolved dynamically through:
|
||||
|
||||
- `get_node_data_model()` in `workflow/dto.py`
|
||||
- `get_node_spec()` / `all_node_specs()` in `services/integrations/registry.py`
|
||||
|
||||
`RFNodeDTO` validates integration nodes by `type` through the registry. That is
|
||||
the intended extension path.
|
||||
|
||||
## Live Call Path
|
||||
|
||||
If the integration needs live call data, implement `create_runtime_sessions(...)`
|
||||
in `runtime.py` and return `IntegrationRuntimeSession` objects.
|
||||
|
||||
The generic wiring is already in `api/services/pipecat/run_pipeline.py`:
|
||||
|
||||
- `create_runtime_sessions(IntegrationRuntimeContext(...))` is called before the
|
||||
pipeline task starts.
|
||||
- Each returned session gets `session.attach(task)` called.
|
||||
|
||||
Use this only for lightweight live collection:
|
||||
|
||||
- attach task observers
|
||||
- read context messages
|
||||
- capture timing / turn / tool events
|
||||
- build an in-memory snapshot
|
||||
|
||||
Do not do outbound network I/O in the live path unless there is a very strong
|
||||
reason. Prefer the standard pattern: collect live, deliver after the call.
|
||||
|
||||
`IntegrationRuntimeContext` gives you:
|
||||
|
||||
- `workflow_run_id`
|
||||
- `workflow_run`
|
||||
- `workflow_graph`
|
||||
- `run_definition`
|
||||
- `user_config`
|
||||
- `is_realtime`
|
||||
- `context_messages_provider`
|
||||
|
||||
Typical runtime pattern:
|
||||
|
||||
- scan `context.workflow_graph.nodes.values()` for enabled nodes of your type
|
||||
- if none are enabled, return `[]`
|
||||
- build one collector/session per workflow run, not per node, unless the
|
||||
integration truly needs multiple independent collectors
|
||||
|
||||
## Call-Finish Snapshot Path
|
||||
|
||||
`api/services/pipecat/event_handlers.py` finalizes runtime sessions before the
|
||||
engine is cleaned up.
|
||||
|
||||
The generic flow:
|
||||
|
||||
1. `on_pipeline_finished` builds `gathered_context`
|
||||
2. each runtime session gets `await session.on_call_finished(...)`
|
||||
3. returned dicts are merged into `integration_logs`
|
||||
4. those logs are persisted into `workflow_run.logs`
|
||||
|
||||
Use `on_call_finished(...)` to emit a compact, serializable snapshot that the
|
||||
post-call completion handler can consume later. Return `None` if there is nothing
|
||||
to persist.
|
||||
|
||||
This is the handoff between the live call path and the post-call task path.
|
||||
|
||||
## Post-Call Completion Path
|
||||
|
||||
If the integration needs durable artifacts, public URLs, retries, or external
|
||||
delivery, implement `run_completion(nodes, context)` in `completion.py`.
|
||||
|
||||
The generic orchestration is already in `api/tasks/run_integrations.py`:
|
||||
|
||||
1. load the pinned workflow definition from the workflow run
|
||||
2. create a public token if post-call work exists
|
||||
3. run QA nodes first
|
||||
4. run registered integration completion handlers
|
||||
5. run webhook nodes last
|
||||
|
||||
Your handler receives:
|
||||
|
||||
- `nodes`: raw workflow node dicts for your node types only
|
||||
- `IntegrationCompletionContext`:
|
||||
- `workflow_run_id`
|
||||
- `workflow_run`
|
||||
- `workflow_definition`
|
||||
- `definition_id`
|
||||
- `organization_id`
|
||||
- `public_token`
|
||||
|
||||
Expected completion handler pattern:
|
||||
|
||||
- validate each node with `YourNodeData.model_validate(node.get("data", {}))`
|
||||
- skip disabled nodes
|
||||
- read any runtime snapshot from `context.workflow_run.logs`
|
||||
- build durable URLs using `public_token` when appropriate
|
||||
- perform external delivery
|
||||
- return a result dict keyed per node, usually with `node_id` embedded
|
||||
|
||||
Returned data is merged into `workflow_run.annotations`.
|
||||
|
||||
Do not assume completion runs inside the live pipeline process. Treat it as a
|
||||
separate post-call worker step.
|
||||
|
||||
## Optional Routes
|
||||
|
||||
If an integration exposes HTTP routes, put them in `routes.py` and include the
|
||||
router in `IntegrationPackageSpec.routers`.
|
||||
|
||||
Routers are mounted automatically by `api/routes/main.py` through `all_routers()`.
|
||||
Do not edit `routes/main.py` for per-integration route wiring.
|
||||
|
||||
## Import Discipline
|
||||
|
||||
Keep package import side effects light.
|
||||
|
||||
The integration loader runs during:
|
||||
|
||||
- node-type/spec enumeration
|
||||
- tests
|
||||
- route startup
|
||||
- registry access
|
||||
|
||||
So avoid top-level imports that require environment variables, network access,
|
||||
or heavyweight initialization when possible. Prefer lazy imports inside
|
||||
`run_completion()` / `create_runtime_sessions()` if the dependency is optional or
|
||||
environment-sensitive.
|
||||
|
||||
## Testing Expectations
|
||||
|
||||
At minimum, new integrations should add coverage for:
|
||||
|
||||
- node model validation
|
||||
- generated spec/example validity
|
||||
- secret masking + masked round-trip preservation if secrets exist
|
||||
- runtime snapshot creation if live collectors exist
|
||||
- completion handler happy path and disabled-node skip path
|
||||
|
||||
If you change shared integration machinery, test the framework in the generic
|
||||
code path, not only the concrete integration.
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
from api.services.integrations.base import (
|
||||
IntegrationCompletionContext,
|
||||
IntegrationNodeRegistration,
|
||||
IntegrationPackageSpec,
|
||||
IntegrationRuntimeContext,
|
||||
IntegrationRuntimeSession,
|
||||
)
|
||||
from api.services.integrations.registry import (
|
||||
all_node_specs,
|
||||
all_packages,
|
||||
all_routers,
|
||||
create_runtime_sessions,
|
||||
get_node_data_model,
|
||||
get_node_registration,
|
||||
get_node_secret_fields,
|
||||
get_node_spec,
|
||||
has_completion_handlers,
|
||||
register_package,
|
||||
run_completion_handlers,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"IntegrationCompletionContext",
|
||||
"IntegrationNodeRegistration",
|
||||
"IntegrationPackageSpec",
|
||||
"IntegrationRuntimeContext",
|
||||
"IntegrationRuntimeSession",
|
||||
"all_node_specs",
|
||||
"all_packages",
|
||||
"all_routers",
|
||||
"create_runtime_sessions",
|
||||
"get_node_data_model",
|
||||
"get_node_registration",
|
||||
"get_node_secret_fields",
|
||||
"get_node_spec",
|
||||
"has_completion_handlers",
|
||||
"register_package",
|
||||
"run_completion_handlers",
|
||||
]
|
||||
69
api/services/integrations/base.py
Normal file
69
api/services/integrations/base.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Protocol
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from api.services.workflow.node_data import BaseNodeData
|
||||
from api.services.workflow.node_specs._base import NodeSpec
|
||||
|
||||
|
||||
class IntegrationRuntimeSession(Protocol):
|
||||
name: str
|
||||
|
||||
def attach(self, task: Any) -> None: ...
|
||||
|
||||
async def on_call_finished(
|
||||
self,
|
||||
*,
|
||||
gathered_context: dict[str, Any],
|
||||
) -> dict[str, Any] | None: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IntegrationRuntimeContext:
|
||||
workflow_run_id: int
|
||||
workflow_run: Any
|
||||
workflow_graph: Any
|
||||
run_definition: Any
|
||||
user_config: Any
|
||||
is_realtime: bool
|
||||
context_messages_provider: Callable[[], list[dict[str, Any]]]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IntegrationCompletionContext:
|
||||
workflow_run_id: int
|
||||
workflow_run: Any
|
||||
workflow_definition: dict[str, Any]
|
||||
definition_id: int | None
|
||||
organization_id: int
|
||||
public_token: str | None
|
||||
|
||||
|
||||
RuntimeFactory = Callable[
|
||||
[IntegrationRuntimeContext],
|
||||
list[IntegrationRuntimeSession],
|
||||
]
|
||||
CompletionHandler = Callable[
|
||||
[list[dict[str, Any]], IntegrationCompletionContext],
|
||||
Awaitable[dict[str, Any]],
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IntegrationNodeRegistration:
|
||||
type_name: str
|
||||
data_model: type[BaseNodeData]
|
||||
node_spec: NodeSpec
|
||||
sensitive_fields: tuple[str, ...] = ()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IntegrationPackageSpec:
|
||||
name: str
|
||||
nodes: tuple[IntegrationNodeRegistration, ...] = ()
|
||||
routers: tuple[APIRouter, ...] = ()
|
||||
create_runtime_sessions: RuntimeFactory | None = None
|
||||
run_completion: CompletionHandler | None = None
|
||||
21
api/services/integrations/loader.py
Normal file
21
api/services/integrations/loader.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
_INTERNAL_MODULES = {"base", "loader", "registry"}
|
||||
_loaded = False
|
||||
|
||||
|
||||
def ensure_integrations_loaded() -> None:
|
||||
global _loaded
|
||||
if _loaded:
|
||||
return
|
||||
|
||||
package = importlib.import_module("api.services.integrations")
|
||||
for module_info in pkgutil.iter_modules(package.__path__):
|
||||
if module_info.name in _INTERNAL_MODULES:
|
||||
continue
|
||||
importlib.import_module(f"{package.__name__}.{module_info.name}")
|
||||
|
||||
_loaded = True
|
||||
|
|
@ -1,253 +0,0 @@
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
NANGO_ALLOWED_INTEGRATIONS = [
|
||||
i.strip() for i in os.environ.get("NANGO_ALLOWED_INTEGRATIONS", "slack").split(",")
|
||||
]
|
||||
|
||||
|
||||
class NangoWebhookRequest(BaseModel):
|
||||
type: str
|
||||
connectionId: str
|
||||
providerConfigKey: str
|
||||
authMode: str
|
||||
provider: str
|
||||
environment: str
|
||||
operation: str
|
||||
endUser: dict # Contains endUserId and organizationId
|
||||
success: bool
|
||||
|
||||
|
||||
class NangoService:
|
||||
def __init__(self):
|
||||
self.base_url = "https://api.nango.dev"
|
||||
self.secret_key = os.getenv("NANGO_API_KEY")
|
||||
|
||||
def _verify_webhook_signature(
|
||||
self, request_body: str, signature: str = None
|
||||
) -> bool:
|
||||
"""
|
||||
Verify the webhook signature using SHA256 hash.
|
||||
|
||||
Args:
|
||||
request_body: The raw request body as string
|
||||
signature: The signature from request headers (optional for now)
|
||||
|
||||
Returns:
|
||||
True if signature is valid
|
||||
"""
|
||||
expected_signature = self.secret_key + request_body
|
||||
expected_hash = hashlib.sha256(expected_signature.encode("utf-8")).hexdigest()
|
||||
return expected_hash == signature
|
||||
|
||||
async def create_session(
|
||||
self, user_id: str, organization_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a Nango session for the given user and organization.
|
||||
|
||||
Args:
|
||||
user_id: The end user ID
|
||||
organization_id: The organization ID
|
||||
|
||||
Returns:
|
||||
Response from Nango API
|
||||
"""
|
||||
if not self.secret_key:
|
||||
raise ValueError("NANGO_SECRET_KEY environment variable is not set")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.secret_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"end_user": {"id": user_id},
|
||||
"organization": {"id": str(organization_id)},
|
||||
"allowed_integrations": NANGO_ALLOWED_INTEGRATIONS,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/connect/sessions", headers=headers, json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 201:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Nango API error: {response.status_code}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def process_webhook(
|
||||
self, raw_body: bytes, signature: str = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Process incoming Nango webhook request.
|
||||
|
||||
Args:
|
||||
raw_body: The raw request body as bytes
|
||||
signature: Optional signature from request headers
|
||||
|
||||
Returns:
|
||||
Dict with status and message
|
||||
"""
|
||||
# Decode and parse the request body
|
||||
try:
|
||||
body_text = raw_body.decode("utf-8")
|
||||
webhook_json = json.loads(body_text) if body_text else {}
|
||||
logger.debug(f"received webhook from nango: {webhook_json}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error: {e} body_text: {body_text}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
|
||||
|
||||
# Verify webhook signature
|
||||
if not self._verify_webhook_signature(body_text, signature):
|
||||
raise HTTPException(status_code=401, detail="Invalid webhook signature")
|
||||
|
||||
# Parse webhook data
|
||||
try:
|
||||
webhook_data = NangoWebhookRequest(**webhook_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse webhook data: {e}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid webhook format: {str(e)}"
|
||||
)
|
||||
|
||||
# Extract user and organization IDs from the webhook payload
|
||||
end_user = webhook_data.endUser
|
||||
if (
|
||||
not end_user
|
||||
or "endUserId" not in end_user
|
||||
or "organizationId" not in end_user
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing endUser information in webhook payload"
|
||||
)
|
||||
|
||||
user_id = int(end_user["endUserId"])
|
||||
organization_id = int(end_user["organizationId"])
|
||||
|
||||
# Use the connectionId as the integration_id since it's unique per integration
|
||||
integration_id = webhook_data.connectionId
|
||||
|
||||
# Initialize connection_details
|
||||
connection_details = {}
|
||||
|
||||
# Fetch connection details if type is auth and provider is slack
|
||||
if webhook_data.type == "auth":
|
||||
connection_details = await self._fetch_connection_details(
|
||||
integration_id, webhook_data.provider
|
||||
)
|
||||
|
||||
# Create the integration in the database
|
||||
integration = await db_client.create_integration(
|
||||
integration_id=integration_id,
|
||||
organisation_id=organization_id,
|
||||
provider=webhook_data.provider,
|
||||
created_by=user_id,
|
||||
is_active=True,
|
||||
connection_details=connection_details,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Integration created successfully with ID: {integration.id}",
|
||||
}
|
||||
|
||||
async def _fetch_connection_details(
|
||||
self, connection_id: str, provider_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch connection details from Nango API for a given connection ID.
|
||||
|
||||
Args:
|
||||
connection_id: The connection ID from the webhook
|
||||
|
||||
Returns:
|
||||
Connection details as a dictionary
|
||||
"""
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.secret_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/connection/{connection_id}/?provider_config_key={provider_key}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Failed to fetch connection details: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Nango API error while fetching connection: {response.status_code}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
connection_details = response.json()
|
||||
return connection_details
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error while fetching connection details: {e}")
|
||||
# Return empty dict if API call fails, but log the error
|
||||
return {}
|
||||
|
||||
async def get_access_token(
|
||||
self, connection_id: str, provider_config_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the latest access token for a connection from Nango.
|
||||
|
||||
Args:
|
||||
connection_id: The connection ID
|
||||
provider_config_key: The provider config key (e.g., 'google-sheet')
|
||||
|
||||
Returns:
|
||||
Dict containing access token and other connection details
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.secret_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/connection/{connection_id}?provider_config_key={provider_config_key}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Failed to get access token: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Nango API error: {response.status_code}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error while getting access token: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance
|
||||
nango_service = NangoService()
|
||||
128
api/services/integrations/registry.py
Normal file
128
api/services/integrations/registry.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from api.services.integrations.base import (
|
||||
IntegrationCompletionContext,
|
||||
IntegrationNodeRegistration,
|
||||
IntegrationPackageSpec,
|
||||
IntegrationRuntimeContext,
|
||||
)
|
||||
from api.services.workflow.node_data import BaseNodeData
|
||||
|
||||
_PACKAGE_REGISTRY: dict[str, IntegrationPackageSpec] = {}
|
||||
|
||||
|
||||
def register_package(spec: IntegrationPackageSpec) -> IntegrationPackageSpec:
|
||||
existing = _PACKAGE_REGISTRY.get(spec.name)
|
||||
if existing is not None and existing is not spec:
|
||||
raise ValueError(
|
||||
f"Duplicate integration package registration for {spec.name!r}"
|
||||
)
|
||||
_PACKAGE_REGISTRY[spec.name] = spec
|
||||
return spec
|
||||
|
||||
|
||||
def _ensure_loaded() -> None:
|
||||
from api.services.integrations.loader import ensure_integrations_loaded
|
||||
|
||||
ensure_integrations_loaded()
|
||||
|
||||
|
||||
def all_packages() -> list[IntegrationPackageSpec]:
|
||||
_ensure_loaded()
|
||||
return [_PACKAGE_REGISTRY[name] for name in sorted(_PACKAGE_REGISTRY)]
|
||||
|
||||
|
||||
def get_package(name: str) -> IntegrationPackageSpec | None:
|
||||
_ensure_loaded()
|
||||
return _PACKAGE_REGISTRY.get(name)
|
||||
|
||||
|
||||
def get_node_registration(type_name: str) -> IntegrationNodeRegistration | None:
|
||||
_ensure_loaded()
|
||||
for package in _PACKAGE_REGISTRY.values():
|
||||
for node in package.nodes:
|
||||
if node.type_name == type_name:
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
def get_node_data_model(type_name: str) -> type[BaseNodeData] | None:
|
||||
registration = get_node_registration(type_name)
|
||||
return registration.data_model if registration else None
|
||||
|
||||
|
||||
def get_node_spec(type_name: str):
|
||||
registration = get_node_registration(type_name)
|
||||
return registration.node_spec if registration else None
|
||||
|
||||
|
||||
def get_node_secret_fields(type_name: str) -> tuple[str, ...]:
|
||||
registration = get_node_registration(type_name)
|
||||
return registration.sensitive_fields if registration else ()
|
||||
|
||||
|
||||
def all_node_specs():
|
||||
_ensure_loaded()
|
||||
specs = []
|
||||
for package in all_packages():
|
||||
specs.extend(node.node_spec for node in package.nodes)
|
||||
return specs
|
||||
|
||||
|
||||
def all_routers():
|
||||
_ensure_loaded()
|
||||
routers = []
|
||||
for package in all_packages():
|
||||
routers.extend(package.routers)
|
||||
return routers
|
||||
|
||||
|
||||
def create_runtime_sessions(
|
||||
context: IntegrationRuntimeContext,
|
||||
):
|
||||
_ensure_loaded()
|
||||
sessions = []
|
||||
for package in all_packages():
|
||||
if package.create_runtime_sessions is None:
|
||||
continue
|
||||
sessions.extend(package.create_runtime_sessions(context))
|
||||
return sessions
|
||||
|
||||
|
||||
def iter_completion_packages(
|
||||
workflow_definition: dict[str, Any],
|
||||
):
|
||||
_ensure_loaded()
|
||||
nodes = workflow_definition.get("nodes", []) if workflow_definition else []
|
||||
for package in all_packages():
|
||||
node_types = {node.type_name for node in package.nodes}
|
||||
package_nodes = [
|
||||
node
|
||||
for node in nodes
|
||||
if isinstance(node, dict) and node.get("type") in node_types
|
||||
]
|
||||
if package_nodes:
|
||||
yield package, package_nodes
|
||||
|
||||
|
||||
def has_completion_handlers(workflow_definition: dict[str, Any]) -> bool:
|
||||
return any(
|
||||
package.run_completion is not None
|
||||
for package, _nodes in iter_completion_packages(workflow_definition)
|
||||
)
|
||||
|
||||
|
||||
async def run_completion_handlers(
|
||||
*,
|
||||
context: IntegrationCompletionContext,
|
||||
) -> dict[str, Any]:
|
||||
results: dict[str, Any] = {}
|
||||
for package, nodes in iter_completion_packages(context.workflow_definition):
|
||||
if package.run_completion is None:
|
||||
continue
|
||||
package_result = await package.run_completion(nodes, context)
|
||||
if package_result:
|
||||
results.update(package_result)
|
||||
return results
|
||||
19
api/services/integrations/tuner/__init__.py
Normal file
19
api/services/integrations/tuner/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from api.services.integrations.base import IntegrationPackageSpec
|
||||
from api.services.integrations.registry import register_package
|
||||
|
||||
from .completion import run_completion
|
||||
from .node import NODE
|
||||
from .runtime import create_runtime_sessions
|
||||
|
||||
PACKAGE = register_package(
|
||||
IntegrationPackageSpec(
|
||||
name="tuner",
|
||||
nodes=(NODE,),
|
||||
create_runtime_sessions=create_runtime_sessions,
|
||||
run_completion=run_completion,
|
||||
)
|
||||
)
|
||||
|
||||
__all__ = ["PACKAGE"]
|
||||
71
api/services/integrations/tuner/client.py
Normal file
71
api/services/integrations/tuner/client.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class TunerDeliveryConfig(BaseModel):
|
||||
base_url: str
|
||||
api_key: str
|
||||
workspace_id: int
|
||||
agent_id: str
|
||||
|
||||
@field_validator("api_key", "agent_id")
|
||||
@classmethod
|
||||
def _must_not_be_empty(cls, value: str) -> str:
|
||||
if not value or not value.strip():
|
||||
raise ValueError("must not be empty")
|
||||
return value
|
||||
|
||||
@field_validator("workspace_id")
|
||||
@classmethod
|
||||
def _workspace_must_be_positive(cls, value: int) -> int:
|
||||
if value <= 0:
|
||||
raise ValueError("must be a positive integer")
|
||||
return value
|
||||
|
||||
|
||||
async def post_call(
|
||||
config: TunerDeliveryConfig,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
url = (
|
||||
f"{config.base_url}/api/v1/public/call"
|
||||
f"?workspace_id={config.workspace_id}"
|
||||
f"&agent_remote_identifier={config.agent_id}"
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {config.api_key}"}
|
||||
|
||||
logger.info(
|
||||
"[tuner] posting completed call {} to workspace {} / agent {}",
|
||||
payload.get("call_id"),
|
||||
config.workspace_id,
|
||||
config.agent_id,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code == 409:
|
||||
logger.info("[tuner] call {} already exists in tuner", payload.get("call_id"))
|
||||
return {"status": "duplicate", "status_code": response.status_code}
|
||||
|
||||
if response.status_code >= 400:
|
||||
logger.error(
|
||||
"[tuner] POST failed for call {} with status {}: {}",
|
||||
payload.get("call_id"),
|
||||
response.status_code,
|
||||
response.text[:200],
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
logger.info(
|
||||
"[tuner] POST succeeded for call {} with status {}",
|
||||
payload.get("call_id"),
|
||||
response.status_code,
|
||||
)
|
||||
return {"status": "delivered", "status_code": response.status_code}
|
||||
191
api/services/integrations/tuner/collector.py
Normal file
191
api/services/integrations/tuner/collector.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
MetricsFrame,
|
||||
StartFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.observers.turn_tracking_observer import TurnTrackingObserver
|
||||
from pipecat.observers.user_bot_latency_observer import UserBotLatencyObserver
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.utils.context.message_sanitization import strip_thought_ids_from_messages
|
||||
from tuner_pipecat_sdk.accumulator import CallAccumulator
|
||||
from tuner_pipecat_sdk.payload_builder import build_payload
|
||||
|
||||
from api.enums import WorkflowRunMode
|
||||
|
||||
TUNER_RECORDING_PLACEHOLDER = "pipecat://no-recording"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _PayloadConfig:
|
||||
call_id: str
|
||||
call_type: str
|
||||
recording_url: str
|
||||
asr_model: str
|
||||
llm_model: str
|
||||
tts_model: str
|
||||
sip_call_id: str | None = None
|
||||
sip_headers: dict[str, str] | None = None
|
||||
agent_version: int | None = None
|
||||
|
||||
|
||||
def mode_to_tuner_call_type(mode: str | None) -> str:
|
||||
if mode in {
|
||||
WorkflowRunMode.WEBRTC.value,
|
||||
WorkflowRunMode.SMALLWEBRTC.value,
|
||||
}:
|
||||
return "web_call"
|
||||
return "phone_call"
|
||||
|
||||
|
||||
class TunerCollector(BaseObserver):
|
||||
"""Collect runtime call metadata and build a deferred Tuner payload."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workflow_run_id: int,
|
||||
call_type: str,
|
||||
asr_model: str = "",
|
||||
llm_model: str = "",
|
||||
tts_model: str = "",
|
||||
agent_version: int | None = None,
|
||||
max_frames: int = 500,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._call_id = str(workflow_run_id)
|
||||
self._call_type = call_type
|
||||
self._asr_model = asr_model
|
||||
self._llm_model = llm_model
|
||||
self._tts_model = tts_model
|
||||
self._agent_version = agent_version
|
||||
self._acc = CallAccumulator()
|
||||
self._acc.call_start_abs_ns = time.time_ns()
|
||||
self._pipeline_start_rel_ns: int | None = None
|
||||
self._context_provider: Callable[[], list[dict[str, Any]]] | None = None
|
||||
self._processed_frames: set[int] = set()
|
||||
self._frame_history: deque[int] = deque(maxlen=max_frames)
|
||||
|
||||
def attach_context(self, provider: Callable[[], list[dict[str, Any]]]) -> None:
|
||||
self._context_provider = provider
|
||||
|
||||
def set_disconnection_reason(self, reason: str | None) -> None:
|
||||
if reason:
|
||||
self._acc.set_disconnection_reason(reason)
|
||||
|
||||
def attach_turn_tracking_observer(
|
||||
self, turn_tracker: TurnTrackingObserver | None
|
||||
) -> None:
|
||||
if turn_tracker is None:
|
||||
return
|
||||
|
||||
@turn_tracker.event_handler("on_turn_started")
|
||||
async def _on_turn_started(_tracker: Any, turn_number: int) -> None:
|
||||
self._acc.on_turn_started(turn_number, time.time_ns())
|
||||
|
||||
@turn_tracker.event_handler("on_turn_ended")
|
||||
async def _on_turn_ended(
|
||||
_tracker: Any, turn_number: int, _duration: float, was_interrupted: bool
|
||||
) -> None:
|
||||
self._acc.on_turn_ended(turn_number, was_interrupted)
|
||||
|
||||
def attach_latency_observer(
|
||||
self, latency_observer: UserBotLatencyObserver | None
|
||||
) -> None:
|
||||
if latency_observer is None:
|
||||
return
|
||||
|
||||
@latency_observer.event_handler("on_latency_measured")
|
||||
async def _on_latency_measured(_observer: Any, latency: float) -> None:
|
||||
self._acc.on_latency_measured(latency)
|
||||
|
||||
@latency_observer.event_handler("on_latency_breakdown")
|
||||
async def _on_latency_breakdown(_observer: Any, breakdown: Any) -> None:
|
||||
self._acc.on_latency_breakdown(breakdown)
|
||||
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
if data.direction != FrameDirection.DOWNSTREAM:
|
||||
return
|
||||
|
||||
if data.frame.id in self._processed_frames:
|
||||
return
|
||||
|
||||
self._processed_frames.add(data.frame.id)
|
||||
self._frame_history.append(data.frame.id)
|
||||
if len(self._processed_frames) > len(self._frame_history):
|
||||
self._processed_frames = set(self._frame_history)
|
||||
|
||||
frame = data.frame
|
||||
|
||||
# data.timestamp is a pipeline-relative clock (ns since pipeline start).
|
||||
# Convert to absolute ns so the accumulator's _rel_ms() works correctly.
|
||||
if self._pipeline_start_rel_ns is None:
|
||||
self._pipeline_start_rel_ns = data.timestamp
|
||||
timestamp_ns = self._acc.call_start_abs_ns + (
|
||||
data.timestamp - self._pipeline_start_rel_ns
|
||||
)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
self._acc.on_start(timestamp_ns)
|
||||
elif isinstance(frame, FunctionCallInProgressFrame):
|
||||
self._acc.on_function_call_in_progress(frame, timestamp_ns)
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
self._acc.on_function_call_result(frame.tool_call_id, timestamp_ns)
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
self._acc.on_metrics_frame(frame)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
self._acc.on_user_started_speaking(timestamp_ns)
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
self._acc.on_user_stopped_speaking(timestamp_ns)
|
||||
self._acc.on_user_turn_stopped(timestamp_ns)
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
self._acc.on_bot_started_speaking(timestamp_ns)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._acc.on_bot_stopped(timestamp_ns)
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
self._acc.on_vad_stopped(timestamp_ns)
|
||||
elif isinstance(frame, (CancelFrame, EndFrame)):
|
||||
self._acc.on_call_end(timestamp_ns)
|
||||
|
||||
def build_payload_snapshot(
|
||||
self,
|
||||
*,
|
||||
recording_url: str = TUNER_RECORDING_PLACEHOLDER,
|
||||
) -> dict[str, Any] | None:
|
||||
if self._context_provider is None:
|
||||
logger.warning(
|
||||
"[tuner] no context provider attached; skipping payload snapshot"
|
||||
)
|
||||
return None
|
||||
|
||||
transcript = strip_thought_ids_from_messages(list(self._context_provider()))
|
||||
payload = build_payload(
|
||||
self._acc,
|
||||
_PayloadConfig(
|
||||
call_id=self._call_id,
|
||||
call_type=self._call_type,
|
||||
recording_url=recording_url,
|
||||
asr_model=self._asr_model,
|
||||
llm_model=self._llm_model,
|
||||
tts_model=self._tts_model,
|
||||
agent_version=self._agent_version,
|
||||
),
|
||||
transcript,
|
||||
)
|
||||
return payload.to_dict()
|
||||
76
api/services/integrations/tuner/completion.py
Normal file
76
api/services/integrations/tuner/completion.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import BACKEND_API_ENDPOINT, TUNER_BASE_URL
|
||||
from api.services.integrations.base import IntegrationCompletionContext
|
||||
|
||||
from .client import TunerDeliveryConfig, post_call
|
||||
from .collector import TUNER_RECORDING_PLACEHOLDER
|
||||
from .node import TunerNodeData
|
||||
|
||||
|
||||
def _build_recording_url(
|
||||
context: IntegrationCompletionContext,
|
||||
) -> str | None:
|
||||
workflow_run = context.workflow_run
|
||||
if context.public_token:
|
||||
base_url = f"{BACKEND_API_ENDPOINT}/api/v1/public/download/workflow/{context.public_token}"
|
||||
return f"{base_url}/recording" if workflow_run.recording_url else None
|
||||
return workflow_run.recording_url
|
||||
|
||||
|
||||
async def run_completion(
|
||||
nodes: list[dict[str, Any]],
|
||||
context: IntegrationCompletionContext,
|
||||
) -> dict[str, Any]:
|
||||
results: dict[str, Any] = {}
|
||||
payload_snapshot = (context.workflow_run.logs or {}).get("tuner_payload")
|
||||
recording_url = _build_recording_url(context) or TUNER_RECORDING_PLACEHOLDER
|
||||
|
||||
for node in nodes:
|
||||
node_id = node.get("id", "unknown")
|
||||
try:
|
||||
tuner_data = TunerNodeData.model_validate(node.get("data", {}))
|
||||
except Exception as exc:
|
||||
logger.warning(f"Tuner node #{node_id} failed validation, skipping: {exc}")
|
||||
results[f"tuner_{node_id}"] = {"error": "validation_failed"}
|
||||
continue
|
||||
|
||||
if not tuner_data.tuner_enabled:
|
||||
logger.debug(f"Tuner node '{tuner_data.name}' is disabled, skipping")
|
||||
continue
|
||||
|
||||
if not payload_snapshot:
|
||||
logger.warning(
|
||||
f"Tuner payload snapshot missing for node '{tuner_data.name}' (#{node_id})"
|
||||
)
|
||||
results[f"tuner_{node_id}"] = {"error": "missing_payload_snapshot"}
|
||||
continue
|
||||
|
||||
payload = copy.deepcopy(payload_snapshot)
|
||||
payload["recording_url"] = recording_url
|
||||
|
||||
try:
|
||||
config = TunerDeliveryConfig(
|
||||
base_url=TUNER_BASE_URL,
|
||||
api_key=tuner_data.tuner_api_key or "",
|
||||
workspace_id=tuner_data.tuner_workspace_id or 0,
|
||||
agent_id=tuner_data.tuner_agent_id or "",
|
||||
)
|
||||
delivery = await post_call(config, payload)
|
||||
results[f"tuner_{node_id}"] = {
|
||||
**delivery,
|
||||
"workspace_id": tuner_data.tuner_workspace_id,
|
||||
"agent_id": tuner_data.tuner_agent_id,
|
||||
"exported_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error(f"Tuner export failed for node '{tuner_data.name}': {exc}")
|
||||
results[f"tuner_{node_id}"] = {"error": str(exc)}
|
||||
|
||||
return results
|
||||
139
api/services/integrations/tuner/node.py
Normal file
139
api/services/integrations/tuner/node.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from api.services.integrations.base import IntegrationNodeRegistration
|
||||
from api.services.workflow.node_data import BaseNodeData
|
||||
from api.services.workflow.node_specs._base import (
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
PropertyType,
|
||||
)
|
||||
from api.services.workflow.node_specs.model_spec import (
|
||||
build_spec,
|
||||
node_spec,
|
||||
spec_field,
|
||||
)
|
||||
|
||||
|
||||
@node_spec(
|
||||
name="tuner",
|
||||
display_name="Tuner",
|
||||
description="Export the completed call to Tuner for Agent Observability",
|
||||
llm_hint=(
|
||||
"Tuner is a post-call observability export. It does not participate in the "
|
||||
"conversation graph and should not be connected to other nodes."
|
||||
),
|
||||
category=NodeCategory.integration,
|
||||
icon="Activity",
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="tuner_export",
|
||||
data={
|
||||
"name": "Primary Tuner Export",
|
||||
"tuner_enabled": True,
|
||||
"tuner_agent_id": "sales-bot-prod",
|
||||
"tuner_workspace_id": 42,
|
||||
"tuner_api_key": "tuner_live_xxxxxxxx",
|
||||
},
|
||||
)
|
||||
],
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0,
|
||||
max_incoming=0,
|
||||
min_outgoing=0,
|
||||
max_outgoing=0,
|
||||
),
|
||||
property_order=(
|
||||
"name",
|
||||
"tuner_enabled",
|
||||
"tuner_agent_id",
|
||||
"tuner_workspace_id",
|
||||
"tuner_api_key",
|
||||
),
|
||||
field_overrides={
|
||||
"name": {
|
||||
"spec_default": "Tuner",
|
||||
"description": "Short identifier for this Tuner export configuration.",
|
||||
},
|
||||
"tuner_enabled": {
|
||||
"display_name": "Enabled",
|
||||
"description": "When false, Dograh skips exporting this call to Tuner.",
|
||||
},
|
||||
"tuner_agent_id": {
|
||||
"display_name": "Tuner Agent ID",
|
||||
"description": "The agent identifier registered in your Tuner workspace.",
|
||||
"required": True,
|
||||
},
|
||||
"tuner_workspace_id": {
|
||||
"display_name": "Tuner Workspace ID",
|
||||
"description": "Your numeric Tuner workspace ID.",
|
||||
"required": True,
|
||||
"min_value": 1,
|
||||
},
|
||||
"tuner_api_key": {
|
||||
"display_name": "Tuner API Key",
|
||||
"description": "Bearer token used when posting completed calls to Tuner.",
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
class TunerNodeData(BaseNodeData):
|
||||
tuner_enabled: bool = spec_field(
|
||||
default=True,
|
||||
ui_type=PropertyType.boolean,
|
||||
display_name="Enabled",
|
||||
description="When false, Dograh skips exporting this call to Tuner.",
|
||||
)
|
||||
tuner_agent_id: str | None = spec_field(
|
||||
default=None,
|
||||
ui_type=PropertyType.string,
|
||||
display_name="Tuner Agent ID",
|
||||
description="The agent identifier registered in your Tuner workspace.",
|
||||
)
|
||||
tuner_workspace_id: int | None = spec_field(
|
||||
default=None,
|
||||
gt=0,
|
||||
ui_type=PropertyType.number,
|
||||
display_name="Tuner Workspace ID",
|
||||
description="Your numeric Tuner workspace ID.",
|
||||
)
|
||||
tuner_api_key: str | None = spec_field(
|
||||
default=None,
|
||||
ui_type=PropertyType.string,
|
||||
display_name="Tuner API Key",
|
||||
description="Bearer token used when posting completed calls to Tuner.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_enabled_config(self):
|
||||
if not self.tuner_enabled:
|
||||
return self
|
||||
|
||||
missing: list[str] = []
|
||||
if not self.tuner_agent_id or not self.tuner_agent_id.strip():
|
||||
missing.append("tuner_agent_id")
|
||||
if self.tuner_workspace_id is None:
|
||||
missing.append("tuner_workspace_id")
|
||||
if not self.tuner_api_key or not self.tuner_api_key.strip():
|
||||
missing.append("tuner_api_key")
|
||||
|
||||
if missing:
|
||||
fields = ", ".join(missing)
|
||||
raise ValueError(
|
||||
f"Tuner node is enabled but missing required fields: {fields}"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
SPEC = build_spec(TunerNodeData)
|
||||
|
||||
|
||||
NODE = IntegrationNodeRegistration(
|
||||
type_name="tuner",
|
||||
data_model=TunerNodeData,
|
||||
node_spec=SPEC,
|
||||
sensitive_fields=("tuner_api_key",),
|
||||
)
|
||||
101
api/services/integrations/tuner/runtime.py
Normal file
101
api/services/integrations/tuner/runtime.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.integrations.base import (
|
||||
IntegrationRuntimeContext,
|
||||
IntegrationRuntimeSession,
|
||||
)
|
||||
|
||||
from .collector import TunerCollector, mode_to_tuner_call_type
|
||||
|
||||
|
||||
def _format_model_label(provider: str | None, model: str | None) -> str:
|
||||
if provider and model:
|
||||
return f"{provider}/{model}"
|
||||
if model:
|
||||
return model
|
||||
return provider or ""
|
||||
|
||||
|
||||
def _resolve_model_labels(context: IntegrationRuntimeContext) -> tuple[str, str, str]:
|
||||
user_config = context.user_config
|
||||
|
||||
if context.is_realtime and user_config.realtime:
|
||||
realtime_provider = user_config.realtime.provider
|
||||
realtime_model = user_config.realtime.model
|
||||
llm_model = _format_model_label(realtime_provider, realtime_model)
|
||||
if realtime_provider in {
|
||||
ServiceProviders.GOOGLE_REALTIME.value,
|
||||
ServiceProviders.GOOGLE_VERTEX_REALTIME.value,
|
||||
ServiceProviders.OPENAI_REALTIME.value,
|
||||
}:
|
||||
return "", llm_model, ""
|
||||
return "", llm_model, ""
|
||||
|
||||
return (
|
||||
_format_model_label(
|
||||
getattr(user_config.stt, "provider", None),
|
||||
getattr(user_config.stt, "model", None),
|
||||
),
|
||||
_format_model_label(
|
||||
getattr(user_config.llm, "provider", None),
|
||||
getattr(user_config.llm, "model", None),
|
||||
),
|
||||
_format_model_label(
|
||||
getattr(user_config.tts, "provider", None),
|
||||
getattr(user_config.tts, "model", None),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TunerRuntimeSession(IntegrationRuntimeSession):
|
||||
name = "tuner"
|
||||
|
||||
def __init__(self, collector: TunerCollector) -> None:
|
||||
self._collector = collector
|
||||
|
||||
def attach(self, task: Any) -> None:
|
||||
self._collector.attach_turn_tracking_observer(task.turn_tracking_observer)
|
||||
self._collector.attach_latency_observer(task.user_bot_latency_observer)
|
||||
task.add_observer(self._collector)
|
||||
|
||||
async def on_call_finished(
|
||||
self,
|
||||
*,
|
||||
gathered_context: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
self._collector.set_disconnection_reason(
|
||||
gathered_context.get("call_disposition")
|
||||
)
|
||||
payload = self._collector.build_payload_snapshot()
|
||||
if payload is None:
|
||||
return None
|
||||
return {"tuner_payload": payload}
|
||||
|
||||
|
||||
def create_runtime_sessions(
|
||||
context: IntegrationRuntimeContext,
|
||||
) -> list[IntegrationRuntimeSession]:
|
||||
tuner_nodes = [
|
||||
node
|
||||
for node in context.workflow_graph.nodes.values()
|
||||
if node.node_type == "tuner" and getattr(node.data, "tuner_enabled", True)
|
||||
]
|
||||
if not tuner_nodes:
|
||||
return []
|
||||
|
||||
asr_model, llm_model, tts_model = _resolve_model_labels(context)
|
||||
|
||||
collector = TunerCollector(
|
||||
workflow_run_id=context.workflow_run_id,
|
||||
call_type=mode_to_tuner_call_type(context.workflow_run.mode),
|
||||
asr_model=asr_model,
|
||||
llm_model=llm_model,
|
||||
tts_model=tts_model,
|
||||
agent_version=getattr(context.run_definition, "version_number", None),
|
||||
)
|
||||
collector.attach_context(context.context_messages_provider)
|
||||
|
||||
return [TunerRuntimeSession(collector)]
|
||||
|
|
@ -5,6 +5,7 @@ from loguru import logger
|
|||
from api.db import db_client
|
||||
from api.enums import PostHogEvent, WorkflowRunState
|
||||
from api.services.campaign.circuit_breaker import circuit_breaker
|
||||
from api.services.integrations import IntegrationRuntimeSession
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.audio_playback import play_audio_loop
|
||||
from api.services.pipecat.in_memory_buffers import (
|
||||
|
|
@ -67,6 +68,7 @@ def register_event_handlers(
|
|||
audio_config=AudioConfig,
|
||||
pre_call_fetch_task: asyncio.Task | None = None,
|
||||
user_provider_id: str | None = None,
|
||||
integration_runtime_sessions: list[IntegrationRuntimeSession] | None = None,
|
||||
):
|
||||
"""Register all event handlers for transport and task events.
|
||||
|
||||
|
|
@ -272,6 +274,20 @@ def register_event_handlers(
|
|||
)
|
||||
|
||||
# Clean up engine resources (including voicemail detector)
|
||||
integration_logs: dict[str, object] = {}
|
||||
for runtime_session in integration_runtime_sessions or []:
|
||||
try:
|
||||
session_logs = await runtime_session.on_call_finished(
|
||||
gathered_context=gathered_context
|
||||
)
|
||||
if session_logs:
|
||||
integration_logs.update(session_logs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error finalizing integration runtime session '{runtime_session.name}': {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await engine.cleanup()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -321,14 +337,11 @@ def register_event_handlers(
|
|||
)
|
||||
)
|
||||
|
||||
# Save real-time feedback logs to workflow run
|
||||
logs_update: dict[str, object] = {}
|
||||
if not in_memory_logs_buffer.is_empty:
|
||||
try:
|
||||
feedback_events = in_memory_logs_buffer.get_events()
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
logs={"realtime_feedback_events": feedback_events},
|
||||
)
|
||||
logs_update["realtime_feedback_events"] = feedback_events
|
||||
logger.debug(
|
||||
f"Saved {len(feedback_events)} feedback events to workflow run logs"
|
||||
)
|
||||
|
|
@ -337,6 +350,17 @@ def register_event_handlers(
|
|||
else:
|
||||
logger.debug("Logs buffer is empty, skipping save")
|
||||
|
||||
logs_update.update(integration_logs)
|
||||
|
||||
if logs_update:
|
||||
try:
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
logs=logs_update,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving workflow run logs: {e}", exc_info=True)
|
||||
|
||||
# Write buffers to temp files and enqueue combined processing task
|
||||
audio_temp_path = None
|
||||
transcript_temp_path = None
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@ from loguru import logger
|
|||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.integrations import (
|
||||
IntegrationRuntimeContext,
|
||||
create_runtime_sessions,
|
||||
)
|
||||
from api.services.pipecat.audio_config import AudioConfig, create_audio_config
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_audio_data_handler,
|
||||
|
|
@ -525,6 +529,18 @@ async def _run_pipeline(
|
|||
# Create pipeline components
|
||||
audio_buffer, context = create_pipeline_components(audio_config)
|
||||
|
||||
integration_runtime_sessions = create_runtime_sessions(
|
||||
IntegrationRuntimeContext(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run=workflow_run,
|
||||
workflow_graph=workflow_graph,
|
||||
run_definition=run_definition,
|
||||
user_config=user_config,
|
||||
is_realtime=is_realtime,
|
||||
context_messages_provider=lambda: context.messages,
|
||||
)
|
||||
)
|
||||
|
||||
# Set the context, audio_config, and audio_buffer after creation
|
||||
engine.set_context(context)
|
||||
engine.set_audio_config(audio_config)
|
||||
|
|
@ -717,6 +733,14 @@ async def _run_pipeline(
|
|||
# Create pipeline task with audio configuration
|
||||
task = create_pipeline_task(pipeline, workflow_run_id, audio_config)
|
||||
|
||||
for runtime_session in integration_runtime_sessions:
|
||||
runtime_session.attach(task)
|
||||
logger.info(
|
||||
"[integrations] attached runtime session '{}' for workflow run {}",
|
||||
runtime_session.name,
|
||||
workflow_run_id,
|
||||
)
|
||||
|
||||
# Now set the task and transport output on the engine
|
||||
engine.set_task(task)
|
||||
engine.set_transport_output(transport.output())
|
||||
|
|
@ -780,6 +804,7 @@ async def _run_pipeline(
|
|||
audio_config=audio_config,
|
||||
pre_call_fetch_task=pre_call_fetch_task,
|
||||
user_provider_id=user_provider_id,
|
||||
integration_runtime_sessions=integration_runtime_sessions,
|
||||
)
|
||||
|
||||
register_audio_data_handler(audio_buffer, workflow_run_id, in_memory_audio_buffer)
|
||||
|
|
|
|||
10
api/services/telephony/AGENTS.md
Normal file
10
api/services/telephony/AGENTS.md
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
# Telephony
|
||||
|
||||
Shared telephony code lives here. Provider-specific code lives in `providers/`;
|
||||
read `providers/AGENTS.md` before changing a provider package.
|
||||
|
||||
- Keep cross-provider contracts, registry/factory wiring, shared status/transfer handling, and org-scoped config resolution in this folder.
|
||||
- Keep provider-specific transports, serializers, config models, and webhook handlers in `providers/`.
|
||||
- Resolve providers through the shared telephony helpers in this folder; do not instantiate provider classes directly from routes, tasks, or unrelated services.
|
||||
- Keep telephony config lookups tenant-safe and respect any run-scoped telephony configuration carried on a workflow run.
|
||||
- Keep provider-specific HTTP routes in provider packages; shared route glue belongs in `api/routes/`.
|
||||
123
api/services/telephony/providers/AGENTS.md
Normal file
123
api/services/telephony/providers/AGENTS.md
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
# Telephony Providers
|
||||
|
||||
Each subdirectory here is a self-registering telephony provider. Adding a new one should touch this folder plus **exactly two lines** outside it. If a change you're making requires editing `factory.py`, `audio_config.py`, `run_pipeline.py`, `routes/telephony.py`, or any frontend file, stop — that's a smell. Push the variation through the registry instead.
|
||||
|
||||
## Anatomy of a provider package
|
||||
|
||||
```
|
||||
providers/<name>/
|
||||
├── __init__.py # Required. Builds + register()s ProviderSpec
|
||||
├── config.py # Required. Pydantic Request + Response, both with `provider: Literal["<name>"]`
|
||||
├── provider.py # Required. TelephonyProvider subclass
|
||||
├── transport.py # Required. async create_transport(...) -> FastAPIWebsocketTransport
|
||||
├── serializers.py # Optional but conventional. Re-export from pipecat
|
||||
├── routes.py # Optional. APIRouter mounted lazily under /api/v1/telephony
|
||||
└── strategies.py # Optional. Transfer/Hangup strategies for the frame serializer
|
||||
```
|
||||
|
||||
Every file is provider-local. Nothing here imports another provider package.
|
||||
|
||||
## The two edits outside this folder
|
||||
|
||||
After creating `providers/<name>/`:
|
||||
|
||||
1. `providers/__init__.py` — add `<name>` to the import-for-side-effects list. Registration runs at import time.
|
||||
2. `api/schemas/telephony_config.py` — import `<Name>ConfigurationRequest`/`Response` and add the request to the `TelephonyConfigRequest` `Union[...]` and the response as an optional field on `TelephonyConfigurationResponse`.
|
||||
|
||||
If you find yourself editing anything else, re-read the registry plumbing first:
|
||||
|
||||
| Want to change... | Source of truth |
|
||||
| ------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Outbound provider lookup | `factory.get_default_telephony_provider`, `get_telephony_provider_by_id`, and `get_telephony_provider_for_run` read `registry.get(name).provider_cls` |
|
||||
| Stored credentials → constructor dict | `ProviderSpec.config_loader` |
|
||||
| Audio sample rate / VAD rate | `ProviderSpec.transport_sample_rate` (full `AudioConfig` is built in `pipecat/audio_config.py::create_audio_config`) |
|
||||
| Which transport runs in `run_pipeline_telephony` | `ProviderSpec.transport_factory` |
|
||||
| Save-request validation + masked response shape | `ProviderSpec.config_request_cls` / `config_response_cls` |
|
||||
| Form rendered by the telephony-config UI | `ProviderSpec.ui_metadata` (`ProviderUIField` list) |
|
||||
| Which credential masks on read | `ui_metadata.fields[*].sensitive=True` (no separate list) |
|
||||
| Inbound webhook → config row matching | `ProviderSpec.account_id_credential_field` |
|
||||
| HTTP routes (answer URL, status callbacks) | `providers/<name>/routes.py` (auto-mounted via `importlib`) |
|
||||
|
||||
## ProviderSpec — minimum viable shape
|
||||
|
||||
```python
|
||||
SPEC = ProviderSpec(
|
||||
name="<name>", # registry key, WorkflowRunMode value, stored discriminator
|
||||
provider_cls=YourProvider,
|
||||
config_loader=_config_loader, # raw dict from DB → constructor dict
|
||||
transport_factory=create_transport,
|
||||
transport_sample_rate=8000, # wire-format rate; pipecat derives the full AudioConfig
|
||||
config_request_cls=YourProviderConfigurationRequest,
|
||||
config_response_cls=YourProviderConfigurationResponse,
|
||||
ui_metadata=ProviderUIMetadata(...), # drives the form UI
|
||||
account_id_credential_field="api_key", # "" if provider has no account-id concept
|
||||
)
|
||||
register(SPEC)
|
||||
```
|
||||
|
||||
`ProviderSpec` is frozen — immutable post-registration. Re-registration with the same instance is a no-op; re-registration with a different instance raises.
|
||||
|
||||
## Registration is import-driven, not config-driven
|
||||
|
||||
`api/services/telephony/__init__.py` imports `providers/` for side effects. Don't add a registration call elsewhere — by the time `factory`, `audio_config`, or `run_pipeline_telephony` look the spec up, the package init has already executed.
|
||||
|
||||
The package init **does not import `routes.py`** — `api/routes/telephony.py::_mount_provider_routers()` walks `registry.all_specs()` and uses `importlib.import_module(f"...providers.{spec.name}.routes")`, treating `ModuleNotFoundError` as "no routes for this provider." This is what keeps `from api.services.telephony.base import TelephonyProvider` from fanning out to every route handler in the app. Don't undo it by importing `.routes` from `__init__.py`.
|
||||
|
||||
## Conventions
|
||||
|
||||
### `provider: Literal["<name>"]` on both Request and Response
|
||||
|
||||
Pydantic's discriminated union dispatches on this field. Forgetting `Literal` makes the union accept any provider's payload as yours. Default it to the literal so save calls don't have to send it explicitly.
|
||||
|
||||
### Transports load credentials lazily
|
||||
|
||||
Always:
|
||||
|
||||
```python
|
||||
from api.services.telephony.factory import load_credentials_for_transport
|
||||
|
||||
config = await load_credentials_for_transport(
|
||||
organization_id, telephony_configuration_id, expected_provider="<name>",
|
||||
)
|
||||
```
|
||||
|
||||
Never read the org's default config from `transport.py`. The workflow run carries `telephony_configuration_id` in `initial_context` for multi-config orgs; `load_credentials_for_transport` resolves the right row and validates the provider matches.
|
||||
|
||||
### `_config_loader` is a pure dict reshape
|
||||
|
||||
It runs over `TelephonyConfigurationModel.credentials` (the JSONB column). Don't do I/O in it. Don't pull `from_numbers` from credentials — the factory attaches active phone numbers from `telephony_phone_numbers` after the loader runs, by joining and normalizing addresses.
|
||||
|
||||
### Sensitive fields
|
||||
|
||||
Mark every credential field `sensitive=True` in `ProviderUIMetadata`. The org routes derive masking from `ui_metadata`, not from a separate hardcoded list. If you re-submit a masked value, `preserve_masked_fields` restores the original — relying on this means you should never write `sensitive=False` on a real secret to "make the form simpler."
|
||||
|
||||
### Inbound webhook routing
|
||||
|
||||
When multiple configs of the same provider live in one org (e.g. two Twilio sub-accounts), the inbound dispatcher matches the webhook to a config by `credentials[<account_id_credential_field>]`. Set this to whatever your provider stamps on inbound payloads (`account_sid` for Twilio, `auth_id` for Plivo, etc.). Set `""` only when the provider truly has no account-id concept (e.g. ARI — there's at most one config per org).
|
||||
|
||||
### `configure_inbound` defaults to no-op
|
||||
|
||||
Override only when the provider supports programmatic webhook binding (Plivo `application_id`, Telnyx app config). Markup-response providers that learn the webhook URL from console-side configuration leave the default. Returning `ProviderSyncResult(ok=False, message="...")` surfaces a non-fatal warning to the user without aborting the DB write.
|
||||
|
||||
## Reference implementations
|
||||
|
||||
Pick the closest shape and copy from it.
|
||||
|
||||
| Provider | Pick when... |
|
||||
| ------------ | ------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `twilio/` | Markup-response (TwiML), HMAC-signed webhooks, conference-style transfers, status callbacks. The most full-featured reference. |
|
||||
| `plivo/` | Markup-response with multi-callback signature schemes, programmatic answer-URL sync via Application API. |
|
||||
| `vonage/` | JWT auth, 16 kHz Linear PCM wire format, NCCO JSON responses. |
|
||||
| `cloudonix/` | SIP-trunk-style with custom transfer/hangup strategies. |
|
||||
| `telnyx/` | Call-control style — REST calls to answer/stream rather than markup response. |
|
||||
| `vobiz/` | Body-signed webhooks (signature covers raw bytes). |
|
||||
| `ari/` | Smallest viable: no `routes.py`, no `verify_inbound_signature`, WebSocket-only, no account-id. |
|
||||
|
||||
## What NOT to do
|
||||
|
||||
- **Don't import another provider's `provider.py` or `transport.py`.** Cross-provider behavior belongs in `services/telephony/` (e.g. `status_processor`, `ari_manager`, `call_transfer_manager`), not in another provider's package.
|
||||
- **Don't add a hardcoded provider list anywhere.** If you need to iterate, use `registry.all_specs()` / `registry.names()`.
|
||||
- **Don't add a route under `routes/telephony.py` for a single provider.** Provider-specific handlers go in `providers/<name>/routes.py`. Cross-provider handlers (`/inbound/run`, `/twiml`) stay in `routes/telephony.py`.
|
||||
- **Don't import `.routes` from a provider's `__init__.py`.** That's the cycle we deliberately broke — see "Registration is import-driven."
|
||||
- **Don't write a frontend form for a new provider.** The UI consumes `GET /api/v1/organizations/telephony-providers/metadata` and renders generically from `ProviderUIField`. If a `field.type` you need doesn't exist (`text`/`password`/`textarea`/`string-array`/`number`), extend the shared telephony UI under `ui/src/components/telephony/` once — not per provider.
|
||||
- **Don't run a database migration to add a provider.** The discriminator lives in JSONB credentials and a `VARCHAR(64)` `mode` column; nothing in the DB schema knows the set of provider names.
|
||||
|
|
@ -1,123 +1 @@
|
|||
# Telephony Providers
|
||||
|
||||
Each subdirectory here is a self-registering telephony provider. Adding a new one should touch this folder plus **exactly two lines** outside it. If a change you're making requires editing `factory.py`, `audio_config.py`, `run_pipeline.py`, `routes/telephony.py`, or any frontend file, stop — that's a smell. Push the variation through the registry instead.
|
||||
|
||||
## Anatomy of a provider package
|
||||
|
||||
```
|
||||
providers/<name>/
|
||||
├── __init__.py # Required. Builds + register()s ProviderSpec
|
||||
├── config.py # Required. Pydantic Request + Response, both with `provider: Literal["<name>"]`
|
||||
├── provider.py # Required. TelephonyProvider subclass
|
||||
├── transport.py # Required. async create_transport(...) -> FastAPIWebsocketTransport
|
||||
├── serializers.py # Optional but conventional. Re-export from pipecat
|
||||
├── routes.py # Optional. APIRouter mounted lazily under /api/v1/telephony
|
||||
└── strategies.py # Optional. Transfer/Hangup strategies for the frame serializer
|
||||
```
|
||||
|
||||
Every file is provider-local. Nothing here imports another provider package.
|
||||
|
||||
## The two edits outside this folder
|
||||
|
||||
After creating `providers/<name>/`:
|
||||
|
||||
1. `providers/__init__.py` — add `<name>` to the import-for-side-effects list. Registration runs at import time.
|
||||
2. `api/schemas/telephony_config.py` — import `<Name>ConfigurationRequest`/`Response` and add the request to the `TelephonyConfigRequest` `Union[...]` and the response as an optional field on `TelephonyConfigurationResponse`.
|
||||
|
||||
If you find yourself editing anything else, re-read the registry plumbing first:
|
||||
|
||||
| Want to change... | Source of truth |
|
||||
| --- | --- |
|
||||
| Outbound provider lookup | `factory.get_default_telephony_provider`, `get_telephony_provider_by_id`, and `get_telephony_provider_for_run` read `registry.get(name).provider_cls` |
|
||||
| Stored credentials → constructor dict | `ProviderSpec.config_loader` |
|
||||
| Audio sample rate / VAD rate | `ProviderSpec.transport_sample_rate` (full `AudioConfig` is built in `pipecat/audio_config.py::create_audio_config`) |
|
||||
| Which transport runs in `run_pipeline_telephony` | `ProviderSpec.transport_factory` |
|
||||
| Save-request validation + masked response shape | `ProviderSpec.config_request_cls` / `config_response_cls` |
|
||||
| Form rendered by the telephony-config UI | `ProviderSpec.ui_metadata` (`ProviderUIField` list) |
|
||||
| Which credential masks on read | `ui_metadata.fields[*].sensitive=True` (no separate list) |
|
||||
| Inbound webhook → config row matching | `ProviderSpec.account_id_credential_field` |
|
||||
| HTTP routes (answer URL, status callbacks) | `providers/<name>/routes.py` (auto-mounted via `importlib`) |
|
||||
|
||||
## ProviderSpec — minimum viable shape
|
||||
|
||||
```python
|
||||
SPEC = ProviderSpec(
|
||||
name="<name>", # registry key, WorkflowRunMode value, stored discriminator
|
||||
provider_cls=YourProvider,
|
||||
config_loader=_config_loader, # raw dict from DB → constructor dict
|
||||
transport_factory=create_transport,
|
||||
transport_sample_rate=8000, # wire-format rate; pipecat derives the full AudioConfig
|
||||
config_request_cls=YourProviderConfigurationRequest,
|
||||
config_response_cls=YourProviderConfigurationResponse,
|
||||
ui_metadata=ProviderUIMetadata(...), # drives the form UI
|
||||
account_id_credential_field="api_key", # "" if provider has no account-id concept
|
||||
)
|
||||
register(SPEC)
|
||||
```
|
||||
|
||||
`ProviderSpec` is frozen — immutable post-registration. Re-registration with the same instance is a no-op; re-registration with a different instance raises.
|
||||
|
||||
## Registration is import-driven, not config-driven
|
||||
|
||||
`api/services/telephony/__init__.py` imports `providers/` for side effects. Don't add a registration call elsewhere — by the time `factory`, `audio_config`, or `run_pipeline_telephony` look the spec up, the package init has already executed.
|
||||
|
||||
The package init **does not import `routes.py`** — `api/routes/telephony.py::_mount_provider_routers()` walks `registry.all_specs()` and uses `importlib.import_module(f"...providers.{spec.name}.routes")`, treating `ModuleNotFoundError` as "no routes for this provider." This is what keeps `from api.services.telephony.base import TelephonyProvider` from fanning out to every route handler in the app. Don't undo it by importing `.routes` from `__init__.py`.
|
||||
|
||||
## Conventions
|
||||
|
||||
### `provider: Literal["<name>"]` on both Request and Response
|
||||
|
||||
Pydantic's discriminated union dispatches on this field. Forgetting `Literal` makes the union accept any provider's payload as yours. Default it to the literal so save calls don't have to send it explicitly.
|
||||
|
||||
### Transports load credentials lazily
|
||||
|
||||
Always:
|
||||
|
||||
```python
|
||||
from api.services.telephony.factory import load_credentials_for_transport
|
||||
|
||||
config = await load_credentials_for_transport(
|
||||
organization_id, telephony_configuration_id, expected_provider="<name>",
|
||||
)
|
||||
```
|
||||
|
||||
Never read the org's default config from `transport.py`. The workflow run carries `telephony_configuration_id` in `initial_context` for multi-config orgs; `load_credentials_for_transport` resolves the right row and validates the provider matches.
|
||||
|
||||
### `_config_loader` is a pure dict reshape
|
||||
|
||||
It runs over `TelephonyConfigurationModel.credentials` (the JSONB column). Don't do I/O in it. Don't pull `from_numbers` from credentials — the factory attaches active phone numbers from `telephony_phone_numbers` after the loader runs, by joining and normalizing addresses.
|
||||
|
||||
### Sensitive fields
|
||||
|
||||
Mark every credential field `sensitive=True` in `ProviderUIMetadata`. The org routes derive masking from `ui_metadata`, not from a separate hardcoded list. If you re-submit a masked value, `preserve_masked_fields` restores the original — relying on this means you should never write `sensitive=False` on a real secret to "make the form simpler."
|
||||
|
||||
### Inbound webhook routing
|
||||
|
||||
When multiple configs of the same provider live in one org (e.g. two Twilio sub-accounts), the inbound dispatcher matches the webhook to a config by `credentials[<account_id_credential_field>]`. Set this to whatever your provider stamps on inbound payloads (`account_sid` for Twilio, `auth_id` for Plivo, etc.). Set `""` only when the provider truly has no account-id concept (e.g. ARI — there's at most one config per org).
|
||||
|
||||
### `configure_inbound` defaults to no-op
|
||||
|
||||
Override only when the provider supports programmatic webhook binding (Plivo `application_id`, Telnyx app config). Markup-response providers that learn the webhook URL from console-side configuration leave the default. Returning `ProviderSyncResult(ok=False, message="...")` surfaces a non-fatal warning to the user without aborting the DB write.
|
||||
|
||||
## Reference implementations
|
||||
|
||||
Pick the closest shape and copy from it.
|
||||
|
||||
| Provider | Pick when... |
|
||||
| --- | --- |
|
||||
| `twilio/` | Markup-response (TwiML), HMAC-signed webhooks, conference-style transfers, status callbacks. The most full-featured reference. |
|
||||
| `plivo/` | Markup-response with multi-callback signature schemes, programmatic answer-URL sync via Application API. |
|
||||
| `vonage/` | JWT auth, 16 kHz Linear PCM wire format, NCCO JSON responses. |
|
||||
| `cloudonix/` | SIP-trunk-style with custom transfer/hangup strategies. |
|
||||
| `telnyx/` | Call-control style — REST calls to answer/stream rather than markup response. |
|
||||
| `vobiz/` | Body-signed webhooks (signature covers raw bytes). |
|
||||
| `ari/` | Smallest viable: no `routes.py`, no `verify_inbound_signature`, WebSocket-only, no account-id. |
|
||||
|
||||
## What NOT to do
|
||||
|
||||
- **Don't import another provider's `provider.py` or `transport.py`.** Cross-provider behavior belongs in `services/telephony/` (e.g. `status_processor`, `ari_manager`, `call_transfer_manager`), not in another provider's package.
|
||||
- **Don't add a hardcoded provider list anywhere.** If you need to iterate, use `registry.all_specs()` / `registry.names()`.
|
||||
- **Don't add a route under `routes/telephony.py` for a single provider.** Provider-specific handlers go in `providers/<name>/routes.py`. Cross-provider handlers (`/inbound/run`, `/twiml`) stay in `routes/telephony.py`.
|
||||
- **Don't import `.routes` from a provider's `__init__.py`.** That's the cycle we deliberately broke — see "Registration is import-driven."
|
||||
- **Don't write a frontend form for a new provider.** The UI consumes `GET /api/v1/organizations/telephony-providers/metadata` and renders generically from `ProviderUIField`. If a `field.type` you need doesn't exist (`text`/`password`/`textarea`/`string-array`/`number`), extend the renderer in `ui/src/app/(authenticated)/telephony-configurations/` once — not per provider.
|
||||
- **Don't run a database migration to add a provider.** The discriminator lives in JSONB credentials and a `VARCHAR(64)` `mode` column; nothing in the DB schema knows the set of provider names.
|
||||
@AGENTS.md
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ TELNYX_TIMESTAMP_TOLERANCE_SECONDS = 300
|
|||
TELNYX_PUBLIC_KEY_BYTES = 32
|
||||
TELNYX_SIGNATURE_BYTES = 64
|
||||
|
||||
from api.constants import TELNYX_WEBHOOK_VERIFICATION_OPTIONAL
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.telephony.base import (
|
||||
CallInitiationResult,
|
||||
|
|
@ -211,12 +210,6 @@ class TelnyxProvider(TelephonyProvider):
|
|||
return False
|
||||
|
||||
if not self.webhook_public_key:
|
||||
# REMOVE-AFTER 2026-05-15: transition window. Allow webhooks
|
||||
# through for configs that haven't added the key yet. Remove this
|
||||
# branch along with TELNYX_WEBHOOK_VERIFICATION_OPTIONAL after
|
||||
# the cutoff.
|
||||
if TELNYX_WEBHOOK_VERIFICATION_OPTIONAL:
|
||||
return True
|
||||
logger.error("Missing Telnyx webhook_public_key configuration")
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ script in `api/services/admin_utils/local_exec.py` is the production
|
|||
consumer.
|
||||
"""
|
||||
|
||||
from api.services.workflow.node_specs import REGISTRY
|
||||
from api.services.workflow.node_specs import all_specs
|
||||
|
||||
|
||||
def _build_type_rules() -> tuple[set[str], set[str]]:
|
||||
|
|
@ -16,14 +16,14 @@ def _build_type_rules() -> tuple[set[str], set[str]]:
|
|||
(max_incoming == 0)."""
|
||||
src_forbidden: set[str] = set()
|
||||
tgt_forbidden: set[str] = set()
|
||||
for name, spec in REGISTRY.items():
|
||||
for spec in all_specs():
|
||||
gc = spec.graph_constraints
|
||||
if gc is None:
|
||||
continue
|
||||
if gc.max_outgoing == 0:
|
||||
src_forbidden.add(name)
|
||||
src_forbidden.add(spec.name)
|
||||
if gc.max_incoming == 0:
|
||||
tgt_forbidden.add(name)
|
||||
tgt_forbidden.add(spec.name)
|
||||
return src_forbidden, tgt_forbidden
|
||||
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
254
api/services/workflow/mcp_tool_session.py
Normal file
254
api/services/workflow/mcp_tool_session.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
"""Single unit that knows the MCP protocol + credentials.
|
||||
|
||||
Wraps the vendored Pipecat ``MCPClient`` for connection/session, builds
|
||||
streamable-HTTP params from a Dograh credential, exposes namespaced
|
||||
``FunctionSchema``s, and proxies tool calls. Connection failures degrade
|
||||
(``available = False``) instead of raising — the call must survive a
|
||||
dead MCP server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
|
||||
from loguru import logger
|
||||
from mcp.client.session_group import StreamableHttpParameters
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.services.mcp_service import MCPClient
|
||||
|
||||
from api.services.workflow.tools.mcp_tool import namespace_function_name
|
||||
from api.utils.credential_auth import build_auth_header
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.db.models import ExternalCredentialModel
|
||||
|
||||
|
||||
def build_streamable_http_params(
|
||||
*,
|
||||
url: str,
|
||||
credential: Optional["ExternalCredentialModel"],
|
||||
timeout_secs: int,
|
||||
sse_read_timeout_secs: int,
|
||||
) -> StreamableHttpParameters:
|
||||
"""Build Pipecat/MCP streamable-HTTP params, injecting the auth header
|
||||
from an ExternalCredentialModel (reuses the http_api credential path)."""
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
if credential is not None:
|
||||
auth = build_auth_header(credential)
|
||||
headers = auth or None
|
||||
return StreamableHttpParameters(
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=timedelta(seconds=timeout_secs),
|
||||
sse_read_timeout=timedelta(seconds=sse_read_timeout_secs),
|
||||
)
|
||||
|
||||
|
||||
class McpToolSession:
|
||||
"""One live MCP server connection for the duration of a call."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tool_uuid: str,
|
||||
tool_name: str,
|
||||
url: str,
|
||||
credential: Optional["ExternalCredentialModel"],
|
||||
tools_filter: List[str],
|
||||
timeout_secs: int,
|
||||
sse_read_timeout_secs: int,
|
||||
) -> None:
|
||||
self._tool_uuid = tool_uuid
|
||||
self._tool_name = tool_name
|
||||
self._url = url
|
||||
self._credential = credential
|
||||
# An empty list is intentionally treated as "no filter (expose all
|
||||
# tools)" — Pipecat's MCPClient applies a filter only when this is a
|
||||
# non-empty list, so [] and None are equivalent ("all tools").
|
||||
self._tools_filter = tools_filter or None
|
||||
self._timeout_secs = timeout_secs
|
||||
self._sse_read_timeout_secs = sse_read_timeout_secs
|
||||
|
||||
self._client: Optional[MCPClient] = None
|
||||
self._session: Any = None # mcp.ClientSession (read once after start)
|
||||
self._schemas: List[FunctionSchema] = []
|
||||
# namespaced LLM name -> original MCP tool name
|
||||
self._name_map: Dict[str, str] = {}
|
||||
self.available: bool = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect, initialize, and cache the tool list. Never raises —
|
||||
on any failure the session is marked unavailable."""
|
||||
try:
|
||||
params = build_streamable_http_params(
|
||||
url=self._url,
|
||||
credential=self._credential,
|
||||
timeout_secs=self._timeout_secs,
|
||||
sse_read_timeout_secs=self._sse_read_timeout_secs,
|
||||
)
|
||||
self._client = MCPClient(params, tools_filter=self._tools_filter)
|
||||
await self._client.start()
|
||||
# Single, isolated touch of Pipecat internals (vendored submodule).
|
||||
self._session = self._client._active_session
|
||||
tools_schema = await self._client.get_tools_schema()
|
||||
|
||||
fallback = self._tool_uuid[:8] if self._tool_uuid else "server"
|
||||
for fs in tools_schema.standard_tools:
|
||||
ns_name = namespace_function_name(
|
||||
self._tool_name, fs.name, fallback=fallback
|
||||
)
|
||||
self._name_map[ns_name] = fs.name
|
||||
self._schemas.append(
|
||||
FunctionSchema(
|
||||
name=ns_name,
|
||||
description=fs.description,
|
||||
properties=fs.properties,
|
||||
required=fs.required,
|
||||
)
|
||||
)
|
||||
self.available = True
|
||||
logger.info(
|
||||
f"MCP session ready for tool '{self._tool_name}' "
|
||||
f"({self._tool_uuid}): {sorted(self._name_map)}"
|
||||
)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
raise
|
||||
except asyncio.CancelledError as e:
|
||||
# Empirically, a dead/unreachable MCP server does NOT surface as a
|
||||
# plain Exception here. The real failure is httpx.ConnectError, but
|
||||
# anyio's streamablehttp_client task group, while tearing down that
|
||||
# ConnectError, re-surfaces it to our frame as an *internal*
|
||||
# cancel-scope CancelledError carrying the signature message
|
||||
# "Cancelled via cancel scope <id>". A genuine *external*
|
||||
# cancellation (call teardown / shutdown) is a bare CancelledError
|
||||
# (empty args) or one with an application-chosen message. Type, MRO,
|
||||
# context chain, and asyncio task.cancelling() are all identical
|
||||
# between the two, so the anyio scope-signature message is the only
|
||||
# reliable discriminator. Re-raise genuine external cancellation to
|
||||
# preserve structured concurrency; degrade only on the anyio
|
||||
# connect-teardown artifact.
|
||||
msg = "" if not e.args else str(e.args[0] or "")
|
||||
if not msg.startswith("Cancelled via cancel scope"):
|
||||
raise
|
||||
await self._degrade(e)
|
||||
except Exception as e: # noqa: BLE001 — see _degrade docstring
|
||||
# Defensive: if a future Pipecat/httpx version surfaces the connect
|
||||
# failure directly (e.g. httpx.ConnectError) instead of via the
|
||||
# anyio cancel-scope artifact above, still degrade gracefully.
|
||||
await self._degrade(e)
|
||||
|
||||
async def _degrade(self, e: BaseException) -> None:
|
||||
"""Mark this session unavailable and tear down any dangling client so
|
||||
start() leaves self._client either fully usable or None. The contract
|
||||
requires graceful degradation on any *connect* failure (never raising
|
||||
for a dead MCP server) while genuine external cancellation /
|
||||
KeyboardInterrupt / SystemExit are re-raised by the caller."""
|
||||
self.available = False
|
||||
self._schemas = []
|
||||
self._name_map = {}
|
||||
# Self-contained cleanup: _client.start() may have succeeded before a
|
||||
# later step (e.g. get_tools_schema()) failed, leaving an open client.
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._client = None
|
||||
self._session = None
|
||||
logger.warning(
|
||||
f"MCP session unavailable for tool '{self._tool_name}' "
|
||||
f"({self._tool_uuid}) at {self._url}: {e!r}. "
|
||||
f"Call proceeds without these tools."
|
||||
)
|
||||
|
||||
@property
|
||||
def call_timeout_secs(self) -> float:
|
||||
"""Pipecat function-call timeout for this server's tools. Slightly
|
||||
longer than the transport read timeout so a slow MCP call surfaces
|
||||
as a structured tool error (handled in the handler) rather than a
|
||||
hard pipeline timeout."""
|
||||
return float(self._sse_read_timeout_secs) + 5.0
|
||||
|
||||
def function_schemas(
|
||||
self, allowed_raw_names: Optional[Set[str]] = None
|
||||
) -> List[FunctionSchema]:
|
||||
"""Return cached FunctionSchemas, optionally filtered by raw MCP tool name.
|
||||
|
||||
``allowed_raw_names=None`` returns all schemas. An empty set returns none.
|
||||
Raw names are the pre-namespace MCP tool names (e.g. ``echo``, not
|
||||
``mcp__slug__echo``).
|
||||
"""
|
||||
if allowed_raw_names is None:
|
||||
return list(self._schemas)
|
||||
return [
|
||||
s for s in self._schemas if self._name_map.get(s.name) in allowed_raw_names
|
||||
]
|
||||
|
||||
def discovered_tools(self) -> List[Dict[str, str]]:
|
||||
"""Raw MCP tool catalog for UI/cache: ``[{name, description}]``
|
||||
using the *raw* server names (not the namespaced LLM names).
|
||||
Empty if the session is unavailable."""
|
||||
out: List[Dict[str, str]] = []
|
||||
for s in self._schemas:
|
||||
raw = self._name_map.get(s.name)
|
||||
if raw is None:
|
||||
continue
|
||||
out.append({"name": raw, "description": s.description or ""})
|
||||
return out
|
||||
|
||||
async def call(self, namespaced_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""Invoke an MCP tool by its namespaced LLM name. Returns a string
|
||||
(flattened text content). Raises if the session is unavailable so
|
||||
the caller can map it to a structured error for the LLM."""
|
||||
if not self.available or self._session is None:
|
||||
raise RuntimeError(f"MCP session unavailable for {namespaced_name}")
|
||||
original = self._name_map.get(namespaced_name)
|
||||
if original is None:
|
||||
raise RuntimeError(f"Unknown MCP function {namespaced_name}")
|
||||
result = await self._session.call_tool(original, arguments=arguments)
|
||||
text = ""
|
||||
for content in getattr(result, "content", []) or []:
|
||||
if getattr(content, "text", None):
|
||||
text += content.text
|
||||
return text or "Sorry, the MCP tool returned no content."
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing MCP session {self._tool_uuid}: {e}")
|
||||
finally:
|
||||
self._client = None
|
||||
self._session = None
|
||||
|
||||
|
||||
async def discover_mcp_tools(
|
||||
*,
|
||||
url: str,
|
||||
credential: Optional["ExternalCredentialModel"],
|
||||
timeout_secs: int,
|
||||
sse_read_timeout_secs: int,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Open an ephemeral MCP session, list its tools, close it. Returns
|
||||
``[{name, description}]`` (raw names). Never raises — on any connect
|
||||
failure returns ``[]``."""
|
||||
session = McpToolSession(
|
||||
tool_uuid="discover",
|
||||
tool_name="discover",
|
||||
url=url,
|
||||
credential=credential,
|
||||
tools_filter=[],
|
||||
timeout_secs=timeout_secs,
|
||||
sse_read_timeout_secs=sse_read_timeout_secs,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
if not session.available:
|
||||
return []
|
||||
return session.discovered_tools()
|
||||
finally:
|
||||
await session.close()
|
||||
19
api/services/workflow/node_data.py
Normal file
19
api/services/workflow/node_data.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.services.workflow.node_specs._base import PropertyType
|
||||
from api.services.workflow.node_specs.model_spec import spec_field
|
||||
|
||||
|
||||
class BaseNodeData(BaseModel):
|
||||
name: str = spec_field(
|
||||
...,
|
||||
min_length=1,
|
||||
ui_type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description="Short identifier shown in the canvas and call logs.",
|
||||
required=True,
|
||||
)
|
||||
is_start: bool = spec_field(default=False, spec_exclude=True)
|
||||
is_end: bool = spec_field(default=False, spec_exclude=True)
|
||||
|
|
@ -1,10 +1,8 @@
|
|||
"""Node specification registry.
|
||||
|
||||
Adding a new node type:
|
||||
1. Create a new module under this package, define a `SPEC: NodeSpec`.
|
||||
2. Add it to the imports + REGISTRY below.
|
||||
3. The Pydantic discriminated-union variant in dto.py must use the same
|
||||
`name` value as `SPEC.name`.
|
||||
Core node specs are generated from the workflow DTO models. Third-party
|
||||
integration node specs live under `api.services.integrations/<name>/` and
|
||||
register through the integration registry so they don't need edits here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -21,8 +19,10 @@ from api.services.workflow.node_specs._base import (
|
|||
PropertyType,
|
||||
evaluate_display_options,
|
||||
)
|
||||
from api.services.workflow.node_specs.model_spec import build_spec
|
||||
|
||||
REGISTRY: dict[str, NodeSpec] = {}
|
||||
_CORE_SPECS_LOADED = False
|
||||
|
||||
|
||||
def register(spec: NodeSpec) -> NodeSpec:
|
||||
|
|
@ -38,12 +38,23 @@ def register(spec: NodeSpec) -> NodeSpec:
|
|||
|
||||
|
||||
def get_spec(name: str) -> NodeSpec | None:
|
||||
return REGISTRY.get(name)
|
||||
_ensure_core_registered()
|
||||
if name in REGISTRY:
|
||||
return REGISTRY[name]
|
||||
|
||||
from api.services.integrations import get_node_spec
|
||||
|
||||
return get_node_spec(name)
|
||||
|
||||
|
||||
def all_specs() -> list[NodeSpec]:
|
||||
"""All registered specs, sorted by name for stable output."""
|
||||
return [REGISTRY[name] for name in sorted(REGISTRY)]
|
||||
_ensure_core_registered()
|
||||
from api.services.integrations import all_node_specs
|
||||
|
||||
specs = {spec.name: spec for spec in REGISTRY.values()}
|
||||
specs.update({spec.name: spec for spec in all_node_specs()})
|
||||
return [specs[name] for name in sorted(specs)]
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -64,19 +75,15 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
# Side-effect imports — each module's `register(SPEC)` call populates REGISTRY.
|
||||
# Keep at module bottom so the registry helpers are defined first.
|
||||
from api.services.workflow.node_specs import ( # noqa: E402, F401
|
||||
agent,
|
||||
end_call,
|
||||
global_node,
|
||||
qa,
|
||||
start_call,
|
||||
trigger,
|
||||
webhook,
|
||||
)
|
||||
def _ensure_core_registered() -> None:
|
||||
global _CORE_SPECS_LOADED
|
||||
if _CORE_SPECS_LOADED:
|
||||
return
|
||||
|
||||
# Wire up registrations from the SPEC constants in each module.
|
||||
for _module in (start_call, agent, end_call, global_node, trigger, webhook, qa):
|
||||
register(_module.SPEC)
|
||||
del _module
|
||||
from api.services.workflow.dto import _CORE_NODE_DATA_CLASSES
|
||||
|
||||
for model_cls in _CORE_NODE_DATA_CLASSES.values():
|
||||
if model_cls.__node_spec_metadata__.name in REGISTRY:
|
||||
continue
|
||||
register(build_spec(model_cls))
|
||||
_CORE_SPECS_LOADED = True
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
"""Spec schema for node definitions.
|
||||
|
||||
A `NodeSpec` is the single source of truth for a node type. It drives:
|
||||
- Pydantic validation (the per-type DTOs in dto.py mirror these property types)
|
||||
- The generic UI renderer (frontend reads specs via /api/v1/node-types)
|
||||
- The LLM SDK (constructors and JSON-Schema derived from these specs)
|
||||
`NodeSpec` is the serialized contract exposed to the frontend, MCP tools, and
|
||||
SDKs. Core workflow node specs are generated from the DTO models plus
|
||||
model-attached metadata; integration packages may generate them the same way or
|
||||
register a prebuilt spec object.
|
||||
|
||||
Every property's `description` is LLM-readable copy — treat it as production
|
||||
documentation, not internal notes. Spec lint enforces non-empty descriptions
|
||||
|
|
@ -122,6 +122,16 @@ class PropertyOption(BaseModel):
|
|||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def to_mcp_dict(self) -> dict[str, Any]:
|
||||
"""Lean projection for `get_node_type`: the `value` an LLM writes in
|
||||
code, plus a `description` when one carries real meaning. The UI
|
||||
`label` is dropped — it's the option's display string, never used
|
||||
when authoring."""
|
||||
out: dict[str, Any] = {"value": self.value}
|
||||
if self.description:
|
||||
out["description"] = self.description
|
||||
return out
|
||||
|
||||
|
||||
class PropertySpec(BaseModel):
|
||||
"""Single field on a node.
|
||||
|
|
@ -175,6 +185,43 @@ class PropertySpec(BaseModel):
|
|||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def to_mcp_dict(self) -> dict[str, Any]:
|
||||
"""Lean projection of this property for the `get_node_type` MCP tool.
|
||||
|
||||
Keeps only what an LLM needs to author a valid value: name, type,
|
||||
description, llm_hint, requiredness, default, enum options, nested
|
||||
row properties, and validation bounds. UI-rendering concerns
|
||||
(`display_name`, `placeholder`, `display_options`, `editor`,
|
||||
`extra`) and null/empty fields are omitted — they're noise in the
|
||||
model's context and never appear in authored SDK code.
|
||||
"""
|
||||
out: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"type": self.type.value,
|
||||
"description": self.description,
|
||||
}
|
||||
if self.llm_hint:
|
||||
out["llm_hint"] = self.llm_hint
|
||||
if self.required:
|
||||
out["required"] = True
|
||||
if self.default is not None:
|
||||
out["default"] = self.default
|
||||
if self.options:
|
||||
out["options"] = [opt.to_mcp_dict() for opt in self.options]
|
||||
if self.properties:
|
||||
out["properties"] = [prop.to_mcp_dict() for prop in self.properties]
|
||||
for constraint in (
|
||||
"min_value",
|
||||
"max_value",
|
||||
"min_length",
|
||||
"max_length",
|
||||
"pattern",
|
||||
):
|
||||
value = getattr(self, constraint)
|
||||
if value is not None:
|
||||
out[constraint] = value
|
||||
return out
|
||||
|
||||
|
||||
PropertySpec.model_rebuild()
|
||||
|
||||
|
|
@ -222,3 +269,33 @@ class NodeSpec(BaseModel):
|
|||
graph_constraints: Optional[GraphConstraints] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def to_mcp_dict(self) -> dict[str, Any]:
|
||||
"""Lean projection of this spec for the `get_node_type` MCP tool.
|
||||
|
||||
Drops node-level UI metadata (`display_name`, `category`, `icon`,
|
||||
`version`) and the per-property rendering concerns trimmed by
|
||||
`PropertySpec.to_mcp_dict`, leaving just the authoring-relevant
|
||||
schema the LLM consumes when composing a workflow. The full spec is
|
||||
still served verbatim to the frontend renderer (REST `node-types`
|
||||
route) and the SDK codegen / TS validator (`ts_bridge`), which need
|
||||
the dropped fields.
|
||||
"""
|
||||
out: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
}
|
||||
if self.llm_hint:
|
||||
out["llm_hint"] = self.llm_hint
|
||||
out["properties"] = [prop.to_mcp_dict() for prop in self.properties]
|
||||
if self.examples:
|
||||
out["examples"] = [
|
||||
ex.model_dump(mode="json", exclude_none=True) for ex in self.examples
|
||||
]
|
||||
if self.graph_constraints:
|
||||
constraints = self.graph_constraints.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
)
|
||||
if constraints:
|
||||
out["graph_constraints"] = constraints
|
||||
return out
|
||||
|
|
|
|||
|
|
@ -1,168 +0,0 @@
|
|||
"""Spec for the Agent node — the workhorse mid-call node where the LLM
|
||||
executes a focused conversational step with optional tools and documents."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
DisplayOptions,
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertyOption,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="agentNode",
|
||||
display_name="Agent Node",
|
||||
description="Conversational step — the LLM runs one focused exchange.",
|
||||
llm_hint=(
|
||||
"Mid-call step executed by the LLM. Most workflows are a chain of "
|
||||
"agent nodes connected by edges that describe transition conditions. "
|
||||
"Each agent node can invoke tools and reference documents."
|
||||
),
|
||||
category=NodeCategory.call_node,
|
||||
icon="Headset",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description=(
|
||||
"Short identifier for this step (e.g., 'Qualify Budget'). "
|
||||
"Appears in call logs and edge transition tools."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="Agent",
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.mention_textarea,
|
||||
display_name="Prompt",
|
||||
description=(
|
||||
"Agent system prompt for this step. Supports "
|
||||
"{{template_variables}} from extraction or pre-call fetch."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
placeholder="Ask the caller about their budget and timeline.",
|
||||
),
|
||||
PropertySpec(
|
||||
name="allow_interrupt",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Allow Interruption",
|
||||
description=(
|
||||
"When true, the user can interrupt the agent mid-utterance. "
|
||||
"Set false for non-interruptible disclosures."
|
||||
),
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="add_global_prompt",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Add Global Prompt",
|
||||
description=(
|
||||
"When true and a Global node exists, prepends the global "
|
||||
"prompt to this node's prompt at runtime."
|
||||
),
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Enable Variable Extraction",
|
||||
description=(
|
||||
"When true, runs an LLM extraction pass on transition out of "
|
||||
"this node to capture variables from the conversation."
|
||||
),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="Extraction Prompt",
|
||||
description="Overall instructions guiding variable extraction.",
|
||||
display_options=DisplayOptions(show={"extraction_enabled": [True]}),
|
||||
editor="textarea",
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_variables",
|
||||
type=PropertyType.fixed_collection,
|
||||
display_name="Variables to Extract",
|
||||
description=(
|
||||
"Each entry declares one variable to capture from the "
|
||||
"conversation, with its name, type, and per-variable hint."
|
||||
),
|
||||
display_options=DisplayOptions(show={"extraction_enabled": [True]}),
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Variable Name",
|
||||
description="snake_case identifier used downstream.",
|
||||
required=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="type",
|
||||
type=PropertyType.options,
|
||||
display_name="Type",
|
||||
description="Data type of the extracted value.",
|
||||
required=True,
|
||||
default="string",
|
||||
options=[
|
||||
PropertyOption(value="string", label="String"),
|
||||
PropertyOption(value="number", label="Number"),
|
||||
PropertyOption(value="boolean", label="Boolean"),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="Extraction Hint",
|
||||
description="Per-variable hint describing what to look for.",
|
||||
editor="textarea",
|
||||
),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="tool_uuids",
|
||||
type=PropertyType.tool_refs,
|
||||
display_name="Tools",
|
||||
description="Tools the agent can invoke during this step.",
|
||||
llm_hint="List of tool UUIDs from `list_tools`.",
|
||||
),
|
||||
PropertySpec(
|
||||
name="document_uuids",
|
||||
type=PropertyType.document_refs,
|
||||
display_name="Knowledge Base Documents",
|
||||
description="Documents the agent can reference during this step.",
|
||||
llm_hint="List of document UUIDs from `list_documents`.",
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="qualify_lead",
|
||||
data={
|
||||
"name": "Qualify Budget",
|
||||
"prompt": "Ask about budget and timeline. Capture both before transitioning.",
|
||||
"allow_interrupt": True,
|
||||
"extraction_enabled": True,
|
||||
"extraction_prompt": "Extract budget amount and rough timeline.",
|
||||
"extraction_variables": [
|
||||
{
|
||||
"name": "budget_usd",
|
||||
"type": "number",
|
||||
"prompt": "Stated budget in USD",
|
||||
},
|
||||
{
|
||||
"name": "timeline",
|
||||
"type": "string",
|
||||
"prompt": "When they want to start",
|
||||
},
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
graph_constraints=GraphConstraints(min_incoming=1),
|
||||
)
|
||||
44
api/services/workflow/node_specs/constants.py
Normal file
44
api/services/workflow/node_specs/constants.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
DEFAULT_QA_SYSTEM_PROMPT = """You are a QA analyst evaluating a specific segment of a voice AI conversation.
|
||||
|
||||
## Node Purpose
|
||||
{{node_summary}}
|
||||
|
||||
## Previous Conversation Context (For start of conversation, previous conversation summary can be empty.)
|
||||
{{previous_conversation_summary}}
|
||||
|
||||
## Tags to evaluate
|
||||
|
||||
Examine the conversation carefully and identify which of the following tags apply:
|
||||
|
||||
- UNCLEAR_CONVERSATION - The conversation is not coherent or clear, messages don't connect logically
|
||||
- ASSISTANT_IN_LOOP - The assistant asks the same question multiple times or gets stuck repeating itself
|
||||
- ASSISTANT_REPLY_IMPROPER - The assistant did not reply properly to the user's question/query or seems confused by what the user said
|
||||
- USER_FRUSTRATED - The user seems angry, frustrated, or is complaining about something in the call
|
||||
- USER_NOT_UNDERSTANDING - The user explicitly says they don't understand or repeatedly asks for clarification
|
||||
- HEARING_ISSUES - Either party can't hear the other ("hello?", "are you there?", "can you hear me?")
|
||||
- DEAD_AIR - Unusually long silences in the conversation (use the timestamps to judge)
|
||||
- USER_REQUESTING_FEATURE - The user asks for something the assistant can't fulfill
|
||||
- ASSISTANT_LACKS_EMPATHY - The assistant ignores the user's personal situation or emotional state and continues pitching or pushing the agenda.
|
||||
- USER_DETECTS_AI - The user suspects or identifies that they are talking to an AI/robot/bot rather than a real human.
|
||||
|
||||
## Call metrics (pre-computed)
|
||||
|
||||
Use these alongside the transcript for your analysis:
|
||||
{{metrics}}
|
||||
|
||||
## Output format
|
||||
|
||||
Return ONLY a valid JSON object (no markdown):
|
||||
{
|
||||
"tags": [
|
||||
{
|
||||
"tag": "TAG_NAME",
|
||||
"reason": "Short reason with evidence from the transcript"
|
||||
}
|
||||
],
|
||||
"overall_sentiment": "positive|neutral|negative",
|
||||
"call_quality_score": <1-10>,
|
||||
"summary": "1-2 sentence summary of this segment"
|
||||
}
|
||||
|
||||
If no tags apply, return an empty tags list. Always provide sentiment, score, and summary."""
|
||||
|
|
@ -1,141 +0,0 @@
|
|||
"""Spec for the End Call node — terminal node that wraps up a conversation
|
||||
and optionally extracts variables before hangup."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
DisplayOptions,
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertyOption,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="endCall",
|
||||
display_name="End Call",
|
||||
description="Closes the conversation and hangs up.",
|
||||
llm_hint=(
|
||||
"Terminal node that politely closes the conversation. Variable "
|
||||
"extraction can run before hangup. A workflow can have multiple "
|
||||
"endCall nodes reached via different edge conditions."
|
||||
),
|
||||
category=NodeCategory.call_node,
|
||||
icon="OctagonX",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description=(
|
||||
"Short identifier shown in call logs. Should describe the "
|
||||
"ending context (e.g., 'Successful close', 'Polite decline')."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="End Call",
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.mention_textarea,
|
||||
display_name="Prompt",
|
||||
description=(
|
||||
"Agent system prompt for the closing exchange. Supports "
|
||||
"{{template_variables}} from extraction or pre-call fetch."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
placeholder="Thank the caller and confirm next steps before ending the call.",
|
||||
),
|
||||
PropertySpec(
|
||||
name="add_global_prompt",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Add Global Prompt",
|
||||
description=(
|
||||
"When true and a Global node exists, prepends the global "
|
||||
"prompt to this node's prompt at runtime."
|
||||
),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Enable Variable Extraction",
|
||||
description=(
|
||||
"When true, runs an LLM extraction pass before hangup to "
|
||||
"capture variables from the conversation."
|
||||
),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="Extraction Prompt",
|
||||
description=(
|
||||
"Overall instructions guiding how variables should be "
|
||||
"extracted from the conversation."
|
||||
),
|
||||
display_options=DisplayOptions(show={"extraction_enabled": [True]}),
|
||||
editor="textarea",
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_variables",
|
||||
type=PropertyType.fixed_collection,
|
||||
display_name="Variables to Extract",
|
||||
description=(
|
||||
"Each entry declares one variable to capture from the "
|
||||
"conversation, with its name, data type, and a per-variable "
|
||||
"extraction hint."
|
||||
),
|
||||
display_options=DisplayOptions(show={"extraction_enabled": [True]}),
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Variable Name",
|
||||
description="snake_case identifier used downstream.",
|
||||
required=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="type",
|
||||
type=PropertyType.options,
|
||||
display_name="Type",
|
||||
description="The data type of the extracted value.",
|
||||
required=True,
|
||||
default="string",
|
||||
options=[
|
||||
PropertyOption(value="string", label="String"),
|
||||
PropertyOption(value="number", label="Number"),
|
||||
PropertyOption(value="boolean", label="Boolean"),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="Extraction Hint",
|
||||
description=(
|
||||
"Per-variable hint describing what to look for in "
|
||||
"the conversation."
|
||||
),
|
||||
editor="textarea",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="successful_close",
|
||||
data={
|
||||
"name": "Successful Close",
|
||||
"prompt": "Confirm the appointment time, thank the caller, and end the call.",
|
||||
"add_global_prompt": False,
|
||||
},
|
||||
),
|
||||
],
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=1,
|
||||
min_outgoing=0,
|
||||
max_outgoing=0,
|
||||
),
|
||||
)
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
"""Spec for the Global node — system-level instructions appended to every
|
||||
agent node that opts in via `add_global_prompt`."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="globalNode",
|
||||
display_name="Global Node",
|
||||
description="Persona/tone appended to every agent node's prompt.",
|
||||
llm_hint=(
|
||||
"System-level prompt appended to every prompted node whose "
|
||||
"`add_global_prompt` is true. Use it for persona, tone, and shared "
|
||||
"rules that apply across the entire conversation. At most one "
|
||||
"global node per workflow."
|
||||
),
|
||||
category=NodeCategory.global_node,
|
||||
icon="Globe",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description=(
|
||||
"Short identifier shown in the canvas and call logs. Has no "
|
||||
"runtime effect."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="Global Node",
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.mention_textarea,
|
||||
display_name="Global Prompt",
|
||||
description=(
|
||||
"Text appended to every prompted node's system prompt when "
|
||||
"that node has `add_global_prompt=true`. Supports "
|
||||
"{{template_variables}}."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
placeholder="You are a friendly assistant calling on behalf of {{company_name}}.",
|
||||
default=(
|
||||
"You are a helpful assistant whose mode of interaction with "
|
||||
"the user is voice. So don't use any special characters which "
|
||||
"can not be pronounced. Use short sentences and simple language."
|
||||
),
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="basic_persona",
|
||||
description="Establishes a consistent persona across the call.",
|
||||
data={
|
||||
"name": "Persona",
|
||||
"prompt": (
|
||||
"You are Sarah, a polite and warm representative from "
|
||||
"Acme Corp. Always thank the caller for their time and "
|
||||
"speak in short conversational sentences."
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0,
|
||||
max_incoming=0,
|
||||
min_outgoing=0,
|
||||
max_outgoing=0,
|
||||
),
|
||||
)
|
||||
404
api/services/workflow/node_specs/model_spec.py
Normal file
404
api/services/workflow/node_specs/model_spec.py
Normal file
|
|
@ -0,0 +1,404 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field as dataclass_field
|
||||
from enum import Enum
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, Literal, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.fields import FieldInfo, PydanticUndefined
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
DisplayOptions,
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertyOption,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
_SPEC_FIELD_META_KEY = "__dograh_spec_field__"
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NodeSpecMetadata:
|
||||
name: str
|
||||
display_name: str
|
||||
description: str
|
||||
category: NodeCategory
|
||||
icon: str
|
||||
llm_hint: str | None = None
|
||||
version: str = "1.0.0"
|
||||
examples: tuple[NodeExample, ...] = ()
|
||||
graph_constraints: GraphConstraints | None = None
|
||||
property_order: tuple[str, ...] = ()
|
||||
field_overrides: dict[str, dict[str, Any]] = dataclass_field(default_factory=dict)
|
||||
|
||||
|
||||
def spec_field(
|
||||
*field_args: Any,
|
||||
ui_type: PropertyType | str | None = None,
|
||||
display_name: str | None = None,
|
||||
llm_hint: str | None = None,
|
||||
required: bool | None = None,
|
||||
spec_default: Any = _UNSET,
|
||||
placeholder: str | None = None,
|
||||
display_options: DisplayOptions | None = None,
|
||||
options: list[PropertyOption] | None = None,
|
||||
editor: str | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
spec_exclude: bool = False,
|
||||
min_value: float | None = None,
|
||||
max_value: float | None = None,
|
||||
min_length: int | None = None,
|
||||
max_length: int | None = None,
|
||||
pattern: str | None = None,
|
||||
**field_kwargs: Any,
|
||||
):
|
||||
json_schema_extra = dict(field_kwargs.pop("json_schema_extra", {}) or {})
|
||||
json_schema_extra[_SPEC_FIELD_META_KEY] = {
|
||||
"ui_type": ui_type.value if isinstance(ui_type, PropertyType) else ui_type,
|
||||
"display_name": display_name,
|
||||
"llm_hint": llm_hint,
|
||||
"required": required,
|
||||
"placeholder": placeholder,
|
||||
"display_options": display_options,
|
||||
"options": options,
|
||||
"editor": editor,
|
||||
"extra": extra or {},
|
||||
"spec_exclude": spec_exclude,
|
||||
"min_value": min_value,
|
||||
"max_value": max_value,
|
||||
"min_length": min_length,
|
||||
"max_length": max_length,
|
||||
"pattern": pattern,
|
||||
}
|
||||
if spec_default is not _UNSET:
|
||||
json_schema_extra[_SPEC_FIELD_META_KEY]["spec_default"] = spec_default
|
||||
return Field(*field_args, json_schema_extra=json_schema_extra, **field_kwargs)
|
||||
|
||||
|
||||
def node_spec(
|
||||
*,
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str,
|
||||
category: NodeCategory,
|
||||
icon: str,
|
||||
llm_hint: str | None = None,
|
||||
version: str = "1.0.0",
|
||||
examples: list[NodeExample] | tuple[NodeExample, ...] = (),
|
||||
graph_constraints: GraphConstraints | None = None,
|
||||
property_order: list[str] | tuple[str, ...] = (),
|
||||
field_overrides: dict[str, dict[str, Any]] | None = None,
|
||||
) -> Callable[[type[BaseModel]], type[BaseModel]]:
|
||||
metadata = NodeSpecMetadata(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
category=category,
|
||||
icon=icon,
|
||||
llm_hint=llm_hint,
|
||||
version=version,
|
||||
examples=tuple(examples),
|
||||
graph_constraints=graph_constraints,
|
||||
property_order=tuple(property_order),
|
||||
field_overrides=field_overrides or {},
|
||||
)
|
||||
|
||||
def decorator(model_cls: type[BaseModel]) -> type[BaseModel]:
|
||||
setattr(model_cls, "__node_spec_metadata__", metadata)
|
||||
return model_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def build_spec(model_cls: type[BaseModel]) -> NodeSpec:
|
||||
metadata: NodeSpecMetadata | None = getattr(
|
||||
model_cls, "__node_spec_metadata__", None
|
||||
)
|
||||
if metadata is None:
|
||||
raise ValueError(f"{model_cls.__name__} is missing __node_spec_metadata__")
|
||||
|
||||
properties: list[PropertySpec] = []
|
||||
for name, field in model_cls.model_fields.items():
|
||||
prop = _build_property_spec(model_cls, name, field)
|
||||
if prop is not None:
|
||||
properties.append(prop)
|
||||
properties = _sort_properties(metadata.name, properties, metadata.property_order)
|
||||
|
||||
return NodeSpec(
|
||||
name=metadata.name,
|
||||
display_name=metadata.display_name,
|
||||
description=metadata.description,
|
||||
llm_hint=metadata.llm_hint,
|
||||
category=metadata.category,
|
||||
icon=metadata.icon,
|
||||
version=metadata.version,
|
||||
properties=properties,
|
||||
examples=list(metadata.examples),
|
||||
graph_constraints=metadata.graph_constraints,
|
||||
)
|
||||
|
||||
|
||||
def _sort_properties(
|
||||
spec_name: str,
|
||||
properties: list[PropertySpec],
|
||||
property_order: tuple[str, ...],
|
||||
) -> list[PropertySpec]:
|
||||
if not property_order:
|
||||
return properties
|
||||
|
||||
property_names = {prop.name for prop in properties}
|
||||
missing = [name for name in property_order if name not in property_names]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"{spec_name}: property_order references unknown properties: {missing}"
|
||||
)
|
||||
|
||||
order_map = {name: idx for idx, name in enumerate(property_order)}
|
||||
ordered = sorted(
|
||||
enumerate(properties),
|
||||
key=lambda item: (order_map.get(item[1].name, len(order_map)), item[0]),
|
||||
)
|
||||
return [prop for _, prop in ordered]
|
||||
|
||||
|
||||
def _build_property_spec(
|
||||
owner_cls: type[BaseModel],
|
||||
field_name: str,
|
||||
field: FieldInfo,
|
||||
) -> PropertySpec | None:
|
||||
meta = _merged_field_meta(owner_cls, field_name, field)
|
||||
if meta.get("spec_exclude"):
|
||||
return None
|
||||
|
||||
prop_type = _resolve_property_type(field.annotation, meta)
|
||||
nested_properties = _resolve_nested_properties(field.annotation, prop_type)
|
||||
options = _resolve_options(field.annotation, meta, prop_type)
|
||||
min_value, max_value, min_length, max_length, pattern = _resolve_constraints(
|
||||
field, meta
|
||||
)
|
||||
|
||||
description = meta.get("description") or field.description
|
||||
if not description:
|
||||
raise ValueError(f"{owner_cls.__name__}.{field_name} is missing a description")
|
||||
|
||||
return PropertySpec(
|
||||
name=field_name,
|
||||
type=prop_type,
|
||||
display_name=meta.get("display_name") or _humanize_identifier(field_name),
|
||||
description=description,
|
||||
llm_hint=meta.get("llm_hint"),
|
||||
default=_resolve_default(field, meta),
|
||||
required=_resolve_required(field, meta),
|
||||
placeholder=meta.get("placeholder"),
|
||||
display_options=meta.get("display_options"),
|
||||
options=options,
|
||||
properties=nested_properties,
|
||||
min_value=min_value,
|
||||
max_value=max_value,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
pattern=pattern,
|
||||
editor=meta.get("editor"),
|
||||
extra=meta.get("extra") or {},
|
||||
)
|
||||
|
||||
|
||||
def _merged_field_meta(
|
||||
owner_cls: type[BaseModel],
|
||||
field_name: str,
|
||||
field: FieldInfo,
|
||||
) -> dict[str, Any]:
|
||||
field_meta = {}
|
||||
if isinstance(field.json_schema_extra, dict):
|
||||
field_meta = dict(field.json_schema_extra.get(_SPEC_FIELD_META_KEY, {}) or {})
|
||||
metadata: NodeSpecMetadata | None = getattr(
|
||||
owner_cls, "__node_spec_metadata__", None
|
||||
)
|
||||
override = (
|
||||
dict(metadata.field_overrides.get(field_name, {}) or {})
|
||||
if metadata is not None
|
||||
else {}
|
||||
)
|
||||
merged = dict(field_meta)
|
||||
merged.update(override)
|
||||
return merged
|
||||
|
||||
|
||||
def _resolve_property_type(annotation: Any, meta: dict[str, Any]) -> PropertyType:
|
||||
ui_type = meta.get("ui_type")
|
||||
if ui_type:
|
||||
return PropertyType(ui_type)
|
||||
|
||||
inner = _strip_optional(annotation)
|
||||
origin = get_origin(inner)
|
||||
args = get_args(inner)
|
||||
|
||||
if origin is list:
|
||||
item_type = _strip_optional(args[0]) if args else Any
|
||||
if isinstance(item_type, type) and issubclass(item_type, BaseModel):
|
||||
return PropertyType.fixed_collection
|
||||
raise ValueError(
|
||||
"List-valued fields must declare an explicit ui_type unless they wrap a "
|
||||
f"BaseModel row type (field annotation: {annotation!r})."
|
||||
)
|
||||
|
||||
if _is_enum(inner) or _is_literal(inner):
|
||||
return PropertyType.options
|
||||
|
||||
if inner in (str,):
|
||||
return PropertyType.string
|
||||
if inner in (int, float):
|
||||
return PropertyType.number
|
||||
if inner is bool:
|
||||
return PropertyType.boolean
|
||||
if inner in (dict, Any) or origin is dict:
|
||||
return PropertyType.json
|
||||
|
||||
raise ValueError(f"Unable to derive PropertyType for annotation {annotation!r}")
|
||||
|
||||
|
||||
def _resolve_nested_properties(
|
||||
annotation: Any,
|
||||
prop_type: PropertyType,
|
||||
) -> list[PropertySpec] | None:
|
||||
if prop_type != PropertyType.fixed_collection:
|
||||
return None
|
||||
|
||||
inner = _strip_optional(annotation)
|
||||
args = get_args(inner)
|
||||
if not args:
|
||||
raise ValueError(
|
||||
f"fixed_collection field annotation is missing row type: {annotation!r}"
|
||||
)
|
||||
row_type = _strip_optional(args[0])
|
||||
if not isinstance(row_type, type) or not issubclass(row_type, BaseModel):
|
||||
raise ValueError(
|
||||
f"fixed_collection rows must be BaseModel subclasses: {annotation!r}"
|
||||
)
|
||||
|
||||
properties: list[PropertySpec] = []
|
||||
for field_name, field in row_type.model_fields.items():
|
||||
prop = _build_property_spec(row_type, field_name, field)
|
||||
if prop is not None:
|
||||
properties.append(prop)
|
||||
return properties
|
||||
|
||||
|
||||
def _resolve_options(
|
||||
annotation: Any,
|
||||
meta: dict[str, Any],
|
||||
prop_type: PropertyType,
|
||||
) -> list[PropertyOption] | None:
|
||||
if prop_type not in (PropertyType.options, PropertyType.multi_options):
|
||||
return meta.get("options")
|
||||
|
||||
if meta.get("options"):
|
||||
return meta["options"]
|
||||
|
||||
inner = _strip_optional(annotation)
|
||||
if prop_type == PropertyType.multi_options:
|
||||
inner = _strip_optional(get_args(inner)[0])
|
||||
|
||||
if _is_enum(inner):
|
||||
return [
|
||||
PropertyOption(
|
||||
value=member.value, label=_humanize_option_label(member.value)
|
||||
)
|
||||
for member in inner
|
||||
]
|
||||
if _is_literal(inner):
|
||||
return [
|
||||
PropertyOption(value=value, label=_humanize_option_label(value))
|
||||
for value in get_args(inner)
|
||||
if value is not None
|
||||
]
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_constraints(
|
||||
field: FieldInfo,
|
||||
meta: dict[str, Any],
|
||||
) -> tuple[float | None, float | None, int | None, int | None, str | None]:
|
||||
min_value = meta.get("min_value")
|
||||
max_value = meta.get("max_value")
|
||||
min_length = meta.get("min_length")
|
||||
max_length = meta.get("max_length")
|
||||
pattern = meta.get("pattern")
|
||||
|
||||
for item in field.metadata:
|
||||
if min_value is None:
|
||||
if hasattr(item, "ge") and item.ge is not None:
|
||||
min_value = item.ge
|
||||
elif hasattr(item, "gt") and item.gt is not None:
|
||||
min_value = item.gt
|
||||
if max_value is None:
|
||||
if hasattr(item, "le") and item.le is not None:
|
||||
max_value = item.le
|
||||
elif hasattr(item, "lt") and item.lt is not None:
|
||||
max_value = item.lt
|
||||
if (
|
||||
min_length is None
|
||||
and hasattr(item, "min_length")
|
||||
and item.min_length is not None
|
||||
):
|
||||
min_length = item.min_length
|
||||
if (
|
||||
max_length is None
|
||||
and hasattr(item, "max_length")
|
||||
and item.max_length is not None
|
||||
):
|
||||
max_length = item.max_length
|
||||
if pattern is None and hasattr(item, "pattern") and item.pattern is not None:
|
||||
pattern = item.pattern
|
||||
|
||||
return min_value, max_value, min_length, max_length, pattern
|
||||
|
||||
|
||||
def _resolve_default(field: FieldInfo, meta: dict[str, Any]) -> Any:
|
||||
if "spec_default" in meta:
|
||||
return meta["spec_default"]
|
||||
if field.default is not PydanticUndefined:
|
||||
return field.default
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_required(field: FieldInfo, meta: dict[str, Any]) -> bool:
|
||||
if meta.get("required") is not None:
|
||||
return bool(meta["required"])
|
||||
return bool(field.is_required())
|
||||
|
||||
|
||||
def _strip_optional(annotation: Any) -> Any:
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
return annotation
|
||||
|
||||
args = [arg for arg in get_args(annotation) if arg is not NoneType]
|
||||
if len(args) == 1 and len(args) != len(get_args(annotation)):
|
||||
return args[0]
|
||||
return annotation
|
||||
|
||||
|
||||
def _is_enum(annotation: Any) -> bool:
|
||||
return isinstance(annotation, type) and issubclass(annotation, Enum)
|
||||
|
||||
|
||||
def _is_literal(annotation: Any) -> bool:
|
||||
return get_origin(annotation) is Literal
|
||||
|
||||
|
||||
def _humanize_identifier(name: str) -> str:
|
||||
return name.replace("_", " ").strip().title()
|
||||
|
||||
|
||||
def _humanize_option_label(value: Any) -> str:
|
||||
if isinstance(value, str):
|
||||
return value.replace("_", " ").replace("-", " ").strip().title()
|
||||
return str(value)
|
||||
|
|
@ -1,203 +0,0 @@
|
|||
"""Spec for the QA Analysis node — runs an LLM quality review on the call
|
||||
transcript after completion."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
DisplayOptions,
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertyOption,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
DEFAULT_QA_SYSTEM_PROMPT = """You are a QA analyst evaluating a specific segment of a voice AI conversation.
|
||||
|
||||
## Node Purpose
|
||||
{{node_summary}}
|
||||
|
||||
## Previous Conversation Context (For start of conversation, previous conversation summary can be empty.)
|
||||
{{previous_conversation_summary}}
|
||||
|
||||
## Tags to evaluate
|
||||
|
||||
Examine the conversation carefully and identify which of the following tags apply:
|
||||
|
||||
- UNCLEAR_CONVERSATION - The conversation is not coherent or clear, messages don't connect logically
|
||||
- ASSISTANT_IN_LOOP - The assistant asks the same question multiple times or gets stuck repeating itself
|
||||
- ASSISTANT_REPLY_IMPROPER - The assistant did not reply properly to the user's question/query or seems confused by what the user said
|
||||
- USER_FRUSTRATED - The user seems angry, frustrated, or is complaining about something in the call
|
||||
- USER_NOT_UNDERSTANDING - The user explicitly says they don't understand or repeatedly asks for clarification
|
||||
- HEARING_ISSUES - Either party can't hear the other ("hello?", "are you there?", "can you hear me?")
|
||||
- DEAD_AIR - Unusually long silences in the conversation (use the timestamps to judge)
|
||||
- USER_REQUESTING_FEATURE - The user asks for something the assistant can't fulfill
|
||||
- ASSISTANT_LACKS_EMPATHY - The assistant ignores the user's personal situation or emotional state and continues pitching or pushing the agenda.
|
||||
- USER_DETECTS_AI - The user suspects or identifies that they are talking to an AI/robot/bot rather than a real human.
|
||||
|
||||
## Call metrics (pre-computed)
|
||||
|
||||
Use these alongside the transcript for your analysis:
|
||||
{{metrics}}
|
||||
|
||||
## Output format
|
||||
|
||||
Return ONLY a valid JSON object (no markdown):
|
||||
{
|
||||
"tags": [
|
||||
{
|
||||
"tag": "TAG_NAME",
|
||||
"reason": "Short reason with evidence from the transcript"
|
||||
}
|
||||
],
|
||||
"overall_sentiment": "positive|neutral|negative",
|
||||
"call_quality_score": <1-10>,
|
||||
"summary": "1-2 sentence summary of this segment"
|
||||
}
|
||||
|
||||
If no tags apply, return an empty tags list. Always provide sentiment, score, and summary."""
|
||||
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="qa",
|
||||
display_name="QA Analysis",
|
||||
description="Run LLM quality analysis on the call transcript.",
|
||||
llm_hint=(
|
||||
"Runs an LLM quality review on the call transcript after completion. "
|
||||
"Per-node analysis splits the conversation by node and evaluates each "
|
||||
"segment against the configured system prompt. Sampling, minimum "
|
||||
"duration, and voicemail filters are supported."
|
||||
),
|
||||
category=NodeCategory.integration,
|
||||
icon="ClipboardCheck",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description="Short identifier for this QA configuration.",
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="QA Analysis",
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Enabled",
|
||||
description="When false, the QA run is skipped.",
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_system_prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="System Prompt",
|
||||
description=(
|
||||
"Instructions to the QA reviewer LLM. Supports placeholders: "
|
||||
"`{node_summary}`, `{previous_conversation_summary}`, "
|
||||
"`{transcript}`, `{metrics}`."
|
||||
),
|
||||
editor="textarea",
|
||||
default=DEFAULT_QA_SYSTEM_PROMPT,
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_min_call_duration",
|
||||
type=PropertyType.number,
|
||||
display_name="Minimum Call Duration (seconds)",
|
||||
description="Calls shorter than this are skipped.",
|
||||
default=15,
|
||||
min_value=0,
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_voicemail_calls",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Include Voicemail Calls",
|
||||
description="When false, calls flagged as voicemail are skipped.",
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_sample_rate",
|
||||
type=PropertyType.number,
|
||||
display_name="Sample Rate (%)",
|
||||
description=(
|
||||
"Percent of eligible calls QA'd. 100 means every call; lower "
|
||||
"values use random sampling."
|
||||
),
|
||||
default=100,
|
||||
min_value=1,
|
||||
max_value=100,
|
||||
),
|
||||
# ---- LLM configuration ----
|
||||
PropertySpec(
|
||||
name="qa_use_workflow_llm",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Use Workflow's LLM",
|
||||
description=(
|
||||
"When true, the QA pass uses the same LLM the workflow runs "
|
||||
"with. Set false to specify a separate provider/model."
|
||||
),
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_provider",
|
||||
type=PropertyType.options,
|
||||
display_name="QA LLM Provider",
|
||||
description="LLM provider used for the QA pass.",
|
||||
display_options=DisplayOptions(show={"qa_use_workflow_llm": [False]}),
|
||||
options=[
|
||||
PropertyOption(value="openai", label="OpenAI"),
|
||||
PropertyOption(value="azure", label="Azure OpenAI"),
|
||||
PropertyOption(value="openrouter", label="OpenRouter"),
|
||||
PropertyOption(value="anthropic", label="Anthropic"),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_model",
|
||||
type=PropertyType.string,
|
||||
display_name="QA Model",
|
||||
description=(
|
||||
"Model identifier (e.g., 'gpt-4o', 'claude-sonnet-4-6'). "
|
||||
"Provider-specific."
|
||||
),
|
||||
display_options=DisplayOptions(show={"qa_use_workflow_llm": [False]}),
|
||||
default="default",
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_api_key",
|
||||
type=PropertyType.string,
|
||||
display_name="API Key",
|
||||
description="API key for the chosen provider.",
|
||||
display_options=DisplayOptions(show={"qa_use_workflow_llm": [False]}),
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_endpoint",
|
||||
type=PropertyType.url,
|
||||
display_name="Azure Endpoint",
|
||||
description="Required for the Azure provider.",
|
||||
display_options=DisplayOptions(
|
||||
show={"qa_use_workflow_llm": [False], "qa_provider": ["azure"]}
|
||||
),
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="basic_qa",
|
||||
data={
|
||||
"name": "Compliance Check",
|
||||
"qa_enabled": True,
|
||||
"qa_system_prompt": (
|
||||
"You are a compliance reviewer. Review the transcript and "
|
||||
"produce a JSON object with `tags`, `summary`, "
|
||||
"`call_quality_score`, and `overall_sentiment`."
|
||||
),
|
||||
"qa_min_call_duration": 30,
|
||||
"qa_sample_rate": 100,
|
||||
},
|
||||
),
|
||||
],
|
||||
# QA runs post-call against the saved transcript (run_integrations
|
||||
# scans by type), never as a graph step. Reject any edge into or out
|
||||
# of a QA node.
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0, max_incoming=0, min_outgoing=0, max_outgoing=0
|
||||
),
|
||||
)
|
||||
|
|
@ -1,250 +0,0 @@
|
|||
"""Spec for the Start Call node — the single entry point of every workflow.
|
||||
Carries greeting, pre-call data fetch, and the same prompt/extraction/tools
|
||||
fields as agent nodes."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
DisplayOptions,
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertyOption,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="startCall",
|
||||
display_name="Start Call",
|
||||
description="Entry point of the workflow — plays a greeting and opens the conversation.",
|
||||
llm_hint=(
|
||||
"The entry point of every workflow (exactly one required). Plays an "
|
||||
"optional greeting, can fetch context from an external API before "
|
||||
"the call begins, and executes the first conversational turn."
|
||||
),
|
||||
category=NodeCategory.call_node,
|
||||
icon="Play",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description="Short identifier shown in the canvas and call logs.",
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="Start Call",
|
||||
),
|
||||
# ---- Greeting (variant via greeting_type) ----
|
||||
PropertySpec(
|
||||
name="greeting_type",
|
||||
type=PropertyType.options,
|
||||
display_name="Greeting Type",
|
||||
description=(
|
||||
"Whether the optional greeting is spoken via TTS from text "
|
||||
"or played from a pre-recorded audio file."
|
||||
),
|
||||
default="text",
|
||||
options=[
|
||||
PropertyOption(value="text", label="Text (TTS)"),
|
||||
PropertyOption(value="audio", label="Pre-recorded Audio"),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="greeting",
|
||||
type=PropertyType.string,
|
||||
display_name="Greeting Text",
|
||||
description=(
|
||||
"Text spoken via TTS at the start of the call. Supports "
|
||||
"{{template_variables}}. Leave empty to skip the greeting."
|
||||
),
|
||||
display_options=DisplayOptions(show={"greeting_type": ["text"]}),
|
||||
editor="textarea",
|
||||
placeholder="Hi {{first_name}}, this is Sarah from Acme.",
|
||||
),
|
||||
PropertySpec(
|
||||
name="greeting_recording_id",
|
||||
type=PropertyType.recording_ref,
|
||||
display_name="Greeting Recording",
|
||||
description="Pre-recorded audio file played at the start of the call.",
|
||||
llm_hint=(
|
||||
"Value is the `recording_id` string. Use the `list_recordings` "
|
||||
"MCP tool to discover available recordings."
|
||||
),
|
||||
display_options=DisplayOptions(show={"greeting_type": ["audio"]}),
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.mention_textarea,
|
||||
display_name="Prompt",
|
||||
description=(
|
||||
"Agent system prompt for the opening turn. Supports "
|
||||
"{{template_variables}} from pre-call fetch and the initial context."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
placeholder="Greet the caller warmly and ask how you can help today.",
|
||||
),
|
||||
# ---- Behavior toggles ----
|
||||
PropertySpec(
|
||||
name="allow_interrupt",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Allow Interruption",
|
||||
description=("When true, the user can interrupt the agent mid-utterance."),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="add_global_prompt",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Add Global Prompt",
|
||||
description=(
|
||||
"When true and a Global node exists, prepends the global "
|
||||
"prompt to this node's prompt at runtime."
|
||||
),
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="delayed_start",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Delayed Start",
|
||||
description=(
|
||||
"When true, the agent waits before speaking after pickup. "
|
||||
"Useful for outbound calls where the called party needs a "
|
||||
"moment to settle."
|
||||
),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="delayed_start_duration",
|
||||
type=PropertyType.number,
|
||||
display_name="Delay Duration (seconds)",
|
||||
description="Seconds to wait before the agent speaks. 0.1–10.",
|
||||
default=2.0,
|
||||
min_value=0.1,
|
||||
max_value=10.0,
|
||||
display_options=DisplayOptions(show={"delayed_start": [True]}),
|
||||
),
|
||||
# ---- Variable extraction ----
|
||||
PropertySpec(
|
||||
name="extraction_enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Enable Variable Extraction",
|
||||
description=(
|
||||
"When true, runs an LLM extraction pass on transition out of "
|
||||
"this node to capture variables from the opening turn."
|
||||
),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="Extraction Prompt",
|
||||
description="Overall instructions guiding variable extraction.",
|
||||
display_options=DisplayOptions(show={"extraction_enabled": [True]}),
|
||||
editor="textarea",
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_variables",
|
||||
type=PropertyType.fixed_collection,
|
||||
display_name="Variables to Extract",
|
||||
description=(
|
||||
"Each entry declares one variable to capture, with its name, "
|
||||
"data type, and per-variable extraction hint."
|
||||
),
|
||||
display_options=DisplayOptions(show={"extraction_enabled": [True]}),
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Variable Name",
|
||||
description="snake_case identifier used downstream.",
|
||||
required=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="type",
|
||||
type=PropertyType.options,
|
||||
display_name="Type",
|
||||
description="Data type of the extracted value.",
|
||||
required=True,
|
||||
default="string",
|
||||
options=[
|
||||
PropertyOption(value="string", label="String"),
|
||||
PropertyOption(value="number", label="Number"),
|
||||
PropertyOption(value="boolean", label="Boolean"),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="Extraction Hint",
|
||||
description="Per-variable hint describing what to look for.",
|
||||
editor="textarea",
|
||||
),
|
||||
],
|
||||
),
|
||||
# ---- Tools / documents ----
|
||||
PropertySpec(
|
||||
name="tool_uuids",
|
||||
type=PropertyType.tool_refs,
|
||||
display_name="Tools",
|
||||
description="Tools the agent can invoke during the opening turn.",
|
||||
llm_hint="List of tool UUIDs from `list_tools`.",
|
||||
),
|
||||
PropertySpec(
|
||||
name="document_uuids",
|
||||
type=PropertyType.document_refs,
|
||||
display_name="Knowledge Base Documents",
|
||||
description="Documents the agent can reference.",
|
||||
llm_hint="List of document UUIDs from `list_documents`.",
|
||||
),
|
||||
# ---- Pre-call data fetch (advanced) ----
|
||||
PropertySpec(
|
||||
name="pre_call_fetch_enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Pre-Call Data Fetch",
|
||||
description=(
|
||||
"When true, makes a POST request to an external API before "
|
||||
"the call starts and merges the JSON response into the call "
|
||||
"context as template variables."
|
||||
),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="pre_call_fetch_url",
|
||||
type=PropertyType.url,
|
||||
display_name="Endpoint URL",
|
||||
description=(
|
||||
"URL the pre-call POST request is sent to. The request body "
|
||||
"includes caller and called numbers."
|
||||
),
|
||||
display_options=DisplayOptions(show={"pre_call_fetch_enabled": [True]}),
|
||||
placeholder="https://api.example.com/customer-lookup",
|
||||
),
|
||||
PropertySpec(
|
||||
name="pre_call_fetch_credential_uuid",
|
||||
type=PropertyType.credential_ref,
|
||||
display_name="Authentication",
|
||||
description="Optional credential attached to the pre-call request.",
|
||||
llm_hint="Credential UUID from `list_credentials`.",
|
||||
display_options=DisplayOptions(show={"pre_call_fetch_enabled": [True]}),
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="warm_greeting",
|
||||
data={
|
||||
"name": "Greeting",
|
||||
"prompt": "Greet warmly and ask the caller's reason for calling.",
|
||||
"greeting_type": "text",
|
||||
"greeting": "Hi {{first_name}}, this is Sarah from Acme.",
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
# `min_outgoing` is intentionally unset: a startCall is allowed to
|
||||
# sit on the canvas without an outgoing edge (e.g. a workflow with
|
||||
# just a greeting). Only constraint: nothing flows INTO the start.
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0,
|
||||
max_incoming=0,
|
||||
),
|
||||
)
|
||||
|
|
@ -1,79 +0,0 @@
|
|||
"""Spec for the API Trigger node — exposes a public webhook URL that
|
||||
external systems can hit to launch the workflow."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="trigger",
|
||||
display_name="API Trigger",
|
||||
description=("Public HTTP endpoints that launch the workflow."),
|
||||
llm_hint=(
|
||||
"Exposes two public HTTP POST endpoints derived from the auto-generated "
|
||||
"`trigger_path`:\n"
|
||||
" • Production: `<backend>/api/v1/public/agent/<trigger_path>` — runs "
|
||||
"the published agent. Use this from production systems.\n"
|
||||
" • Test: `<backend>/api/v1/public/agent/test/<trigger_path>` — runs "
|
||||
"the latest draft, useful for verifying changes before publishing. "
|
||||
"Falls back to the published agent when no draft exists.\n"
|
||||
"Both require an API key in the `X-API-Key` header.\n"
|
||||
"Request body fields:\n"
|
||||
" • `phone_number` (string, required) — destination to dial.\n"
|
||||
" • `initial_context` (object, optional) — merged into the run's "
|
||||
"initial context.\n"
|
||||
" • `telephony_configuration_id` (int, optional) — pick a specific "
|
||||
"telephony configuration for the call. Must belong to the same "
|
||||
"organization as the trigger. When omitted, the org's default "
|
||||
"outbound configuration is used."
|
||||
),
|
||||
category=NodeCategory.trigger,
|
||||
icon="Webhook",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description="Short identifier shown in the canvas. No runtime effect.",
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="API Trigger",
|
||||
),
|
||||
PropertySpec(
|
||||
name="enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Enabled",
|
||||
description="When false, the trigger URL returns 404.",
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="trigger_path",
|
||||
type=PropertyType.string,
|
||||
display_name="Trigger Path",
|
||||
description=(
|
||||
"Auto-generated UUID-style path segment that uniquely "
|
||||
"identifies this trigger. Used in both URLs:\n"
|
||||
" • Production: `/api/v1/public/agent/<trigger_path>` — "
|
||||
"executes the published agent.\n"
|
||||
" • Test: `/api/v1/public/agent/test/<trigger_path>` — "
|
||||
"executes the latest draft.\n"
|
||||
"Do not edit manually."
|
||||
),
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="default",
|
||||
data={"name": "Inbound Trigger", "enabled": True},
|
||||
),
|
||||
],
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0,
|
||||
max_incoming=0,
|
||||
),
|
||||
)
|
||||
|
|
@ -1,133 +0,0 @@
|
|||
"""Spec for the Webhook node — sends an HTTP request to an external system
|
||||
after the workflow completes."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertyOption,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="webhook",
|
||||
display_name="Webhook",
|
||||
description="Send HTTP request after the workflow completes.",
|
||||
llm_hint=(
|
||||
"Sends an HTTP request to an external system after the workflow "
|
||||
"completes. The payload is a Jinja-templated JSON body with access "
|
||||
"to `workflow_run_id`, `initial_context`, `gathered_context`, "
|
||||
"`annotations`, and call metadata."
|
||||
),
|
||||
category=NodeCategory.integration,
|
||||
icon="Link2",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description="Short identifier shown in the canvas and run logs.",
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="Webhook",
|
||||
),
|
||||
PropertySpec(
|
||||
name="enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Enabled",
|
||||
description="When false, the webhook is skipped at run time.",
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="http_method",
|
||||
type=PropertyType.options,
|
||||
display_name="HTTP Method",
|
||||
description="HTTP verb used for the outbound request.",
|
||||
default="POST",
|
||||
options=[
|
||||
PropertyOption(value="GET", label="GET"),
|
||||
PropertyOption(value="POST", label="POST"),
|
||||
PropertyOption(value="PUT", label="PUT"),
|
||||
PropertyOption(value="PATCH", label="PATCH"),
|
||||
PropertyOption(value="DELETE", label="DELETE"),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="endpoint_url",
|
||||
type=PropertyType.url,
|
||||
display_name="Endpoint URL",
|
||||
description="URL the request is sent to.",
|
||||
placeholder="https://api.example.com/webhook",
|
||||
),
|
||||
PropertySpec(
|
||||
name="credential_uuid",
|
||||
type=PropertyType.credential_ref,
|
||||
display_name="Authentication",
|
||||
description="Optional credential applied as the Authorization header.",
|
||||
llm_hint="Credential UUID from `list_credentials`.",
|
||||
),
|
||||
PropertySpec(
|
||||
name="custom_headers",
|
||||
type=PropertyType.fixed_collection,
|
||||
display_name="Custom Headers",
|
||||
description="Additional HTTP headers to include with the request.",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="key",
|
||||
type=PropertyType.string,
|
||||
display_name="Header Name",
|
||||
description="HTTP header name (e.g., 'X-Source').",
|
||||
required=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="value",
|
||||
type=PropertyType.string,
|
||||
display_name="Header Value",
|
||||
description="Header value (supports {{template_variables}}).",
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="payload_template",
|
||||
type=PropertyType.json,
|
||||
display_name="Payload Template",
|
||||
description=(
|
||||
"JSON body of the request. Values are Jinja-rendered against "
|
||||
"the run context — `{{workflow_run_id}}`, "
|
||||
"`{{gathered_context.foo}}`, `{{annotations.qa_xxx}}`, etc."
|
||||
),
|
||||
default={
|
||||
"call_id": "{{workflow_run_id}}",
|
||||
"first_name": "{{initial_context.first_name}}",
|
||||
"rsvp": "{{gathered_context.rsvp}}",
|
||||
"duration": "{{cost_info.call_duration_seconds}}",
|
||||
"recording_url": "{{recording_url}}",
|
||||
"transcript_url": "{{transcript_url}}",
|
||||
},
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="post_to_crm",
|
||||
data={
|
||||
"name": "Notify CRM",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://crm.example.com/calls",
|
||||
"payload_template": {
|
||||
"run_id": "{{workflow_run_id}}",
|
||||
"outcome": "{{gathered_context.call_disposition}}",
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
# Webhooks fire post-call (run_integrations scans nodes by type),
|
||||
# never as a graph step. Reject any edge into or out of a webhook so
|
||||
# the editor can't wire one into the conversation flow.
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0, max_incoming=0, min_outgoing=0, max_outgoing=0
|
||||
),
|
||||
)
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import TYPE_CHECKING, Awaitable, Callable, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Union
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import (
|
||||
|
|
@ -17,6 +18,7 @@ from pipecat.services.settings import LLMSettings
|
|||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import ToolCategory
|
||||
from api.services.pipecat.audio_playback import play_audio
|
||||
from api.services.workflow.disposition_mapper import apply_disposition_mapping
|
||||
from api.services.workflow.workflow_graph import Node, WorkflowGraph
|
||||
|
|
@ -35,6 +37,7 @@ import asyncio
|
|||
from loguru import logger
|
||||
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
compose_functions_for_node,
|
||||
compose_system_prompt_for_node,
|
||||
|
|
@ -117,6 +120,9 @@ class PipecatEngine:
|
|||
# Cached organization ID (resolved lazily from workflow run)
|
||||
self._organization_id: Optional[int] = None
|
||||
|
||||
# Open MCP tool sessions for this call, keyed by tool_uuid
|
||||
self._mcp_sessions: Dict[str, McpToolSession] = {}
|
||||
|
||||
# Embeddings configuration (passed from run_pipeline.py)
|
||||
self._embeddings_api_key: Optional[str] = embeddings_api_key
|
||||
self._embeddings_model: Optional[str] = embeddings_model
|
||||
|
|
@ -179,6 +185,9 @@ class PipecatEngine:
|
|||
# Helper that encapsulates custom tool management
|
||||
self._custom_tool_manager = CustomToolManager(self)
|
||||
|
||||
# Open persistent MCP server sessions for this call (degrades on failure)
|
||||
await self._open_mcp_sessions()
|
||||
|
||||
# Helper that encapsulates context summarization
|
||||
if self._context_compaction_enabled:
|
||||
self._context_summarization_manager = ContextSummarizationManager(self)
|
||||
|
|
@ -504,7 +513,10 @@ class PipecatEngine:
|
|||
|
||||
# Register custom tool handlers for this node
|
||||
if node.tool_uuids and self._custom_tool_manager:
|
||||
await self._custom_tool_manager.register_handlers(node.tool_uuids)
|
||||
await self._custom_tool_manager.register_handlers(
|
||||
node.tool_uuids,
|
||||
mcp_tool_filters=getattr(node, "mcp_tool_filters", None),
|
||||
)
|
||||
|
||||
# Register knowledge base retrieval handler if node has documents
|
||||
if node.document_uuids:
|
||||
|
|
@ -530,7 +542,7 @@ class PipecatEngine:
|
|||
node = self.workflow.nodes[node_id]
|
||||
|
||||
logger.debug(
|
||||
f"Executing node: name: {node.name} is_static: {node.is_static} allow_interrupt: {node.allow_interrupt} is_end: {node.is_end}"
|
||||
f"Executing node: name: {node.name} allow_interrupt: {node.allow_interrupt} is_end: {node.is_end}"
|
||||
)
|
||||
|
||||
# Track previous node for transition event
|
||||
|
|
@ -585,11 +597,8 @@ class PipecatEngine:
|
|||
)
|
||||
await asyncio.sleep(delay_duration)
|
||||
|
||||
if node.is_static:
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Setup LLM Context with Prompts and Functions
|
||||
await self._setup_llm_context(node)
|
||||
# Setup LLM context with prompts and functions.
|
||||
await self._setup_llm_context(node)
|
||||
|
||||
def get_node_greeting(self, node_id: str) -> Optional[tuple[str, Optional[str]]]:
|
||||
"""Return the greeting info for a node, or None if not configured.
|
||||
|
|
@ -685,19 +694,13 @@ class PipecatEngine:
|
|||
|
||||
async def _handle_end_node(self, node: Node) -> None:
|
||||
"""Handle end node execution."""
|
||||
if node.is_static:
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Setup LLM Context with Prompts and Functions
|
||||
await self._setup_llm_context(node)
|
||||
# Setup LLM context with prompts and functions.
|
||||
await self._setup_llm_context(node)
|
||||
|
||||
async def _handle_agent_node(self, node: Node) -> None:
|
||||
"""Handle agent node execution."""
|
||||
if node.is_static:
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Setup LLM Context with Prompts and Functions
|
||||
await self._setup_llm_context(node)
|
||||
# Setup LLM context with prompts and functions.
|
||||
await self._setup_llm_context(node)
|
||||
|
||||
async def end_call_with_reason(
|
||||
self,
|
||||
|
|
@ -884,6 +887,79 @@ class PipecatEngine:
|
|||
"""Get the gathered context including extracted variables."""
|
||||
return self._gathered_context.copy()
|
||||
|
||||
async def _open_mcp_sessions(self) -> None:
|
||||
"""Connect every MCP-category tool referenced by any workflow node.
|
||||
Failures degrade (session marked unavailable); never raises."""
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpDefinitionError,
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
||||
try:
|
||||
tool_uuids: set[str] = set()
|
||||
for node in self.workflow.nodes.values():
|
||||
for tu in getattr(node, "tool_uuids", None) or []:
|
||||
tool_uuids.add(tu)
|
||||
if not tool_uuids:
|
||||
return
|
||||
|
||||
organization_id = await self._get_organization_id()
|
||||
if not organization_id:
|
||||
logger.warning("Cannot open MCP sessions: organization_id missing")
|
||||
return
|
||||
|
||||
tools = await db_client.get_tools_by_uuids(
|
||||
list(tool_uuids), organization_id
|
||||
)
|
||||
for tool in tools:
|
||||
if tool.category != ToolCategory.MCP.value:
|
||||
continue
|
||||
try:
|
||||
cfg = validate_mcp_definition(tool.definition)
|
||||
except McpDefinitionError as e:
|
||||
logger.warning(
|
||||
f"Skipping MCP tool '{tool.name}' ({tool.tool_uuid}): "
|
||||
f"invalid definition: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
credential = None
|
||||
if cfg["credential_uuid"]:
|
||||
try:
|
||||
credential = await db_client.get_credential_by_uuid(
|
||||
cfg["credential_uuid"], organization_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MCP tool '{tool.name}': credential fetch failed: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
tools_filter=cfg["tools_filter"],
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
await session.start()
|
||||
self._mcp_sessions[tool.tool_uuid] = session
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to open MCP sessions; call proceeds without MCP tools: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _close_mcp_sessions(self) -> None:
|
||||
for tool_uuid, session in list(self._mcp_sessions.items()):
|
||||
try:
|
||||
await session.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing MCP session {tool_uuid}: {e}")
|
||||
self._mcp_sessions = {}
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up engine resources on disconnect."""
|
||||
# Cancel any pending timeout tasks
|
||||
|
|
@ -893,6 +969,12 @@ class PipecatEngine:
|
|||
):
|
||||
self._user_response_timeout_task.cancel()
|
||||
|
||||
# Cancel any in-flight background summarization
|
||||
if self._context_summarization_manager:
|
||||
await self._context_summarization_manager.cleanup()
|
||||
# Cancel any in-flight background summarization.
|
||||
# MCP sessions are closed in a finally block so they are guaranteed to
|
||||
# run even if the summarization cleanup raises an exception.
|
||||
try:
|
||||
if self._context_summarization_manager:
|
||||
await self._context_summarization_manager.cleanup()
|
||||
finally:
|
||||
# Close any open MCP tool sessions
|
||||
await self._close_mcp_sessions()
|
||||
|
|
|
|||
|
|
@ -117,7 +117,8 @@ async def compose_functions_for_node(
|
|||
# Custom tools
|
||||
if node.tool_uuids and custom_tool_manager:
|
||||
custom_tool_schemas = await custom_tool_manager.get_tool_schemas(
|
||||
node.tool_uuids
|
||||
node.tool_uuids,
|
||||
mcp_tool_filters=getattr(node, "mcp_tool_filters", None),
|
||||
)
|
||||
functions.extend(custom_tool_schemas)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ from api.services.workflow.tools.custom_tool import (
|
|||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
|
|
@ -121,11 +122,18 @@ class CustomToolManager:
|
|||
"""Get the organization ID from the engine (shared cache)."""
|
||||
return await self._engine._get_organization_id()
|
||||
|
||||
async def get_tool_schemas(self, tool_uuids: list[str]) -> list[FunctionSchema]:
|
||||
async def get_tool_schemas(
|
||||
self,
|
||||
tool_uuids: list[str],
|
||||
mcp_tool_filters: Optional[dict[str, list[str]]] = None,
|
||||
) -> list[FunctionSchema]:
|
||||
"""Fetch custom tools and convert them to function schemas.
|
||||
|
||||
Args:
|
||||
tool_uuids: List of tool UUIDs to fetch
|
||||
mcp_tool_filters: Optional per-node filter mapping tool_uuid → list of
|
||||
raw MCP tool names to expose. None (default) exposes all tools.
|
||||
Empty dict or entry with [] suppresses all tools for that uuid.
|
||||
|
||||
Returns:
|
||||
List of FunctionSchema objects for LLM
|
||||
|
|
@ -154,6 +162,22 @@ class CustomToolManager:
|
|||
)
|
||||
continue
|
||||
|
||||
if tool.category == ToolCategory.MCP.value:
|
||||
session = self._engine._mcp_sessions.get(tool.tool_uuid)
|
||||
if session is None or not session.available:
|
||||
logger.warning(
|
||||
f"MCP tool '{tool.name}' ({tool.tool_uuid}) "
|
||||
f"unavailable; skipping"
|
||||
)
|
||||
continue
|
||||
allowed = (
|
||||
None
|
||||
if mcp_tool_filters is None
|
||||
else set(mcp_tool_filters.get(tool.tool_uuid, []))
|
||||
)
|
||||
schemas.extend(session.function_schemas(allowed))
|
||||
continue
|
||||
|
||||
raw_schema = tool_to_function_schema(tool)
|
||||
function_name = raw_schema["function"]["name"]
|
||||
|
||||
|
|
@ -178,11 +202,18 @@ class CustomToolManager:
|
|||
logger.error(f"Failed to fetch custom tools: {e}")
|
||||
return []
|
||||
|
||||
async def register_handlers(self, tool_uuids: list[str]) -> None:
|
||||
async def register_handlers(
|
||||
self,
|
||||
tool_uuids: list[str],
|
||||
mcp_tool_filters: Optional[dict[str, list[str]]] = None,
|
||||
) -> None:
|
||||
"""Register custom tool execution handlers with the LLM.
|
||||
|
||||
Args:
|
||||
tool_uuids: List of tool UUIDs to register handlers for
|
||||
mcp_tool_filters: Optional per-node filter mapping tool_uuid → list of
|
||||
raw MCP tool names to expose. None (default) exposes all tools.
|
||||
Empty dict or entry with [] suppresses all tools for that uuid.
|
||||
"""
|
||||
organization_id = await self.get_organization_id()
|
||||
if not organization_id:
|
||||
|
|
@ -203,6 +234,32 @@ class CustomToolManager:
|
|||
)
|
||||
continue
|
||||
|
||||
if tool.category == ToolCategory.MCP.value:
|
||||
session = self._engine._mcp_sessions.get(tool.tool_uuid)
|
||||
if session is None or not session.available:
|
||||
logger.warning(
|
||||
f"MCP tool '{tool.name}' ({tool.tool_uuid}) "
|
||||
f"unavailable; skipping handler registration"
|
||||
)
|
||||
continue
|
||||
allowed = (
|
||||
None
|
||||
if mcp_tool_filters is None
|
||||
else set(mcp_tool_filters.get(tool.tool_uuid, []))
|
||||
)
|
||||
mcp_schemas = session.function_schemas(allowed)
|
||||
for fs in mcp_schemas:
|
||||
self._engine.llm.register_function(
|
||||
fs.name,
|
||||
self._create_mcp_handler(session, fs.name),
|
||||
timeout_secs=session.call_timeout_secs,
|
||||
)
|
||||
logger.debug(
|
||||
f"Registered {len(mcp_schemas)} MCP "
|
||||
f"handlers for tool '{tool.name}' ({tool.tool_uuid})"
|
||||
)
|
||||
continue
|
||||
|
||||
schema = tool_to_function_schema(tool)
|
||||
function_name = schema["function"]["name"]
|
||||
|
||||
|
|
@ -335,6 +392,29 @@ class CustomToolManager:
|
|||
|
||||
return http_tool_handler
|
||||
|
||||
def _create_mcp_handler(self, session: "McpToolSession", function_name: str):
|
||||
"""Create a handler that proxies an LLM function call to a live MCP
|
||||
session. Errors are returned to the LLM as structured text so the
|
||||
agent can recover verbally; the call is never crashed."""
|
||||
|
||||
async def mcp_tool_handler(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
logger.info(f"MCP Tool EXECUTED: {function_name}")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
try:
|
||||
result = await session.call(
|
||||
function_name, function_call_params.arguments or {}
|
||||
)
|
||||
await function_call_params.result_callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool '{function_name}' failed: {e}")
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": str(e)}
|
||||
)
|
||||
|
||||
return mcp_tool_handler
|
||||
|
||||
def _create_end_call_handler(self, tool: Any, function_name: str):
|
||||
"""Create a handler function for an end call tool.
|
||||
|
||||
|
|
|
|||
116
api/services/workflow/tools/mcp_tool.py
Normal file
116
api/services/workflow/tools/mcp_tool.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""Pure helpers for MCP-category tools: definition validation and
|
||||
LLM-function-name namespacing. No I/O, no MCP protocol here."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
|
||||
DEFAULT_TIMEOUT_SECS = 30
|
||||
DEFAULT_SSE_READ_TIMEOUT_SECS = 300
|
||||
|
||||
|
||||
class McpDefinitionError(ValueError):
|
||||
"""Raised when an MCP tool definition is structurally invalid."""
|
||||
|
||||
|
||||
class McpToolConfig(BaseModel):
|
||||
"""Configuration for an MCP tool definition."""
|
||||
|
||||
transport: Literal["streamable_http"] = Field(
|
||||
default="streamable_http", description="MCP transport protocol"
|
||||
)
|
||||
url: str = Field(description="MCP server URL (must be http:// or https://)")
|
||||
credential_uuid: Optional[str] = Field(
|
||||
default=None, description="Reference to ExternalCredentialModel for auth"
|
||||
)
|
||||
tools_filter: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Allowlist of MCP tool names to expose (empty = all tools)",
|
||||
)
|
||||
timeout_secs: int = Field(
|
||||
default=DEFAULT_TIMEOUT_SECS, description="Connection timeout in seconds"
|
||||
)
|
||||
sse_read_timeout_secs: int = Field(
|
||||
default=DEFAULT_SSE_READ_TIMEOUT_SECS,
|
||||
description="SSE read timeout in seconds",
|
||||
)
|
||||
discovered_tools: list[dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Server-managed cache of the MCP server's tool catalog "
|
||||
"[{name, description}]. Populated best-effort by the backend."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def validate_url(cls, v: str) -> str:
|
||||
if not isinstance(v, str) or not v.startswith(("http://", "https://")):
|
||||
raise ValueError("config.url must be an http(s) URL")
|
||||
return v
|
||||
|
||||
@field_validator("tools_filter")
|
||||
@classmethod
|
||||
def validate_tools_filter(cls, v: list[str]) -> list[str]:
|
||||
if not all(isinstance(tool_name, str) for tool_name in v):
|
||||
raise ValueError("config.tools_filter must be a list of strings")
|
||||
return v
|
||||
|
||||
|
||||
class McpToolDefinition(BaseModel):
|
||||
"""Persisted MCP tool definition."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version")
|
||||
type: Literal["mcp"] = Field(description="Tool type")
|
||||
config: McpToolConfig = Field(description="MCP server configuration")
|
||||
|
||||
|
||||
def _format_validation_error(error: ValidationError) -> str:
|
||||
parts: list[str] = []
|
||||
for item in error.errors():
|
||||
location = ".".join(str(part) for part in item["loc"])
|
||||
parts.append(f"{location}: {item['msg']}")
|
||||
return "; ".join(parts)
|
||||
|
||||
|
||||
def validate_mcp_definition(definition: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate a ``type: "mcp"`` ToolModel definition and return a
|
||||
normalized config dict with defaults applied.
|
||||
|
||||
Raises:
|
||||
McpDefinitionError: if the definition is missing required fields
|
||||
or uses an unsupported transport.
|
||||
"""
|
||||
if not isinstance(definition, dict) or definition.get("type") != "mcp":
|
||||
raise McpDefinitionError("definition.type must be 'mcp'")
|
||||
|
||||
config = definition.get("config")
|
||||
if not isinstance(config, dict):
|
||||
raise McpDefinitionError("definition.config is required and must be an object")
|
||||
|
||||
try:
|
||||
parsed = McpToolDefinition.model_validate(definition)
|
||||
except ValidationError as e:
|
||||
raise McpDefinitionError(_format_validation_error(e)) from e
|
||||
|
||||
return parsed.config.model_dump(exclude={"discovered_tools"})
|
||||
|
||||
|
||||
def _slugify(value: str) -> str:
|
||||
slug = re.sub(r"[^a-z0-9]+", "_", value.strip().lower()).strip("_")
|
||||
return slug
|
||||
|
||||
|
||||
def namespace_function_name(
|
||||
tool_name: str, mcp_tool_name: str, *, fallback: str = "server"
|
||||
) -> str:
|
||||
"""Build a collision-safe LLM function name: ``mcp__<slug>__<tool>``.
|
||||
|
||||
``slug`` is derived from the Dograh ToolModel name; if it slugifies to
|
||||
empty, ``fallback`` (e.g. first 8 chars of tool_uuid) is used instead.
|
||||
"""
|
||||
slug = _slugify(tool_name) or _slugify(fallback) or "server"
|
||||
return f"mcp__{slug}__{mcp_tool_name}"
|
||||
|
|
@ -4,7 +4,8 @@ from typing import Dict, List, Set
|
|||
|
||||
from api.services.workflow.dto import EdgeDataDTO, NodeType, ReactFlowDTO
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
from api.services.workflow.node_specs import REGISTRY
|
||||
from api.services.workflow.node_data import BaseNodeData
|
||||
from api.services.workflow.node_specs import get_spec
|
||||
|
||||
# Regex for matching {{ variable }} template placeholders.
|
||||
# Captures: group(1) = variable path, group(2) = filter name, group(3) = filter value.
|
||||
|
|
@ -62,7 +63,7 @@ class Edge:
|
|||
|
||||
|
||||
class Node:
|
||||
def __init__(self, id: str, node_type: NodeType, data):
|
||||
def __init__(self, id: str, node_type: str, data: BaseNodeData):
|
||||
self.id, self.node_type, self.data = id, node_type, data
|
||||
self.out: Dict[str, "Node"] = {} # forward nodes
|
||||
self.out_edges: List[Edge] = [] # forward edges with properties
|
||||
|
|
@ -75,7 +76,6 @@ class Node:
|
|||
# Type-specific fields — read with getattr so this works for every
|
||||
# node variant in the discriminated union.
|
||||
self.prompt = getattr(data, "prompt", None)
|
||||
self.is_static = getattr(data, "is_static", False)
|
||||
self.allow_interrupt = getattr(data, "allow_interrupt", False)
|
||||
self.extraction_enabled = getattr(data, "extraction_enabled", False)
|
||||
self.extraction_prompt = getattr(data, "extraction_prompt", None)
|
||||
|
|
@ -84,11 +84,11 @@ class Node:
|
|||
self.greeting = getattr(data, "greeting", None)
|
||||
self.greeting_type = getattr(data, "greeting_type", None)
|
||||
self.greeting_recording_id = getattr(data, "greeting_recording_id", None)
|
||||
self.detect_voicemail = getattr(data, "detect_voicemail", False)
|
||||
self.delayed_start = getattr(data, "delayed_start", False)
|
||||
self.delayed_start_duration = getattr(data, "delayed_start_duration", None)
|
||||
self.tool_uuids = getattr(data, "tool_uuids", None)
|
||||
self.document_uuids = getattr(data, "document_uuids", None)
|
||||
self.mcp_tool_filters = getattr(data, "mcp_tool_filters", None)
|
||||
self.pre_call_fetch_enabled = getattr(data, "pre_call_fetch_enabled", False)
|
||||
self.pre_call_fetch_url = getattr(data, "pre_call_fetch_url", None)
|
||||
self.pre_call_fetch_credential_uuid = getattr(
|
||||
|
|
@ -105,11 +105,11 @@ class WorkflowGraph:
|
|||
"""
|
||||
|
||||
def __init__(self, dto: ReactFlowDTO):
|
||||
# build adjacency list. n.type comes off the discriminated-union
|
||||
# variant as a literal string; coerce to NodeType for downstream
|
||||
# comparisons.
|
||||
# Build adjacency list from validated DTO nodes. Core node comparisons
|
||||
# still use NodeType string enums; integration nodes remain plain
|
||||
# strings and resolve constraints through node specs.
|
||||
self.nodes: Dict[str, Node] = {
|
||||
n.id: Node(n.id, NodeType(n.type), n.data) for n in dto.nodes
|
||||
n.id: Node(n.id, n.type, n.data) for n in dto.nodes
|
||||
}
|
||||
|
||||
# Store all edges
|
||||
|
|
@ -139,7 +139,7 @@ class WorkflowGraph:
|
|||
# Get a reference to the global node
|
||||
try:
|
||||
self.global_node_id = [
|
||||
n.id for n in dto.nodes if n.type == NodeType.globalNode
|
||||
n.id for n in dto.nodes if n.type == NodeType.globalNode.value
|
||||
][0]
|
||||
except IndexError:
|
||||
self.global_node_id = None
|
||||
|
|
@ -249,7 +249,7 @@ class WorkflowGraph:
|
|||
def _assert_global_node(self):
|
||||
errors: list[WorkflowError] = []
|
||||
global_node = [
|
||||
n for n in self.nodes.values() if n.node_type == NodeType.globalNode
|
||||
n for n in self.nodes.values() if n.node_type == NodeType.globalNode.value
|
||||
]
|
||||
if not len(global_node) <= 1:
|
||||
errors.append(
|
||||
|
|
@ -281,7 +281,7 @@ class WorkflowGraph:
|
|||
in_deg[m.id] += 1
|
||||
|
||||
for n in self.nodes.values():
|
||||
spec = REGISTRY.get(n.node_type.value)
|
||||
spec = get_spec(n.node_type)
|
||||
if spec is None or spec.graph_constraints is None:
|
||||
continue
|
||||
gc = spec.graph_constraints
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ async def sync_campaign_source(ctx: Dict, campaign_id: int) -> None:
|
|||
Phase 1: Syncs data from configured source to queued_runs table
|
||||
- Campaign state should already be 'syncing'
|
||||
- Determines source type from campaign configuration
|
||||
- Fetches data via appropriate sync service (Google Sheets, HubSpot, etc.)
|
||||
- Fetches data via the appropriate sync service
|
||||
- Creates queued_run entries with unique source_uuid
|
||||
- Updates campaign total_rows
|
||||
- Transitions campaign state to 'running' on success
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
"""Execute integrations (QA analysis, webhooks) after workflow run completion."""
|
||||
|
||||
import random
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
|
@ -14,6 +13,11 @@ from api.constants import BACKEND_API_ENDPOINT
|
|||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.integrations import (
|
||||
IntegrationCompletionContext,
|
||||
has_completion_handlers,
|
||||
run_completion_handlers,
|
||||
)
|
||||
from api.services.pipecat.tracing_config import register_org_langfuse_credentials
|
||||
from api.services.workflow.dto import (
|
||||
QANodeData,
|
||||
|
|
@ -214,16 +218,20 @@ async def run_integrations_post_workflow_run(_ctx, workflow_run_id: int):
|
|||
nodes = workflow_definition.get("nodes", [])
|
||||
qa_nodes = [n for n in nodes if n.get("type") == "qa"]
|
||||
webhook_nodes = [n for n in nodes if n.get("type") == "webhook"]
|
||||
has_registered_integrations = has_completion_handlers(workflow_definition)
|
||||
|
||||
# Step 4: Generate public access token if webhooks exist or campaign_id is set
|
||||
# Step 4: Generate a public access token for any run that needs post-call work.
|
||||
has_campaign = workflow_run.campaign_id is not None
|
||||
if not webhook_nodes and not qa_nodes and not has_campaign:
|
||||
if (
|
||||
not webhook_nodes
|
||||
and not qa_nodes
|
||||
and not has_registered_integrations
|
||||
and not has_campaign
|
||||
):
|
||||
logger.debug("No integration nodes and no campaign, skipping")
|
||||
return
|
||||
|
||||
public_token = None
|
||||
if webhook_nodes or has_campaign:
|
||||
public_token = await db_client.ensure_public_access_token(workflow_run_id)
|
||||
public_token = await db_client.ensure_public_access_token(workflow_run_id)
|
||||
|
||||
# Step 5: Run QA analysis before webhooks
|
||||
if qa_nodes:
|
||||
|
|
@ -263,17 +271,37 @@ async def run_integrations_post_workflow_run(_ctx, workflow_run_id: int):
|
|||
workflow_run_id
|
||||
)
|
||||
|
||||
# Step 6: Execute webhooks
|
||||
# Step 6: Run registered third-party integrations after uploads are complete
|
||||
integration_results = await run_completion_handlers(
|
||||
context=IntegrationCompletionContext(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run=workflow_run,
|
||||
workflow_definition=workflow_definition,
|
||||
definition_id=definition_id,
|
||||
organization_id=organization_id,
|
||||
public_token=public_token,
|
||||
)
|
||||
)
|
||||
|
||||
if integration_results:
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id, annotations=integration_results
|
||||
)
|
||||
workflow_run, _ = await db_client.get_workflow_run_with_context(
|
||||
workflow_run_id
|
||||
)
|
||||
|
||||
# Step 7: Execute webhooks
|
||||
if not webhook_nodes:
|
||||
logger.debug("No webhook nodes in workflow")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(webhook_nodes)} webhook nodes to execute")
|
||||
|
||||
# Step 7: Build render context (includes annotations from QA)
|
||||
# Step 8: Build render context (includes annotations from QA and integrations)
|
||||
render_context = _build_render_context(workflow_run, public_token)
|
||||
|
||||
# Step 8: Execute each webhook node
|
||||
# Step 9: Execute each webhook node
|
||||
for node in webhook_nodes:
|
||||
node_id = node.get("id", "unknown")
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -15,16 +15,14 @@ import pytest
|
|||
|
||||
from api.services.workflow.dto import (
|
||||
AgentNodeData,
|
||||
AgentRFNode,
|
||||
EdgeDataDTO,
|
||||
EndCallNodeData,
|
||||
EndCallRFNode,
|
||||
ExtractionVariableDTO,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
StartCallNodeData,
|
||||
StartCallRFNode,
|
||||
VariableType,
|
||||
)
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
|
|
@ -270,8 +268,9 @@ def simple_workflow() -> WorkflowGraph:
|
|||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -290,8 +289,9 @@ def simple_workflow() -> WorkflowGraph:
|
|||
],
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
@ -333,8 +333,9 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -353,8 +354,9 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
],
|
||||
),
|
||||
),
|
||||
AgentRFNode(
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type="agentNode",
|
||||
position=Position(x=0, y=200),
|
||||
data=AgentNodeData(
|
||||
name="Collect Info",
|
||||
|
|
@ -372,8 +374,9 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
],
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=400),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
@ -424,8 +427,9 @@ def three_node_workflow_extraction_start_only() -> WorkflowGraph:
|
|||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -444,8 +448,9 @@ def three_node_workflow_extraction_start_only() -> WorkflowGraph:
|
|||
],
|
||||
),
|
||||
),
|
||||
AgentRFNode(
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type="agentNode",
|
||||
position=Position(x=0, y=200),
|
||||
data=AgentNodeData(
|
||||
name="Collect Info",
|
||||
|
|
@ -455,8 +460,9 @@ def three_node_workflow_extraction_start_only() -> WorkflowGraph:
|
|||
extraction_enabled=False, # Explicitly disabled for testing
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=400),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
@ -503,8 +509,9 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
|||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -515,8 +522,9 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
AgentRFNode(
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type="agentNode",
|
||||
position=Position(x=0, y=200),
|
||||
data=AgentNodeData(
|
||||
name="Collect Info",
|
||||
|
|
@ -526,8 +534,9 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
|||
extraction_enabled=False, # Explicitly disabled for testing
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=400),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
|
|||
|
|
@ -63,7 +63,6 @@
|
|||
},
|
||||
"data": {
|
||||
"prompt": "Hello, I am Abhishek from Dograh. ",
|
||||
"is_static": true,
|
||||
"name": "Start Call",
|
||||
"is_start": true
|
||||
},
|
||||
|
|
@ -83,7 +82,6 @@
|
|||
},
|
||||
"data": {
|
||||
"prompt": "Thank you for calling Dograh. Have a great day!",
|
||||
"is_static": true,
|
||||
"name": "End Call"
|
||||
},
|
||||
"measured": {
|
||||
|
|
@ -161,4 +159,4 @@
|
|||
"y": 0,
|
||||
"zoom": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
0
api/tests/support/__init__.py
Normal file
0
api/tests/support/__init__.py
Normal file
103
api/tests/support/mcp_mock_server.py
Normal file
103
api/tests/support/mcp_mock_server.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""A real FastMCP server exposing 2 tools over streamable-HTTP, run in a
|
||||
background uvicorn thread on an ephemeral port. Used to exercise the real
|
||||
MCP protocol path in tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
import threading
|
||||
from typing import AsyncIterator
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastmcp import FastMCP
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
def _build_app(required_headers: dict[str, str] | None = None):
|
||||
mcp = FastMCP("mock-mcp")
|
||||
|
||||
@mcp.tool()
|
||||
def echo(text: str) -> str:
|
||||
"""Echo the provided text back."""
|
||||
return f"echo:{text}"
|
||||
|
||||
@mcp.tool()
|
||||
def add(a: int, b: int) -> int:
|
||||
"""Add two integers."""
|
||||
return a + b
|
||||
|
||||
# FastMCP 3.x: ASGI app for streamable-HTTP transport at "/mcp".
|
||||
app = mcp.http_app()
|
||||
if not required_headers:
|
||||
return app
|
||||
|
||||
normalized = {k.lower(): v for k, v in required_headers.items()}
|
||||
|
||||
async def guarded_app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
headers = {
|
||||
key.decode("latin-1").lower(): value.decode("latin-1")
|
||||
for key, value in scope.get("headers", [])
|
||||
}
|
||||
for header_name, expected_value in normalized.items():
|
||||
if headers.get(header_name) != expected_value:
|
||||
response = JSONResponse(
|
||||
{"detail": f"Missing or invalid header: {header_name}"},
|
||||
status_code=401,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
await app(scope, receive, send)
|
||||
|
||||
return guarded_app
|
||||
|
||||
|
||||
def _free_port() -> int:
|
||||
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def running_mcp_server(
|
||||
*, required_headers: dict[str, str] | None = None
|
||||
) -> AsyncIterator[str]:
|
||||
"""Yield the base streamable-HTTP URL of a live mock MCP server."""
|
||||
port = _free_port()
|
||||
config = uvicorn.Config(
|
||||
_build_app(required_headers), host="127.0.0.1", port=port, log_level="warning"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
|
||||
base_url = f"http://127.0.0.1:{port}/mcp"
|
||||
server_ready = False
|
||||
for _ in range(50):
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
await client.get(base_url, timeout=0.5)
|
||||
server_ready = True
|
||||
break
|
||||
except Exception:
|
||||
await asyncio.sleep(0.1)
|
||||
if not server_ready:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
raise RuntimeError(f"Mock MCP server at {base_url} failed to start within 5s")
|
||||
try:
|
||||
yield base_url
|
||||
finally:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
if thread.is_alive():
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Mock MCP server thread did not terminate within 5s",
|
||||
ResourceWarning,
|
||||
)
|
||||
|
|
@ -153,45 +153,16 @@ async def test_verify_inbound_signature_rejects_missing_config_public_key():
|
|||
_, headers = _signed_headers(body)
|
||||
provider = _provider()
|
||||
|
||||
# REMOVE-AFTER 2026-05-15: drop the patch wrapper once
|
||||
# TELNYX_WEBHOOK_VERIFICATION_OPTIONAL is removed; the bare call below
|
||||
# will then assert the only path.
|
||||
with patch(
|
||||
"api.services.telephony.providers.telnyx.provider.TELNYX_WEBHOOK_VERIFICATION_OPTIONAL",
|
||||
False,
|
||||
):
|
||||
result = await provider.verify_inbound_signature(
|
||||
"https://example.test/api/v1/telephony/inbound/run",
|
||||
json.loads(body),
|
||||
headers,
|
||||
body,
|
||||
)
|
||||
result = await provider.verify_inbound_signature(
|
||||
"https://example.test/api/v1/telephony/inbound/run",
|
||||
json.loads(body),
|
||||
headers,
|
||||
body,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# REMOVE-AFTER 2026-05-15: delete this whole test along with the
|
||||
# TELNYX_WEBHOOK_VERIFICATION_OPTIONAL flag.
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_inbound_signature_allows_missing_key_when_optional_flag_set():
|
||||
body = _body()
|
||||
_, headers = _signed_headers(body)
|
||||
provider = _provider()
|
||||
|
||||
with patch(
|
||||
"api.services.telephony.providers.telnyx.provider.TELNYX_WEBHOOK_VERIFICATION_OPTIONAL",
|
||||
True,
|
||||
):
|
||||
result = await provider.verify_inbound_signature(
|
||||
"https://example.test/api/v1/telephony/inbound/run",
|
||||
json.loads(body),
|
||||
headers,
|
||||
body,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_inbound_signature_reads_headers_case_insensitively():
|
||||
body = _body()
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from dograh_sdk.typed import (
|
|||
Qa,
|
||||
StartCall,
|
||||
Trigger,
|
||||
Tuner,
|
||||
TypedNode,
|
||||
Webhook,
|
||||
)
|
||||
|
|
@ -50,6 +51,7 @@ def client() -> _StubClient:
|
|||
(Trigger, "trigger"),
|
||||
(Webhook, "webhook"),
|
||||
(Qa, "qa"),
|
||||
(Tuner, "tuner"),
|
||||
],
|
||||
ids=lambda v: v.__name__ if isinstance(v, type) else v,
|
||||
)
|
||||
|
|
@ -68,8 +70,15 @@ def test_typed_class_declares_spec_name(cls: type[TypedNode], expected_type: str
|
|||
inst = cls(name="t")
|
||||
elif cls is Webhook:
|
||||
inst = cls(name="wh")
|
||||
else: # Qa
|
||||
elif cls is Qa:
|
||||
inst = cls(name="qa")
|
||||
else: # Tuner
|
||||
inst = cls(
|
||||
name="tuner",
|
||||
tuner_agent_id="agent",
|
||||
tuner_workspace_id=1,
|
||||
tuner_api_key="secret",
|
||||
)
|
||||
assert inst.type == expected_type
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,37 @@ async def test_dto():
|
|||
assert dto is not None
|
||||
|
||||
|
||||
def test_dto_ignores_legacy_unknown_node_data_fields():
|
||||
dto = ReactFlowDTO.model_validate(
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "n1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"name": "Start",
|
||||
"prompt": "Hello",
|
||||
"is_static": True,
|
||||
"detect_voicemail": True,
|
||||
"wait_for_user_response": False,
|
||||
"wait_for_user_response_timeout": 2.5,
|
||||
"legacy_field": "ignored",
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
)
|
||||
|
||||
data = dto.nodes[0].data.model_dump()
|
||||
assert "is_static" not in data
|
||||
assert "detect_voicemail" not in data
|
||||
assert "wait_for_user_response" not in data
|
||||
assert "wait_for_user_response_timeout" not in data
|
||||
assert "legacy_field" not in data
|
||||
|
||||
|
||||
def test_sanitize_strips_ui_runtime_fields():
|
||||
definition = {
|
||||
"viewport": {"x": 0, "y": 0, "zoom": 1},
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
from api.routes.webrtc_signaling import is_private_ip_candidate
|
||||
from api.enums import Environment
|
||||
from api.routes.webrtc_signaling import (
|
||||
NonRelayFilterPolicy,
|
||||
_keep_candidate,
|
||||
is_local_or_cgnat_ip,
|
||||
is_private_ip_candidate,
|
||||
resolve_ice_filter_policies,
|
||||
)
|
||||
|
||||
|
||||
class TestIsPrivateIpCandidate:
|
||||
|
|
@ -142,3 +149,78 @@ class TestIsPrivateIpCandidate:
|
|||
"candidate:999 1 tcp 1518280447 192.168.1.100 9 typ host tcptype active"
|
||||
)
|
||||
assert is_private_ip_candidate(candidate) is True
|
||||
|
||||
|
||||
class TestIsLocalOrCgnatIp:
|
||||
def test_loopback_is_local(self):
|
||||
assert is_local_or_cgnat_ip("127.0.0.1") is True
|
||||
|
||||
def test_link_local_is_local(self):
|
||||
assert is_local_or_cgnat_ip("169.254.1.1") is True
|
||||
|
||||
def test_cgnat_is_local(self):
|
||||
assert is_local_or_cgnat_ip("100.64.0.1") is True
|
||||
|
||||
def test_public_ipv4_is_not_local(self):
|
||||
assert is_local_or_cgnat_ip("8.8.8.8") is False
|
||||
|
||||
|
||||
class TestKeepCandidate:
|
||||
def test_private_relay_candidate_survives_private_policy(self):
|
||||
candidate = (
|
||||
"candidate:111 1 udp 41885439 192.168.1.50 50000 typ relay raddr 0.0.0.0 rport 0"
|
||||
)
|
||||
assert _keep_candidate(candidate, NonRelayFilterPolicy.PRIVATE) is True
|
||||
|
||||
def test_private_host_candidate_drops_under_private_policy(self):
|
||||
candidate = (
|
||||
"candidate:123 1 udp 2122260223 192.168.50.24 63603 typ host generation 0"
|
||||
)
|
||||
assert _keep_candidate(candidate, NonRelayFilterPolicy.PRIVATE) is False
|
||||
|
||||
|
||||
class TestResolveIceFilterPolicies:
|
||||
def test_local_deployment_keeps_all_candidates(self):
|
||||
outbound, inbound = resolve_ice_filter_policies(
|
||||
Environment.LOCAL.value,
|
||||
False,
|
||||
"",
|
||||
)
|
||||
assert outbound == NonRelayFilterPolicy.NONE
|
||||
assert inbound == NonRelayFilterPolicy.NONE
|
||||
|
||||
def test_private_lan_remote_keeps_all_candidates(self):
|
||||
outbound, inbound = resolve_ice_filter_policies(
|
||||
Environment.PRODUCTION.value,
|
||||
False,
|
||||
"192.168.50.24",
|
||||
)
|
||||
assert outbound == NonRelayFilterPolicy.NONE
|
||||
assert inbound == NonRelayFilterPolicy.NONE
|
||||
|
||||
def test_public_remote_filters_private_candidates(self):
|
||||
outbound, inbound = resolve_ice_filter_policies(
|
||||
Environment.PRODUCTION.value,
|
||||
False,
|
||||
"8.8.8.8",
|
||||
)
|
||||
assert outbound == NonRelayFilterPolicy.PRIVATE
|
||||
assert inbound == NonRelayFilterPolicy.PRIVATE
|
||||
|
||||
def test_force_turn_relay_stays_relay_only_on_private_lan(self):
|
||||
outbound, inbound = resolve_ice_filter_policies(
|
||||
Environment.PRODUCTION.value,
|
||||
True,
|
||||
"192.168.50.24",
|
||||
)
|
||||
assert outbound == NonRelayFilterPolicy.ALL
|
||||
assert inbound == NonRelayFilterPolicy.NONE
|
||||
|
||||
def test_force_turn_relay_keeps_public_remote_private_filter(self):
|
||||
outbound, inbound = resolve_ice_filter_policies(
|
||||
Environment.PRODUCTION.value,
|
||||
True,
|
||||
"8.8.8.8",
|
||||
)
|
||||
assert outbound == NonRelayFilterPolicy.ALL
|
||||
assert inbound == NonRelayFilterPolicy.PRIVATE
|
||||
|
|
|
|||
63
api/tests/test_mcp_auth.py
Normal file
63
api/tests/test_mcp_auth.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_mcp_request_accepts_bearer_authorization():
|
||||
user = MagicMock()
|
||||
user.id = 1
|
||||
user.selected_organization_id = 90
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.auth.get_http_headers",
|
||||
return_value={"authorization": "Bearer secret-api-key"},
|
||||
) as get_headers,
|
||||
patch(
|
||||
"api.mcp_server.auth._handle_api_key_auth",
|
||||
AsyncMock(return_value=user),
|
||||
) as handle_auth,
|
||||
):
|
||||
authed = await authenticate_mcp_request()
|
||||
|
||||
assert authed is user
|
||||
get_headers.assert_called_once_with(include={"authorization"})
|
||||
handle_auth.assert_awaited_once_with("secret-api-key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_mcp_request_accepts_x_api_key():
|
||||
user = MagicMock()
|
||||
user.id = 2
|
||||
user.selected_organization_id = 91
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.auth.get_http_headers",
|
||||
return_value={"x-api-key": "secret-api-key"},
|
||||
) as get_headers,
|
||||
patch(
|
||||
"api.mcp_server.auth._handle_api_key_auth",
|
||||
AsyncMock(return_value=user),
|
||||
) as handle_auth,
|
||||
):
|
||||
authed = await authenticate_mcp_request()
|
||||
|
||||
assert authed is user
|
||||
get_headers.assert_called_once_with(include={"authorization"})
|
||||
handle_auth.assert_awaited_once_with("secret-api-key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_mcp_request_rejects_missing_api_key():
|
||||
with patch("api.mcp_server.auth.get_http_headers", return_value={}) as get_headers:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await authenticate_mcp_request()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Missing API key" in str(exc_info.value.detail)
|
||||
get_headers.assert_called_once_with(include={"authorization"})
|
||||
181
api/tests/test_mcp_custom_tool_manager.py
Normal file
181
api/tests/test_mcp_custom_tool_manager.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.tests.support.mcp_mock_server import running_mcp_server
|
||||
|
||||
|
||||
def _mcp_tool():
|
||||
t = MagicMock()
|
||||
t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8]
|
||||
t.name = "Acme MCP"
|
||||
t.category = ToolCategory.MCP.value
|
||||
t.definition = {"type": "mcp", "config": {"url": "https://x/mcp"}}
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tool_schemas_and_handler_for_mcp(monkeypatch):
|
||||
async with running_mcp_server() as base_url:
|
||||
tool = _mcp_tool()
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
|
||||
engine = MagicMock()
|
||||
engine._mcp_sessions = {tool.tool_uuid: session}
|
||||
registered = {}
|
||||
reg_kwargs = {}
|
||||
|
||||
def _reg(name, fn, **kw):
|
||||
registered[name] = fn
|
||||
reg_kwargs[name] = kw
|
||||
|
||||
engine.llm.register_function = _reg
|
||||
|
||||
mgr = CustomToolManager(engine)
|
||||
mgr.get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
||||
)
|
||||
|
||||
try:
|
||||
schemas = await mgr.get_tool_schemas([tool.tool_uuid])
|
||||
names = sorted(s.name for s in schemas)
|
||||
assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"]
|
||||
|
||||
await mgr.register_handlers([tool.tool_uuid])
|
||||
assert "mcp__acme_mcp__echo" in registered
|
||||
assert reg_kwargs["mcp__acme_mcp__echo"]["timeout_secs"] == pytest.approx(
|
||||
15.0
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
class P:
|
||||
function_name = "mcp__acme_mcp__echo"
|
||||
arguments = {"text": "yo"}
|
||||
|
||||
async def result_callback(self, r, *, properties=None):
|
||||
captured["r"] = r
|
||||
|
||||
await registered["mcp__acme_mcp__echo"](P())
|
||||
assert "echo:yo" in str(captured["r"])
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unavailable_mcp_session_contributes_nothing(monkeypatch):
|
||||
tool = _mcp_tool()
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=1,
|
||||
sse_read_timeout_secs=1,
|
||||
)
|
||||
await session.start() # degrades
|
||||
|
||||
engine = MagicMock()
|
||||
engine._mcp_sessions = {tool.tool_uuid: session}
|
||||
mgr = CustomToolManager(engine)
|
||||
mgr.get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]))
|
||||
|
||||
schemas = await mgr.get_tool_schemas([tool.tool_uuid])
|
||||
assert schemas == []
|
||||
await mgr.register_handlers([tool.tool_uuid]) # must not raise
|
||||
|
||||
|
||||
def test_call_timeout_secs_is_read_timeout_plus_buffer():
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-abc123",
|
||||
tool_name="Acme MCP",
|
||||
url="https://x/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
assert session.call_timeout_secs == 25.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_node_mcp_filter_intersection(monkeypatch):
|
||||
async with running_mcp_server() as base_url:
|
||||
tool = _mcp_tool()
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
|
||||
engine = MagicMock()
|
||||
engine._mcp_sessions = {tool.tool_uuid: session}
|
||||
registered = {}
|
||||
engine.llm.register_function = lambda name, fn, **kw: registered.__setitem__(
|
||||
name, fn
|
||||
)
|
||||
|
||||
mgr = CustomToolManager(engine)
|
||||
mgr.get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
||||
)
|
||||
try:
|
||||
# Allow only raw "echo" for this node
|
||||
filters = {tool.tool_uuid: ["echo"]}
|
||||
schemas = await mgr.get_tool_schemas(
|
||||
[tool.tool_uuid], mcp_tool_filters=filters
|
||||
)
|
||||
# Check only "echo" schema returned (namespaced name depends on tool.name)
|
||||
assert len(schemas) == 1
|
||||
assert all("echo" in s.name for s in schemas)
|
||||
|
||||
await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters=filters)
|
||||
assert len(registered) == 1
|
||||
assert all("echo" in k for k in registered)
|
||||
|
||||
# No filter entry for this uuid = none (default-none)
|
||||
registered.clear()
|
||||
result = await mgr.get_tool_schemas([tool.tool_uuid], mcp_tool_filters={})
|
||||
assert result == []
|
||||
await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters={})
|
||||
assert registered == {}
|
||||
|
||||
# mcp_tool_filters=None = backward-compatible (all tools)
|
||||
registered.clear()
|
||||
all_schemas = await mgr.get_tool_schemas([tool.tool_uuid])
|
||||
assert len(all_schemas) == 2 # both echo and add
|
||||
await mgr.register_handlers([tool.tool_uuid])
|
||||
assert len(registered) == 2 # both handlers registered
|
||||
finally:
|
||||
await session.close()
|
||||
359
api/tests/test_mcp_docs_search.py
Normal file
359
api/tests/test_mcp_docs_search.py
Normal file
|
|
@ -0,0 +1,359 @@
|
|||
"""Unit tests for the MCP docs discovery tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.mcp_server.tools import docs_search as docs_search_module
|
||||
from api.mcp_server.tools.docs_search import (
|
||||
_docs_index,
|
||||
_extract_page_title,
|
||||
_resolve_docs_root,
|
||||
_score_page,
|
||||
_strip_frontmatter,
|
||||
_tokenize_query,
|
||||
list_docs,
|
||||
read_doc,
|
||||
search_docs,
|
||||
)
|
||||
|
||||
|
||||
def _clear_docs_caches() -> None:
|
||||
docs_search_module._docs_index.cache_clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_docs_root(tmp_path: Path) -> Path:
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
|
||||
(docs_root / "getting-started").mkdir()
|
||||
(docs_root / "getting-started" / "index.mdx").write_text(
|
||||
"---\n"
|
||||
'title: "Getting started"\n'
|
||||
'description: "Start using Dograh."\n'
|
||||
"---\n\n"
|
||||
"# Getting started\n\n"
|
||||
"Welcome to Dograh.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
(docs_root / "voice-agent").mkdir()
|
||||
(docs_root / "voice-agent" / "introduction.mdx").write_text(
|
||||
"---\n"
|
||||
'title: "Voice Agent Builder"\n'
|
||||
'description: "Build conversational workflows."\n'
|
||||
"---\n\n"
|
||||
"# Voice Agent Builder\n\n"
|
||||
"Build workflows with nodes and tools.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
(docs_root / "voice-agent" / "tools").mkdir()
|
||||
(docs_root / "voice-agent" / "tools" / "mcp-tool.mdx").write_text(
|
||||
"---\n"
|
||||
'title: "MCP Tool"\n'
|
||||
'description: "Connect external MCP servers."\n'
|
||||
'llm_hint: "Use for MCP server setup, remote tools, or model context protocol questions."\n'
|
||||
"aliases:\n"
|
||||
' - "model context protocol"\n'
|
||||
"---\n\n"
|
||||
"# MCP Tool\n\n"
|
||||
"Connect an external MCP server to your voice agent.\n\n"
|
||||
"## Authentication\n\n"
|
||||
"Provide the MCP endpoint URL and headers.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
(docs_root / "deployment").mkdir()
|
||||
(docs_root / "deployment" / "docker.mdx").write_text(
|
||||
"---\n"
|
||||
'title: "Docker"\n'
|
||||
'description: "Deploy Dograh with Docker."\n'
|
||||
'llm_hint: "Use for Docker deployment, local setup, remote setup, TURN server, coturn, or WebRTC connectivity questions."\n'
|
||||
"aliases:\n"
|
||||
' - "coturn"\n'
|
||||
' - "turn server"\n'
|
||||
"---\n\n"
|
||||
"# Docker\n\n"
|
||||
"Run Dograh with Docker.\n\n"
|
||||
"## Troubleshooting WebRTC Connectivity\n\n"
|
||||
"If audio fails or ICE fails, configure a TURN server. Coturn is the recommended choice.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Hidden/orphaned docs page: present on disk but not in docs.json, so it
|
||||
# must not be indexed by the MCP tools.
|
||||
(docs_root / "internal-only.mdx").write_text(
|
||||
"---\n"
|
||||
'title: "Internal TURN Notes"\n'
|
||||
"---\n\n"
|
||||
"# Internal TURN Notes\n\n"
|
||||
"This page mentions zyxinternalturntoken but is not user-facing.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
(docs_root / "AGENTS.md").write_text("# Internal instructions\n", encoding="utf-8")
|
||||
|
||||
(docs_root / "docs.json").write_text(
|
||||
"""{
|
||||
"navigation": {
|
||||
"tabs": [
|
||||
{
|
||||
"tab": "Guides",
|
||||
"groups": [
|
||||
{
|
||||
"group": "Getting started",
|
||||
"pages": [
|
||||
"getting-started/index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Voice Agent Builder",
|
||||
"pages": [
|
||||
"voice-agent/introduction",
|
||||
{
|
||||
"group": "Tools",
|
||||
"pages": [
|
||||
"voice-agent/tools/mcp-tool"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"tab": "Developer",
|
||||
"groups": [
|
||||
{
|
||||
"group": "Deployment",
|
||||
"pages": [
|
||||
"deployment/docker"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
_clear_docs_caches()
|
||||
with patch.dict(os.environ, {"DOGRAH_DOCS_PATH": str(docs_root)}):
|
||||
yield docs_root
|
||||
_clear_docs_caches()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authed_user():
|
||||
class _FakeUser:
|
||||
selected_organization_id = 1
|
||||
id = 42
|
||||
|
||||
with patch(
|
||||
"api.mcp_server.tools.docs_search.authenticate_mcp_request",
|
||||
new=AsyncMock(return_value=_FakeUser()),
|
||||
):
|
||||
yield _FakeUser()
|
||||
|
||||
|
||||
def test_tokenize_query_dedupes_and_drops_stopwords():
|
||||
assert _tokenize_query("How do I configure a TURN server TURN?") == [
|
||||
"configure",
|
||||
"turn",
|
||||
"server",
|
||||
]
|
||||
|
||||
|
||||
def test_tokenize_query_empty_input_returns_empty():
|
||||
assert _tokenize_query("") == []
|
||||
assert _tokenize_query("?? // !!") == []
|
||||
|
||||
|
||||
def test_strip_frontmatter_removes_yaml_block():
|
||||
body = '---\ntitle: "X"\n---\n\n# Heading\n'
|
||||
assert _strip_frontmatter(body).startswith("# Heading")
|
||||
|
||||
|
||||
def test_extract_page_title_prefers_frontmatter():
|
||||
body = '---\ntitle: "Front Title"\n---\n\n# Heading Title\n'
|
||||
assert _extract_page_title(body, fallback="x.mdx") == "Front Title"
|
||||
|
||||
|
||||
def test_extract_page_title_falls_back_to_first_heading():
|
||||
body = "# Heading Title\nbody\n"
|
||||
assert _extract_page_title(body, fallback="x.mdx") == "Heading Title"
|
||||
|
||||
|
||||
def test_score_page_uses_llm_hint_and_aliases():
|
||||
page = docs_search_module.DocPage(
|
||||
path="deployment/docker",
|
||||
file_path="deployment/docker.mdx",
|
||||
title="Docker",
|
||||
description="Deploy Dograh with Docker.",
|
||||
llm_hint="Use for TURN server and coturn setup.",
|
||||
aliases=("coturn",),
|
||||
breadcrumb=("Developer", "Deployment"),
|
||||
content="Docker deployment.",
|
||||
sections=(
|
||||
docs_search_module.DocSection(
|
||||
title="Troubleshooting WebRTC Connectivity",
|
||||
slug="troubleshooting-webrtc-connectivity",
|
||||
level=2,
|
||||
content="Configure a TURN server with coturn.",
|
||||
),
|
||||
),
|
||||
order=0,
|
||||
)
|
||||
score, section = _score_page(page, ["coturn"])
|
||||
assert score > 0
|
||||
assert section is not None
|
||||
assert section.slug == "troubleshooting-webrtc-connectivity"
|
||||
|
||||
|
||||
def test_resolve_docs_root_honors_env_override(tmp_path: Path):
|
||||
docs = tmp_path / "custom_docs"
|
||||
docs.mkdir()
|
||||
(docs / "docs.json").write_text("{}", encoding="utf-8")
|
||||
with patch.dict(os.environ, {"DOGRAH_DOCS_PATH": str(docs)}):
|
||||
assert _resolve_docs_root() == docs.resolve()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_docs_ranks_turn_doc_and_uses_route_path(
|
||||
fake_docs_root, authed_user
|
||||
):
|
||||
results = await search_docs("How do I configure coturn for WebRTC?")
|
||||
assert results
|
||||
assert results[0]["path"] == "deployment/docker"
|
||||
assert results[0]["section_slug"] == "troubleshooting-webrtc-connectivity"
|
||||
assert "TURN server" in results[0]["llm_hint"]
|
||||
assert "snippet" not in results[0]
|
||||
assert "score" not in results[0]
|
||||
assert "url" not in results[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_docs_indexes_only_docs_json_pages(fake_docs_root, authed_user):
|
||||
results = await search_docs("zyxinternalturntoken")
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_docs_respects_limit(fake_docs_root, authed_user):
|
||||
results = await search_docs("dograh", limit=1)
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_docs_returns_empty_when_no_match(fake_docs_root, authed_user):
|
||||
assert await search_docs("xyzzy unrelated zzz") == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_docs_returns_empty_when_no_corpus(
|
||||
tmp_path, authed_user, monkeypatch
|
||||
):
|
||||
nonexistent = tmp_path / "no-docs-here"
|
||||
monkeypatch.setenv("DOGRAH_DOCS_PATH", str(nonexistent))
|
||||
_clear_docs_caches()
|
||||
with patch(
|
||||
"api.mcp_server.tools.docs_search._resolve_docs_root", return_value=None
|
||||
):
|
||||
assert await search_docs("anything") == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_docs_rejects_empty_query(fake_docs_root, authed_user):
|
||||
with pytest.raises(ValueError, match="non-empty string"):
|
||||
await search_docs("")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_docs_rejects_query_with_only_stopwords(
|
||||
fake_docs_root, authed_user
|
||||
):
|
||||
with pytest.raises(ValueError, match="non-stopword"):
|
||||
await search_docs("how do I")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_docs_rejects_zero_limit(fake_docs_root, authed_user):
|
||||
with pytest.raises(ValueError, match="at least 1"):
|
||||
await search_docs("Dograh", limit=0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_docs_returns_top_level_sections(fake_docs_root, authed_user):
|
||||
results = await list_docs()
|
||||
assert results[0]["kind"] == "section"
|
||||
assert results[0]["path"] == "guides/getting-started"
|
||||
assert results[1]["path"] == "guides/voice-agent-builder"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_docs_depth_expands_children(fake_docs_root, authed_user):
|
||||
results = await list_docs("guides/voice-agent-builder", depth=2)
|
||||
paths = [item["path"] for item in results]
|
||||
assert "voice-agent/introduction" in paths
|
||||
assert "guides/voice-agent-builder/tools" in paths
|
||||
assert "voice-agent/tools/mcp-tool" in paths
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_docs_rejects_unknown_section(fake_docs_root, authed_user):
|
||||
with pytest.raises(HTTPException, match="Unknown docs section"):
|
||||
await list_docs("nope")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_doc_returns_full_page_and_sections(fake_docs_root, authed_user):
|
||||
result = await read_doc("deployment/docker")
|
||||
assert result["path"] == "deployment/docker"
|
||||
assert result["title"] == "Docker"
|
||||
assert "url" not in result
|
||||
section_slugs = [section["slug"] for section in result["sections"]]
|
||||
assert "docker" in section_slugs
|
||||
assert "troubleshooting-webrtc-connectivity" in section_slugs
|
||||
assert "Coturn" in result["content"] or "coturn" in result["content"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_doc_can_target_section(fake_docs_root, authed_user):
|
||||
result = await read_doc(
|
||||
"deployment/docker",
|
||||
section="troubleshooting-webrtc-connectivity",
|
||||
)
|
||||
assert result["section_slug"] == "troubleshooting-webrtc-connectivity"
|
||||
assert "ICE fails" in result["content"] or "TURN server" in result["content"]
|
||||
assert "Run Dograh with Docker." not in result["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_doc_rejects_unknown_page(fake_docs_root, authed_user):
|
||||
with pytest.raises(HTTPException, match="Unknown docs page"):
|
||||
await read_doc("missing/page")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_doc_rejects_unknown_section(fake_docs_root, authed_user):
|
||||
with pytest.raises(HTTPException, match="Unknown section"):
|
||||
await read_doc("deployment/docker", section="missing-section")
|
||||
|
||||
|
||||
def test_docs_index_uses_docs_json_navigation(fake_docs_root):
|
||||
index = _docs_index()
|
||||
assert "internal-only" not in index.pages_by_path
|
||||
assert "guides/voice-agent-builder/tools" in index.sections_by_path
|
||||
assert index.pages_by_path["voice-agent/tools/mcp-tool"].breadcrumb == (
|
||||
"Guides",
|
||||
"Voice Agent Builder",
|
||||
"Tools",
|
||||
)
|
||||
115
api/tests/test_mcp_instructions_drift.py
Normal file
115
api/tests/test_mcp_instructions_drift.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""Drift guards between the static MCP guide and the live tool surface.
|
||||
|
||||
`api/mcp_server/instructions.py` is free text baked into the client
|
||||
system prompt. It is *not* the authoritative description of the tools —
|
||||
names, signatures, and per-tool error codes reach the model dynamically
|
||||
via `tools/list`, derived from each tool's own function signature and
|
||||
docstring. These tests fail on the two classic drift modes:
|
||||
|
||||
1. The guide references a tool that is no longer registered (renamed or
|
||||
removed) — the model would be told to call something that 404s.
|
||||
2. A tool returns an `error_code` that is absent from the description it
|
||||
ships via `tools/list` — the model can't learn to recover from it.
|
||||
|
||||
Keep the guide about orchestration (call order, hard constraints) and let
|
||||
the tools describe themselves.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from api.mcp_server import instructions as instructions_module
|
||||
from api.mcp_server.server import mcp
|
||||
from api.mcp_server.tools import create_workflow as create_workflow_module
|
||||
from api.mcp_server.tools import save_workflow as save_workflow_module
|
||||
|
||||
# Every registered MCP tool name starts with one of these verbs. A
|
||||
# backticked snake_case token in the guide whose leading word is a verb is
|
||||
# treated as a tool reference; field/reference names like `tool_refs`,
|
||||
# `credential_ref`, or `pre_call_fetch` don't start with a verb and are
|
||||
# correctly ignored. Extend this only when a new tool introduces a new
|
||||
# leading verb (a missing verb under-checks, it never false-fails).
|
||||
_TOOL_VERB_PREFIXES = frozenset(
|
||||
{
|
||||
"search",
|
||||
"read",
|
||||
"list",
|
||||
"get",
|
||||
"save",
|
||||
"create",
|
||||
"update",
|
||||
"delete",
|
||||
"add",
|
||||
"remove",
|
||||
"set",
|
||||
}
|
||||
)
|
||||
|
||||
# A backtick immediately followed by a snake_case identifier (>= 1
|
||||
# underscore). Anchoring on the opening backtick captures the leading
|
||||
# identifier of a code span whether it is bare (`read_doc`) or a call
|
||||
# (`read_doc(path)`), while skipping DSL constructs like `wf.edge` or
|
||||
# `new Workflow` whose first char after the backtick isn't `[a-z_]`.
|
||||
_BACKTICKED_SNAKE_RE = re.compile(r"`([a-z][a-z0-9]*(?:_[a-z0-9]+)+)")
|
||||
|
||||
# Error codes are emitted as the first string arg to `_error_result(...)`.
|
||||
_ERROR_RESULT_LITERAL_RE = re.compile(r'_error_result\(\s*"([a-z_]+)"')
|
||||
# `parse_error` / `validation_error` are picked by a `code_key` ternary
|
||||
# rather than passed as a literal to `_error_result`, so match them too.
|
||||
_CODE_KEY_LITERAL_RE = re.compile(r'"(parse_error|validation_error)"')
|
||||
|
||||
|
||||
def _referenced_tool_names(text: str) -> set[str]:
|
||||
return {
|
||||
token
|
||||
for token in _BACKTICKED_SNAKE_RE.findall(text)
|
||||
if token.split("_", 1)[0] in _TOOL_VERB_PREFIXES
|
||||
}
|
||||
|
||||
|
||||
def _returned_error_codes(module) -> set[str]:
|
||||
source = Path(module.__file__).read_text(encoding="utf-8")
|
||||
return set(_ERROR_RESULT_LITERAL_RE.findall(source)) | set(
|
||||
_CODE_KEY_LITERAL_RE.findall(source)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guide_only_references_registered_tools():
|
||||
registered = {tool.name for tool in await mcp.list_tools()}
|
||||
referenced = _referenced_tool_names(instructions_module.DOGRAH_MCP_INSTRUCTIONS)
|
||||
|
||||
assert referenced, "no tool references extracted — the regex likely broke"
|
||||
unknown = sorted(referenced - registered)
|
||||
assert not unknown, (
|
||||
f"instructions.py references tools that are not registered: {unknown}. "
|
||||
f"Rename/remove the reference or register the tool. "
|
||||
f"Registered tools: {sorted(registered)}."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"tool_name, module",
|
||||
[
|
||||
("save_workflow", save_workflow_module),
|
||||
("create_workflow", create_workflow_module),
|
||||
],
|
||||
)
|
||||
async def test_tool_documents_every_error_code_it_returns(tool_name, module):
|
||||
descriptions = {
|
||||
tool.name: tool.description or "" for tool in await mcp.list_tools()
|
||||
}
|
||||
description = descriptions[tool_name]
|
||||
returned = _returned_error_codes(module)
|
||||
|
||||
assert returned, f"no error codes detected in {tool_name} source — regex broke"
|
||||
undocumented = sorted(code for code in returned if code not in description)
|
||||
assert not undocumented, (
|
||||
f"{tool_name} returns error_code(s) {undocumented} absent from the description "
|
||||
f"shipped via tools/list. Document them in the {tool_name} docstring."
|
||||
)
|
||||
107
api/tests/test_mcp_integration.py
Normal file
107
api/tests/test_mcp_integration.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.tests.support.mcp_mock_server import running_mcp_server
|
||||
|
||||
|
||||
def _mcp_tool(url: str):
|
||||
t = MagicMock()
|
||||
t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8]
|
||||
t.name = "Acme MCP"
|
||||
t.category = ToolCategory.MCP.value
|
||||
t.definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http", "url": url},
|
||||
}
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_opens_and_closes_mcp_sessions(monkeypatch):
|
||||
async with running_mcp_server() as base_url:
|
||||
tool = _mcp_tool(base_url)
|
||||
|
||||
engine = PipecatEngine.__new__(PipecatEngine)
|
||||
node = MagicMock()
|
||||
node.tool_uuids = [tool.tool_uuid]
|
||||
workflow = MagicMock()
|
||||
workflow.nodes = {"n1": node}
|
||||
engine.workflow = workflow
|
||||
engine._mcp_sessions = {}
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_credential_by_uuid", AsyncMock(return_value=None)
|
||||
)
|
||||
engine._get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
await engine._open_mcp_sessions()
|
||||
try:
|
||||
assert tool.tool_uuid in engine._mcp_sessions
|
||||
sess = engine._mcp_sessions[tool.tool_uuid]
|
||||
assert sess.available is True
|
||||
assert len(sess.function_schemas()) == 2
|
||||
finally:
|
||||
await engine._close_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_mcp_sessions_swallows_db_error(monkeypatch):
|
||||
engine = PipecatEngine.__new__(PipecatEngine)
|
||||
node = MagicMock()
|
||||
node.tool_uuids = ["uuid-deadbeef"]
|
||||
workflow = MagicMock()
|
||||
workflow.nodes = {"n1": node}
|
||||
engine.workflow = workflow
|
||||
engine._mcp_sessions = {}
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client,
|
||||
"get_tools_by_uuids",
|
||||
AsyncMock(side_effect=RuntimeError("db down")),
|
||||
)
|
||||
engine._get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
# Must NOT raise
|
||||
await engine._open_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_mcp_sessions_skips_tool_when_credential_fetch_fails(monkeypatch):
|
||||
tool = _mcp_tool("http://127.0.0.1:1/mcp")
|
||||
tool.definition["config"]["credential_uuid"] = "cred-1234"
|
||||
|
||||
engine = PipecatEngine.__new__(PipecatEngine)
|
||||
node = MagicMock()
|
||||
node.tool_uuids = [tool.tool_uuid]
|
||||
workflow = MagicMock()
|
||||
workflow.nodes = {"n1": node}
|
||||
engine.workflow = workflow
|
||||
engine._mcp_sessions = {}
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]))
|
||||
monkeypatch.setattr(
|
||||
db_client,
|
||||
"get_credential_by_uuid",
|
||||
AsyncMock(side_effect=RuntimeError("cred store down")),
|
||||
)
|
||||
engine._get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
# Must NOT raise, and must skip the tool (no futile unauthenticated start)
|
||||
await engine._open_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
112
api/tests/test_mcp_tool_definition.py
Normal file
112
api/tests/test_mcp_tool_definition.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import importlib
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.routes.tool import McpToolConfig as RouteMcpToolConfig
|
||||
from api.routes.tool import McpToolDefinition as RouteMcpToolDefinition
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpDefinitionError,
|
||||
McpToolConfig,
|
||||
McpToolDefinition,
|
||||
namespace_function_name,
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_category_exists():
|
||||
assert ToolCategory.MCP.value == "mcp"
|
||||
assert ToolCategory("mcp") is ToolCategory.MCP
|
||||
|
||||
|
||||
def test_mcp_migration_present_and_chained(monkeypatch):
|
||||
mod = importlib.import_module(
|
||||
"api.alembic.versions.0a1b2c3d4e5f_add_mcp_in_toolcategory"
|
||||
)
|
||||
assert mod.revision == "0a1b2c3d4e5f"
|
||||
assert mod.down_revision == "4c1f1e3e8ef2"
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_sync_enum_values(**kwargs):
|
||||
calls.append(kwargs)
|
||||
|
||||
monkeypatch.setattr(mod.op, "sync_enum_values", fake_sync_enum_values)
|
||||
|
||||
mod.upgrade()
|
||||
mod.downgrade()
|
||||
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["enum_name"] == "tool_category"
|
||||
assert "mcp" in calls[0]["new_values"]
|
||||
assert "mcp" not in calls[1]["new_values"]
|
||||
|
||||
|
||||
def test_route_reuses_shared_mcp_models():
|
||||
assert RouteMcpToolConfig is McpToolConfig
|
||||
assert RouteMcpToolDefinition is McpToolDefinition
|
||||
|
||||
|
||||
def test_validate_mcp_definition_ok():
|
||||
cfg = validate_mcp_definition(
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://acme.example.com/mcp",
|
||||
"credential_uuid": "cred-123",
|
||||
"tools_filter": ["lookup_patient"],
|
||||
"timeout_secs": 30,
|
||||
"sse_read_timeout_secs": 300,
|
||||
},
|
||||
}
|
||||
)
|
||||
assert cfg["url"] == "https://acme.example.com/mcp"
|
||||
assert cfg["transport"] == "streamable_http"
|
||||
assert cfg["tools_filter"] == ["lookup_patient"]
|
||||
assert cfg["timeout_secs"] == 30
|
||||
assert cfg["sse_read_timeout_secs"] == 300
|
||||
assert cfg["credential_uuid"] == "cred-123"
|
||||
|
||||
|
||||
def test_validate_mcp_definition_defaults():
|
||||
cfg = validate_mcp_definition({"type": "mcp", "config": {"url": "https://x/mcp"}})
|
||||
assert cfg["transport"] == "streamable_http"
|
||||
assert cfg["tools_filter"] == []
|
||||
assert cfg["timeout_secs"] == 30
|
||||
assert cfg["sse_read_timeout_secs"] == 300
|
||||
assert cfg["credential_uuid"] is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"definition",
|
||||
[
|
||||
{"type": "mcp", "config": {}},
|
||||
{"type": "mcp", "config": {"url": ""}},
|
||||
{"type": "mcp", "config": {"url": "ftp://x"}},
|
||||
{"type": "mcp"},
|
||||
{"type": "mcp", "config": {"url": "https://x", "transport": "stdio"}},
|
||||
],
|
||||
)
|
||||
def test_validate_mcp_definition_rejects(definition):
|
||||
with pytest.raises(McpDefinitionError):
|
||||
validate_mcp_definition(definition)
|
||||
|
||||
|
||||
def test_validate_mcp_definition_zero_timeout_preserved():
|
||||
cfg = validate_mcp_definition(
|
||||
{"type": "mcp", "config": {"url": "https://x/mcp", "timeout_secs": 0}}
|
||||
)
|
||||
assert cfg["timeout_secs"] == 0
|
||||
|
||||
|
||||
def test_namespace_function_name():
|
||||
assert (
|
||||
namespace_function_name("Acme MCP", "lookup_patient")
|
||||
== "mcp__acme_mcp__lookup_patient"
|
||||
)
|
||||
assert (
|
||||
namespace_function_name("", "ping", fallback="abcd1234")
|
||||
== "mcp__abcd1234__ping"
|
||||
)
|
||||
437
api/tests/test_mcp_tool_route.py
Normal file
437
api/tests/test_mcp_tool_route.py
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
"""Route-level tests for the MCP tool definition schema.
|
||||
|
||||
These tests exercise the Pydantic request models (CreateToolRequest /
|
||||
UpdateToolRequest) to catch schema gaps at the route/request-model layer —
|
||||
the layer where the pre-fix defect lived (HTTP 422 on every MCP tool
|
||||
creation attempt).
|
||||
|
||||
Test coverage:
|
||||
- CreateToolRequest validates a valid MCP definition (was 422 before Part A).
|
||||
- UpdateToolRequest validates a valid MCP definition.
|
||||
- Invalid MCP bodies are rejected (ftp:// url, missing url).
|
||||
- Round-trip: validated definition dict passes through validate_mcp_definition
|
||||
unchanged, proving the request schema and call-time validator agree.
|
||||
- Full HTTP round-trip via the ASGI test client (POST /api/v1/tools/).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.routes.tool import CreateToolRequest, McpToolDefinition, UpdateToolRequest
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
||||
# ── Canonical valid MCP request body ─────────────────────────────────────────
|
||||
|
||||
VALID_MCP_DEFINITION = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://x/mcp",
|
||||
"credential_uuid": None,
|
||||
"tools_filter": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── Part A regression: CreateToolRequest / UpdateToolRequest validation ───────
|
||||
|
||||
|
||||
def test_create_tool_request_accepts_mcp_definition():
|
||||
"""CreateToolRequest must accept an MCP definition (was HTTP 422 before fix)."""
|
||||
req = CreateToolRequest(
|
||||
name="My MCP Tool",
|
||||
description="Integration via MCP",
|
||||
category="mcp",
|
||||
definition=VALID_MCP_DEFINITION,
|
||||
)
|
||||
assert isinstance(req.definition, McpToolDefinition)
|
||||
assert req.definition.type == "mcp"
|
||||
assert req.definition.config.url == "https://x/mcp"
|
||||
assert req.definition.config.transport == "streamable_http"
|
||||
assert req.definition.config.credential_uuid is None
|
||||
assert req.definition.config.tools_filter == []
|
||||
assert req.definition.config.timeout_secs == 30
|
||||
assert req.definition.config.sse_read_timeout_secs == 300
|
||||
|
||||
|
||||
def test_update_tool_request_accepts_mcp_definition():
|
||||
"""UpdateToolRequest must also accept an MCP definition."""
|
||||
req = UpdateToolRequest(
|
||||
name="Updated MCP Tool",
|
||||
definition=VALID_MCP_DEFINITION,
|
||||
)
|
||||
assert isinstance(req.definition, McpToolDefinition)
|
||||
assert req.definition.type == "mcp"
|
||||
assert req.definition.config.url == "https://x/mcp"
|
||||
|
||||
|
||||
def test_create_tool_request_accepts_mcp_with_all_fields():
|
||||
"""All optional MCP config fields are accepted and preserved."""
|
||||
req = CreateToolRequest(
|
||||
name="Full MCP Tool",
|
||||
category="mcp",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://acme.example.com/mcp",
|
||||
"credential_uuid": "cred-abc-123",
|
||||
"tools_filter": ["lookup_patient", "schedule_appointment"],
|
||||
"timeout_secs": 60,
|
||||
"sse_read_timeout_secs": 600,
|
||||
},
|
||||
},
|
||||
)
|
||||
cfg = req.definition.config # type: ignore[union-attr]
|
||||
assert cfg.url == "https://acme.example.com/mcp"
|
||||
assert cfg.credential_uuid == "cred-abc-123"
|
||||
assert cfg.tools_filter == ["lookup_patient", "schedule_appointment"]
|
||||
assert cfg.timeout_secs == 60
|
||||
assert cfg.sse_read_timeout_secs == 600
|
||||
|
||||
|
||||
# ── Invalid bodies are rejected ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"definition",
|
||||
[
|
||||
# ftp:// URL — rejected by McpToolConfig.validate_url
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http", "url": "ftp://x/mcp"},
|
||||
},
|
||||
# Empty url — rejected by McpToolConfig.validate_url
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http", "url": ""},
|
||||
},
|
||||
# Missing url — rejected by McpToolConfig (required field)
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http"},
|
||||
},
|
||||
# Unsupported transport — rejected because Literal["streamable_http"] constraint
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"url": "https://x/mcp", "transport": "stdio"},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_create_tool_request_rejects_invalid_mcp_definition(definition):
|
||||
"""Invalid MCP definitions must raise ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
CreateToolRequest(
|
||||
name="Bad MCP Tool",
|
||||
category="mcp",
|
||||
definition=definition,
|
||||
)
|
||||
|
||||
|
||||
# ── Round-trip compatibility: request schema ↔ validate_mcp_definition ───────
|
||||
|
||||
|
||||
def test_mcp_definition_round_trips_through_validate_mcp_definition():
|
||||
"""The dict produced by CreateToolRequest.definition.model_dump() must be
|
||||
accepted by validate_mcp_definition without raising, and the result must
|
||||
contain the expected fields. This proves the request-layer schema and the
|
||||
call-time validator agree on the stored config shape."""
|
||||
req = CreateToolRequest(
|
||||
name="Round-Trip MCP Tool",
|
||||
category="mcp",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://roundtrip.example.com/mcp",
|
||||
"credential_uuid": "cred-rt-456",
|
||||
"tools_filter": ["ping"],
|
||||
"timeout_secs": 45,
|
||||
"sse_read_timeout_secs": 400,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Simulate what the route does: persist definition as a plain dict
|
||||
persisted = req.definition.model_dump() # type: ignore[union-attr]
|
||||
|
||||
# validate_mcp_definition must accept the persisted shape without raising
|
||||
normalized = validate_mcp_definition(persisted)
|
||||
|
||||
assert normalized["url"] == "https://roundtrip.example.com/mcp"
|
||||
assert normalized["transport"] == "streamable_http"
|
||||
assert normalized["credential_uuid"] == "cred-rt-456"
|
||||
assert normalized["tools_filter"] == ["ping"]
|
||||
assert normalized["timeout_secs"] == 45
|
||||
assert normalized["sse_read_timeout_secs"] == 400
|
||||
|
||||
|
||||
def test_mcp_definition_round_trip_defaults():
|
||||
"""Round-trip with minimal body: defaults fill in correctly and
|
||||
validate_mcp_definition agrees on them."""
|
||||
req = CreateToolRequest(
|
||||
name="Minimal MCP Tool",
|
||||
category="mcp",
|
||||
definition=VALID_MCP_DEFINITION,
|
||||
)
|
||||
|
||||
persisted = req.definition.model_dump() # type: ignore[union-attr]
|
||||
normalized = validate_mcp_definition(persisted)
|
||||
|
||||
assert normalized["transport"] == "streamable_http"
|
||||
assert normalized["tools_filter"] == []
|
||||
assert normalized["timeout_secs"] == 30
|
||||
assert normalized["sse_read_timeout_secs"] == 300
|
||||
assert normalized["credential_uuid"] is None
|
||||
# Part B: auth_header / auth_scheme must NOT be present in the normalized
|
||||
# config dict (they were dead config removed in the fix)
|
||||
assert "auth_header" not in normalized
|
||||
assert "auth_scheme" not in normalized
|
||||
|
||||
|
||||
# ── Full HTTP round-trip via ASGI test client ─────────────────────────────────
|
||||
|
||||
|
||||
async def test_post_tool_mcp_returns_200(test_client_factory, db_session):
|
||||
"""POST /api/v1/tools/ with an MCP definition must return HTTP 200 and
|
||||
persist the definition with type='mcp'. Before Part A this always
|
||||
returned 422."""
|
||||
# Create a user and an organization, then link them so the route's
|
||||
# selected_organization_id check passes.
|
||||
user, _ = await db_session.get_or_create_user_by_provider_id("mcp_route_test_user")
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
"mcp_route_test_org", user.id
|
||||
)
|
||||
await db_session.update_user_selected_organization(user.id, org.id)
|
||||
# Reload the user so selected_organization_id is populated on the object.
|
||||
user = await db_session.get_user_by_id(user.id)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/tools/",
|
||||
json={
|
||||
"name": "HTTP Round-Trip MCP Tool",
|
||||
"description": "Testing the full route",
|
||||
"category": "mcp",
|
||||
"definition": {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://roundtrip.example.com/mcp",
|
||||
"credential_uuid": None,
|
||||
"tools_filter": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, (
|
||||
f"Expected 200, got {response.status_code}: {response.text}"
|
||||
)
|
||||
body = response.json()
|
||||
assert body["definition"]["type"] == "mcp"
|
||||
assert body["definition"]["config"]["url"] == "https://roundtrip.example.com/mcp"
|
||||
assert body["category"] == "mcp"
|
||||
|
||||
|
||||
async def test_post_tool_mcp_invalid_url_returns_422(test_client_factory, db_session):
|
||||
"""POST /api/v1/tools/ with an ftp:// URL must return HTTP 422."""
|
||||
user, _ = await db_session.get_or_create_user_by_provider_id(
|
||||
"mcp_route_test_user_422"
|
||||
)
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
"mcp_route_test_org_422", user.id
|
||||
)
|
||||
await db_session.update_user_selected_organization(user.id, org.id)
|
||||
user = await db_session.get_user_by_id(user.id)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/tools/",
|
||||
json={
|
||||
"name": "Bad MCP Tool",
|
||||
"category": "mcp",
|
||||
"definition": {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "ftp://invalid.example.com/mcp",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ── Task 6: discovered_tools field and _populate_discovered_tools helper ──────
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from api.routes.tool import McpToolConfig, _populate_discovered_tools
|
||||
|
||||
|
||||
def test_mcp_config_accepts_discovered_tools():
|
||||
cfg = McpToolConfig(
|
||||
url="https://x/mcp",
|
||||
discovered_tools=[{"name": "echo", "description": "Echo"}],
|
||||
)
|
||||
assert cfg.discovered_tools == [{"name": "echo", "description": "Echo"}]
|
||||
# Defaults to [] when omitted
|
||||
assert McpToolConfig(url="https://x/mcp").discovered_tools == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_overwrites_cache(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
|
||||
)
|
||||
definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"url": "https://x/mcp",
|
||||
"tools_filter": [],
|
||||
"discovered_tools": [{"name": "stale", "description": "old"}],
|
||||
},
|
||||
}
|
||||
out = await _populate_discovered_tools(definition, organization_id=1)
|
||||
assert out["config"]["discovered_tools"] == [
|
||||
{"name": "echo", "description": "Echo"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_non_mcp_is_noop():
|
||||
definition = {"schema_version": 1, "type": "http_api", "config": {}}
|
||||
out = await _populate_discovered_tools(definition, organization_id=1)
|
||||
assert out == definition # untouched
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_server_down_sets_empty(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(side_effect=RuntimeError("connection refused")),
|
||||
)
|
||||
definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"url": "https://x/mcp", "tools_filter": []},
|
||||
}
|
||||
out = await _populate_discovered_tools(definition, organization_id=1)
|
||||
assert out["config"]["discovered_tools"] == []
|
||||
|
||||
|
||||
# ── Task 7: POST /{tool_uuid}/mcp/refresh ─────────────────────────────────────
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.routes.tool import refresh_mcp_tools
|
||||
|
||||
|
||||
def _fake_user(org_id=1):
|
||||
u = MagicMock()
|
||||
u.selected_organization_id = org_id
|
||||
u.id = 1
|
||||
u.provider_id = "p1"
|
||||
return u
|
||||
|
||||
|
||||
def _mcp_tool_model(org_id=1):
|
||||
t = MagicMock()
|
||||
t.tool_uuid = "tu-mcp"
|
||||
t.name = "Mock MCP"
|
||||
t.category = "mcp"
|
||||
t.definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"url": "https://x/mcp", "tools_filter": []},
|
||||
}
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_success(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client,
|
||||
"update_tool",
|
||||
AsyncMock(return_value=tool),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
|
||||
)
|
||||
resp = await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert resp.discovered_tools == [{"name": "echo", "description": "Echo"}]
|
||||
assert resp.error is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_server_down_returns_200_with_error(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
monkeypatch.setattr(tool_mod.db_client, "update_tool", AsyncMock(return_value=tool))
|
||||
monkeypatch.setattr(tool_mod, "discover_mcp_tools", AsyncMock(return_value=[]))
|
||||
resp = await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert resp.discovered_tools == []
|
||||
assert resp.error # non-empty human-readable message
|
||||
# update_tool should NOT be called when discovery returns empty
|
||||
tool_mod.db_client.update_tool.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_non_mcp_is_400(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
tool.category = "http_api"
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert ei.value.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_not_found_is_404(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=None)
|
||||
)
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
await refresh_mcp_tools("nope", user=_fake_user())
|
||||
assert ei.value.status_code == 404
|
||||
274
api/tests/test_mcp_tool_session.py
Normal file
274
api/tests/test_mcp_tool_session.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.mcp_tool_session import (
|
||||
McpToolSession,
|
||||
build_streamable_http_params,
|
||||
discover_mcp_tools,
|
||||
)
|
||||
from api.tests.support.mcp_mock_server import running_mcp_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_server_starts_and_serves():
|
||||
async with running_mcp_server() as base_url:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(base_url, timeout=5.0)
|
||||
assert resp.status_code in (400, 404, 405, 406)
|
||||
|
||||
|
||||
def test_build_streamable_http_params_with_credential():
|
||||
cred = MagicMock()
|
||||
cred.credential_type = "bearer_token"
|
||||
cred.credential_data = {"token": "abc"}
|
||||
params = build_streamable_http_params(
|
||||
url="https://acme.example.com/mcp",
|
||||
credential=cred,
|
||||
timeout_secs=30,
|
||||
sse_read_timeout_secs=300,
|
||||
)
|
||||
assert params.url == "https://acme.example.com/mcp"
|
||||
assert params.headers == {"Authorization": "Bearer abc"}
|
||||
assert params.timeout == timedelta(seconds=30)
|
||||
assert params.sse_read_timeout == timedelta(seconds=300)
|
||||
|
||||
|
||||
def test_build_streamable_http_params_no_credential():
|
||||
params = build_streamable_http_params(
|
||||
url="https://acme.example.com/mcp",
|
||||
credential=None,
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
assert params.headers is None or params.headers == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_start_passes_auth_header_to_real_server():
|
||||
cred = MagicMock()
|
||||
cred.credential_type = "bearer_token"
|
||||
cred.credential_data = {"token": "abc"}
|
||||
|
||||
async with running_mcp_server(
|
||||
required_headers={"Authorization": "Bearer abc"}
|
||||
) as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-auth-ok",
|
||||
tool_name="Secure MCP",
|
||||
url=base_url,
|
||||
credential=cred,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
assert session.available is True
|
||||
names = sorted(s.name for s in session.function_schemas())
|
||||
assert names == ["mcp__secure_mcp__add", "mcp__secure_mcp__echo"]
|
||||
result = await session.call("mcp__secure_mcp__echo", {"text": "hi"})
|
||||
assert "echo:hi" in result
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_auth_failure_degrades_not_raises():
|
||||
async with running_mcp_server(
|
||||
required_headers={"Authorization": "Bearer abc"}
|
||||
) as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-auth-fail",
|
||||
tool_name="Secure MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=2,
|
||||
sse_read_timeout_secs=2,
|
||||
)
|
||||
await session.start() # must degrade instead of raising on 401
|
||||
try:
|
||||
assert session.available is False
|
||||
assert session.function_schemas() == []
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_start_lists_and_calls_real_server():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
assert session.available is True
|
||||
schemas = session.function_schemas()
|
||||
names = sorted(s.name for s in schemas)
|
||||
assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"]
|
||||
result = await session.call("mcp__acme_mcp__echo", {"text": "hi"})
|
||||
assert "echo:hi" in result
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_tools_filter_applied():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=["echo"],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
names = sorted(s.name for s in session.function_schemas())
|
||||
assert names == ["mcp__acme_mcp__echo"]
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_unreachable_degrades_not_raises():
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=2,
|
||||
sse_read_timeout_secs=2,
|
||||
)
|
||||
await session.start() # must NOT raise
|
||||
assert session.available is False
|
||||
assert session.function_schemas() == []
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_on_unavailable_session_raises():
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=2,
|
||||
sse_read_timeout_secs=2,
|
||||
)
|
||||
await session.start()
|
||||
with pytest.raises(RuntimeError):
|
||||
await session.call("mcp__acme_mcp__echo", {"text": "x"})
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_unknown_function_raises():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
with pytest.raises(RuntimeError):
|
||||
await session.call("mcp__acme_mcp__does_not_exist", {})
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_schemas_filter_by_raw_name():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="t-filter",
|
||||
tool_name="Mock MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
# No arg = all (backward compatible)
|
||||
all_names = sorted(s.name for s in session.function_schemas())
|
||||
assert all_names == ["mcp__mock_mcp__add", "mcp__mock_mcp__echo"]
|
||||
|
||||
# Allow only raw "echo"
|
||||
only_echo = session.function_schemas(allowed_raw_names={"echo"})
|
||||
assert [s.name for s in only_echo] == ["mcp__mock_mcp__echo"]
|
||||
|
||||
# Empty set = none (default-none semantics)
|
||||
assert session.function_schemas(allowed_raw_names=set()) == []
|
||||
|
||||
# Unknown raw name = skipped (pure intersection)
|
||||
assert session.function_schemas(allowed_raw_names={"nope"}) == []
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_mcp_tools_success():
|
||||
async with running_mcp_server() as base_url:
|
||||
tools = await discover_mcp_tools(
|
||||
url=base_url,
|
||||
credential=None,
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
names = sorted(t["name"] for t in tools)
|
||||
assert names == ["add", "echo"]
|
||||
by_name = {t["name"]: t for t in tools}
|
||||
assert by_name["echo"]["description"] # non-empty description
|
||||
assert set(by_name["echo"]) == {"name", "description"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_mcp_tools_server_down_returns_empty():
|
||||
# Unroutable port, short timeouts: must degrade to [] (never raise).
|
||||
tools = await discover_mcp_tools(
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
timeout_secs=1,
|
||||
sse_read_timeout_secs=1,
|
||||
)
|
||||
assert tools == []
|
||||
|
||||
|
||||
def test_agent_node_data_carries_mcp_tool_filters():
|
||||
from api.services.workflow.dto import AgentNodeData, NodeType
|
||||
from api.services.workflow.workflow_graph import Node
|
||||
|
||||
data = AgentNodeData(
|
||||
name="N1",
|
||||
tool_uuids=["tu-1"],
|
||||
mcp_tool_filters={"tu-1": ["echo"]},
|
||||
)
|
||||
assert data.mcp_tool_filters == {"tu-1": ["echo"]}
|
||||
|
||||
node = Node("n1", NodeType.agentNode, data)
|
||||
assert node.mcp_tool_filters == {"tu-1": ["echo"]}
|
||||
|
||||
# Absent field defaults to None (backward compatible)
|
||||
data2 = AgentNodeData(name="N2")
|
||||
assert data2.mcp_tool_filters is None
|
||||
assert Node("n2", NodeType.agentNode, data2).mcp_tool_filters is None
|
||||
42
api/tests/test_message_sanitization.py
Normal file
42
api/tests/test_message_sanitization.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from pipecat.utils.context.message_sanitization import (
|
||||
strip_thought_from_id,
|
||||
strip_thought_ids_from_messages,
|
||||
)
|
||||
|
||||
|
||||
def test_strip_thought_from_id():
|
||||
assert strip_thought_from_id("call_123__thought__abc") == "call_123"
|
||||
assert strip_thought_from_id("call_123") == "call_123"
|
||||
assert strip_thought_from_id(None) is None
|
||||
|
||||
|
||||
def test_strip_thought_ids_from_messages_does_not_mutate_input():
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1__thought__hidden",
|
||||
"type": "function",
|
||||
"function": {"name": "lookup", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1__thought__hidden",
|
||||
"content": '{"status":"ok"}',
|
||||
},
|
||||
]
|
||||
original = deepcopy(messages)
|
||||
|
||||
cleaned = strip_thought_ids_from_messages(messages)
|
||||
|
||||
assert messages == original
|
||||
assert cleaned is not messages
|
||||
assert cleaned[0]["tool_calls"][0]["id"] == "call_1"
|
||||
assert cleaned[1]["tool_call_id"] == "call_1"
|
||||
|
|
@ -14,7 +14,12 @@ import re
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.dto import NodeType, ReactFlowDTO
|
||||
from api.services.workflow.dto import (
|
||||
ReactFlowDTO,
|
||||
all_node_type_names,
|
||||
get_node_data_model,
|
||||
)
|
||||
from api.services.workflow.node_data import BaseNodeData
|
||||
from api.services.workflow.node_specs import (
|
||||
NodeSpec,
|
||||
PropertySpec,
|
||||
|
|
@ -118,9 +123,9 @@ def test_fixed_collection_has_sub_properties(spec: NodeSpec):
|
|||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_spec_name_matches_dto_discriminator(spec: NodeSpec):
|
||||
valid_names = {t.value for t in NodeType}
|
||||
valid_names = all_node_type_names()
|
||||
assert spec.name in valid_names, (
|
||||
f"NodeSpec {spec.name!r} doesn't match any NodeType discriminator. "
|
||||
f"NodeSpec {spec.name!r} doesn't match any registered node type. "
|
||||
f"Valid: {sorted(valid_names)}"
|
||||
)
|
||||
|
||||
|
|
@ -187,10 +192,226 @@ def test_examples_validate_against_dto(spec: NodeSpec):
|
|||
|
||||
|
||||
def test_all_dto_types_have_specs():
|
||||
"""Every NodeType discriminator value must have a registered NodeSpec —
|
||||
catches the case where someone adds a new node type to dto.py but
|
||||
forgets to author a spec."""
|
||||
"""Every registered node type must have a registered NodeSpec."""
|
||||
spec_names = {s.name for s in all_specs()}
|
||||
type_values = {t.value for t in NodeType}
|
||||
type_values = all_node_type_names()
|
||||
missing = type_values - spec_names
|
||||
assert not missing, f"NodeType discriminators without specs: {sorted(missing)}"
|
||||
assert not missing, f"Registered node types without specs: {sorted(missing)}"
|
||||
|
||||
|
||||
def test_all_registered_node_models_inherit_base_node_data():
|
||||
for type_name in sorted(all_node_type_names()):
|
||||
data_model = get_node_data_model(type_name)
|
||||
assert data_model is not None, f"{type_name}: missing node data model"
|
||||
assert issubclass(data_model, BaseNodeData), (
|
||||
f"{type_name}: node data model must inherit BaseNodeData"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("spec_name", "expected_order"),
|
||||
[
|
||||
(
|
||||
"startCall",
|
||||
[
|
||||
"name",
|
||||
"greeting_type",
|
||||
"greeting",
|
||||
"greeting_recording_id",
|
||||
"prompt",
|
||||
"allow_interrupt",
|
||||
"add_global_prompt",
|
||||
"delayed_start",
|
||||
"delayed_start_duration",
|
||||
"extraction_enabled",
|
||||
"extraction_prompt",
|
||||
"extraction_variables",
|
||||
"tool_uuids",
|
||||
"document_uuids",
|
||||
"pre_call_fetch_enabled",
|
||||
"pre_call_fetch_url",
|
||||
"pre_call_fetch_credential_uuid",
|
||||
],
|
||||
),
|
||||
(
|
||||
"agentNode",
|
||||
[
|
||||
"name",
|
||||
"prompt",
|
||||
"allow_interrupt",
|
||||
"add_global_prompt",
|
||||
"extraction_enabled",
|
||||
"extraction_prompt",
|
||||
"extraction_variables",
|
||||
"tool_uuids",
|
||||
"document_uuids",
|
||||
],
|
||||
),
|
||||
(
|
||||
"endCall",
|
||||
[
|
||||
"name",
|
||||
"prompt",
|
||||
"add_global_prompt",
|
||||
"extraction_enabled",
|
||||
"extraction_prompt",
|
||||
"extraction_variables",
|
||||
],
|
||||
),
|
||||
("globalNode", ["name", "prompt"]),
|
||||
("trigger", ["name", "enabled", "trigger_path"]),
|
||||
(
|
||||
"webhook",
|
||||
[
|
||||
"name",
|
||||
"enabled",
|
||||
"http_method",
|
||||
"endpoint_url",
|
||||
"credential_uuid",
|
||||
"custom_headers",
|
||||
"payload_template",
|
||||
],
|
||||
),
|
||||
(
|
||||
"qa",
|
||||
[
|
||||
"name",
|
||||
"qa_enabled",
|
||||
"qa_system_prompt",
|
||||
"qa_min_call_duration",
|
||||
"qa_voicemail_calls",
|
||||
"qa_sample_rate",
|
||||
"qa_use_workflow_llm",
|
||||
"qa_provider",
|
||||
"qa_model",
|
||||
"qa_api_key",
|
||||
"qa_endpoint",
|
||||
],
|
||||
),
|
||||
(
|
||||
"tuner",
|
||||
[
|
||||
"name",
|
||||
"tuner_enabled",
|
||||
"tuner_agent_id",
|
||||
"tuner_workspace_id",
|
||||
"tuner_api_key",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_node_spec_property_order_stable(spec_name: str, expected_order: list[str]):
|
||||
spec = next(spec for spec in all_specs() if spec.name == spec_name)
|
||||
assert [prop.name for prop in spec.properties] == expected_order
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# `to_mcp_dict` projection — the lean view served by the `get_node_type`
|
||||
# MCP tool. UI-only metadata is dropped so it doesn't poison LLM context;
|
||||
# the full spec stays available to the frontend and SDK via other paths.
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# Keys that are UI-rendering concerns and must never reach the LLM view, at
|
||||
# either the node or property level.
|
||||
_UI_ONLY_KEYS = frozenset(
|
||||
{
|
||||
"display_name",
|
||||
"icon",
|
||||
"category",
|
||||
"version",
|
||||
"placeholder",
|
||||
"display_options",
|
||||
"editor",
|
||||
"extra",
|
||||
"label", # PropertyOption display string
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _walk_dicts(node):
|
||||
"""Yield every dict nested anywhere inside a projected structure."""
|
||||
if isinstance(node, dict):
|
||||
yield node
|
||||
for value in node.values():
|
||||
yield from _walk_dicts(value)
|
||||
elif isinstance(node, list):
|
||||
for item in node:
|
||||
yield from _walk_dicts(item)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_to_mcp_dict_drops_ui_only_keys(spec: NodeSpec):
|
||||
projected = spec.to_mcp_dict()
|
||||
for d in _walk_dicts(projected):
|
||||
leaked = _UI_ONLY_KEYS & d.keys()
|
||||
assert not leaked, f"{spec.name}: UI-only keys leaked into LLM view: {leaked}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_to_mcp_dict_omits_null_and_empty(spec: NodeSpec):
|
||||
"""The lean view never emits null values — absent means unset/optional,
|
||||
which is what halves the noise versus the full model dump."""
|
||||
for d in _walk_dicts(spec.to_mcp_dict()):
|
||||
for key, value in d.items():
|
||||
assert value is not None, f"{spec.name}: {key!r} emitted as null"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_to_mcp_dict_keeps_property_essentials(spec: NodeSpec):
|
||||
"""Every property in the LLM view carries the minimum an LLM needs to
|
||||
author a value: machine name, type, and a description."""
|
||||
|
||||
def _check(props: list[dict]):
|
||||
for prop in props:
|
||||
assert prop.get("name"), f"{spec.name}: property missing name"
|
||||
assert prop.get("type"), f"{spec.name}.{prop.get('name')}: missing type"
|
||||
assert prop.get("description"), (
|
||||
f"{spec.name}.{prop.get('name')}: missing description"
|
||||
)
|
||||
if prop.get("properties"):
|
||||
_check(prop["properties"])
|
||||
|
||||
_check(spec.to_mcp_dict()["properties"])
|
||||
|
||||
|
||||
def test_to_mcp_dict_retains_authoring_signal_startcall():
|
||||
"""startCall is the richest core node — lock in that the projection
|
||||
keeps the fields an LLM actually authors against while shedding the rest."""
|
||||
spec = next(s for s in all_specs() if s.name == "startCall")
|
||||
projected = spec.to_mcp_dict()
|
||||
|
||||
assert set(projected) == {
|
||||
"name",
|
||||
"description",
|
||||
"llm_hint",
|
||||
"properties",
|
||||
"examples",
|
||||
"graph_constraints",
|
||||
}
|
||||
|
||||
props = {p["name"]: p for p in projected["properties"]}
|
||||
|
||||
# Required field keeps `required`; optional fields omit it.
|
||||
assert props["prompt"]["required"] is True
|
||||
assert "required" not in props["greeting"]
|
||||
|
||||
# Enum options project to bare values, dropping the UI label.
|
||||
assert props["greeting_type"]["options"] == [{"value": "text"}, {"value": "audio"}]
|
||||
|
||||
# Validation bounds survive (they constrain valid authored values).
|
||||
assert props["delayed_start_duration"]["min_value"] == 0.1
|
||||
assert props["delayed_start_duration"]["max_value"] == 10.0
|
||||
|
||||
# llm_hint survives where present (catalog-tool references).
|
||||
assert "list_recordings" in props["greeting_recording_id"]["llm_hint"]
|
||||
|
||||
# fixed_collection rows recurse through the same projection.
|
||||
var_rows = {p["name"]: p for p in props["extraction_variables"]["properties"]}
|
||||
assert var_rows["type"]["options"] == [
|
||||
{"value": "string"},
|
||||
{"value": "number"},
|
||||
{"value": "boolean"},
|
||||
]
|
||||
|
||||
# graph_constraints drops its null sub-fields.
|
||||
assert projected["graph_constraints"] == {"min_incoming": 0, "max_incoming": 0}
|
||||
|
|
|
|||
|
|
@ -45,12 +45,11 @@ from api.enums import ToolCategory
|
|||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
EndCallNodeData,
|
||||
EndCallRFNode,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
StartCallNodeData,
|
||||
StartCallRFNode,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
|
|
@ -1014,8 +1013,9 @@ class TestEndCallExtractionBehavior:
|
|||
# Create a workflow where start node has NO extraction
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -1026,8 +1026,9 @@ class TestEndCallExtractionBehavior:
|
|||
extraction_enabled=False, # No extraction
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
|
|||
|
|
@ -34,12 +34,11 @@ from api.services.pipecat.recording_audio_cache import RecordingAudio
|
|||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
EndCallNodeData,
|
||||
EndCallRFNode,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
StartCallNodeData,
|
||||
StartCallRFNode,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
|
|
@ -65,8 +64,9 @@ def text_workflow() -> WorkflowGraph:
|
|||
"""Start->End workflow with text greeting and text transition speech."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -79,8 +79,9 @@ def text_workflow() -> WorkflowGraph:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
@ -114,8 +115,9 @@ def audio_workflow() -> WorkflowGraph:
|
|||
"""Start->End workflow with audio greeting and audio transition speech."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -128,8 +130,9 @@ def audio_workflow() -> WorkflowGraph:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
@ -290,8 +293,9 @@ class TestStartGreeting:
|
|||
"""No greeting configured should return None."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start",
|
||||
|
|
@ -301,8 +305,9 @@ class TestStartGreeting:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End",
|
||||
|
|
@ -333,8 +338,9 @@ class TestStartGreeting:
|
|||
"""Text greeting with {{variable}} placeholders should be rendered."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start",
|
||||
|
|
@ -346,8 +352,9 @@ class TestStartGreeting:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End",
|
||||
|
|
|
|||
|
|
@ -18,6 +18,25 @@ def _qa_node(node_id="qa-1", api_key="", **extra_data):
|
|||
return {"id": node_id, "type": "qa", "position": {"x": 0, "y": 0}, "data": data}
|
||||
|
||||
|
||||
def _tuner_node(node_id="tuner-1", api_key="", **extra_data):
|
||||
"""Helper to build a Tuner node."""
|
||||
data = {
|
||||
"name": "Tuner",
|
||||
"tuner_enabled": True,
|
||||
"tuner_agent_id": "sales-bot",
|
||||
"tuner_workspace_id": 7,
|
||||
**extra_data,
|
||||
}
|
||||
if api_key:
|
||||
data["tuner_api_key"] = api_key
|
||||
return {
|
||||
"id": node_id,
|
||||
"type": "tuner",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": data,
|
||||
}
|
||||
|
||||
|
||||
def _agent_node(node_id="agent-1"):
|
||||
"""Helper to build a non-QA node."""
|
||||
return {
|
||||
|
|
@ -66,6 +85,19 @@ class TestMaskWorkflowDefinition:
|
|||
assert "qa_api_key" not in masked["nodes"][0]["data"]
|
||||
assert masked["nodes"][1]["data"]["qa_api_key"] == mask_key("sk-secret1234")
|
||||
|
||||
def test_masks_tuner_api_key(self):
|
||||
"""Tuner node api_key is masked, showing only last 4 chars."""
|
||||
real_key = "tuner_live_abcdefghijklmnop"
|
||||
wf = _make_workflow_def([_tuner_node(api_key=real_key)])
|
||||
|
||||
masked = mask_workflow_definition(wf)
|
||||
|
||||
masked_key = masked["nodes"][0]["data"]["tuner_api_key"]
|
||||
assert masked_key == mask_key(real_key)
|
||||
assert masked_key.endswith("mnop")
|
||||
assert masked_key.startswith("*")
|
||||
assert real_key not in str(masked)
|
||||
|
||||
def test_qa_node_without_api_key(self):
|
||||
"""QA node with no api_key is left as-is."""
|
||||
wf = _make_workflow_def([_qa_node()])
|
||||
|
|
@ -154,6 +186,16 @@ class TestMergeWorkflowApiKeys:
|
|||
|
||||
assert result["nodes"][0]["data"]["qa_api_key"] == new_key
|
||||
|
||||
def test_masked_tuner_key_is_restored(self):
|
||||
"""Masked Tuner keys round-trip without losing the stored secret."""
|
||||
real_key = "tuner_live_abcdefghijklmnop"
|
||||
existing = _make_workflow_def([_tuner_node(api_key=real_key)])
|
||||
incoming = _make_workflow_def([_tuner_node(api_key=mask_key(real_key))])
|
||||
|
||||
result = merge_workflow_api_keys(incoming, existing)
|
||||
|
||||
assert result["nodes"][0]["data"]["tuner_api_key"] == real_key
|
||||
|
||||
def test_no_incoming_api_key(self):
|
||||
"""QA node without api_key in incoming is left alone."""
|
||||
existing = _make_workflow_def([_qa_node(api_key="sk-existing-key1")])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue