Merge pull request #1443 from CREDO23/feature-automations

[Feat] Automation V1 — Scheduled Agent Tasks, Created via Chat (HITL) or JSON
This commit is contained in:
Rohan Verma 2026-05-28 12:41:41 -07:00 committed by GitHub
commit 4dda02c06c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
219 changed files with 13821 additions and 55 deletions

View file

@ -0,0 +1,179 @@
"""Add automation tables (automations, automation_triggers, automation_runs)
Revision ID: 144
Revises: 143
Create Date: 2026-05-26
Adds the three tables that back the v1 automation engine, plus the
three PostgreSQL ENUM types they reference. Matches the SQLAlchemy
models under ``app.automations.persistence.models`` and the v1 data
model in ``automation-design-plan.md`` §9.
v1 ships these three tables only. ``domain_events`` is deferred to
Phase 3 with the event trigger; ``mcp_connections`` / ``mcp_tools``
are deferred to Phase 4 with the MCP integration.
"""
from collections.abc import Sequence
from alembic import op
revision: str = "144"
down_revision: str | None = "143"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ENUM types (PostgreSQL requires types created before tables that use them)
op.execute(
"""
CREATE TYPE automation_status AS ENUM (
'active', 'paused', 'archived'
);
"""
)
op.execute(
"""
CREATE TYPE automation_trigger_type AS ENUM (
'schedule', 'manual'
);
"""
)
op.execute(
"""
CREATE TYPE automation_run_status AS ENUM (
'pending', 'running', 'succeeded', 'failed',
'cancelled', 'timed_out'
);
"""
)
# automations — the editable, versioned automation definition
op.execute(
"""
CREATE TABLE automations (
id SERIAL PRIMARY KEY,
search_space_id INTEGER NOT NULL
REFERENCES searchspaces(id) ON DELETE CASCADE,
created_by_user_id UUID
REFERENCES "user"(id) ON DELETE SET NULL,
name VARCHAR(200) NOT NULL,
description TEXT,
status automation_status NOT NULL DEFAULT 'active',
definition JSONB NOT NULL,
version INTEGER NOT NULL DEFAULT 1,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
"""
)
op.execute(
"CREATE INDEX ix_automations_search_space_id ON automations(search_space_id);"
)
op.execute(
"CREATE INDEX ix_automations_created_by_user_id ON automations(created_by_user_id);"
)
op.execute("CREATE INDEX ix_automations_status ON automations(status);")
op.execute("CREATE INDEX ix_automations_created_at ON automations(created_at);")
op.execute("CREATE INDEX ix_automations_updated_at ON automations(updated_at);")
# automation_triggers — one row per (automation, trigger-instance) pair
op.execute(
"""
CREATE TABLE automation_triggers (
id SERIAL PRIMARY KEY,
automation_id INTEGER NOT NULL
REFERENCES automations(id) ON DELETE CASCADE,
type automation_trigger_type NOT NULL,
params JSONB NOT NULL,
static_inputs JSONB NOT NULL DEFAULT '{}'::jsonb,
enabled BOOLEAN NOT NULL DEFAULT true,
last_fired_at TIMESTAMP WITH TIME ZONE,
next_fire_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
"""
)
op.execute(
"CREATE INDEX ix_automation_triggers_automation_id ON automation_triggers(automation_id);"
)
op.execute(
"CREATE INDEX ix_automation_triggers_type ON automation_triggers(type);"
)
op.execute(
"CREATE INDEX ix_automation_triggers_enabled ON automation_triggers(enabled);"
)
op.execute(
"CREATE INDEX ix_automation_triggers_created_at ON automation_triggers(created_at);"
)
# Partial index for the schedule tick: only enabled schedule triggers
# with a scheduled next fire are ever scanned for due rows.
op.execute(
"""
CREATE INDEX ix_automation_triggers_due
ON automation_triggers (next_fire_at)
WHERE enabled = true
AND type = 'schedule'
AND next_fire_at IS NOT NULL;
"""
)
# automation_runs — the immutable per-fire execution record
op.execute(
"""
CREATE TABLE automation_runs (
id SERIAL PRIMARY KEY,
automation_id INTEGER NOT NULL
REFERENCES automations(id) ON DELETE CASCADE,
trigger_id INTEGER
REFERENCES automation_triggers(id) ON DELETE SET NULL,
status automation_run_status NOT NULL DEFAULT 'pending',
definition_snapshot JSONB NOT NULL,
inputs JSONB NOT NULL DEFAULT '{}'::jsonb,
step_results JSONB NOT NULL DEFAULT '[]'::jsonb,
output JSONB,
artifacts JSONB NOT NULL DEFAULT '[]'::jsonb,
error JSONB,
started_at TIMESTAMP WITH TIME ZONE,
finished_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
"""
)
op.execute(
"CREATE INDEX ix_automation_runs_automation_id ON automation_runs(automation_id);"
)
op.execute(
"CREATE INDEX ix_automation_runs_trigger_id ON automation_runs(trigger_id);"
)
op.execute("CREATE INDEX ix_automation_runs_status ON automation_runs(status);")
op.execute(
"CREATE INDEX ix_automation_runs_created_at ON automation_runs(created_at);"
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS ix_automation_runs_created_at;")
op.execute("DROP INDEX IF EXISTS ix_automation_runs_status;")
op.execute("DROP INDEX IF EXISTS ix_automation_runs_trigger_id;")
op.execute("DROP INDEX IF EXISTS ix_automation_runs_automation_id;")
op.execute("DROP TABLE IF EXISTS automation_runs;")
op.execute("DROP INDEX IF EXISTS ix_automation_triggers_due;")
op.execute("DROP INDEX IF EXISTS ix_automation_triggers_created_at;")
op.execute("DROP INDEX IF EXISTS ix_automation_triggers_enabled;")
op.execute("DROP INDEX IF EXISTS ix_automation_triggers_type;")
op.execute("DROP INDEX IF EXISTS ix_automation_triggers_automation_id;")
op.execute("DROP TABLE IF EXISTS automation_triggers;")
op.execute("DROP INDEX IF EXISTS ix_automations_updated_at;")
op.execute("DROP INDEX IF EXISTS ix_automations_created_at;")
op.execute("DROP INDEX IF EXISTS ix_automations_status;")
op.execute("DROP INDEX IF EXISTS ix_automations_created_by_user_id;")
op.execute("DROP INDEX IF EXISTS ix_automations_search_space_id;")
op.execute("DROP TABLE IF EXISTS automations;")
op.execute("DROP TYPE IF EXISTS automation_run_status;")
op.execute("DROP TYPE IF EXISTS automation_trigger_type;")
op.execute("DROP TYPE IF EXISTS automation_status;")

View file

@ -0,0 +1,87 @@
"""Add automations permissions to existing Editor/Viewer roles
Revision ID: 145
Revises: 144
Create Date: 2026-05-27
Owners already have ``*`` and need no backfill. Custom (non-system) roles
are left untouched on purpose: workspace admins manage those explicitly.
"""
from collections.abc import Sequence
from sqlalchemy import text
from alembic import op
revision: str = "145"
down_revision: str | None = "144"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
_EDITOR_PERMISSIONS = (
"automations:create",
"automations:read",
"automations:update",
"automations:execute",
)
_VIEWER_PERMISSIONS = ("automations:read",)
def upgrade():
connection = op.get_bind()
for permission in _EDITOR_PERMISSIONS:
connection.execute(
text(
"""
UPDATE search_space_roles
SET permissions = array_append(permissions, :permission)
WHERE name = 'Editor'
AND NOT (:permission = ANY(permissions))
"""
),
{"permission": permission},
)
for permission in _VIEWER_PERMISSIONS:
connection.execute(
text(
"""
UPDATE search_space_roles
SET permissions = array_append(permissions, :permission)
WHERE name = 'Viewer'
AND NOT (:permission = ANY(permissions))
"""
),
{"permission": permission},
)
def downgrade():
connection = op.get_bind()
for permission in _EDITOR_PERMISSIONS:
connection.execute(
text(
"""
UPDATE search_space_roles
SET permissions = array_remove(permissions, :permission)
WHERE name = 'Editor'
"""
),
{"permission": permission},
)
for permission in _VIEWER_PERMISSIONS:
connection.execute(
text(
"""
UPDATE search_space_roles
SET permissions = array_remove(permissions, :permission)
WHERE name = 'Viewer'
"""
),
{"permission": permission},
)

View file

@ -0,0 +1 @@
"""``create_automation`` — description + few-shot examples."""

View file

@ -0,0 +1,34 @@
- `create_automation` — Draft and author a new automation. You describe the
user's intent; a focused drafter inside the tool turns it into the full
automation JSON; the user sees a preview on an approval card and chooses
approve or reject. All three phases happen in a single tool call.
- Call when the user wants SurfSense to do something on its own: anything
recurring or scheduled ("every morning…", "each Monday…", "weekly
recap…").
- Args:
- `intent` (string): restate the user's request **concretely**, in one
paragraph. Cover three things:
- **What** should run (the action: summarize, recap, post, draft, …).
- **When** it should run (schedule + timezone if the user mentioned one;
otherwise leave the timezone for the drafter to default to UTC).
- **Static values** the automation needs (folder ids, channel names,
project keys, parent page ids, …) — list them with their values.
If the user did NOT supply one the automation needs, say so
explicitly ("the Notion parent page id was not specified") so the
drafter leaves a placeholder.
- Do NOT prompt the user to confirm before calling — the approval card
IS the confirmation. The card shows a structured preview plus the raw
JSON; it offers approve/reject only. If the user wants changes after
seeing the draft, they reply in chat and you call this tool again with
a refined `intent` — that's the edit path.
- Returns:
- `{status: "saved", automation_id, name}` — confirm briefly to the
user ("Saved as automation #N — runs <when>."). Don't dump JSON back.
- `{status: "rejected", message}` — the user declined on the card.
Acknowledge once ("Understood, I didn't create it.") and stop. Do
NOT retry or pitch variants without a fresh user request.
- `{status: "invalid", issues, raw?}` — drafting/validation failed
before the card was shown. Read the issues, refine your `intent`
with the missing details, call again.
- `{status: "error", message}` — surface the message verbatim and
offer to retry.

View file

@ -0,0 +1,13 @@
<example>
user: "Every weekday at 9am, summarize new documents in folder 12 and post the summary to Slack channel #daily-digest."
→ create_automation(intent="Every weekday at 09:00 UTC, summarize documents added to folder_id=12 since the last run, then post the summary to Slack channel '#daily-digest'. Static inputs: folder_id=12, slack_channel='#daily-digest'.")
tool returns: {"status": "saved", "automation_id": 42, "name": "Daily folder 12 digest"}
(Reply briefly: "Saved as automation #42 — runs weekdays at 9am UTC.")
</example>
<example>
user: "Once a week on Mondays at 7am Paris time, draft a Notion page recapping last week's Jira tickets in project CORE."
→ create_automation(intent="Every Monday at 07:00 Europe/Paris, read last week's Jira issues in project CORE, then draft a Notion page recapping them. Static inputs: jira_project_key='CORE'. The user did NOT specify which Notion page the recap should sit under — leave notion_parent_page_id as a placeholder.")
tool returns: {"status": "saved", "automation_id": 51, "name": "Weekly CORE Jira recap"}
(Reply: "Saved as automation #51. I left the Notion parent page id as a placeholder — set it on the automation before next Monday.")
</example>

View file

@ -0,0 +1,7 @@
"""``create_automation`` — author + persist an automation via a HITL card."""
from __future__ import annotations
from .create import create_create_automation_tool
__all__ = ["create_create_automation_tool"]

View file

@ -0,0 +1,208 @@
"""``create_automation`` — NL intent → drafted JSON → HITL approval card → persisted.
Single tool that:
1. Drafts a structured automation from the user's intent via a focused sub-LLM
(system prompt in :mod:`.prompt`).
2. Surfaces the validated draft in a HITL approval card
(``action_type="automation_create"``).
3. On approval, validates the (possibly edited) payload again and persists
it via :class:`AutomationService`.
The main agent only restates the user's request as a single ``intent`` string.
The drafting sub-LLM owns the JSON shape; the HITL card is the user's review.
"""
from __future__ import annotations
import json
import logging
import re
from typing import Any
from uuid import UUID
from fastapi import HTTPException
from langchain.tools import ToolRuntime
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from pydantic import ValidationError
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
request_approval,
)
from app.automations.schemas.api import AutomationCreate
from app.automations.services.automation import AutomationService
from app.db import User, async_session_maker
from app.utils.content_utils import extract_text_content
from .prompt import build_draft_prompt
logger = logging.getLogger(__name__)
_JSON_FENCE = re.compile(r"```(?:json)?\s*(.*?)\s*```", re.DOTALL)
def create_create_automation_tool(
*,
search_space_id: int,
user_id: str | UUID,
llm: Any,
):
"""Factory for the ``create_automation`` tool.
``search_space_id`` is injected from the chat session (the model never
has to guess it). ``llm`` is the drafting sub-model we reuse the main
agent's LLM and tag the call so it's identifiable in traces. A fresh
``AsyncSession`` is opened per call to avoid stale sessions on
compiled-agent cache hits (same pattern as the Notion / memory tools).
"""
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool
async def create_automation(intent: str, runtime: ToolRuntime) -> dict[str, Any]:
"""Draft + save an automation from a natural-language intent.
Use this when the user wants SurfSense to do something on its own
on a schedule (e.g. "every morning summarize folder 12 to Slack").
Restate the user's request as ONE concrete ``intent`` string: what
should run, when, and which static values (folder ids, channel
names, ) it needs.
The tool drafts the full automation JSON internally, shows the user
a structured preview on an approval card, and persists on approval.
The card supports approve/reject only if the user wants edits
after seeing the draft, they say so in chat and you call this tool
again with a refined intent. Do NOT prompt the user to confirm
before calling the card IS the confirmation.
Args:
intent: Concrete restatement of the user's request. Include
the schedule (with timezone if mentioned), the action to
take, and any static values. Example: "Every weekday at
09:00 UTC, summarize new docs added to folder_id=12 since
the last run, then post the summary to Slack channel
'#daily-digest'."
Returns:
``{"status": "saved", "automation_id": int, "name": str}`` on
approval + save.
``{"status": "rejected", "message": "..."}`` when the user
declines on the card.
``{"status": "invalid", "issues": [...], "raw": ...}`` when
the drafter produced output that did not validate (call again
with a more precise intent).
``{"status": "error", "message": "..."}`` on drafter or
persistence failure.
IMPORTANT: when status is ``"rejected"`` the user explicitly
declined. Acknowledge once and stop do NOT retry or pitch
variants without a fresh user request.
"""
# --- 1. Draft via sub-LLM ---
prompt = build_draft_prompt(search_space_id=search_space_id, intent=intent)
try:
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "automation-draft"]},
)
except Exception as exc:
logger.exception("create_automation drafting LLM call failed")
return {"status": "error", "message": f"drafting failed: {exc}"}
raw_text = extract_text_content(response.content).strip()
draft = _extract_json(raw_text)
if draft is None:
return {
"status": "invalid",
"issues": ["model output was not parseable JSON"],
"raw": raw_text,
}
# search_space_id is injected here so the sub-LLM never has to guess.
draft["search_space_id"] = search_space_id
try:
validated_draft = AutomationCreate.model_validate(draft)
except ValidationError as exc:
return {
"status": "invalid",
"issues": _format_validation_issues(exc),
"raw": draft,
}
# --- 2. HITL approval card ---
try:
card_params = validated_draft.model_dump(mode="json", by_alias=True)
# search_space_id is session-scoped, not user-editable.
card_params.pop("search_space_id", None)
result = request_approval(
action_type="automation_create",
tool_name="create_automation",
params=card_params,
context={"search_space_id": search_space_id},
tool_call_id=runtime.tool_call_id,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
# --- 3. Persist (re-validate in case the user edited) ---
final_payload = {**result.params, "search_space_id": search_space_id}
try:
final_validated = AutomationCreate.model_validate(final_payload)
except ValidationError as exc:
return {
"status": "invalid",
"issues": _format_validation_issues(exc),
}
async with async_session_maker() as session:
user = await session.get(User, uid)
if user is None:
return {
"status": "error",
"message": "user not found in this session",
}
service = AutomationService(session=session, user=user)
created = await service.create(final_validated)
return {
"status": "saved",
"automation_id": created.id,
"name": created.name,
}
except HTTPException as exc:
return {"status": "error", "message": exc.detail}
except Exception as exc:
from langgraph.errors import GraphInterrupt
if isinstance(exc, GraphInterrupt):
raise
logger.exception("create_automation failed")
return {"status": "error", "message": f"persistence failed: {exc}"}
return create_automation
def _extract_json(text: str) -> dict[str, Any] | None:
"""Pull a JSON object out of the model response, tolerating ``` fences."""
if not text:
return None
candidate = text
fence_match = _JSON_FENCE.search(text)
if fence_match:
candidate = fence_match.group(1)
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _format_validation_issues(exc: ValidationError) -> list[str]:
return [
f"{'.'.join(str(p) for p in err['loc'])}: {err['msg']}" for err in exc.errors()
]

View file

@ -0,0 +1,179 @@
"""System prompt for the drafting sub-LLM inside ``create_automation``.
Converts a natural-language ``intent`` into a structured ``AutomationCreate``
JSON object. That object becomes the payload the HITL approval card surfaces.
Scope split:
Real automation JSONs live here this is the graph that *generates*
the JSON. The main agent's prompt fragments (``description.md`` /
``example.md``) only carry intent-string examples; the main agent
never sees the schema.
Layout:
The prompt is concatenated from four format-safe pieces. ``_HEADER`` /
``_FOOTER`` carry the only ``str.format`` placeholders; ``_SCHEMA`` and
``_FEW_SHOTS`` are plain strings so their JSON literals (and the
``{{ inputs.X }}`` Jinja references in queries) can stay readable
without doubled-brace escaping.
Catalog handling:
v1 hard-codes the action/trigger catalog (one action, one trigger).
When new types ship, swap the inline lines for a render-time pull
from ``app.automations.actions`` / ``app.automations.triggers`` via
lazy imports inside :func:`build_draft_prompt` so this module never
participates in the ``multi_agent_chat`` import cycle.
"""
from __future__ import annotations
from datetime import UTC, datetime
_HEADER = """\
You are the SurfSense automation drafter. Convert the user intent below
into a SINGLE JSON object matching the AutomationCreate schema. Output
ONLY that JSON object no prose, no markdown fence, no commentary.
Current UTC time (for cron context): {now}
Target search_space_id: {search_space_id}
"""
_SCHEMA = """
Required JSON shape:
{
"name": "<1-200 char identifier>",
"description": "<one-liner or null>",
"definition": {
"schema_version": "1.0",
"name": "<same as outer name>",
"goal": "<one sentence>",
"plan": [
{
"step_id": "<slug>",
"action": "agent_task",
"params": {
"query": "<Jinja string referencing {{ inputs.X }}>",
"auto_approve_all": true
}
}
],
"metadata": {"tags": ["..."]}
},
"triggers": [
{
"type": "schedule",
"params": {"cron": "<5-field cron>", "timezone": "<IANA tz, default UTC>"},
"static_inputs": {"<key>": <value>, ...},
"enabled": true
}
]
}
v1 catalog (only these are valid):
- Actions: agent_task params: query (string, Jinja), auto_approve_all (bool).
- Triggers: schedule params: cron (5-field), timezone (IANA, e.g. "UTC",
"Europe/Paris"). Has static_inputs (object).
Conventions:
- Whatever the plan references via {{ inputs.X }} MUST appear either in a
trigger's static_inputs OR in definition.inputs.schema_.properties so the
executor can resolve it at fire time.
- static_inputs carries values that stay the same across every fire
(folder ids, channel names, project keys, parent page ids). Put them on
the trigger that supplies them, not in the plan.
- If the user did NOT supply a value the plan needs, put "REPLACE_ME" in
static_inputs. Do NOT invent ids, channels, or paths.
- Cron is 5-field (minute hour day-of-month month day-of-week). Use the
timezone the user mentioned; default "UTC" when unspecified.
- Templating variables available at fire time: inputs.* (merged
static_inputs + runtime), inputs.fired_at, inputs.last_fired_at.
"""
_FEW_SHOTS = """
Few-shot examples (intent JSON output):
### Example 1 — schedule with all static values supplied
intent: "Every weekday at 09:00 UTC, summarize documents added to folder_id=12 since the last run, then post the summary to Slack channel '#daily-digest'. Static inputs: folder_id=12, slack_channel='#daily-digest'."
output:
{
"name": "Daily folder 12 digest",
"description": "Weekday 09:00 UTC summary of folder 12 documents posted to #daily-digest",
"definition": {
"schema_version": "1.0",
"name": "Daily folder 12 digest",
"goal": "Summarize new docs in folder 12 since the last run and post to #daily-digest",
"plan": [
{
"step_id": "summarize_and_post",
"action": "agent_task",
"params": {
"query": "Summarize documents added to folder {{ inputs.folder_id }} since {{ inputs.last_fired_at or 'yesterday' }}, then send the summary to Slack channel {{ inputs.slack_channel }}.",
"auto_approve_all": true
}
}
],
"metadata": {"tags": ["daily", "digest", "slack"]}
},
"triggers": [
{
"type": "schedule",
"params": {"cron": "0 9 * * 1-5", "timezone": "UTC"},
"static_inputs": {"folder_id": 12, "slack_channel": "#daily-digest"},
"enabled": true
}
]
}
### Example 2 — schedule with a missing value (REPLACE_ME placeholder)
intent: "Every Monday at 07:00 Europe/Paris, read last week's Jira issues in project CORE, then draft a Notion page recapping them. Static inputs: jira_project_key='CORE'. The user did NOT specify the Notion parent page id — leave it as a placeholder."
output:
{
"name": "Weekly CORE Jira recap",
"description": "Monday 07:00 Europe/Paris recap of last week's CORE Jira issues, drafted to Notion",
"definition": {
"schema_version": "1.0",
"name": "Weekly CORE Jira recap",
"goal": "Recap last week's CORE Jira issues into a Notion page",
"plan": [
{
"step_id": "recap",
"action": "agent_task",
"params": {
"query": "List Jira issues in project {{ inputs.jira_project_key }} updated in the 7 days before {{ inputs.fired_at }}. Draft a Notion page under parent id {{ inputs.notion_parent_page_id }} titled 'CORE recap — week of {{ inputs.fired_at }}'.",
"auto_approve_all": true
}
}
],
"metadata": {"tags": ["weekly", "recap", "jira", "notion"]}
},
"triggers": [
{
"type": "schedule",
"params": {"cron": "0 7 * * 1", "timezone": "Europe/Paris"},
"static_inputs": {"jira_project_key": "CORE", "notion_parent_page_id": "REPLACE_ME"},
"enabled": true
}
]
}
"""
_FOOTER = """
User intent:
{intent}
"""
def build_draft_prompt(*, search_space_id: int, intent: str) -> str:
"""Render the drafting sub-LLM system prompt for the given intent."""
return (
_HEADER.format(
now=datetime.now(UTC).isoformat(timespec="seconds"),
search_space_id=search_space_id,
)
+ _SCHEMA
+ _FEW_SHOTS
+ _FOOTER.format(intent=intent.strip())
)

View file

@ -10,6 +10,7 @@ MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED: tuple[str, ...] = (
"web_search", "web_search",
"scrape_webpage", "scrape_webpage",
"update_memory", "update_memory",
"create_automation",
) )
MAIN_AGENT_SURFSENSE_TOOL_NAMES: frozenset[str] = frozenset( MAIN_AGENT_SURFSENSE_TOOL_NAMES: frozenset[str] = frozenset(

View file

@ -49,6 +49,7 @@ def request_approval(
params: dict[str, Any], params: dict[str, Any],
context: dict[str, Any] | None = None, context: dict[str, Any] | None = None,
trusted_tools: list[str] | None = None, trusted_tools: list[str] | None = None,
tool_call_id: str | None = None,
) -> HITLResult: ) -> HITLResult:
"""Pause the graph for user approval and return the user's decision. """Pause the graph for user approval and return the user's decision.
@ -64,6 +65,10 @@ def request_approval(
forwarded verbatim to the FE for richer card chrome. forwarded verbatim to the FE for richer card chrome.
trusted_tools: Per-session allowlist; when ``tool_name`` is in it the trusted_tools: Per-session allowlist; when ``tool_name`` is in it the
interrupt is skipped and the tool runs immediately. interrupt is skipped and the tool runs immediately.
tool_call_id: Caller's LangChain tool-call id. Required for tools
running directly on the main agent; subagent-mounted tools omit
it (the ``task`` chokepoint stamps it on re-raise see
:mod:`...checkpointed_subagent_middleware.propagation`).
Returns: Returns:
:class:`HITLResult` with ``rejected=True`` if the user declined or :class:`HITLResult` with ``rejected=True`` if the user declined or
@ -90,6 +95,8 @@ def request_approval(
interrupt_type=action_type, interrupt_type=action_type,
context=context, context=context,
) )
if tool_call_id:
payload["tool_call_id"] = tool_call_id
approval = interrupt(payload) approval = interrupt(payload)
parsed = parse_lc_envelope(approval) parsed = parse_lc_envelope(approval)

View file

@ -150,6 +150,28 @@ class ToolDefinition:
reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None
# =============================================================================
# Deferred-import factories
# =============================================================================
# Used for tools whose impls live under ``multi_agent_chat``. Importing those
# at module-load time would cycle (``multi_agent_chat`` middleware imports
# this registry). The import inside the factory runs only when
# ``build_tools`` is called, by which point ``multi_agent_chat`` is fully
# initialised.
def _build_create_automation_tool(deps: dict[str, Any]) -> BaseTool:
from app.agents.multi_agent_chat.main_agent.tools.automation import (
create_create_automation_tool,
)
return create_create_automation_tool(
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
llm=deps["llm"],
)
# ============================================================================= # =============================================================================
# Built-in Tools Registry # Built-in Tools Registry
# ============================================================================= # =============================================================================
@ -261,6 +283,21 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
requires=["db_session", "search_space_id", "user_id"], requires=["db_session", "search_space_id", "user_id"],
), ),
# ========================================================================= # =========================================================================
# AUTOMATION AUTHORING - single HITL tool. The tool takes an NL ``intent``
# from the main agent, drafts the full AutomationCreate JSON via a focused
# sub-LLM, surfaces it on an approval card, and persists on approval. The
# factory defers its import because the impl lives under ``multi_agent_chat``
# and that package transitively pulls this registry via middleware;
# deferring to ``build_tools`` call-time breaks the cycle without a
# parallel registry.
# =========================================================================
ToolDefinition(
name="create_automation",
description="Draft an automation from an NL intent; user approves the card; tool saves",
factory=_build_create_automation_tool,
requires=["search_space_id", "user_id", "llm"],
),
# =========================================================================
# MEMORY TOOL - single update_memory, private or team by thread_visibility # MEMORY TOOL - single update_memory, private or team by thread_visibility
# ========================================================================= # =========================================================================
ToolDefinition( ToolDefinition(

View file

@ -0,0 +1,5 @@
"""Automations engine — see automation-design-plan.md."""
from __future__ import annotations
__all__: list[str] = []

View file

@ -0,0 +1,24 @@
"""Actions domain: registry surface + built-in action packages.
Each action lives in its own subpackage (``agent_task/``, ...) and self-registers
at import time via its ``definition`` module. Side-effect imports below ensure
the registry is populated whenever anyone touches the actions package.
"""
from __future__ import annotations
from .store import all_actions, get_action, register_action
from .types import ActionContext, ActionDefinition, ActionHandler, ActionHandlerFactory
__all__ = [
"ActionContext",
"ActionDefinition",
"ActionHandler",
"ActionHandlerFactory",
"all_actions",
"get_action",
"register_action",
]
# Built-in actions self-register at import time.
from . import agent_task # noqa: E402, F401

View file

@ -0,0 +1,15 @@
"""``agent_task`` action: spin up multi_agent_chat for one rendered query.
Imports ``definition`` for its side-effect (self-registration on the actions
registry) and re-exports ``build_handler`` for direct consumers.
"""
from __future__ import annotations
from .factory import build_handler
from .params import AgentTaskActionParams
__all__ = ["AgentTaskActionParams", "build_handler"]
# Side-effect: register on the actions store.
from . import definition # noqa: E402, F401

View file

@ -0,0 +1,39 @@
"""Synthesize HITL decisions for every pending interrupt (approve-all or reject-all)."""
from __future__ import annotations
from typing import Any
def build_auto_decisions(
state: Any, decision: str
) -> tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]]]:
"""Return ``(lg_resume_map, surfsense_resume_value)`` covering every pending interrupt.
``lg_resume_map`` is keyed by ``Interrupt.id`` for ``Command(resume=...)``;
``surfsense_resume_value`` is keyed by ``tool_call_id`` for the subagent
middleware bridge. Action count is read from ``value.action_requests`` when
present and falls back to ``1`` for wrapped scalar interrupts.
"""
lg_resume_map: dict[str, dict[str, Any]] = {}
routed: dict[str, dict[str, Any]] = {}
for interrupt_obj in getattr(state, "interrupts", ()) or ():
value = getattr(interrupt_obj, "value", None)
if not isinstance(value, dict):
continue
interrupt_id = getattr(interrupt_obj, "id", None)
if not isinstance(interrupt_id, str):
continue
action_requests = value.get("action_requests")
count = len(action_requests) if isinstance(action_requests, list) else 1
decisions = [{"type": decision} for _ in range(count)]
lg_resume_map[interrupt_id] = {"decisions": decisions}
tool_call_id = value.get("tool_call_id")
if isinstance(tool_call_id, str):
routed[tool_call_id] = {"decisions": decisions}
return lg_resume_map, routed

View file

@ -0,0 +1,18 @@
"""``agent_task`` ``ActionDefinition`` registration."""
from __future__ import annotations
from ..store import register_action
from ..types import ActionDefinition
from .factory import build_handler
from .params import AgentTaskActionParams
AGENT_TASK_ACTION = ActionDefinition(
type="agent_task",
name="Agent task",
description="Run a multi_agent_chat turn from an automation step.",
params_model=AgentTaskActionParams,
build_handler=build_handler,
)
register_action(AGENT_TASK_ACTION)

View file

@ -0,0 +1,75 @@
"""Build the per-invocation dependencies the multi_agent_chat factory needs."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from langgraph.checkpoint.memory import InMemorySaver
from sqlalchemy.ext.asyncio import AsyncSession
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
from app.tasks.chat.streaming.flows.shared.pre_stream_setup import (
setup_connector_and_firecrawl,
)
class DependencyError(Exception):
"""An external dependency (LLM config, connector service, ...) refused to load."""
@dataclass(frozen=True, slots=True)
class AgentDependencies:
"""Everything ``create_multi_agent_chat_deep_agent`` needs from the environment."""
llm: Any
agent_config: Any
connector_service: Any
firecrawl_api_key: str | None
checkpointer: Any
async def build_dependencies(
*,
session: AsyncSession,
search_space_id: int,
) -> AgentDependencies:
"""Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer.
Uses the search space's default LLM config (``config_id=-1``). Per-step
model overrides land in a future iteration alongside the ``model`` param.
"""
llm, agent_config, err = await load_llm_bundle(
session, config_id=-1, search_space_id=search_space_id
)
if err is not None or llm is None:
raise DependencyError(err or "failed to load default LLM config")
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
session, search_space_id=search_space_id
)
# Quick fix: use an in-memory checkpointer for automation runs.
#
# The shared Postgres checkpointer caches DB connections in a
# module-level pool. Each cached connection is bound to the asyncio
# loop that opened it. Celery throws away the loop after every task,
# so the pool ends up full of connections pointing to a dead loop,
# and the next Celery task (running on a fresh loop) can't use any
# of them — it hangs 30s and fails with
# `PoolTimeout: couldn't get a connection after 30.00 sec`.
#
# InMemorySaver has no cached connections, no loop binding — each
# Celery task creates one and drops it on exit.
#
# TODO(checkpointer): proper fix is to dispose the checkpointer
# pool around each Celery task in `run_async_celery_task`, the same
# way `_dispose_shared_db_engine` already does for the SQLAlchemy
# pool. Then this site can switch back to the shared checkpointer.
checkpointer = InMemorySaver()
return AgentDependencies(
llm=llm,
agent_config=agent_config,
connector_service=connector_service,
firecrawl_api_key=firecrawl_api_key,
checkpointer=checkpointer,
)

View file

@ -0,0 +1,23 @@
"""Bind ``ActionContext`` to a callable that runs one ``agent_task`` step."""
from __future__ import annotations
from typing import Any
from ..types import ActionContext, ActionHandler
from .invoke import run_agent_task
from .params import AgentTaskActionParams
def build_handler(ctx: ActionContext) -> ActionHandler:
"""Return a handler closure that validates params and runs the agent task."""
async def handle(params: dict[str, Any]) -> dict[str, Any]:
validated = AgentTaskActionParams.model_validate(params)
return await run_agent_task(
ctx=ctx,
query=validated.query,
auto_approve_all=validated.auto_approve_all,
)
return handle

View file

@ -0,0 +1,44 @@
"""Extract the agent's final assistant text from the terminal invoke result."""
from __future__ import annotations
from typing import Any
from langchain_core.messages import AIMessage
def extract_final_assistant_message(result: Any) -> str | None:
"""Return the last ``AIMessage`` text content, or ``None`` if there isn't one.
Multi-part messages (content lists) are flattened by concatenating ``text``
parts in order. Non-string content (tool calls, images) is skipped.
"""
if not isinstance(result, dict):
return None
messages = result.get("messages")
if not isinstance(messages, list):
return None
for msg in reversed(messages):
if not isinstance(msg, AIMessage):
continue
return _content_to_text(msg.content)
return None
def _content_to_text(content: Any) -> str | None:
if isinstance(content, str):
text = content.strip()
return text or None
if isinstance(content, list):
parts: list[str] = []
for part in content:
if isinstance(part, str):
parts.append(part)
elif isinstance(part, dict) and part.get("type") == "text":
text = part.get("text")
if isinstance(text, str):
parts.append(text)
joined = "".join(parts).strip()
return joined or None
return None

View file

@ -0,0 +1,98 @@
"""Run one ``agent_task`` invocation: ainvoke + auto-decision resume loop."""
from __future__ import annotations
import time
import uuid
from typing import Any
from langchain_core.messages import HumanMessage
from langgraph.types import Command
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
from app.db import ChatVisibility, async_session_maker
from ..types import ActionContext
from .auto_decide import build_auto_decisions
from .dependencies import build_dependencies
from .finalize import extract_final_assistant_message
# Cap on HITL resume iterations. The agent should not need this many turns in one
# step; treat overshoot as a runaway and fail the step.
_MAX_RESUMES = 50
async def run_agent_task(
*,
ctx: ActionContext,
query: str,
auto_approve_all: bool,
) -> dict[str, Any]:
"""Invoke multi_agent_chat for one rendered query and return its outcome.
Opens its own DB session so the executor's bookkeeping session isn't tied
up for the entire invocation. The LangGraph ``thread_id`` (a fresh UUID)
is returned as ``agent_session_id`` for later inspection.
"""
agent_session_id = str(uuid.uuid4())
user_id = str(ctx.creator_user_id) if ctx.creator_user_id else None
decision = "approve" if auto_approve_all else "reject"
async with async_session_maker() as agent_session:
deps = await build_dependencies(
session=agent_session,
search_space_id=ctx.search_space_id,
)
agent = await create_multi_agent_chat_deep_agent(
llm=deps.llm,
search_space_id=ctx.search_space_id,
db_session=agent_session,
connector_service=deps.connector_service,
checkpointer=deps.checkpointer,
user_id=user_id,
thread_id=None,
agent_config=deps.agent_config,
firecrawl_api_key=deps.firecrawl_api_key,
thread_visibility=ChatVisibility.PRIVATE,
)
request_id = f"automation:{ctx.run_id}:{ctx.step_id}"
turn_id = f"{request_id}:{int(time.time() * 1000)}"
input_state: dict[str, Any] = {
"messages": [HumanMessage(content=query)],
"search_space_id": ctx.search_space_id,
"request_id": request_id,
"turn_id": turn_id,
}
config: dict[str, Any] = {
"configurable": {
"thread_id": agent_session_id,
"request_id": request_id,
"turn_id": turn_id,
},
"recursion_limit": 10_000,
}
result = await agent.ainvoke(input_state, config=config)
resumes = 0
while True:
state = await agent.aget_state(config)
if not getattr(state, "interrupts", None):
break
if resumes >= _MAX_RESUMES:
raise RuntimeError(
f"agent_task exceeded {_MAX_RESUMES} HITL resume iterations"
)
lg_resume_map, routed = build_auto_decisions(state, decision)
config["configurable"]["surfsense_resume_value"] = routed
result = await agent.ainvoke(Command(resume=lg_resume_map), config=config)
resumes += 1
return {
"agent_session_id": agent_session_id,
"final_message": extract_final_assistant_message(result),
"resumes": resumes,
}

View file

@ -0,0 +1,21 @@
"""``AgentTaskActionParams`` — params for the ``agent_task`` action type."""
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, Field
class AgentTaskActionParams(BaseModel):
"""Run a multi_agent_chat turn from an automation step."""
model_config = ConfigDict(extra="forbid")
query: str = Field(
...,
min_length=1,
description="User query for the agent; rendered at execute time.",
)
auto_approve_all: bool = Field(
default=False,
description="If true, every HITL approval is auto-approved; otherwise rejected.",
)

View file

@ -0,0 +1,23 @@
"""In-memory action registry. Populated once at process startup."""
from __future__ import annotations
from .types import ActionDefinition
_REGISTRY: dict[str, ActionDefinition] = {}
def register_action(action: ActionDefinition) -> None:
"""Register an action. Raises on duplicate type."""
if action.type in _REGISTRY:
raise ValueError(f"Action already registered: {action.type!r}")
_REGISTRY[action.type] = action
def get_action(action_type: str) -> ActionDefinition | None:
return _REGISTRY.get(action_type)
def all_actions() -> dict[str, ActionDefinition]:
"""Defensive snapshot of the registry."""
return dict(_REGISTRY)

View file

@ -0,0 +1,40 @@
"""``ActionDefinition``, ``ActionContext``, and handler/factory signatures."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
@dataclass(frozen=True, slots=True)
class ActionContext:
"""Per-invocation dependencies bound to an action handler at execute time."""
session: AsyncSession
run_id: int
step_id: str
search_space_id: int
creator_user_id: UUID | None
ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]]
ActionHandlerFactory = Callable[[ActionContext], ActionHandler]
@dataclass(frozen=True, slots=True)
class ActionDefinition:
type: str
name: str
description: str
params_model: type[BaseModel]
build_handler: ActionHandlerFactory
@property
def params_schema(self) -> dict[str, Any]:
"""JSON Schema (draft 2020-12) derived from ``params_model``."""
return self.params_model.model_json_schema()

View file

@ -0,0 +1,16 @@
"""HTTP layer for the automations feature."""
from __future__ import annotations
from fastapi import APIRouter
from .automation import router as automation_router
from .run import router as run_router
from .trigger import router as trigger_router
router = APIRouter()
router.include_router(automation_router)
router.include_router(trigger_router)
router.include_router(run_router)
__all__ = ["router"]

View file

@ -0,0 +1,80 @@
"""HTTP routes for the ``Automation`` resource."""
from __future__ import annotations
from fastapi import APIRouter, Depends, Query, status
from app.automations.schemas.api import (
AutomationCreate,
AutomationDetail,
AutomationList,
AutomationSummary,
AutomationUpdate,
)
from app.automations.services import AutomationService, get_automation_service
router = APIRouter()
@router.post(
"/automations",
response_model=AutomationDetail,
status_code=status.HTTP_201_CREATED,
)
async def create_automation(
payload: AutomationCreate,
service: AutomationService = Depends(get_automation_service),
) -> AutomationDetail:
"""Create an automation, optionally with initial triggers (atomic)."""
automation = await service.create(payload)
return AutomationDetail.model_validate(automation)
@router.get("/automations", response_model=AutomationList)
async def list_automations(
search_space_id: int = Query(...),
limit: int = Query(default=50, ge=1, le=200),
offset: int = Query(default=0, ge=0),
service: AutomationService = Depends(get_automation_service),
) -> AutomationList:
"""List automations in a search space."""
items, total = await service.list(
search_space_id=search_space_id, limit=limit, offset=offset
)
return AutomationList(
items=[AutomationSummary.model_validate(a) for a in items],
total=total,
)
@router.get("/automations/{automation_id}", response_model=AutomationDetail)
async def get_automation(
automation_id: int,
service: AutomationService = Depends(get_automation_service),
) -> AutomationDetail:
"""Get one automation with its definition and triggers."""
automation = await service.get(automation_id)
return AutomationDetail.model_validate(automation)
@router.patch("/automations/{automation_id}", response_model=AutomationDetail)
async def update_automation(
automation_id: int,
patch: AutomationUpdate,
service: AutomationService = Depends(get_automation_service),
) -> AutomationDetail:
"""Partially update an automation. Triggers are managed separately."""
automation = await service.update(automation_id, patch)
return AutomationDetail.model_validate(automation)
@router.delete(
"/automations/{automation_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_automation(
automation_id: int,
service: AutomationService = Depends(get_automation_service),
) -> None:
"""Delete an automation; triggers and runs are removed by FK cascade."""
await service.delete(automation_id)

View file

@ -0,0 +1,44 @@
"""HTTP routes for automation run history."""
from __future__ import annotations
from fastapi import APIRouter, Depends, Query
from app.automations.schemas.api import RunDetail, RunList, RunSummary
from app.automations.services import RunService, get_run_service
router = APIRouter()
@router.get(
"/automations/{automation_id}/runs",
response_model=RunList,
)
async def list_runs(
automation_id: int,
limit: int = Query(default=50, ge=1, le=200),
offset: int = Query(default=0, ge=0),
service: RunService = Depends(get_run_service),
) -> RunList:
"""List run history for an automation, newest first."""
items, total = await service.list(
automation_id=automation_id, limit=limit, offset=offset
)
return RunList(
items=[RunSummary.model_validate(r) for r in items],
total=total,
)
@router.get(
"/automations/{automation_id}/runs/{run_id}",
response_model=RunDetail,
)
async def get_run(
automation_id: int,
run_id: int,
service: RunService = Depends(get_run_service),
) -> RunDetail:
"""Get the full record of a single run, including step results and artifacts."""
run = await service.get(automation_id=automation_id, run_id=run_id)
return RunDetail.model_validate(run)

View file

@ -0,0 +1,55 @@
"""HTTP routes for triggers attached to an automation."""
from __future__ import annotations
from fastapi import APIRouter, Depends, status
from app.automations.schemas.api import TriggerCreate, TriggerDetail, TriggerUpdate
from app.automations.services import TriggerService, get_trigger_service
router = APIRouter()
@router.post(
"/automations/{automation_id}/triggers",
response_model=TriggerDetail,
status_code=status.HTTP_201_CREATED,
)
async def add_trigger(
automation_id: int,
payload: TriggerCreate,
service: TriggerService = Depends(get_trigger_service),
) -> TriggerDetail:
"""Attach a new trigger to an automation."""
trigger = await service.add(automation_id=automation_id, payload=payload)
return TriggerDetail.model_validate(trigger)
@router.patch(
"/automations/{automation_id}/triggers/{trigger_id}",
response_model=TriggerDetail,
)
async def update_trigger(
automation_id: int,
trigger_id: int,
patch: TriggerUpdate,
service: TriggerService = Depends(get_trigger_service),
) -> TriggerDetail:
"""Toggle ``enabled`` or replace ``params``. Trigger type is immutable."""
trigger = await service.update(
automation_id=automation_id, trigger_id=trigger_id, patch=patch
)
return TriggerDetail.model_validate(trigger)
@router.delete(
"/automations/{automation_id}/triggers/{trigger_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
async def remove_trigger(
automation_id: int,
trigger_id: int,
service: TriggerService = Depends(get_trigger_service),
) -> None:
"""Detach a trigger from an automation."""
await service.remove(automation_id=automation_id, trigger_id=trigger_id)

View file

@ -0,0 +1,8 @@
"""Generic dispatch primitives shared across trigger types."""
from __future__ import annotations
from .errors import DispatchError
from .run import dispatch_run
__all__ = ["DispatchError", "dispatch_run"]

View file

@ -0,0 +1,7 @@
"""Dispatch errors raised when a fire request cannot be turned into a run."""
from __future__ import annotations
class DispatchError(Exception):
"""A dispatch could not proceed (missing trigger, invalid inputs, ...)."""

View file

@ -0,0 +1,83 @@
"""Generic run dispatch: validate, snapshot, persist, enqueue. Shared by every trigger."""
from __future__ import annotations
from typing import Any
import jsonschema
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.enums.run_status import RunStatus
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.run import AutomationRun
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.tasks.execute_run import automation_run_execute
from .errors import DispatchError
async def dispatch_run(
*,
session: AsyncSession,
automation: Automation,
trigger: AutomationTrigger,
runtime_inputs: dict[str, Any] | None = None,
) -> AutomationRun:
"""Validate, snapshot the definition, persist an ``AutomationRun``, enqueue execution.
Final inputs = ``trigger.static_inputs`` merged with ``runtime_inputs``,
static winning on key collision. The merged dict is validated against
``automation.definition.inputs.schema_`` and stored on the run.
Callers (trigger-specific adapters) are responsible for resolving
``automation`` and ``trigger`` and for the trigger-side ``ACTIVE`` /
``enabled`` guards. This function only handles what's identical across
every trigger type.
"""
try:
definition = AutomationDefinition.model_validate(automation.definition)
except Exception as exc:
raise DispatchError(f"invalid automation definition: {exc}") from exc
merged_inputs = {**(runtime_inputs or {}), **(trigger.static_inputs or {})}
validated_inputs = _validate_inputs(definition, merged_inputs)
snapshot = definition.model_dump(mode="json", by_alias=True)
run = AutomationRun(
automation_id=automation.id,
trigger_id=trigger.id,
status=RunStatus.PENDING,
definition_snapshot=snapshot,
inputs=validated_inputs,
step_results=[],
artifacts=[],
)
session.add(run)
await session.commit()
await session.refresh(run)
automation_run_execute.apply_async(
args=[run.id],
time_limit=definition.execution.timeout_seconds,
)
return run
def _validate_inputs(
definition: AutomationDefinition, inputs: dict[str, Any]
) -> dict[str, Any]:
"""Validate merged inputs against the optional declared schema.
No declared schema pass through (runtime inputs like ``fired_at`` /
``last_fired_at`` and trigger ``static_inputs`` must still reach the
template context). Returning ``{}`` here strips them and makes Jinja
blow up on any ``{{ inputs.* }}`` reference.
"""
if definition.inputs is None or not definition.inputs.schema_:
return inputs
try:
jsonschema.validate(instance=inputs, schema=definition.inputs.schema_)
except jsonschema.ValidationError as exc:
raise DispatchError(f"inputs: {exc.message}") from exc
return inputs

View file

@ -0,0 +1,15 @@
"""Models and enums for the automation tables."""
from __future__ import annotations
from .enums import AutomationStatus, RunStatus, TriggerType
from .models import Automation, AutomationRun, AutomationTrigger
__all__ = [
"Automation",
"AutomationRun",
"AutomationStatus",
"AutomationTrigger",
"RunStatus",
"TriggerType",
]

View file

@ -0,0 +1,13 @@
"""Enums for the automation tables."""
from __future__ import annotations
from .automation_status import AutomationStatus
from .run_status import RunStatus
from .trigger_type import TriggerType
__all__ = [
"AutomationStatus",
"RunStatus",
"TriggerType",
]

View file

@ -0,0 +1,11 @@
"""Automation lifecycle status."""
from __future__ import annotations
from enum import StrEnum
class AutomationStatus(StrEnum):
ACTIVE = "active" # eligible to fire
PAUSED = "paused" # kept, but triggers don't fire
ARCHIVED = "archived" # read-only history

View file

@ -0,0 +1,14 @@
"""AutomationRun state machine: pending → running → (succeeded|failed|cancelled|timed_out)."""
from __future__ import annotations
from enum import StrEnum
class RunStatus(StrEnum):
PENDING = "pending"
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
CANCELLED = "cancelled"
TIMED_OUT = "timed_out"

View file

@ -0,0 +1,15 @@
"""Trigger-kind discriminator.
v1 only registers ``schedule``. ``manual`` is reserved in the enum (mirrors the
postgres enum) but is intentionally unregistered pending a redesign of the
"Run now" UX.
"""
from __future__ import annotations
from enum import StrEnum
class TriggerType(StrEnum):
SCHEDULE = "schedule"
MANUAL = "manual"

View file

@ -0,0 +1,13 @@
"""Models, one per table."""
from __future__ import annotations
from .automation import Automation
from .run import AutomationRun
from .trigger import AutomationTrigger
__all__ = [
"Automation",
"AutomationRun",
"AutomationTrigger",
]

View file

@ -0,0 +1,81 @@
"""``automations`` table — editable, versioned automation definition."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import (
TIMESTAMP,
Column,
Enum as SQLAlchemyEnum,
ForeignKey,
Integer,
String,
Text,
)
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import relationship
from app.db import BaseModel, TimestampMixin
from ..enums.automation_status import AutomationStatus
class Automation(BaseModel, TimestampMixin):
__tablename__ = "automations"
search_space_id = Column(
Integer,
ForeignKey("searchspaces.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
created_by_user_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
name = Column(String(200), nullable=False)
description = Column(Text, nullable=True)
status = Column(
SQLAlchemyEnum(
AutomationStatus,
name="automation_status",
values_callable=lambda x: [e.value for e in x],
),
nullable=False,
default=AutomationStatus.ACTIVE,
server_default=AutomationStatus.ACTIVE.value,
index=True,
)
definition = Column(JSONB, nullable=False)
version = Column(Integer, nullable=False, default=1, server_default="1")
updated_at = Column(
TIMESTAMP(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
index=True,
)
search_space = relationship("SearchSpace", back_populates="automations")
created_by = relationship("User", back_populates="automations")
triggers = relationship(
"AutomationTrigger",
back_populates="automation",
cascade="all, delete-orphan",
passive_deletes=True,
)
runs = relationship(
"AutomationRun",
back_populates="automation",
cascade="all, delete-orphan",
passive_deletes=True,
)

View file

@ -0,0 +1,66 @@
"""``automation_runs`` table — immutable per-fire execution record."""
from __future__ import annotations
from sqlalchemy import (
TIMESTAMP,
Column,
Enum as SQLAlchemyEnum,
ForeignKey,
Integer,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
from app.db import BaseModel, TimestampMixin
from ..enums.run_status import RunStatus
class AutomationRun(BaseModel, TimestampMixin):
__tablename__ = "automation_runs"
automation_id = Column(
Integer,
ForeignKey("automations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
trigger_id = Column(
Integer,
ForeignKey("automation_triggers.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
status = Column(
SQLAlchemyEnum(
RunStatus,
name="automation_run_status",
values_callable=lambda x: [e.value for e in x],
),
nullable=False,
default=RunStatus.PENDING,
server_default=RunStatus.PENDING.value,
index=True,
)
# locked at fire time so historical runs always show the exact code path
definition_snapshot = Column(JSONB, nullable=False)
# merged & validated inputs the run was dispatched with
# (trigger.static_inputs producer runtime data, static wins on collision)
inputs = Column(JSONB, nullable=False, server_default="{}")
# one entry per executed step; agent_task entries carry their own
# `agent_session_id` inside their entry
step_results = Column(JSONB, nullable=False, server_default="[]")
output = Column(JSONB, nullable=True)
artifacts = Column(JSONB, nullable=False, server_default="[]")
error = Column(JSONB, nullable=True)
started_at = Column(TIMESTAMP(timezone=True), nullable=True)
finished_at = Column(TIMESTAMP(timezone=True), nullable=True)
automation = relationship("Automation", back_populates="runs")
trigger = relationship("AutomationTrigger", back_populates="runs")

View file

@ -0,0 +1,67 @@
"""``automation_triggers`` table — one row per (automation, trigger-instance) pair."""
from __future__ import annotations
from sqlalchemy import (
TIMESTAMP,
Boolean,
Column,
Enum as SQLAlchemyEnum,
ForeignKey,
Integer,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
from app.db import BaseModel, TimestampMixin
from ..enums.trigger_type import TriggerType
class AutomationTrigger(BaseModel, TimestampMixin):
__tablename__ = "automation_triggers"
automation_id = Column(
Integer,
ForeignKey("automations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
type = Column(
SQLAlchemyEnum(
TriggerType,
name="automation_trigger_type",
values_callable=lambda x: [e.value for e in x],
),
nullable=False,
index=True,
)
params = Column(JSONB, nullable=False)
# Per-attachment domain values merged into every dispatched run's inputs.
# Static wins over runtime data on key collision.
static_inputs = Column(JSONB, nullable=False, server_default="{}")
enabled = Column(
Boolean,
nullable=False,
default=True,
server_default="true",
index=True,
)
last_fired_at = Column(TIMESTAMP(timezone=True), nullable=True)
# Precomputed next fire moment in UTC; advanced after each fire by the
# schedule tick. NULL means the trigger has never been scheduled (the
# tick self-heals on first sight).
next_fire_at = Column(TIMESTAMP(timezone=True), nullable=True)
automation = relationship("Automation", back_populates="triggers")
runs = relationship(
"AutomationRun",
back_populates="trigger",
passive_deletes=True,
)

View file

@ -0,0 +1,7 @@
"""Automation run executor: plan walker, step dispatch, retries, persistence."""
from __future__ import annotations
from .executor import execute_run
__all__ = ["execute_run"]

View file

@ -0,0 +1,124 @@
"""Walk an ``AutomationRun``'s snapshot plan to terminal state."""
from __future__ import annotations
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.enums.run_status import RunStatus
from app.automations.persistence.models.run import AutomationRun
from app.automations.actions.types import ActionContext
from app.automations.schemas.definition.envelope import AutomationDefinition
from app.automations.schemas.definition.plan_step import PlanStep
from app.automations.templating import build_run_context
from . import repository
from .step import execute_step
async def execute_run(session: AsyncSession, run_id: int) -> None:
"""Load run ``run_id`` and execute its snapshot plan to a terminal state."""
run = await repository.load_run(session, run_id)
if run is None:
raise ValueError(f"automation_run {run_id} not found")
if run.status != RunStatus.PENDING:
return
try:
definition = AutomationDefinition.model_validate(run.definition_snapshot)
except Exception as exc:
await repository.mark_failed(
session,
run,
{"message": f"definition_snapshot invalid: {exc}", "type": type(exc).__name__},
)
await session.commit()
return
await repository.mark_running(session, run)
await session.commit()
step_outputs: dict[str, Any] = {}
for step in definition.plan:
template_ctx = _build_template_ctx(run, step_outputs)
action_ctx = _build_action_ctx(session, run, step)
result = await execute_step(
step=step,
template_context=template_ctx,
action_context=action_ctx,
default_max_retries=definition.execution.max_retries,
default_retry_backoff=definition.execution.retry_backoff,
default_timeout_seconds=definition.execution.timeout_seconds,
)
await repository.append_step_result(session, run, result)
await session.commit()
if result["status"] == "failed":
await _run_on_failure(session, run, definition)
await repository.mark_failed(session, run, result.get("error"))
await session.commit()
return
if result["status"] == "succeeded":
step_outputs[step.output_as or step.step_id] = result.get("result")
await repository.mark_succeeded(session, run)
await session.commit()
async def _run_on_failure(
session: AsyncSession,
run: AutomationRun,
definition: AutomationDefinition,
) -> None:
"""Run the on_failure steps. Their failures don't recurse into more on_failure."""
if not definition.execution.on_failure:
return
template_ctx = _build_template_ctx(run, step_outputs={})
for step in definition.execution.on_failure:
action_ctx = _build_action_ctx(session, run, step)
result = await execute_step(
step=step,
template_context=template_ctx,
action_context=action_ctx,
default_max_retries=definition.execution.max_retries,
default_retry_backoff=definition.execution.retry_backoff,
default_timeout_seconds=definition.execution.timeout_seconds,
)
await repository.append_step_result(session, run, result)
await session.commit()
def _build_template_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, Any]:
automation = run.automation
trigger = run.trigger
return build_run_context(
run_id=run.id,
automation_id=run.automation_id,
automation_name=automation.name if automation else None,
automation_version=automation.version if automation else None,
search_space_id=automation.search_space_id if automation else None,
creator_id=automation.created_by_user_id if automation else None,
trigger_id=run.trigger_id,
trigger_type=trigger.type.value if trigger else None,
started_at=run.started_at,
attempt=1,
inputs=run.inputs or {},
step_outputs=step_outputs,
)
def _build_action_ctx(
session: AsyncSession, run: AutomationRun, step: PlanStep
) -> ActionContext:
automation = run.automation
return ActionContext(
session=session,
run_id=run.id,
step_id=step.step_id,
search_space_id=automation.search_space_id,
creator_user_id=automation.created_by_user_id,
)

View file

@ -0,0 +1,62 @@
"""Persistence operations on ``AutomationRun``. Pure SQL, no business logic."""
from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.automations.persistence.enums.run_status import RunStatus
from app.automations.persistence.models.run import AutomationRun
async def load_run(session: AsyncSession, run_id: int) -> AutomationRun | None:
"""Load a run with its automation and trigger eagerly loaded."""
stmt = (
select(AutomationRun)
.where(AutomationRun.id == run_id)
.options(
selectinload(AutomationRun.automation),
selectinload(AutomationRun.trigger),
)
)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def mark_running(session: AsyncSession, run: AutomationRun) -> None:
run.status = RunStatus.RUNNING
run.started_at = datetime.now(UTC)
await session.flush()
async def mark_succeeded(session: AsyncSession, run: AutomationRun) -> None:
run.status = RunStatus.SUCCEEDED
run.finished_at = datetime.now(UTC)
await session.flush()
async def mark_failed(
session: AsyncSession,
run: AutomationRun,
error: dict[str, Any] | None,
) -> None:
run.status = RunStatus.FAILED
run.finished_at = datetime.now(UTC)
run.error = error
await session.flush()
async def append_step_result(
session: AsyncSession,
run: AutomationRun,
step_result: dict[str, Any],
) -> None:
"""Append one step result. Reassigns the list so SQLAlchemy detects the change."""
current = list(run.step_results or [])
current.append(step_result)
run.step_results = current
await session.flush()

View file

@ -0,0 +1,36 @@
"""Retry policy enforcement for action handlers."""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
async def with_retries[T](
coro_factory: Callable[[], Awaitable[T]],
*,
max_retries: int,
backoff: str,
timeout: int | None,
) -> tuple[T, int]:
"""Call ``coro_factory`` up to ``1 + max_retries`` times. Return ``(result, attempts)``."""
total = 1 + max(0, max_retries)
for attempt in range(1, total + 1):
try:
coro = coro_factory()
if timeout is not None and timeout > 0:
return await asyncio.wait_for(coro, timeout=timeout), attempt
return await coro, attempt
except Exception:
if attempt >= total:
raise
await asyncio.sleep(_backoff_seconds(backoff, attempt))
raise RuntimeError("with_retries exhausted without raising or returning")
def _backoff_seconds(strategy: str, attempt: int) -> float:
if strategy == "exponential":
return float(2 ** (attempt - 1))
if strategy == "linear":
return float(attempt)
return 0.0

View file

@ -0,0 +1,96 @@
"""Execute one plan step: when-predicate, params render, handler dispatch, retries."""
from __future__ import annotations
from collections.abc import Mapping
from datetime import UTC, datetime
from typing import Any
from app.automations.actions import get_action
from app.automations.actions.types import ActionContext
from app.automations.schemas.definition.plan_step import PlanStep
from app.automations.templating import evaluate_predicate, render_value
from .retries import with_retries
async def execute_step(
*,
step: PlanStep,
template_context: Mapping[str, Any],
action_context: ActionContext,
default_max_retries: int,
default_retry_backoff: str,
default_timeout_seconds: int,
) -> dict[str, Any]:
"""Run one step and return its structured result entry."""
started_at = datetime.now(UTC)
if step.when is not None:
try:
should_run = evaluate_predicate(step.when, template_context)
except Exception as exc:
return _result(step, "failed", started_at, attempts=0, error=_error(exc, "when"))
if not should_run:
return _result(step, "skipped", started_at, attempts=0)
try:
resolved_params = render_value(step.params, template_context)
except Exception as exc:
return _result(step, "failed", started_at, attempts=0, error=_error(exc, "render"))
action = get_action(step.action)
if action is None:
return _result(
step,
"failed",
started_at,
attempts=0,
error={"message": f"action not registered: {step.action}", "type": "ActionNotFound"},
)
handler = action.build_handler(action_context)
max_retries = step.max_retries if step.max_retries is not None else default_max_retries
timeout = step.timeout_seconds or default_timeout_seconds
try:
result, attempts = await with_retries(
lambda: handler(resolved_params),
max_retries=max_retries,
backoff=default_retry_backoff,
timeout=timeout,
)
except Exception as exc:
return _result(step, "failed", started_at, attempts=max_retries + 1, error=_error(exc))
return _result(step, "succeeded", started_at, attempts=attempts, result=result)
def _result(
step: PlanStep,
status: str,
started_at: datetime,
*,
attempts: int,
result: Any = None,
error: dict[str, Any] | None = None,
) -> dict[str, Any]:
entry: dict[str, Any] = {
"step_id": step.step_id,
"action": step.action,
"status": status,
"started_at": started_at.isoformat(),
"finished_at": datetime.now(UTC).isoformat(),
"attempts": attempts,
}
if result is not None:
entry["result"] = result
if error is not None:
entry["error"] = error
return entry
def _error(exc: Exception, phase: str | None = None) -> dict[str, Any]:
msg = f"{phase}: {exc}" if phase else str(exc)
return {"message": msg, "type": type(exc).__name__}

View file

@ -0,0 +1,27 @@
"""Schemas for the automation definition envelope.
Per-action and per-trigger params schemas live with the action/trigger
implementations (``app.automations.actions.<name>.params`` /
``app.automations.triggers.<name>.params``); only the cross-cutting envelope
lives here.
"""
from __future__ import annotations
from .definition import (
AutomationDefinition,
Execution,
Inputs,
Metadata,
PlanStep,
TriggerSpec,
)
__all__ = [
"AutomationDefinition",
"Execution",
"Inputs",
"Metadata",
"PlanStep",
"TriggerSpec",
]

View file

@ -0,0 +1,27 @@
"""Request/response schemas for the automations HTTP layer."""
from __future__ import annotations
from .automation import (
AutomationCreate,
AutomationDetail,
AutomationList,
AutomationSummary,
AutomationUpdate,
)
from .run import RunDetail, RunList, RunSummary
from .trigger import TriggerCreate, TriggerDetail, TriggerUpdate
__all__ = [
"AutomationCreate",
"AutomationDetail",
"AutomationList",
"AutomationSummary",
"AutomationUpdate",
"RunDetail",
"RunList",
"RunSummary",
"TriggerCreate",
"TriggerDetail",
"TriggerUpdate",
]

View file

@ -0,0 +1,64 @@
"""Request/response schemas for the ``Automation`` resource."""
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field
from app.automations.persistence.enums.automation_status import AutomationStatus
from app.automations.schemas.definition import AutomationDefinition
from .trigger import TriggerCreate, TriggerDetail
class AutomationCreate(BaseModel):
"""Create an automation, optionally with initial triggers (atomic)."""
model_config = ConfigDict(extra="forbid")
search_space_id: int
name: str = Field(..., min_length=1, max_length=200)
description: str | None = None
definition: AutomationDefinition
triggers: list[TriggerCreate] = Field(default_factory=list)
class AutomationUpdate(BaseModel):
"""Partial update of an automation. Triggers are managed separately."""
model_config = ConfigDict(extra="forbid")
name: str | None = Field(default=None, min_length=1, max_length=200)
description: str | None = None
status: AutomationStatus | None = None
definition: AutomationDefinition | None = None
class AutomationSummary(BaseModel):
"""Lightweight automation view for list endpoints."""
model_config = ConfigDict(from_attributes=True)
id: int
search_space_id: int
name: str
description: str | None = None
status: AutomationStatus
version: int
created_at: datetime
updated_at: datetime
class AutomationDetail(AutomationSummary):
"""Full automation view including definition and attached triggers."""
definition: AutomationDefinition
triggers: list[TriggerDetail] = Field(default_factory=list)
class AutomationList(BaseModel):
"""Paginated list of automations."""
items: list[AutomationSummary]
total: int

View file

@ -0,0 +1,42 @@
"""Response schemas for run sub-resources."""
from __future__ import annotations
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict
from app.automations.persistence.enums.run_status import RunStatus
class RunSummary(BaseModel):
"""Lightweight run view for list endpoints."""
model_config = ConfigDict(from_attributes=True)
id: int
automation_id: int
trigger_id: int | None = None
status: RunStatus
started_at: datetime | None = None
finished_at: datetime | None = None
created_at: datetime
class RunDetail(RunSummary):
"""Full run view including snapshot, results and artifacts."""
definition_snapshot: dict[str, Any]
inputs: dict[str, Any]
step_results: list[dict[str, Any]]
output: dict[str, Any] | None = None
artifacts: list[dict[str, Any]]
error: dict[str, Any] | None = None
class RunList(BaseModel):
"""Paginated list of runs."""
items: list[RunSummary]
total: int

View file

@ -0,0 +1,46 @@
"""Request/response schemas for trigger sub-resources."""
from __future__ import annotations
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.automations.persistence.enums.trigger_type import TriggerType
class TriggerCreate(BaseModel):
"""Attach a trigger to an automation."""
model_config = ConfigDict(extra="forbid")
type: TriggerType
params: dict[str, Any] = Field(default_factory=dict)
static_inputs: dict[str, Any] = Field(default_factory=dict)
enabled: bool = True
class TriggerUpdate(BaseModel):
"""Partial update of an existing trigger."""
model_config = ConfigDict(extra="forbid")
enabled: bool | None = None
params: dict[str, Any] | None = None
static_inputs: dict[str, Any] | None = None
class TriggerDetail(BaseModel):
"""Trigger as returned to clients."""
model_config = ConfigDict(from_attributes=True)
id: int
type: TriggerType
params: dict[str, Any]
static_inputs: dict[str, Any]
enabled: bool
last_fired_at: datetime | None = None
next_fire_at: datetime | None = None
created_at: datetime

View file

@ -0,0 +1,19 @@
"""Automation definition envelope and its components."""
from __future__ import annotations
from .envelope import AutomationDefinition
from .execution import Execution
from .inputs import Inputs
from .metadata import Metadata
from .plan_step import PlanStep
from .trigger_spec import TriggerSpec
__all__ = [
"AutomationDefinition",
"Execution",
"Inputs",
"Metadata",
"PlanStep",
"TriggerSpec",
]

View file

@ -0,0 +1,26 @@
"""``AutomationDefinition`` — top-level envelope persisted in ``automations.definition``."""
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, Field
from .execution import Execution
from .inputs import Inputs
from .metadata import Metadata
from .plan_step import PlanStep
from .trigger_spec import TriggerSpec
class AutomationDefinition(BaseModel):
"""Top-level shape of an automation."""
model_config = ConfigDict(extra="forbid")
schema_version: str = "1.0"
name: str = Field(..., min_length=1, max_length=200)
goal: str | None = None
inputs: Inputs | None = None
triggers: list[TriggerSpec] = Field(default_factory=list)
plan: list[PlanStep] = Field(..., min_length=1)
execution: Execution = Field(default_factory=Execution)
metadata: Metadata = Field(default_factory=Metadata)

View file

@ -0,0 +1,22 @@
"""``Execution`` — automation-wide execution defaults (overridable per step)."""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
from .plan_step import PlanStep
class Execution(BaseModel):
model_config = ConfigDict(extra="forbid")
timeout_seconds: int = Field(default=600, gt=0, description="Wall-clock cap for the run.")
max_retries: int = Field(default=2, ge=0, description="Per-step retry budget.")
retry_backoff: Literal["exponential", "linear", "none"] = "exponential"
concurrency: Literal["drop_if_running", "queue", "always"] = "drop_if_running"
on_failure: list[PlanStep] = Field(
default_factory=list,
description="Steps run when the main plan fails after retries.",
)

View file

@ -0,0 +1,21 @@
"""``Inputs`` — JSON Schema for inputs an automation accepts at fire time."""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class Inputs(BaseModel):
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
serialize_by_alias=True,
)
schema_: dict[str, Any] = Field(
...,
alias="schema",
description="JSON Schema (draft 2020-12) for accepted inputs.",
)

View file

@ -0,0 +1,11 @@
"""``Metadata`` — free-form metadata on a definition. Extra keys allowed."""
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, Field
class Metadata(BaseModel):
model_config = ConfigDict(extra="allow")
tags: list[str] = Field(default_factory=list)

View file

@ -0,0 +1,28 @@
"""``PlanStep`` — one step in the sequential plan."""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class PlanStep(BaseModel):
model_config = ConfigDict(extra="forbid")
step_id: str = Field(..., min_length=1, description="Unique within the plan.")
action: str = Field(..., min_length=1, description="Action type; resolved via registry.")
when: str | None = Field(
default=None,
description="Optional predicate; step is skipped when falsy.",
)
params: dict[str, Any] = Field(
default_factory=dict,
description="Action-type-specific params; rendered at execute time.",
)
output_as: str | None = Field(
default=None,
description="Bind step output under this name. Defaults to step_id.",
)
max_retries: int | None = Field(default=None, ge=0)
timeout_seconds: int | None = Field(default=None, gt=0)

View file

@ -0,0 +1,17 @@
"""``TriggerSpec`` — one entry in the definition's ``triggers[]`` array."""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class TriggerSpec(BaseModel):
model_config = ConfigDict(extra="forbid")
type: str = Field(..., min_length=1, description="Trigger type; resolved via registry.")
params: dict[str, Any] = Field(
default_factory=dict,
description="Type-specific params; validated against the trigger's schema.",
)

View file

@ -0,0 +1,16 @@
"""Services for the automations HTTP layer (one service per resource)."""
from __future__ import annotations
from .automation import AutomationService, get_automation_service
from .run import RunService, get_run_service
from .trigger import TriggerService, get_trigger_service
__all__ = [
"AutomationService",
"RunService",
"TriggerService",
"get_automation_service",
"get_run_service",
"get_trigger_service",
]

View file

@ -0,0 +1,172 @@
"""``AutomationService`` — orchestration for the ``Automation`` resource."""
from __future__ import annotations
from datetime import UTC, datetime
from fastapi import Depends, HTTPException
from pydantic import ValidationError
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.automations.schemas.api import (
AutomationCreate,
AutomationUpdate,
TriggerCreate,
)
from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.triggers import get_trigger
from app.automations.triggers.schedule import compute_next_fire_at
from app.db import Permission, User, get_async_session
from app.users import current_active_user
from app.utils.rbac import check_permission
class AutomationService:
"""Lifecycle of the ``Automation`` resource."""
def __init__(self, *, session: AsyncSession, user: User) -> None:
self.session = session
self.user = user
async def create(self, payload: AutomationCreate) -> Automation:
"""Create an automation and its initial triggers in one transaction."""
await self._authorize(payload.search_space_id, Permission.AUTOMATIONS_CREATE.value)
automation = Automation(
search_space_id=payload.search_space_id,
created_by_user_id=self.user.id,
name=payload.name,
description=payload.description,
definition=payload.definition.model_dump(mode="json", by_alias=True),
version=1,
)
for spec in payload.triggers:
automation.triggers.append(_build_trigger(spec))
self.session.add(automation)
await self.session.commit()
return await self._get_with_triggers_or_raise(automation.id)
async def list(
self,
*,
search_space_id: int,
limit: int,
offset: int,
) -> tuple[list[Automation], int]:
"""Return a page of automations and the total count."""
await self._authorize(search_space_id, Permission.AUTOMATIONS_READ.value)
base = select(Automation).where(Automation.search_space_id == search_space_id)
total = await self.session.scalar(
select(func.count()).select_from(base.subquery())
)
rows = (
await self.session.execute(
base.order_by(Automation.created_at.desc()).limit(limit).offset(offset)
)
).scalars().all()
return list(rows), int(total or 0)
async def get(self, automation_id: int) -> Automation:
"""Get an automation with its triggers loaded."""
automation = await self._get_with_triggers_or_raise(automation_id)
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_READ.value)
return automation
async def update(self, automation_id: int, patch: AutomationUpdate) -> Automation:
"""Patch fields. Bumps ``version`` when ``definition`` changes."""
automation = await self._get_with_triggers_or_raise(automation_id)
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_UPDATE.value)
data = patch.model_dump(exclude_unset=True)
if "name" in data:
automation.name = data["name"]
if "description" in data:
automation.description = data["description"]
if "status" in data:
automation.status = data["status"]
if "definition" in data:
automation.definition = patch.definition.model_dump(mode="json", by_alias=True)
automation.version += 1
await self.session.commit()
return await self._get_with_triggers_or_raise(automation_id)
async def delete(self, automation_id: int) -> None:
"""Delete an automation; FK cascades remove triggers and runs."""
automation = await self._get_or_raise(automation_id)
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_DELETE.value)
await self.session.delete(automation)
await self.session.commit()
async def _get_or_raise(self, automation_id: int) -> Automation:
automation = await self.session.get(Automation, automation_id)
if automation is None:
raise HTTPException(
status_code=404, detail=f"automation {automation_id} not found"
)
return automation
async def _get_with_triggers_or_raise(self, automation_id: int) -> Automation:
stmt = (
select(Automation)
.where(Automation.id == automation_id)
.options(selectinload(Automation.triggers))
)
automation = (await self.session.execute(stmt)).scalar_one_or_none()
if automation is None:
raise HTTPException(
status_code=404, detail=f"automation {automation_id} not found"
)
return automation
async def _authorize(self, search_space_id: int, permission: str) -> None:
await check_permission(
self.session,
self.user,
search_space_id,
permission,
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
)
def _build_trigger(spec: TriggerCreate) -> AutomationTrigger:
"""Validate trigger params via its registered Pydantic model and build the ORM row."""
definition = get_trigger(spec.type.value)
if definition is None:
raise HTTPException(status_code=422, detail=f"unknown trigger type {spec.type.value!r}")
try:
validated = definition.params_model.model_validate(spec.params)
except ValidationError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
params = validated.model_dump(mode="json")
next_fire_at = None
if spec.type == TriggerType.SCHEDULE and spec.enabled:
next_fire_at = compute_next_fire_at(
params["cron"], params["timezone"], after=datetime.now(UTC)
)
return AutomationTrigger(
type=spec.type,
params=params,
static_inputs=spec.static_inputs,
enabled=spec.enabled,
next_fire_at=next_fire_at,
)
def get_automation_service(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
) -> AutomationService:
return AutomationService(session=session, user=user)

View file

@ -0,0 +1,72 @@
"""``RunService`` — read-only access to automation run history."""
from __future__ import annotations
from fastapi import Depends, HTTPException
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.run import AutomationRun
from app.db import Permission, User, get_async_session
from app.users import current_active_user
from app.utils.rbac import check_permission
class RunService:
"""Read-only access to ``AutomationRun`` history."""
def __init__(self, *, session: AsyncSession, user: User) -> None:
self.session = session
self.user = user
async def list(
self,
*,
automation_id: int,
limit: int,
offset: int,
) -> tuple[list[AutomationRun], int]:
"""Return a page of runs for an automation, newest first."""
await self._authorize(automation_id, Permission.AUTOMATIONS_READ.value)
base = select(AutomationRun).where(AutomationRun.automation_id == automation_id)
total = await self.session.scalar(
select(func.count()).select_from(base.subquery())
)
rows = (
await self.session.execute(
base.order_by(AutomationRun.created_at.desc()).limit(limit).offset(offset)
)
).scalars().all()
return list(rows), int(total or 0)
async def get(self, *, automation_id: int, run_id: int) -> AutomationRun:
await self._authorize(automation_id, Permission.AUTOMATIONS_READ.value)
run = await self.session.get(AutomationRun, run_id)
if run is None or run.automation_id != automation_id:
raise HTTPException(status_code=404, detail=f"run {run_id} not found")
return run
async def _authorize(self, automation_id: int, permission: str) -> Automation:
automation = await self.session.get(Automation, automation_id)
if automation is None:
raise HTTPException(
status_code=404, detail=f"automation {automation_id} not found"
)
await check_permission(
self.session,
self.user,
automation.search_space_id,
permission,
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
)
return automation
def get_run_service(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
) -> RunService:
return RunService(session=session, user=user)

View file

@ -0,0 +1,143 @@
"""``TriggerService`` — lifecycle of triggers attached to an automation."""
from __future__ import annotations
from datetime import UTC, datetime
from fastapi import Depends, HTTPException
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.triggers import get_trigger
from app.automations.triggers.schedule import compute_next_fire_at
from app.db import Permission, User, get_async_session
from app.users import current_active_user
from app.utils.rbac import check_permission
class TriggerService:
"""Lifecycle of the ``AutomationTrigger`` sub-resource."""
def __init__(self, *, session: AsyncSession, user: User) -> None:
self.session = session
self.user = user
async def add(
self, *, automation_id: int, payload: TriggerCreate
) -> AutomationTrigger:
automation = await self._authorize_automation(
automation_id, Permission.AUTOMATIONS_UPDATE.value
)
validated_params = _validate_params(payload.type, payload.params)
trigger = AutomationTrigger(
automation_id=automation.id,
type=payload.type,
params=validated_params,
static_inputs=payload.static_inputs,
enabled=payload.enabled,
next_fire_at=_initial_next_fire(payload.type, validated_params, payload.enabled),
)
self.session.add(trigger)
await self.session.commit()
await self.session.refresh(trigger)
return trigger
async def update(
self,
*,
automation_id: int,
trigger_id: int,
patch: TriggerUpdate,
) -> AutomationTrigger:
await self._authorize_automation(automation_id, Permission.AUTOMATIONS_UPDATE.value)
trigger = await self._get_trigger_or_raise(automation_id, trigger_id)
data = patch.model_dump(exclude_unset=True)
if "params" in data:
trigger.params = _validate_params(trigger.type, data["params"])
if "static_inputs" in data:
trigger.static_inputs = data["static_inputs"]
if "enabled" in data:
trigger.enabled = data["enabled"]
# Recompute next_fire_at when schedule timing changed or the trigger was
# toggled back on.
if trigger.type == TriggerType.SCHEDULE:
trigger.next_fire_at = _initial_next_fire(
trigger.type, trigger.params, trigger.enabled
)
await self.session.commit()
await self.session.refresh(trigger)
return trigger
async def remove(self, *, automation_id: int, trigger_id: int) -> None:
await self._authorize_automation(automation_id, Permission.AUTOMATIONS_UPDATE.value)
trigger = await self._get_trigger_or_raise(automation_id, trigger_id)
await self.session.delete(trigger)
await self.session.commit()
async def _authorize_automation(
self, automation_id: int, permission: str
) -> Automation:
automation = await self.session.get(Automation, automation_id)
if automation is None:
raise HTTPException(
status_code=404, detail=f"automation {automation_id} not found"
)
await check_permission(
self.session,
self.user,
automation.search_space_id,
permission,
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
)
return automation
async def _get_trigger_or_raise(
self, automation_id: int, trigger_id: int
) -> AutomationTrigger:
trigger = await self.session.get(AutomationTrigger, trigger_id)
if trigger is None or trigger.automation_id != automation_id:
raise HTTPException(
status_code=404, detail=f"trigger {trigger_id} not found"
)
return trigger
def _validate_params(trigger_type: TriggerType, raw: dict) -> dict:
definition = get_trigger(trigger_type.value)
if definition is None:
raise HTTPException(
status_code=422, detail=f"unknown trigger type {trigger_type.value!r}"
)
try:
validated = definition.params_model.model_validate(raw)
except ValidationError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
return validated.model_dump(mode="json")
def _initial_next_fire(
trigger_type: TriggerType, params: dict, enabled: bool
) -> datetime | None:
if trigger_type != TriggerType.SCHEDULE or not enabled:
return None
return compute_next_fire_at(
params["cron"], params["timezone"], after=datetime.now(UTC)
)
def get_trigger_service(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
) -> TriggerService:
return TriggerService(session=session, user=user)

View file

@ -0,0 +1,3 @@
"""Celery task wrappers for the automation runtime."""
from __future__ import annotations

View file

@ -0,0 +1,33 @@
"""Celery task that runs one automation. Thin wrapper over ``runtime.executor``."""
from __future__ import annotations
import logging
from app.automations.runtime import execute_run
from app.celery_app import celery_app
from app.tasks.celery_tasks import (
get_celery_session_maker,
run_async_celery_task,
)
logger = logging.getLogger(__name__)
TASK_NAME = "automation_run_execute"
@celery_app.task(name=TASK_NAME, bind=True)
def automation_run_execute(self, run_id: int) -> None: # noqa: ARG001 — Celery bind
"""Execute one ``AutomationRun``. Idempotent: terminal runs no-op."""
return run_async_celery_task(lambda: _impl(run_id))
async def _impl(run_id: int) -> None:
session_maker = get_celery_session_maker()
async with session_maker() as session:
try:
await execute_run(session, run_id)
except Exception:
logger.exception("automation_run %d failed unexpectedly", run_id)
await session.rollback()
raise

View file

@ -0,0 +1,187 @@
"""Celery Beat tick that fires due ``schedule`` triggers.
Runs every minute. Each tick performs two passes:
1. **Self-heal**: enabled schedule triggers with NULL ``next_fire_at`` get
it computed from their ``cron`` + ``timezone`` (e.g. fresh inserts or
rows restored from backup).
2. **Claim & fire**: due rows are locked with ``FOR UPDATE SKIP LOCKED``,
their ``next_fire_at`` is advanced and ``last_fired_at`` is set, and
``dispatch_schedule_run`` is invoked for each. Dispatch errors are
logged; a missed fire stays missed (matches K8s CronJob / Airflow
``catchup=False`` semantics).
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.triggers.schedule import (
InvalidCronError,
compute_next_fire_at,
dispatch_schedule_run,
)
from app.celery_app import celery_app
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
TASK_NAME = "automation_schedule_tick"
# Cap rows touched per tick so a backlog of due triggers can't starve the
# worker; remaining rows fire on the next tick.
_TICK_BATCH = 200
@dataclass(frozen=True, slots=True)
class _Claim:
"""Per-trigger fire context captured before row state is mutated."""
trigger_id: int
scheduled_for: datetime
previous_last_fired_at: datetime | None
@celery_app.task(name=TASK_NAME)
def automation_schedule_tick() -> None:
"""Tick once: self-heal NULL next_fire_at, claim due rows, fire each."""
return run_async_celery_task(_tick)
async def _tick() -> None:
session_maker = get_celery_session_maker()
async with session_maker() as session:
now = datetime.now(UTC)
await _self_heal_null_next_fire(session, now=now)
claims = await _claim_due_triggers(session, now=now)
if not claims:
return
for claim in claims:
await _fire_one(session, claim=claim, fired_at=now)
async def _self_heal_null_next_fire(session: AsyncSession, *, now: datetime) -> None:
"""Backfill ``next_fire_at`` for enabled schedule triggers missing it."""
stmt = (
select(AutomationTrigger)
.where(
AutomationTrigger.type == TriggerType.SCHEDULE,
AutomationTrigger.enabled.is_(True),
AutomationTrigger.next_fire_at.is_(None),
)
.limit(_TICK_BATCH)
)
triggers = (await session.execute(stmt)).scalars().all()
if not triggers:
return
for trigger in triggers:
try:
trigger.next_fire_at = compute_next_fire_at(
trigger.params["cron"],
trigger.params["timezone"],
after=now,
)
except (InvalidCronError, KeyError, TypeError) as exc:
logger.warning(
"automation_trigger %d has invalid schedule params, disabling: %s",
trigger.id,
exc,
)
trigger.enabled = False
await session.commit()
async def _claim_due_triggers(
session: AsyncSession, *, now: datetime
) -> list[_Claim]:
"""Lock and advance due rows; return per-trigger fire context."""
stmt = (
select(AutomationTrigger)
.where(
AutomationTrigger.type == TriggerType.SCHEDULE,
AutomationTrigger.enabled.is_(True),
AutomationTrigger.next_fire_at.isnot(None),
AutomationTrigger.next_fire_at <= now,
)
.order_by(AutomationTrigger.next_fire_at)
.limit(_TICK_BATCH)
.with_for_update(skip_locked=True)
)
triggers = (await session.execute(stmt)).scalars().all()
if not triggers:
return []
claims: list[_Claim] = []
for trigger in triggers:
# Snapshot fire-context BEFORE we advance the row.
scheduled_for = trigger.next_fire_at
previous_last_fired_at = trigger.last_fired_at
try:
trigger.next_fire_at = compute_next_fire_at(
trigger.params["cron"],
trigger.params["timezone"],
after=now,
)
except (InvalidCronError, KeyError, TypeError) as exc:
logger.warning(
"automation_trigger %d has invalid schedule params, disabling: %s",
trigger.id,
exc,
)
trigger.enabled = False
continue
trigger.last_fired_at = now
claims.append(
_Claim(
trigger_id=trigger.id,
scheduled_for=scheduled_for,
previous_last_fired_at=previous_last_fired_at,
)
)
await session.commit()
return claims
async def _fire_one(
session: AsyncSession, *, claim: _Claim, fired_at: datetime
) -> None:
"""Reload the trigger post-commit and dispatch a run for it."""
trigger = await session.get(AutomationTrigger, claim.trigger_id)
if trigger is None:
return
try:
run = await dispatch_schedule_run(
session=session,
trigger=trigger,
fired_at=fired_at,
scheduled_for=claim.scheduled_for,
previous_last_fired_at=claim.previous_last_fired_at,
)
logger.info(
"scheduled fire: trigger=%d automation=%d run=%d",
claim.trigger_id,
trigger.automation_id,
run.id,
)
except Exception:
logger.exception(
"scheduled fire failed for trigger %d (next attempt at next match)",
claim.trigger_id,
)
await session.rollback()

View file

@ -0,0 +1,13 @@
"""Sandboxed template engine for automation definitions."""
from __future__ import annotations
from .context import build_run_context
from .render import evaluate_predicate, render_template, render_value
__all__ = [
"build_run_context",
"evaluate_predicate",
"render_template",
"render_value",
]

View file

@ -0,0 +1,31 @@
"""Filter and test names admitted into the sandboxed environment."""
from __future__ import annotations
ALLOWED_FILTERS: tuple[str, ...] = (
"default",
"first",
"join",
"last",
"length",
"lower",
"replace",
"reverse",
"sort",
"tojson",
"trim",
"truncate",
"upper",
"date",
"slugify",
)
ALLOWED_TESTS: tuple[str, ...] = (
"defined",
"none",
"number",
"string",
"mapping",
"sequence",
"boolean",
)

View file

@ -0,0 +1,41 @@
"""Builder for the ``{run, inputs, steps}`` namespace exposed to every template."""
from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime
from typing import Any
def build_run_context(
*,
run_id: int,
automation_id: int,
automation_name: str | None,
automation_version: int | None,
search_space_id: int | None,
creator_id: Any,
trigger_id: int | None,
trigger_type: str | None,
started_at: datetime | None,
attempt: int,
inputs: Mapping[str, Any],
step_outputs: Mapping[str, Any],
) -> dict[str, Any]:
"""Build the ``{run, inputs, steps}`` namespace exposed to every template."""
return {
"run": {
"id": run_id,
"automation_id": automation_id,
"automation_name": automation_name,
"automation_version": automation_version,
"search_space_id": search_space_id,
"creator_id": creator_id,
"trigger_id": trigger_id,
"trigger_type": trigger_type,
"started_at": started_at,
"attempt": attempt,
},
"inputs": dict(inputs),
"steps": dict(step_outputs),
}

View file

@ -0,0 +1,43 @@
"""SandboxedEnvironment construction with the audited filter/test allowlist."""
from __future__ import annotations
import json
from datetime import datetime
from typing import Any
from jinja2 import StrictUndefined
from jinja2.sandbox import SandboxedEnvironment
from .allowlist import ALLOWED_FILTERS, ALLOWED_TESTS
from .filters import filter_date, filter_slugify
def _finalize(value: Any) -> Any:
"""Stringify common non-string values at output sites."""
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, list | dict):
return json.dumps(value, ensure_ascii=False, default=str)
return value
def _build_env() -> SandboxedEnvironment:
env = SandboxedEnvironment(
autoescape=False,
undefined=StrictUndefined,
finalize=_finalize,
)
env.globals.clear()
env.filters = {k: v for k, v in env.filters.items() if k in ALLOWED_FILTERS}
env.filters["date"] = filter_date
env.filters["slugify"] = filter_slugify
env.tests = {k: v for k, v in env.tests.items() if k in ALLOWED_TESTS}
return env
ENV: SandboxedEnvironment = _build_env()

View file

@ -0,0 +1,29 @@
"""Custom Jinja filters registered into the sandboxed environment."""
from __future__ import annotations
import re
from typing import Any
def filter_date(value: Any, fmt: str = "%Y-%m-%d") -> str:
"""Format a datetime-like value with ``strftime``. Strings pass through."""
if value is None:
return ""
if isinstance(value, str):
return value
if hasattr(value, "strftime"):
return value.strftime(fmt)
raise ValueError(f"date filter requires datetime-like, got {type(value).__name__}")
_SLUG_NONALNUM = re.compile(r"[^a-z0-9]+")
_SLUG_DASHES = re.compile(r"-+")
def filter_slugify(value: Any) -> str:
"""Lowercase, replace non-alphanumerics with hyphens, collapse and trim."""
s = str(value).lower()
s = _SLUG_NONALNUM.sub("-", s)
s = _SLUG_DASHES.sub("-", s)
return s.strip("-")

View file

@ -0,0 +1,29 @@
"""Render templates and evaluate predicates against the sandboxed environment."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from .environment import ENV
def render_template(template: str, context: Mapping[str, Any]) -> str:
"""Render ``template`` with ``context``."""
return ENV.from_string(template).render(**context)
def evaluate_predicate(expression: str, context: Mapping[str, Any]) -> bool:
"""Evaluate a Jinja expression (not a template body) and coerce to bool."""
return bool(ENV.compile_expression(expression)(**context))
def render_value(value: Any, context: Mapping[str, Any]) -> Any:
"""Recursively render every string in a JSON-like value structure."""
if isinstance(value, str):
return render_template(value, context)
if isinstance(value, dict):
return {k: render_value(v, context) for k, v in value.items()}
if isinstance(value, list):
return [render_value(v, context) for v in value]
return value

View file

@ -0,0 +1,20 @@
"""Triggers domain: registry surface + built-in trigger packages.
Each trigger lives in its own subpackage (``schedule/``, ...) and
self-registers at import time via its ``definition`` module.
"""
from __future__ import annotations
from .store import all_triggers, get_trigger, register_trigger
from .types import TriggerDefinition
__all__ = [
"TriggerDefinition",
"all_triggers",
"get_trigger",
"register_trigger",
]
# Built-in triggers self-register at import time.
from . import schedule # noqa: E402, F401

View file

@ -0,0 +1,18 @@
"""``schedule`` trigger: fired on a cron schedule in a given timezone."""
from __future__ import annotations
from .cron import InvalidCronError, compute_next_fire_at, validate_cron
from .dispatch import dispatch_schedule_run
from .params import ScheduleTriggerParams
__all__ = [
"InvalidCronError",
"ScheduleTriggerParams",
"compute_next_fire_at",
"dispatch_schedule_run",
"validate_cron",
]
# Side-effect: register on the triggers store.
from . import definition # noqa: E402, F401

View file

@ -0,0 +1,37 @@
"""Cron math for the ``schedule`` trigger: validate + advance ``next_fire_at``."""
from __future__ import annotations
from datetime import UTC, datetime
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from croniter import CroniterBadCronError, croniter
class InvalidCronError(ValueError):
"""Raised when a cron expression or timezone fails validation."""
def validate_cron(cron: str, timezone: str) -> None:
"""Raise ``InvalidCronError`` if cron or timezone are unusable."""
try:
ZoneInfo(timezone)
except ZoneInfoNotFoundError as exc:
raise InvalidCronError(f"unknown timezone {timezone!r}") from exc
try:
croniter(cron)
except (CroniterBadCronError, ValueError) as exc:
raise InvalidCronError(f"invalid cron {cron!r}: {exc}") from exc
def compute_next_fire_at(cron: str, timezone: str, *, after: datetime) -> datetime:
"""Return the next moment matching ``cron`` in ``timezone`` strictly after ``after``.
The result is normalized to UTC for storage. ``after`` is converted into the
given timezone before evaluation so DST and IANA rules apply correctly.
"""
tz = ZoneInfo(timezone)
base = after.astimezone(tz) if after.tzinfo else after.replace(tzinfo=UTC).astimezone(tz)
nxt: datetime = croniter(cron, base).get_next(datetime)
return nxt.astimezone(UTC)

View file

@ -0,0 +1,15 @@
"""``schedule`` ``TriggerDefinition`` registration."""
from __future__ import annotations
from ..store import register_trigger
from ..types import TriggerDefinition
from .params import ScheduleTriggerParams
SCHEDULE_TRIGGER = TriggerDefinition(
type="schedule",
description="Fire on a cron schedule in a given timezone.",
params_model=ScheduleTriggerParams,
)
register_trigger(SCHEDULE_TRIGGER)

View file

@ -0,0 +1,67 @@
"""Schedule dispatch adapter: load + guard, then call generic dispatch."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.dispatch import DispatchError, dispatch_run
from app.automations.persistence.enums.automation_status import AutomationStatus
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.run import AutomationRun
from app.automations.persistence.models.trigger import AutomationTrigger
async def dispatch_schedule_run(
*,
session: AsyncSession,
trigger: AutomationTrigger,
fired_at: datetime,
scheduled_for: datetime,
previous_last_fired_at: datetime | None,
) -> AutomationRun:
"""Fire one scheduled run for ``trigger``.
Emits calendar context as runtime inputs:
- ``fired_at`` actual fire time
- ``scheduled_for`` cron-derived target time for this fire
- ``last_fired_at`` fire time of the previous run, or null on first fire
The caller (the schedule tick) is responsible for selecting due triggers
and advancing ``next_fire_at`` / ``last_fired_at`` before invoking this.
"""
automation = await _load_automation(session, trigger.automation_id)
if automation is None:
raise DispatchError(
f"automation {trigger.automation_id} not found for trigger {trigger.id}"
)
if automation.status != AutomationStatus.ACTIVE:
raise DispatchError(
f"automation {trigger.automation_id} is {automation.status.value}, not active"
)
runtime_inputs = {
"fired_at": fired_at.isoformat(),
"scheduled_for": scheduled_for.isoformat(),
"last_fired_at": (
previous_last_fired_at.isoformat() if previous_last_fired_at else None
),
}
return await dispatch_run(
session=session,
automation=automation,
trigger=trigger,
runtime_inputs=runtime_inputs,
)
async def _load_automation(
session: AsyncSession, automation_id: int
) -> Automation | None:
stmt = select(Automation).where(Automation.id == automation_id)
return (await session.execute(stmt)).scalar_one_or_none()

View file

@ -0,0 +1,22 @@
"""``ScheduleTriggerParams`` — params for the ``schedule`` trigger type."""
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, Field, model_validator
from .cron import InvalidCronError, validate_cron
class ScheduleTriggerParams(BaseModel):
model_config = ConfigDict(extra="forbid")
cron: str = Field(..., description="Five-field cron expression.", examples=["0 9 * * 1-5"])
timezone: str = Field(..., description="IANA timezone.", examples=["Africa/Kigali"])
@model_validator(mode="after")
def _validate(self) -> ScheduleTriggerParams:
try:
validate_cron(self.cron, self.timezone)
except InvalidCronError as exc:
raise ValueError(str(exc)) from exc
return self

View file

@ -0,0 +1,23 @@
"""In-memory trigger registry. Populated once at process startup."""
from __future__ import annotations
from .types import TriggerDefinition
_REGISTRY: dict[str, TriggerDefinition] = {}
def register_trigger(trigger: TriggerDefinition) -> None:
"""Register a trigger. Raises on duplicate type."""
if trigger.type in _REGISTRY:
raise ValueError(f"Trigger already registered: {trigger.type!r}")
_REGISTRY[trigger.type] = trigger
def get_trigger(trigger_type: str) -> TriggerDefinition | None:
return _REGISTRY.get(trigger_type)
def all_triggers() -> dict[str, TriggerDefinition]:
"""Defensive snapshot of the registry."""
return dict(_REGISTRY)

View file

@ -0,0 +1,20 @@
"""``TriggerDefinition`` dataclass. Declarative; firing is the dispatcher's job."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from pydantic import BaseModel
@dataclass(frozen=True, slots=True)
class TriggerDefinition:
type: str
description: str
params_model: type[BaseModel]
@property
def params_schema(self) -> dict[str, Any]:
"""JSON Schema (draft 2020-12) derived from ``params_model``."""
return self.params_model.model_json_schema()

View file

@ -188,6 +188,8 @@ celery_app = Celery(
"app.tasks.celery_tasks.document_reindex_tasks", "app.tasks.celery_tasks.document_reindex_tasks",
"app.tasks.celery_tasks.stale_notification_cleanup_task", "app.tasks.celery_tasks.stale_notification_cleanup_task",
"app.tasks.celery_tasks.stripe_reconciliation_task", "app.tasks.celery_tasks.stripe_reconciliation_task",
"app.automations.tasks.execute_run",
"app.automations.tasks.schedule_tick",
], ],
) )
@ -282,4 +284,14 @@ celery_app.conf.beat_schedule = {
"expires": 60, "expires": 60,
}, },
}, },
# Fire due automation schedule triggers. Ticks every minute; per-row cron
# math is precomputed (next_fire_at column) so the tick is an indexed
# lookup, not N cron evaluations.
"automation-schedule-tick": {
"task": "automation_schedule_tick",
"schedule": crontab(minute="*"),
"options": {
"expires": 50,
},
},
} }

View file

@ -439,6 +439,13 @@ class Permission(StrEnum):
PUBLIC_SHARING_CREATE = "public_sharing:create" PUBLIC_SHARING_CREATE = "public_sharing:create"
PUBLIC_SHARING_DELETE = "public_sharing:delete" PUBLIC_SHARING_DELETE = "public_sharing:delete"
# Automations
AUTOMATIONS_CREATE = "automations:create"
AUTOMATIONS_READ = "automations:read"
AUTOMATIONS_UPDATE = "automations:update"
AUTOMATIONS_DELETE = "automations:delete"
AUTOMATIONS_EXECUTE = "automations:execute"
# Full access wildcard # Full access wildcard
FULL_ACCESS = "*" FULL_ACCESS = "*"
@ -494,6 +501,11 @@ DEFAULT_ROLE_PERMISSIONS = {
# Public Sharing (can create and view, no delete) # Public Sharing (can create and view, no delete)
Permission.PUBLIC_SHARING_VIEW.value, Permission.PUBLIC_SHARING_VIEW.value,
Permission.PUBLIC_SHARING_CREATE.value, Permission.PUBLIC_SHARING_CREATE.value,
# Automations (no delete)
Permission.AUTOMATIONS_CREATE.value,
Permission.AUTOMATIONS_READ.value,
Permission.AUTOMATIONS_UPDATE.value,
Permission.AUTOMATIONS_EXECUTE.value,
], ],
"Viewer": [ "Viewer": [
# Documents (read only) # Documents (read only)
@ -525,6 +537,8 @@ DEFAULT_ROLE_PERMISSIONS = {
Permission.SETTINGS_VIEW.value, Permission.SETTINGS_VIEW.value,
# Public Sharing (view only) # Public Sharing (view only)
Permission.PUBLIC_SHARING_VIEW.value, Permission.PUBLIC_SHARING_VIEW.value,
# Automations (read only)
Permission.AUTOMATIONS_READ.value,
], ],
} }
@ -1533,6 +1547,14 @@ class SearchSpace(BaseModel, TimestampMixin):
cascade="all, delete-orphan", cascade="all, delete-orphan",
) )
automations = relationship(
"Automation",
back_populates="search_space",
order_by="Automation.id",
cascade="all, delete-orphan",
passive_deletes=True,
)
# RBAC relationships # RBAC relationships
roles = relationship( roles = relationship(
"SearchSpaceRole", "SearchSpaceRole",
@ -2125,6 +2147,13 @@ if config.AUTH_TYPE == "GOOGLE":
passive_deletes=True, passive_deletes=True,
) )
# Automations created by this user
automations = relationship(
"Automation",
back_populates="created_by",
passive_deletes=True,
)
# Incentive tasks completed by this user # Incentive tasks completed by this user
incentive_tasks = relationship( incentive_tasks = relationship(
"UserIncentiveTask", "UserIncentiveTask",
@ -2257,6 +2286,13 @@ else:
passive_deletes=True, passive_deletes=True,
) )
# Automations created by this user
automations = relationship(
"Automation",
back_populates="created_by",
passive_deletes=True,
)
# Incentive tasks completed by this user # Incentive tasks completed by this user
incentive_tasks = relationship( incentive_tasks = relationship(
"UserIncentiveTask", "UserIncentiveTask",
@ -2560,6 +2596,16 @@ class RefreshToken(Base, TimestampMixin):
return not self.is_expired and not self.is_revoked return not self.is_expired and not self.is_revoked
# Register model packages that live outside this file so their classes
# are present in Base.metadata before configure_mappers() resolves any
# string-based relationship() references.
from app.automations.persistence import ( # noqa: E402, F401
Automation,
AutomationRun,
AutomationTrigger,
)
engine = create_async_engine( engine = create_async_engine(
DATABASE_URL, DATABASE_URL,
pool_size=30, pool_size=30,

View file

@ -7,6 +7,7 @@ from .agent_revert_route import router as agent_revert_router
from .airtable_add_connector_route import ( from .airtable_add_connector_route import (
router as airtable_add_connector_router, router as airtable_add_connector_router,
) )
from app.automations.api import router as automations_router
from .chat_comments_routes import router as chat_comments_router from .chat_comments_routes import router as chat_comments_router
from .circleback_webhook_route import router as circleback_webhook_router from .circleback_webhook_route import router as circleback_webhook_router
from .clickup_add_connector_route import router as clickup_add_connector_router from .clickup_add_connector_route import router as clickup_add_connector_router
@ -119,3 +120,4 @@ router.include_router(youtube_router) # YouTube playlist resolution
router.include_router(prompts_router) router.include_router(prompts_router)
router.include_router(memory_router) # User personal memory (memory.md style) router.include_router(memory_router) # User personal memory (memory.md style)
router.include_router(team_memory_router) # Search-space team memory router.include_router(team_memory_router) # Search-space team memory
router.include_router(automations_router) # Automations CRUD + run history

View file

@ -107,6 +107,12 @@ PERMISSION_DESCRIPTIONS = {
"settings:view": "View search space settings", "settings:view": "View search space settings",
"settings:update": "Modify search space settings", "settings:update": "Modify search space settings",
"settings:delete": "Delete the entire search space", "settings:delete": "Delete the entire search space",
# Automations
"automations:create": "Create automations from chat or JSON",
"automations:read": "View automations, their triggers, and run history",
"automations:update": "Edit automations and manage their triggers",
"automations:delete": "Remove automations from the search space",
"automations:execute": "Manually fire automations",
# Full access # Full access
"*": "Full access to all features and settings", "*": "Full access to all features and settings",
} }

View file

@ -0,0 +1,8 @@
"""Agent construction and per-turn event-loop drivers."""
from __future__ import annotations
from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread
from app.tasks.chat.streaming.agent.event_loop import stream_agent_events
__all__ = ["build_main_agent_for_thread", "stream_agent_events"]

View file

@ -0,0 +1,49 @@
"""Single per-thread agent (re)build path.
A graph swap mid-turn would corrupt checkpointer state for the same
``thread_id``, so both the initial build and any mid-stream 429 recovery rebuild
must funnel through this single function.
"""
from __future__ import annotations
from typing import Any
from app.agents.new_chat.filesystem_selection import FilesystemSelection
from app.agents.new_chat.llm_config import AgentConfig
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
async def build_main_agent_for_thread(
agent_factory: Any,
*,
llm: Any,
search_space_id: int,
db_session: Any,
connector_service: ConnectorService,
checkpointer: Any,
user_id: str | None,
thread_id: int | None,
agent_config: AgentConfig | None,
firecrawl_api_key: str | None,
thread_visibility: ChatVisibility | None,
filesystem_selection: FilesystemSelection | None,
disabled_tools: list[str] | None = None,
mentioned_document_ids: list[int] | None = None,
) -> Any:
return await agent_factory(
llm=llm,
search_space_id=search_space_id,
db_session=db_session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=thread_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=thread_visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
)

View file

@ -0,0 +1,175 @@
"""Per-turn agent event-loop driver.
Drives ``stream_output`` (graph_stream relay) for one agent turn, then runs the
post-stream agent-state inspection: safety-net commit of any staged filesystem
state (in case ``aafter_agent`` was skipped), file-operation contract scoring,
intent classification, and interrupt detection.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from typing import Any
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware.kb_persistence import (
commit_staged_filesystem_state,
)
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.streaming.contract.file_contract import (
contract_enforcement_active,
evaluate_file_contract_outcome,
log_file_contract,
)
from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
all_interrupt_values,
)
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.tasks.chat.streaming.shared.utils import safe_float
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
async def stream_agent_events(
agent: Any,
config: dict[str, Any],
input_data: Any,
streaming_service: VercelStreamingService,
result: StreamResult,
step_prefix: str = "thinking",
initial_step_id: str | None = None,
initial_step_title: str = "",
initial_step_items: list[str] | None = None,
*,
fallback_commit_search_space_id: int | None = None,
fallback_commit_created_by_id: str | None = None,
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
fallback_commit_thread_id: int | None = None,
runtime_context: Any = None,
content_builder: Any | None = None,
) -> AsyncGenerator[str, None]:
"""Stream and format ``astream_events`` from the agent.
Yields SSE-formatted strings; after exhausting, ``result`` carries
``accumulated_text`` and interrupt state. See ``StreamResult`` for the
side-channel surface populated by the underlying relay.
"""
async for sse in stream_output(
agent=agent,
config=config,
input_data=input_data,
streaming_service=streaming_service,
result=result,
step_prefix=step_prefix,
initial_step_id=initial_step_id,
initial_step_title=initial_step_title,
initial_step_items=initial_step_items,
content_builder=content_builder,
runtime_context=runtime_context,
):
yield sse
accumulated_text = result.accumulated_text
state = await agent.aget_state(config)
state_values = getattr(state, "values", {}) or {}
# Safety net: if astream_events was cancelled before
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
# (dirty_paths / staged_dirs / pending_moves / pending_deletes /
# pending_dir_deletes) is still in the checkpointed state. Run the SAME
# shared commit helper so the turn's writes don't get lost on client
# disconnect, then push the delta back into the graph using ``as_node=...``
# so reducers fire as if the after_agent hook produced it.
if (
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
and fallback_commit_search_space_id is not None
and (
(state_values.get("dirty_paths") or [])
or (state_values.get("staged_dirs") or [])
or (state_values.get("pending_moves") or [])
or (state_values.get("pending_deletes") or [])
or (state_values.get("pending_dir_deletes") or [])
)
):
try:
delta = await commit_staged_filesystem_state(
state_values,
search_space_id=fallback_commit_search_space_id,
created_by_id=fallback_commit_created_by_id,
filesystem_mode=fallback_commit_filesystem_mode,
thread_id=fallback_commit_thread_id,
dispatch_events=False,
)
if delta:
await agent.aupdate_state(
config,
delta,
as_node="KnowledgeBasePersistenceMiddleware.after_agent",
)
except Exception as exc:
_perf_log.warning("[stream_agent_events] safety-net commit failed: %s", exc)
contract_state = state_values.get("file_operation_contract") or {}
contract_turn_id = contract_state.get("turn_id")
current_turn_id = config.get("configurable", {}).get("turn_id", "")
intent_value = contract_state.get("intent")
if (
isinstance(intent_value, str)
and intent_value in ("chat_only", "file_write", "file_read")
and contract_turn_id == current_turn_id
):
result.intent_detected = intent_value
if (
isinstance(intent_value, str)
and intent_value in ("chat_only", "file_write", "file_read")
and contract_turn_id != current_turn_id
):
# Ignore stale intent contracts from previous turns/checkpoints.
result.intent_detected = "chat_only"
result.intent_confidence = (
safe_float(contract_state.get("confidence"), default=0.0)
if contract_turn_id == current_turn_id
else 0.0
)
if result.intent_detected == "file_write":
result.commit_gate_passed, result.commit_gate_reason = (
evaluate_file_contract_outcome(result)
)
if not result.commit_gate_passed and contract_enforcement_active(result):
gate_notice = (
"I could not complete the requested file write because no successful "
"write_file/edit_file operation was confirmed."
)
gate_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(gate_text_id)
if content_builder is not None:
content_builder.on_text_start(gate_text_id)
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
if content_builder is not None:
content_builder.on_text_delta(gate_text_id, gate_notice)
yield streaming_service.format_text_end(gate_text_id)
if content_builder is not None:
content_builder.on_text_end(gate_text_id)
yield streaming_service.format_terminal_info(gate_notice, "error")
accumulated_text = gate_notice
else:
result.commit_gate_passed = True
result.commit_gate_reason = ""
result.accumulated_text = accumulated_text
log_file_contract("turn_outcome", result)
pending_values = all_interrupt_values(state)
if pending_values:
result.is_interrupted = True
# One frame per paused subagent so each parallel HITL renders its own
# approval card on the wire. Order matches ``state.interrupts``, which
# the resume slicer in
# ``checkpointed_subagent_middleware.resume_routing`` consumes in the
# same order — keeping emit and resume in lock-step.
for interrupt_value in pending_values:
yield streaming_service.format_interrupt_request(interrupt_value)

View file

@ -0,0 +1,15 @@
"""Pre-agent context shaping: mentioned-doc rendering and todos extraction."""
from __future__ import annotations
from app.tasks.chat.streaming.context.deepagents_todos import (
extract_todos_from_deepagents,
)
from app.tasks.chat.streaming.context.mentioned_docs import (
format_mentioned_surfsense_docs_as_context,
)
__all__ = [
"extract_todos_from_deepagents",
"format_mentioned_surfsense_docs_as_context",
]

View file

@ -0,0 +1,27 @@
"""Extract todos from a deepagents ``TodoListMiddleware`` ``Command`` output."""
from __future__ import annotations
from typing import Any
def extract_todos_from_deepagents(command_output: Any) -> dict:
"""Normalize todos out of a deepagents ``Command`` or dict payload.
deepagents returns a ``Command`` whose ``update['todos']`` is a list of
``{'content': str, 'status': str}`` dicts. The UI expects the same shape,
so no transformation is required only extraction.
"""
todos_data: list = []
if hasattr(command_output, "update"):
update = command_output.update
todos_data = update.get("todos", [])
elif isinstance(command_output, dict):
if "todos" in command_output:
todos_data = command_output.get("todos", [])
elif "update" in command_output and isinstance(
command_output["update"], dict
):
todos_data = command_output["update"].get("todos", [])
return {"todos": todos_data}

View file

@ -0,0 +1,58 @@
"""Render user-mentioned SurfSense docs as XML context for the agent."""
from __future__ import annotations
import json
from app.db import SurfsenseDocsDocument
from app.utils.surfsense_docs import surfsense_docs_public_url
def format_mentioned_surfsense_docs_as_context(
documents: list[SurfsenseDocsDocument],
) -> str:
if not documents:
return ""
context_parts = ["<mentioned_surfsense_docs>"]
context_parts.append(
"The user has explicitly mentioned the following SurfSense documentation pages. "
"These are official documentation about how to use SurfSense and should be used to answer questions about the application. "
"Use [citation:CHUNK_ID] format for citations (e.g., [citation:doc-123])."
)
for doc in documents:
public_url = surfsense_docs_public_url(doc.source)
metadata_json = json.dumps(
{"source": doc.source, "public_url": public_url}, ensure_ascii=False
)
context_parts.append("<document>")
context_parts.append("<document_metadata>")
context_parts.append(f" <document_id>doc-{doc.id}</document_id>")
context_parts.append(" <document_type>SURFSENSE_DOCS</document_type>")
context_parts.append(f" <title><![CDATA[{doc.title}]]></title>")
context_parts.append(f" <url><![CDATA[{public_url}]]></url>")
context_parts.append(
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>"
)
context_parts.append("</document_metadata>")
context_parts.append("")
context_parts.append("<document_content>")
if hasattr(doc, "chunks") and doc.chunks:
for chunk in doc.chunks:
context_parts.append(
f" <chunk id='doc-{chunk.id}'><![CDATA[{chunk.content}]]></chunk>"
)
else:
context_parts.append(
f" <chunk id='doc-0'><![CDATA[{doc.content}]]></chunk>"
)
context_parts.append("</document_content>")
context_parts.append("</document>")
context_parts.append("")
context_parts.append("</mentioned_surfsense_docs>")
return "\n".join(context_parts)

View file

@ -0,0 +1,15 @@
"""File-operation contract evaluation and logging."""
from __future__ import annotations
from app.tasks.chat.streaming.contract.file_contract import (
contract_enforcement_active,
evaluate_file_contract_outcome,
log_file_contract,
)
__all__ = [
"contract_enforcement_active",
"evaluate_file_contract_outcome",
"log_file_contract",
]

View file

@ -0,0 +1,53 @@
"""File-operation contract: when to enforce, how to score, how to log."""
from __future__ import annotations
import json
from typing import Any
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
def contract_enforcement_active(result: StreamResult) -> bool:
# Enforce only in desktop local-folder mode. Kept deterministic, no
# env-driven progression modes.
return result.filesystem_mode == "desktop_local_folder"
def evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]:
if result.intent_detected != "file_write":
return True, ""
if not result.write_attempted:
return False, "no_write_attempt"
if not result.write_succeeded:
return False, "write_failed"
if not result.verification_succeeded:
return False, "verification_failed"
return True, ""
def log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
payload: dict[str, Any] = {
"stage": stage,
"request_id": result.request_id or "unknown",
"turn_id": result.turn_id or "unknown",
"chat_id": (
result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown"
),
"filesystem_mode": result.filesystem_mode,
"client_platform": result.client_platform,
"intent_detected": result.intent_detected,
"intent_confidence": result.intent_confidence,
"write_attempted": result.write_attempted,
"write_succeeded": result.write_succeeded,
"verification_succeeded": result.verification_succeeded,
"commit_gate_passed": result.commit_gate_passed,
"commit_gate_reason": result.commit_gate_reason or None,
}
payload.update(extra)
_perf_log.info(
"[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False)
)

View file

@ -0,0 +1,17 @@
"""Top-level streaming flows: ``new_chat`` and ``resume_chat`` orchestrators.
Re-exports the public entry points so callers can write::
from app.tasks.chat.streaming.flows import stream_new_chat, stream_resume_chat
The orchestrators themselves live under ``new_chat/orchestrator.py`` and
``resume_chat/orchestrator.py`` (slim composition of the per-concern modules in
each flow folder and the building blocks in ``shared/``).
"""
from __future__ import annotations
from app.tasks.chat.streaming.flows.new_chat import stream_new_chat
from app.tasks.chat.streaming.flows.resume_chat import stream_resume_chat
__all__ = ["stream_new_chat", "stream_resume_chat"]

View file

@ -0,0 +1,12 @@
"""New-chat streaming flow.
The public entry point ``stream_new_chat`` is the slim coroutine in
``orchestrator.py`` that composes the per-concern modules in this folder and
the building blocks under ``flows/shared/``.
"""
from __future__ import annotations
from app.tasks.chat.streaming.flows.new_chat.orchestrator import stream_new_chat
__all__ = ["stream_new_chat"]

View file

@ -0,0 +1,95 @@
"""Resolve the auto-pin for the *initial* turn config.
Auto-pin (``selected_llm_config_id=0``) picks the best eligible LLM config for
this thread / search space / user, optionally filtered to vision-capable
configs when the turn carries images.
Errors classified here:
* ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` the auto-pin pool has no
vision-capable cfg for an image-bearing turn. The same gate fires later
in ``llm_capability`` for explicit selections; mapping both to the same
code keeps the FE error UI consistent.
* ``SERVER_ERROR`` any other ``ValueError`` from the resolver.
This module owns *initial* pin resolution; the rate-limit recovery loop has
its own narrower auto-pin call (with ``exclude_config_ids``) in
``flows/shared/rate_limit_recovery``.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.observability import otel as ot
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
@dataclass
class AutoPinResult:
"""Outcome of ``resolve_initial_auto_pin``.
``llm_config_id`` is set when ``error`` is ``None``; ``error`` carries the
classified user-facing message plus error code/kind so the orchestrator can
emit one terminal-error SSE frame.
"""
llm_config_id: int | None
error: tuple[str, str, Literal["user_error", "server_error"]] | None
async def resolve_initial_auto_pin(
session: AsyncSession,
*,
chat_id: int,
search_space_id: int,
user_id: str | None,
selected_llm_config_id: int,
requires_image_input: bool,
requested_llm_config_id: int,
) -> AutoPinResult:
"""Run the resolver and classify any ``ValueError`` for the SSE error path."""
try:
pinned = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=selected_llm_config_id,
requires_image_input=requires_image_input,
)
ot.add_event(
"model.pin.resolved",
{
"pin.requested_id": requested_llm_config_id,
"pin.resolved_id": pinned.resolved_llm_config_id,
"pin.requires_image_input": requires_image_input,
},
)
return AutoPinResult(
llm_config_id=pinned.resolved_llm_config_id, error=None
)
except ValueError as pin_error:
# The "no vision-capable cfg" path raises a ValueError whose message
# we map to the friendly image-input SSE error so the user sees the
# same message regardless of whether the gate fired in the resolver or
# in ``llm_capability.assert_vision_capability_for_image_turn``.
is_vision_failure = (
requires_image_input and "vision-capable" in str(pin_error)
)
error_code = (
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
if is_vision_failure
else "SERVER_ERROR"
)
error_kind: Literal["user_error", "server_error"] = (
"user_error" if is_vision_failure else "server_error"
)
if is_vision_failure:
ot.add_event("quota.denied", {"quota.code": error_code})
return AutoPinResult(
llm_config_id=None, error=(str(pin_error), error_code, error_kind)
)

View file

@ -0,0 +1,95 @@
"""Build and emit the first ``thinking-1`` step for a new-chat turn.
The step title and "Processing X" items are derived from what the user sent
(text snippet, image count, mentioned doc titles) so the FE can render a
meaningful placeholder while the agent stream warms up.
``thinking-1`` is the canonical id for this step every subsequent
``thinking-N`` produced by ``stream_agent_events`` folds into the same
singleton ``data-thinking-steps`` part on the FE.
"""
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any
from app.db import SurfsenseDocsDocument
from app.services.new_streaming_service import VercelStreamingService
@dataclass
class InitialThinkingStep:
"""Resolved fields passed both into the SSE frame and the builder hook.
``items`` is the bullet list under the step title; ``title`` is the
one-line step header. ``step_id`` is hard-coded ``thinking-1`` so the FE
Timeline can de-duplicate against the prior assistant message on resume.
"""
step_id: str
title: str
items: list[str]
def build_initial_thinking_step(
*,
user_query: str,
user_image_data_urls: list[str] | None,
mentioned_surfsense_docs: list[SurfsenseDocsDocument],
) -> InitialThinkingStep:
if mentioned_surfsense_docs:
title = "Analyzing referenced content"
action_verb = "Analyzing"
else:
title = "Understanding your request"
action_verb = "Processing"
processing_parts: list[str] = []
if user_query.strip():
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
processing_parts.append(query_text)
elif user_image_data_urls:
processing_parts.append(f"[{len(user_image_data_urls)} image(s)]")
else:
processing_parts.append("(message)")
if mentioned_surfsense_docs:
doc_names: list[str] = []
for doc in mentioned_surfsense_docs:
t = doc.title
if len(t) > 30:
t = t[:27] + "..."
doc_names.append(t)
if len(doc_names) == 1:
processing_parts.append(f"[{doc_names[0]}]")
else:
processing_parts.append(f"[{len(doc_names)} docs]")
items = [f"{action_verb}: {' '.join(processing_parts)}"]
return InitialThinkingStep(step_id="thinking-1", title=title, items=items)
def iter_initial_thinking_step_frame(
step: InitialThinkingStep,
*,
streaming_service: VercelStreamingService,
content_builder: Any | None,
) -> Iterator[str]:
"""Drive both the SSE emission and the builder hook for the initial step.
The FE folds this step into the same singleton ``data-thinking-steps`` part
as everything the agent stream emits later, so we mirror that fold
server-side by driving the builder lifecycle ourselves.
"""
if content_builder is not None:
content_builder.on_thinking_step(
step.step_id, step.title, "in_progress", step.items
)
yield streaming_service.format_thinking_step(
step_id=step.step_id,
title=step.title,
status="in_progress",
items=step.items,
)

View file

@ -0,0 +1,264 @@
r"""Assemble the LangGraph ``input_state`` for the new-chat turn.
Pipeline:
1. **History bootstrap** only for cloned chats with no LangGraph checkpoint
yet; flips the per-thread ``needs_history_bootstrap`` flag back to False
once the rows are loaded.
2. **Mentioned SurfSense docs** eager-load chunks so the formatter has the
full content without a second roundtrip.
3. **Recent reports** top 3 by id desc with non-null content, so the LLM
can resolve ``report_id`` for versioning without spelunking history.
4. **@-mention resolve** (cloud mode) substitute ``@title`` tokens in the
query with canonical ``\`/documents/...\``` paths the LLM expects.
5. **Context block render** XML-wrap surfsense docs + reports, prepend to
the rewritten query, optionally prefix with display name for SEARCH_SPACE
visibility.
6. **HumanMessage** multimodal content if images are attached.
Returns the assembled ``input_state`` dict plus side-channel data the
orchestrator needs downstream (``accepted_folder_ids`` for runtime context;
``mentioned_surfsense_docs`` for the initial thinking step).
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
from langchain_core.messages import HumanMessage
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
from app.db import (
ChatVisibility,
NewChatThread,
Report,
SurfsenseDocsDocument,
)
from app.tasks.chat.streaming.context.mentioned_docs import (
format_mentioned_surfsense_docs_as_context,
)
from app.utils.content_utils import bootstrap_history_from_db
from app.utils.user_message_multimodal import build_human_message_content
logger = logging.getLogger(__name__)
@dataclass
class NewChatInputState:
"""Everything ``build_new_chat_input_state`` produces.
``input_state`` is fed straight to the agent. ``accepted_folder_ids``
feeds the runtime context (the resolver may have dropped some chips).
``mentioned_surfsense_docs`` is consumed by the initial thinking-step
builder for the FE placeholder before the agent stream starts.
"""
input_state: dict[str, Any]
accepted_folder_ids: list[int]
mentioned_surfsense_docs: list[SurfsenseDocsDocument]
async def build_new_chat_input_state(
session: AsyncSession,
*,
chat_id: int,
search_space_id: int,
user_query: str,
user_image_data_urls: list[str] | None,
mentioned_document_ids: list[int] | None,
mentioned_surfsense_doc_ids: list[int] | None,
mentioned_folder_ids: list[int] | None,
mentioned_documents: list[dict[str, Any]] | None,
needs_history_bootstrap: bool,
thread_visibility: ChatVisibility,
current_user_display_name: str | None,
filesystem_mode: str,
request_id: str | None,
turn_id: str,
) -> NewChatInputState:
langchain_messages: list[Any] = []
if needs_history_bootstrap:
langchain_messages = await bootstrap_history_from_db(
session, chat_id, thread_visibility=thread_visibility
)
thread_result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
thread = thread_result.scalars().first()
if thread:
thread.needs_history_bootstrap = False
await session.commit()
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
if mentioned_surfsense_doc_ids:
result = await session.execute(
select(SurfsenseDocsDocument)
.options(selectinload(SurfsenseDocsDocument.chunks))
.filter(SurfsenseDocsDocument.id.in_(mentioned_surfsense_doc_ids))
)
mentioned_surfsense_docs = list(result.scalars().all())
# Top 3 reports keyed by id desc (newest first) with content present,
# surfaced inline so the LLM resolves ``report_id`` for versioning without
# digging through conversation history.
recent_reports_result = await session.execute(
select(Report)
.filter(
Report.thread_id == chat_id,
Report.content.isnot(None),
)
.order_by(Report.id.desc())
.limit(3)
)
recent_reports = list(recent_reports_result.scalars().all())
agent_user_query, accepted_folder_ids = await _resolve_mentions_for_query(
session,
search_space_id=search_space_id,
user_query=user_query,
filesystem_mode=filesystem_mode,
mentioned_document_ids=mentioned_document_ids,
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
mentioned_folder_ids=mentioned_folder_ids,
mentioned_documents=mentioned_documents,
)
final_query = _render_query_with_context(
agent_user_query=agent_user_query,
mentioned_surfsense_docs=mentioned_surfsense_docs,
recent_reports=recent_reports,
)
if thread_visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name:
final_query = f"**[{current_user_display_name}]:** {final_query}"
human_content = build_human_message_content(
final_query, list(user_image_data_urls or ())
)
langchain_messages.append(HumanMessage(content=human_content))
input_state = {
"messages": langchain_messages,
"search_space_id": search_space_id,
"request_id": request_id or "unknown",
"turn_id": turn_id,
}
return NewChatInputState(
input_state=input_state,
accepted_folder_ids=accepted_folder_ids,
mentioned_surfsense_docs=mentioned_surfsense_docs,
)
async def _resolve_mentions_for_query(
session: AsyncSession,
*,
search_space_id: int,
user_query: str,
filesystem_mode: str,
mentioned_document_ids: list[int] | None,
mentioned_surfsense_doc_ids: list[int] | None,
mentioned_folder_ids: list[int] | None,
mentioned_documents: list[dict[str, Any]] | None,
) -> tuple[str, list[int]]:
r"""Resolve @-mention chips and rewrite the user query to canonical paths.
Cloud mode only: local-folder mode keeps the legacy ``@title`` text path
(mention support there is a follow-up task the path scheme is
mount-rooted and the picker UI both need separate work).
The substitution lands in the returned ``agent_user_query`` ONLY the
original ``user_query`` (with ``@title`` tokens) flows untouched into
``persist_user_turn`` so chip rendering on reload still works
(``UserTextPart`` ``parseMentionSegments`` matches ``@title``, not
``\`/documents/...\```). It also feeds the human-readable surfaces SSE
"Processing X" status, auto thread title, memory seed which all want
what the user typed.
"""
agent_user_query = user_query
accepted_folder_ids: list[int] = []
has_any_mention = bool(
mentioned_document_ids
or mentioned_surfsense_doc_ids
or mentioned_folder_ids
or mentioned_documents
)
if filesystem_mode != FilesystemMode.CLOUD.value or not has_any_mention:
return agent_user_query, accepted_folder_ids
from app.schemas.new_chat import MentionedDocumentInfo
chip_objs: list[MentionedDocumentInfo] | None = None
if mentioned_documents:
chip_objs = []
for raw in mentioned_documents:
if isinstance(raw, MentionedDocumentInfo):
chip_objs.append(raw)
continue
try:
chip_objs.append(MentionedDocumentInfo.model_validate(raw))
except Exception:
logger.debug(
"stream_new_chat: dropping malformed mention chip %r", raw
)
resolved = await resolve_mentions(
session,
search_space_id=search_space_id,
mentioned_documents=chip_objs,
mentioned_document_ids=mentioned_document_ids,
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
mentioned_folder_ids=mentioned_folder_ids,
)
agent_user_query = substitute_in_text(user_query, resolved.token_to_path)
accepted_folder_ids = resolved.mentioned_folder_ids
return agent_user_query, accepted_folder_ids
def _render_query_with_context(
*,
agent_user_query: str,
mentioned_surfsense_docs: list[SurfsenseDocsDocument],
recent_reports: list[Report],
) -> str:
"""Prepend surfsense-docs + recent-reports XML blocks to the user query."""
context_parts: list[str] = []
if mentioned_surfsense_docs:
context_parts.append(
format_mentioned_surfsense_docs_as_context(mentioned_surfsense_docs)
)
if recent_reports:
report_lines: list[str] = []
for r in recent_reports:
report_lines.append(
f' - report_id={r.id}, title="{r.title}", '
f'style="{r.report_style or "detailed"}"'
)
reports_listing = "\n".join(report_lines)
context_parts.append(
"<report_context>\n"
"Previously generated reports in this conversation:\n"
f"{reports_listing}\n\n"
"If the user wants to MODIFY, REVISE, UPDATE, or ADD to one of "
"these reports, set parent_report_id to the relevant report_id above.\n"
"If the user wants a completely NEW report on a different topic, "
"leave parent_report_id unset.\n"
"</report_context>"
)
if context_parts:
context = "\n\n".join(context_parts)
return f"{context}\n\n<user_query>{agent_user_query}</user_query>"
return agent_user_query

View file

@ -0,0 +1,62 @@
"""Vision-capability gate for image-bearing turns.
Capability safety net for explicit (non-auto-pin) selections: a turn carrying
user-uploaded images cannot be routed to a chat config that LiteLLM's
authoritative model map *explicitly* marks as text-only (``supports_vision``
set to False). The check is intentionally narrow it only fires when LiteLLM
is *certain* the model can't accept image input; unknown / unmapped /
vision-capable models pass through.
Without this guard a known-text-only model would 404 at the provider with
``"No endpoints found that support image input"``, surfacing as an opaque
``SERVER_ERROR`` SSE chunk; failing here lets us return a friendly message that
tells the user what to change.
"""
from __future__ import annotations
from app.agents.new_chat.llm_config import AgentConfig
from app.observability import otel as ot
def check_image_input_capability(
*,
user_image_data_urls: list[str] | None,
agent_config: AgentConfig | None,
) -> tuple[str, str] | None:
"""Return ``(user_message, error_code)`` when the gate trips, else ``None``.
The caller emits one terminal-error SSE frame on a non-``None`` return.
"""
if not (user_image_data_urls and agent_config is not None):
return None
from app.services.provider_capabilities import is_known_text_only_chat_model
agent_litellm_params = agent_config.litellm_params or {}
agent_base_model = (
agent_litellm_params.get("base_model")
if isinstance(agent_litellm_params, dict)
else None
)
if not is_known_text_only_chat_model(
provider=agent_config.provider,
model_name=agent_config.model_name,
base_model=agent_base_model,
custom_provider=agent_config.custom_provider,
):
return None
model_label = agent_config.config_name or agent_config.model_name or "model"
ot.add_event(
"quota.denied", {"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"}
)
return (
(
f"The selected model ({model_label}) does not support "
"image input. Switch to a vision-capable model "
"(e.g. GPT-4o, Claude, Gemini) or remove the image "
"attachment and try again."
),
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
)

View file

@ -0,0 +1,868 @@
"""``stream_new_chat`` — public entry point for a fresh chat turn.
Slim composition layer over the per-concern modules in this folder and the
building blocks under ``flows/shared/``. Each phase corresponds to a numbered
block in the surrounding code so the on-the-wire ordering stays explicit:
1. Validation / config auto-pin, LLM bundle, capability, premium reserve.
2. Concurrent persistence + pre-stream setup spawn DB writes, build the
connector, fetch the checkpointer, build the agent.
3. Input assembly history bootstrap, mentions, surfsense docs, reports.
4. First SSE frames message_start, start_step, turn-info, turn-status.
5. Persistence join + message-id frames (ghost-thread protection).
6. Initial thinking step + title task + runtime context.
7. Stream loop with in-stream rate-limit recovery + mid-stream title emit.
8. Finalize premium debit, token-usage SSE, finish frames.
9. Exception branch classify, emit terminal error, finish frames.
10. Finally premium release, session close, assistant finalize, GC, span.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
import time
from collections.abc import AsyncGenerator
from functools import partial
from typing import Any, Literal
import anyio
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.middleware.busy_mutex import end_turn
from app.config import config as _app_config
from app.db import ChatVisibility, async_session_maker
from app.observability import otel as ot
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.content_builder import AssistantContentBuilder
from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread
from app.tasks.chat.streaming.contract.file_contract import log_file_contract
from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error
from app.tasks.chat.streaming.flows.new_chat.auto_pin import resolve_initial_auto_pin
from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import (
build_initial_thinking_step,
iter_initial_thinking_step_frame,
)
from app.tasks.chat.streaming.flows.new_chat.input_state import (
build_new_chat_input_state,
)
from app.tasks.chat.streaming.flows.new_chat.llm_capability import (
check_image_input_capability,
)
from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import (
await_persist_task,
spawn_persist_assistant_shell_task,
spawn_persist_user_task,
spawn_set_ai_responding_bg,
)
from app.tasks.chat.streaming.flows.new_chat.runtime_context import (
build_new_chat_runtime_context,
)
from app.tasks.chat.streaming.flows.new_chat.title_gen import (
await_pending_title_update,
maybe_emit_title_update,
spawn_title_task,
)
from app.tasks.chat.streaming.flows.shared.assistant_finalize import (
finalize_assistant_message,
)
from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame
from app.tasks.chat.streaming.flows.shared.finally_cleanup import (
close_session_and_clear_ai_responding,
run_gc_pass,
)
from app.tasks.chat.streaming.flows.shared.first_frames import (
iter_final_frames,
iter_initial_frames,
)
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
from app.tasks.chat.streaming.flows.shared.pre_stream_setup import (
get_chat_checkpointer,
setup_connector_and_firecrawl,
)
from app.tasks.chat.streaming.flows.shared.premium_quota import (
PremiumReservation,
finalize_premium,
needs_premium_quota,
release_premium,
reserve_premium,
)
from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import (
can_recover_provider_rate_limit,
log_rate_limit_recovered,
reroute_to_next_auto_pin,
)
from app.tasks.chat.streaming.flows.shared.span import (
close_chat_request_span,
open_chat_request_span,
set_agent_mode,
)
from app.tasks.chat.streaming.flows.shared.stream_loop import run_stream_loop
from app.tasks.chat.streaming.flows.shared.terminal_error import (
handle_terminal_exception,
)
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.utils.perf import get_perf_logger, log_system_snapshot
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
# Holds spawned background tasks (set_ai_responding, persist_user, persist_asst)
# so the GC doesn't drop them before they finish. Kept at module level so it
# survives across turns within one process.
_background_tasks: set[asyncio.Task] = set()
async def stream_new_chat(
user_query: str,
search_space_id: int,
chat_id: int,
user_id: str | None = None,
llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None,
mentioned_surfsense_doc_ids: list[int] | None = None,
mentioned_folder_ids: list[int] | None = None,
mentioned_documents: list[dict[str, Any]] | None = None,
checkpoint_id: str | None = None,
needs_history_bootstrap: bool = False,
thread_visibility: ChatVisibility | None = None,
current_user_display_name: str | None = None,
disabled_tools: list[str] | None = None,
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
user_image_data_urls: list[str] | None = None,
flow: Literal["new", "regenerate"] = "new",
) -> AsyncGenerator[str, None]:
"""Stream a new chat turn using the SurfSense deep agent.
Uses the Vercel AI SDK Data Stream Protocol (SSE). ``chat_id`` is the
LangGraph thread id (durable conversation memory via the checkpointer).
Manages its own database session so cleanup runs even when Starlette
cancels the task on client disconnect.
"""
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
fs_platform = (
filesystem_selection.client_platform.value if filesystem_selection else "web"
)
stream_result.request_id = request_id
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
stream_result.filesystem_mode = fs_mode
stream_result.client_platform = fs_platform
chat_agent_mode = "unknown"
chat_outcome = "success"
chat_error_category: str | None = None
chat_span_cm, chat_span = open_chat_request_span(
chat_id=chat_id,
search_space_id=search_space_id,
flow=flow,
request_id=request_id,
turn_id=stream_result.turn_id,
filesystem_mode=fs_mode,
client_platform=fs_platform,
agent_mode=chat_agent_mode,
)
log_file_contract("turn_start", stream_result)
_perf_log.info(
"[stream_new_chat] filesystem_mode=%s client_platform=%s",
fs_mode,
fs_platform,
)
log_system_snapshot("stream_new_chat_START")
from app.services.token_tracking_service import start_turn
accumulator = start_turn()
premium_reservation: PremiumReservation | None = None
busy_error_raised = False
emit_stream_error = partial(
emit_stream_terminal_error,
streaming_service=streaming_service,
flow=flow,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
)
session = async_session_maker()
# Declared at function scope so SSE-yield join points and the finally
# clause see them on every exit path.
persist_user_task: asyncio.Task[int | None] | None = None
persist_asst_task: asyncio.Task[int | None] | None = None
try:
spawn_set_ai_responding_bg(
chat_id=chat_id, user_id=user_id, background_tasks=_background_tasks
)
# --- Block 1: LLM config + capability ---
requested_llm_config_id = llm_config_id
requires_image_input = bool(user_image_data_urls)
_t0 = time.perf_counter()
pin_result = await resolve_initial_auto_pin(
session,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=llm_config_id,
requires_image_input=requires_image_input,
requested_llm_config_id=requested_llm_config_id,
)
if pin_result.error is not None:
message, error_code, error_kind = pin_result.error
yield emit_stream_error(
message=message, error_kind=error_kind, error_code=error_code
)
yield streaming_service.format_done()
return
llm_config_id = pin_result.llm_config_id # type: ignore[assignment]
llm, agent_config, llm_load_error = await load_llm_bundle(
session, config_id=llm_config_id, search_space_id=search_space_id
)
if llm_load_error:
yield emit_stream_error(
message=llm_load_error,
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
_perf_log.info(
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
time.perf_counter() - _t0,
llm_config_id,
)
capability_error = check_image_input_capability(
user_image_data_urls=user_image_data_urls, agent_config=agent_config
)
if capability_error is not None:
message, error_code = capability_error
yield emit_stream_error(
message=message,
error_kind="user_error",
error_code=error_code,
)
yield streaming_service.format_done()
return
if needs_premium_quota(agent_config, user_id):
premium_reservation = await reserve_premium(
agent_config=agent_config, user_id=user_id # type: ignore[arg-type]
)
if not premium_reservation.allowed:
ot.add_event("quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"})
if requested_llm_config_id == 0:
pin_fallback = await resolve_initial_auto_pin(
session,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
requires_image_input=requires_image_input,
requested_llm_config_id=requested_llm_config_id,
)
if pin_fallback.error is not None:
message, error_code, error_kind = pin_fallback.error
yield emit_stream_error(
message=message,
error_kind=error_kind,
error_code=error_code,
)
yield streaming_service.format_done()
return
llm_config_id = pin_fallback.llm_config_id # type: ignore[assignment]
ot.add_event(
"model.repin",
{
"repin.reason": "premium_quota_exhausted",
"repin.to_config_id": llm_config_id,
},
)
llm, agent_config, llm_load_error = await load_llm_bundle(
session,
config_id=llm_config_id,
search_space_id=search_space_id,
)
if llm_load_error:
yield emit_stream_error(
message=llm_load_error,
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
premium_reservation = None
# Re-route to free fallback logged via the structured
# stream-error logger so cost/analytics see the auto-switch.
from app.tasks.chat.streaming.errors.classifier import (
log_chat_stream_error,
)
log_chat_stream_error(
flow=flow,
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Premium quota exhausted on pinned model; "
"auto-fallback switched to a free model"
),
extra={
"fallback_config_id": llm_config_id,
"auto_fallback": True,
},
)
else:
yield emit_stream_error(
message=(
"Buy more tokens to continue with this model, or "
"switch to a free model"
),
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
extra={
"resolved_config_id": llm_config_id,
"auto_fallback": False,
},
)
yield streaming_service.format_done()
return
if not llm:
yield emit_stream_error(
message="Failed to create LLM instance",
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
# --- Block 2: Spawn concurrent persistence; build pre-stream setup ---
persist_user_task = spawn_persist_user_task(
chat_id=chat_id,
user_id=user_id,
turn_id=stream_result.turn_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
mentioned_documents=mentioned_documents,
background_tasks=_background_tasks,
)
persist_asst_task = spawn_persist_assistant_shell_task(
chat_id=chat_id,
user_id=user_id,
turn_id=stream_result.turn_id,
background_tasks=_background_tasks,
)
_t0 = time.perf_counter()
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
session, search_space_id=search_space_id
)
_perf_log.info(
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
_t0 = time.perf_counter()
checkpointer = await get_chat_checkpointer()
_perf_log.info(
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
visibility = thread_visibility or ChatVisibility.PRIVATE
use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED)
chat_agent_mode = "multi" if use_multi_agent else "single"
set_agent_mode(chat_span, chat_agent_mode)
_t0 = time.perf_counter()
agent_factory = (
create_multi_agent_chat_deep_agent
if use_multi_agent
else create_surfsense_deep_agent
)
# Build the agent inline. Provider 429s surface through the in-stream
# recovery loop below, which repins the thread to an eligible
# alternative config and rebuilds the agent before the user sees any
# output.
agent = await build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
)
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
)
# --- Block 3: Input assembly ---
_t0 = time.perf_counter()
assembled = await build_new_chat_input_state(
session,
chat_id=chat_id,
search_space_id=search_space_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
mentioned_document_ids=mentioned_document_ids,
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
mentioned_folder_ids=mentioned_folder_ids,
mentioned_documents=mentioned_documents,
needs_history_bootstrap=needs_history_bootstrap,
thread_visibility=visibility,
current_user_display_name=current_user_display_name,
filesystem_mode=fs_mode,
request_id=request_id,
turn_id=stream_result.turn_id,
)
input_state = assembled.input_state
accepted_folder_ids = assembled.accepted_folder_ids
mentioned_surfsense_docs = assembled.mentioned_surfsense_docs
_perf_log.info(
"[stream_new_chat] History bootstrap + doc/report queries in %.3fs",
time.perf_counter() - _t0,
)
# All pre-streaming DB reads done. Commit to release the transaction
# and its ACCESS SHARE locks so we don't block DDL (e.g. migrations)
# for the entire LLM streaming duration. Tools that need DB access
# during streaming start their own short-lived transactions (or use
# isolated sessions).
await session.commit()
# Detach heavy ORM objects (documents with chunks, reports, etc.)
# from the session identity map now that we've extracted what we
# need. Without this they accumulate in memory for the entire
# streaming duration (which can be several minutes).
session.expunge_all()
_perf_log.info(
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
chat_id,
)
configurable: dict[str, Any] = {
"thread_id": str(chat_id),
"request_id": request_id or "unknown",
"turn_id": stream_result.turn_id,
}
if checkpoint_id:
configurable["checkpoint_id"] = checkpoint_id
config = {
"configurable": configurable,
# Effectively uncapped, matching the agent-level ``with_config``
# default in ``chat_deepagent.create_agent`` and the unbounded
# ``while(true)`` in OpenCode's ``session/processor.ts``. Real
# circuit-breakers live in middleware (``DoomLoopMiddleware``,
# plus ``enable_tool_call_limit`` / ``enable_model_call_limit``).
# The original 25 (and our previous 80 bump) hit users on
# legitimate multi-tool plans.
"recursion_limit": 10_000,
}
# --- Block 4: First SSE frames ---
for sse in iter_initial_frames(streaming_service, turn_id=stream_result.turn_id):
yield sse
# --- Block 5: Persistence join + message-id frames ---
user_message_id = await await_persist_task(
persist_user_task,
chat_id=chat_id,
turn_id=stream_result.turn_id,
log_label="persist_user_task",
)
if user_message_id is None:
yield emit_stream_error(
message="We couldn't save your message. Please try again in a moment.",
error_kind="server_error",
error_code="MESSAGE_PERSIST_FAILED",
)
for sse in iter_final_frames(streaming_service):
yield sse
return
# Emit canonical user message id BEFORE any LLM streaming so the FE
# can rename its optimistic ``msg-user-XXX`` placeholder to
# ``msg-{user_message_id}`` and unlock features gated on a real DB id
# (comments, edit-from-this-message). See B4 in the
# ``sse-based_message_id_handshake`` plan.
yield streaming_service.format_data(
"user-message-id",
{"message_id": user_message_id, "turn_id": stream_result.turn_id},
)
assistant_message_id = await await_persist_task(
persist_asst_task,
chat_id=chat_id,
turn_id=stream_result.turn_id,
log_label="persist_asst_task",
)
if assistant_message_id is None:
# Genuine DB failure — abort the turn rather than stream into a
# void. The user row is already persisted so the legacy
# ghost-thread gate isn't reopened.
yield emit_stream_error(
message=(
"We couldn't initialize the assistant message. Please try again."
),
error_kind="server_error",
error_code="MESSAGE_PERSIST_FAILED",
)
for sse in iter_final_frames(streaming_service):
yield sse
return
yield streaming_service.format_data(
"assistant-message-id",
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
)
stream_result.assistant_message_id = assistant_message_id
stream_result.content_builder = AssistantContentBuilder()
# --- Block 6: Initial thinking step + title task + runtime context ---
initial_step = build_initial_thinking_step(
user_query=user_query,
user_image_data_urls=user_image_data_urls,
mentioned_surfsense_docs=mentioned_surfsense_docs,
)
for sse in iter_initial_thinking_step_frame(
initial_step,
streaming_service=streaming_service,
content_builder=stream_result.content_builder,
):
yield sse
initial_step_id = initial_step.step_id
initial_step_title = initial_step.title
initial_step_items = initial_step.items
# Drop the heavy ORM objects + the container that holds them so they
# aren't retained for the entire streaming duration. ``input_state``
# already carries the langchain_messages list independently.
del assembled, mentioned_surfsense_docs
title_task = spawn_title_task(
chat_id=chat_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
assistant_message_id=assistant_message_id,
llm=llm,
agent_config=agent_config,
)
title_emitted = False
runtime_context = build_new_chat_runtime_context(
search_space_id=search_space_id,
mentioned_document_ids=mentioned_document_ids,
accepted_folder_ids=accepted_folder_ids,
mentioned_folder_ids=mentioned_folder_ids,
request_id=request_id,
turn_id=stream_result.turn_id,
)
# --- Block 7: Stream loop ---
_t_stream_start = time.perf_counter()
runtime_rate_limit_recovered = False
def _on_first_event() -> None:
_perf_log.info(
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
"%.3fs (total since request start) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
async def _recover(exc: BaseException, first_event_seen: bool):
nonlocal llm_config_id, llm, agent_config, runtime_rate_limit_recovered
nonlocal title_task
if not can_recover_provider_rate_limit(
exc,
first_event_seen=first_event_seen,
runtime_rate_limit_recovered=runtime_rate_limit_recovered,
requested_llm_config_id=requested_llm_config_id,
current_llm_config_id=llm_config_id,
):
return None
runtime_rate_limit_recovered = True
previous_config_id = llm_config_id
llm_config_id = await reroute_to_next_auto_pin(
session,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
current_llm_config_id=llm_config_id,
requires_image_input=requires_image_input,
)
new_llm, new_agent_config, llm_load_err = await load_llm_bundle(
session, config_id=llm_config_id, search_space_id=search_space_id
)
if llm_load_err:
# Re-raise the original so the terminal-error path classifies
# it correctly (don't swallow as "config load error").
return None
llm = new_llm
agent_config = new_agent_config
# Title gen used the initial llm object. After a runtime repin we
# keep the stream focused on response recovery and skip title gen
# for this turn.
if title_task is not None and not title_task.done():
title_task.cancel()
title_task = None
_t_rebuild = time.perf_counter()
new_agent = await build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
)
_perf_log.info(
"[stream_new_chat] Runtime rate-limit recovery repinned "
"config_id=%s -> %s and rebuilt agent in %.3fs",
previous_config_id,
llm_config_id,
time.perf_counter() - _t_rebuild,
)
log_rate_limit_recovered(
flow=flow,
request_id=request_id,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
previous_config_id=previous_config_id,
new_config_id=llm_config_id,
)
return new_agent
async for sse in run_stream_loop(
agent=agent,
streaming_service=streaming_service,
config=config,
input_data=input_state,
stream_result=stream_result,
step_prefix="thinking",
initial_step_id=initial_step_id,
initial_step_title=initial_step_title,
initial_step_items=initial_step_items,
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode if filesystem_selection else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
content_builder=stream_result.content_builder,
recover=_recover,
on_first_event=_on_first_event,
):
yield sse
# Inject the title update mid-stream as soon as the background
# task finishes; gated so we emit at most once.
async for title_sse in maybe_emit_title_update(
title_task=title_task,
title_emitted=title_emitted,
chat_id=chat_id,
accumulator=accumulator,
streaming_service=streaming_service,
):
yield title_sse
title_emitted = True
# Account for the case where the task completed but produced no
# title — flip the flag anyway so we don't keep checking it.
if (
title_task is not None
and title_task.done()
and not title_emitted
):
title_emitted = True
_perf_log.info(
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,
chat_id,
)
log_system_snapshot("stream_new_chat_END")
# --- Block 8: Finalize ---
if stream_result.is_interrupted:
ot.add_event("chat.interrupted", {"chat.flow": flow})
if title_task is not None and not title_task.done():
title_task.cancel()
for sse in iter_token_usage_frame(
streaming_service,
accumulator=accumulator,
log_label="interrupted new_chat",
):
yield sse
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
return
async for title_sse in await_pending_title_update(
title_task=title_task,
title_emitted=title_emitted,
chat_id=chat_id,
accumulator=accumulator,
streaming_service=streaming_service,
):
yield title_sse
# Finalize premium credit debit with the actual provider cost reported
# by LiteLLM, summed across every call in the turn. Mirrors the
# pre-cost behaviour of "premium turn → all calls count" so free
# sub-agent calls during a premium turn still contribute to the bill
# (they're $0 in practice anyway).
if premium_reservation is not None and user_id:
await finalize_premium(
reservation=premium_reservation,
user_id=user_id,
accumulator=accumulator,
)
premium_reservation = None
for sse in iter_token_usage_frame(
streaming_service, accumulator=accumulator, log_label="normal new_chat"
):
yield sse
for sse in iter_final_frames(streaming_service):
yield sse
except Exception as exc:
frames, summary = handle_terminal_exception(
exc,
flow=flow,
flow_label="chat",
log_prefix="stream_new_chat",
streaming_service=streaming_service,
request_id=request_id,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
chat_span=chat_span,
)
if summary["busy_error_raised"]:
busy_error_raised = True
chat_outcome = summary["chat_outcome"]
chat_error_category = summary["chat_error_category"]
for sse in frames:
yield sse
finally:
# Shield the ENTIRE async cleanup from anyio cancel-scope cancellation.
# Starlette's BaseHTTPMiddleware uses anyio task groups; on client
# disconnect, it cancels the scope with level-triggered cancellation
# — every unshielded ``await`` would raise CancelledError immediately.
# Without this the very first ``await`` (session.rollback) would
# raise, ``except Exception`` wouldn't catch it (CancelledError is a
# BaseException), and the rest of cleanup — including session.close()
# — would never run.
with anyio.CancelScope(shield=True):
# Authoritative fallback cleanup for lock/cancel state. Middleware
# teardown can be skipped on some client-abort paths.
end_turn(str(chat_id))
if premium_reservation is not None and user_id:
await release_premium(
reservation=premium_reservation, user_id=user_id
)
await close_session_and_clear_ai_responding(session, chat_id)
await finalize_assistant_message(
stream_result=stream_result,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
accumulator=accumulator,
log_prefix="stream_new_chat",
)
# Persist any sandbox-produced files to local storage so they remain
# downloadable after the Daytona sandbox auto-deletes.
if stream_result and stream_result.sandbox_files:
with contextlib.suppress(Exception):
from app.agents.new_chat.sandbox import (
is_sandbox_enabled,
persist_and_delete_sandbox,
)
if is_sandbox_enabled():
with anyio.CancelScope(shield=True):
await persist_and_delete_sandbox(
chat_id, stream_result.sandbox_files
)
# ``aafter_agent`` doesn't fire on ``interrupt()`` or early bailout.
# Skip on ``BusyError`` (caller never acquired the lock).
if not busy_error_raised:
with contextlib.suppress(Exception):
end_turn(str(chat_id))
_perf_log.info(
"[stream_new_chat] end_turn cleanup (chat_id=%s)", chat_id
)
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
agent = llm = connector_service = None # noqa: F841
input_state = stream_result = None # noqa: F841
session = None # noqa: F841
run_gc_pass(log_prefix="stream_new_chat", chat_id=chat_id)
close_chat_request_span(
span_cm=chat_span_cm,
span=chat_span,
chat_outcome=chat_outcome,
chat_agent_mode=chat_agent_mode,
flow=flow,
chat_error_category=chat_error_category,
duration_seconds=time.perf_counter() - _t_total,
)

View file

@ -0,0 +1,129 @@
"""Concurrent persistence tasks spawned right after the initial validation gate.
These run *during* the rest of the pre-stream setup so we don't serialize
their latency against agent construction. Awaiting them at the SSE message-id
yield sites preserves the ghost-thread protection (the user-row INSERT must
succeed before any LLM streaming begins).
The ``set_ai_responding`` flag flip runs fully fire-and-forget on its own
shielded session failures only delay the "AI is responding…" UI flag, not
the response itself.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from uuid import UUID
from app.db import shielded_async_session
from app.services.chat_session_state_service import set_ai_responding
from app.tasks.chat.persistence import (
persist_assistant_shell,
persist_user_turn,
)
logger = logging.getLogger(__name__)
def spawn_set_ai_responding_bg(
*,
chat_id: int,
user_id: str | None,
background_tasks: set[asyncio.Task[Any]],
) -> None:
"""Fire-and-forget: flip the per-thread AI-responding flag on its own session.
Errors are swallowed and logged the worst case is a stale UI flag, which
is preferable to delaying the SSE stream behind a flag write.
"""
if not user_id:
return
async def _bg_set_ai_responding() -> None:
try:
async with shielded_async_session() as s:
await set_ai_responding(s, chat_id, UUID(user_id))
except Exception:
logger.warning(
"set_ai_responding failed (chat_id=%s)",
chat_id,
exc_info=True,
)
t = asyncio.create_task(_bg_set_ai_responding())
background_tasks.add(t)
t.add_done_callback(background_tasks.discard)
def spawn_persist_user_task(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
user_query: str,
user_image_data_urls: list[str] | None,
mentioned_documents: list[dict[str, Any]] | None,
background_tasks: set[asyncio.Task[Any]],
) -> asyncio.Task[int | None]:
"""Spawn the user-row INSERT; await at the user-message-id yield site."""
task = asyncio.create_task(
persist_user_turn(
chat_id=chat_id,
user_id=user_id,
turn_id=turn_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
mentioned_documents=mentioned_documents,
)
)
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
return task
def spawn_persist_assistant_shell_task(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
background_tasks: set[asyncio.Task[Any]],
) -> asyncio.Task[int | None]:
"""Spawn the assistant-shell INSERT; await at the assistant-message-id yield site."""
task = asyncio.create_task(
persist_assistant_shell(
chat_id=chat_id,
user_id=user_id,
turn_id=turn_id,
)
)
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
return task
async def await_persist_task(
task: asyncio.Task[int | None] | None,
*,
chat_id: int,
turn_id: str,
log_label: str,
) -> int | None:
"""Join a spawned persistence task with ``shield`` + uniform error handling.
``shield`` keeps the DB write alive if the SSE generator is cancelled by
client disconnect mid-await. Returns ``None`` on failure; the caller
abort-paths the turn with a friendly error SSE.
"""
if task is None:
return None
try:
return await asyncio.shield(task)
except asyncio.CancelledError:
raise
except Exception:
logger.exception(
"%s failed (chat_id=%s, turn_id=%s)", log_label, chat_id, turn_id
)
return None

View file

@ -0,0 +1,38 @@
"""Build the per-invocation ``SurfSenseContextSchema`` for a new-chat turn.
Carries the per-turn read inputs that middlewares read via
``runtime.context.*`` instead of from their ``__init__`` closures, so the same
compiled-agent instance can serve multiple turns with different
mention lists / request ids / turn ids without rebuilding the graph.
"""
from __future__ import annotations
from app.agents.new_chat.context import SurfSenseContextSchema
def build_new_chat_runtime_context(
*,
search_space_id: int,
mentioned_document_ids: list[int] | None,
accepted_folder_ids: list[int],
mentioned_folder_ids: list[int] | None,
request_id: str | None,
turn_id: str,
) -> SurfSenseContextSchema:
"""``mentioned_document_ids`` is consumed by ``KnowledgePriorityMiddleware``.
``accepted_folder_ids`` (post-resolve) wins over the raw
``mentioned_folder_ids`` from the request: the resolver drops chips that
pointed at deleted folders or folders the caller can't see, so middlewares
only get authorized ids.
"""
return SurfSenseContextSchema(
search_space_id=search_space_id,
mentioned_document_ids=list(mentioned_document_ids or []),
mentioned_folder_ids=list(
accepted_folder_ids or mentioned_folder_ids or []
),
request_id=request_id,
turn_id=turn_id,
)

View file

@ -0,0 +1,237 @@
"""Background thread-title generation (first-response only).
The first assistant response in a thread gets a short auto-generated title
inserted into ``new_chat_threads.title``. We:
1. Spawn the generation as an ``asyncio.Task`` so it runs in parallel with
the agent stream (no extra TTFT).
2. Probe inside the task (on its own shielded session) whether this is
actually the first response newer turns short-circuit to ``None``.
3. Inject the resulting ``thread-title-update`` SSE frame on the first agent
event after the task completes (mid-stream interlock), or right before
the finish frames (post-stream join) if the task hadn't finished yet.
Usage tokens come directly off the response (LiteLLM's async callback fires
via fire-and-forget ``create_task``, so the ``TokenTrackingCallback`` would
run too late). We also blank the per-task accumulator so the late callback
doesn't double-count.
"""
from __future__ import annotations
import asyncio
import logging
from typing import TYPE_CHECKING, Any
from sqlalchemy.future import select
from app.db import NewChatMessage, NewChatThread, shielded_async_session
from app.prompts import TITLE_GENERATION_PROMPT
from app.services.new_streaming_service import VercelStreamingService
if TYPE_CHECKING:
from app.agents.new_chat.llm_config import AgentConfig
from app.services.token_tracking_service import TokenAccumulator
logger = logging.getLogger(__name__)
def spawn_title_task(
*,
chat_id: int,
user_query: str,
user_image_data_urls: list[str] | None,
assistant_message_id: int | None,
llm: Any,
agent_config: AgentConfig | None,
) -> asyncio.Task[tuple[str | None, dict | None]] | None:
"""Spawn ``_generate_title``; returns ``None`` when prerequisites aren't met.
Title gen is gated on a real ``assistant_message_id`` so a stream that
aborts before persistence can never leave a thread with a title and no
anchoring rows.
"""
if assistant_message_id is None:
return None
return asyncio.create_task(
_generate_title(
chat_id=chat_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
assistant_message_id=assistant_message_id,
llm=llm,
agent_config=agent_config,
)
)
async def _generate_title(
*,
chat_id: int,
user_query: str,
user_image_data_urls: list[str] | None,
assistant_message_id: int,
llm: Any,
agent_config: AgentConfig | None,
) -> tuple[str | None, dict | None]:
"""Probe is-first-response, then call ``acompletion``. Returns ``(title, usage)``."""
try:
from litellm import acompletion
from app.services.llm_router_service import LLMRouterService
from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import _turn_accumulator
# Excludes this turn's own assistant row (pre-written by
# ``persist_assistant_shell``) — without the ``!=`` filter the gate
# would false-negative on every turn after the first.
try:
async with shielded_async_session() as probe_session:
probe_result = await probe_session.execute(
select(NewChatMessage.id)
.filter(
NewChatMessage.thread_id == chat_id,
NewChatMessage.role == "assistant",
NewChatMessage.id != assistant_message_id,
)
.limit(1)
)
is_first_response = probe_result.scalars().first() is None
except Exception:
logger.warning(
"[TitleGen] first-response probe failed (chat_id=%s)",
chat_id,
exc_info=True,
)
return None, None
if not is_first_response:
return None, None
_turn_accumulator.set(None)
title_seed = user_query.strip() or (
f"[{len(user_image_data_urls or [])} image(s)]"
if user_image_data_urls
else ""
)
prompt = TITLE_GENERATION_PROMPT.replace(
"{user_query}", title_seed[:500] or "(message)"
)
messages = [{"role": "user", "content": prompt}]
if getattr(llm, "model", None) == "auto":
router = LLMRouterService.get_router()
response = await router.acompletion(model="auto", messages=messages)
else:
# Apply the same ``api_base`` cascade chat / vision / image-gen
# call sites use so we never inherit ``litellm.api_base``
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
# config itself ships an empty ``api_base``. Without this the
# title-gen on an OpenRouter chat config would 404 against the
# inherited Azure endpoint — see ``provider_api_base`` for the
# same bug repro on the image-gen / vision paths.
raw_model = getattr(llm, "model", "") or ""
provider_prefix = (
raw_model.split("/", 1)[0] if "/" in raw_model else None
)
provider_value = (
agent_config.provider if agent_config is not None else None
)
title_api_base = resolve_api_base(
provider=provider_value,
provider_prefix=provider_prefix,
config_api_base=getattr(llm, "api_base", None),
)
response = await acompletion(
model=raw_model,
messages=messages,
api_key=getattr(llm, "api_key", None),
api_base=title_api_base,
)
usage_info = None
usage = getattr(response, "usage", None)
if usage:
raw_model = getattr(llm, "model", "") or ""
model_name = (
raw_model.split("/", 1)[-1]
if "/" in raw_model
else (raw_model or response.model or "unknown")
)
usage_info = {
"model": model_name,
"prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
"completion_tokens": getattr(usage, "completion_tokens", 0) or 0,
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
}
raw_title = response.choices[0].message.content.strip()
if raw_title and len(raw_title) <= 100:
return raw_title.strip("\"'"), usage_info
return None, usage_info
except Exception:
logger.exception("[TitleGen] _generate_title failed")
return None, None
async def maybe_emit_title_update(
*,
title_task: asyncio.Task[tuple[str | None, dict | None]] | None,
title_emitted: bool,
chat_id: int,
accumulator: TokenAccumulator,
streaming_service: VercelStreamingService,
):
"""Inject one ``thread-title-update`` SSE if the task completed.
Yields the SSE frame (when applicable). Returns nothing; the orchestrator
flips ``title_emitted`` itself after iterating so we don't fight Python's
nonlocal-in-generator semantics.
"""
if title_task is None or title_emitted or not title_task.done():
return
generated_title, title_usage = title_task.result()
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
title_thread = title_thread_result.scalars().first()
if title_thread:
title_thread.title = generated_title
await title_session.commit()
yield streaming_service.format_thread_title_update(chat_id, generated_title)
async def await_pending_title_update(
*,
title_task: asyncio.Task[tuple[str | None, dict | None]] | None,
title_emitted: bool,
chat_id: int,
accumulator: TokenAccumulator,
streaming_service: VercelStreamingService,
):
"""If the task hadn't completed during the stream, await it now and emit.
Used right before the finish frames in the success path. Mirror of
``maybe_emit_title_update`` but unconditionally awaits.
"""
if title_task is None or title_emitted:
return
generated_title, title_usage = await title_task
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
title_thread = title_thread_result.scalars().first()
if title_thread:
title_thread.title = generated_title
await title_session.commit()
yield streaming_service.format_thread_title_update(chat_id, generated_title)

View file

@ -0,0 +1,12 @@
"""Resume-chat streaming flow.
Public entry point ``stream_resume_chat`` is the slim coroutine in
``orchestrator.py`` that composes the per-concern modules in this folder and
the building blocks under ``flows/shared/``.
"""
from __future__ import annotations
from app.tasks.chat.streaming.flows.resume_chat.orchestrator import stream_resume_chat
__all__ = ["stream_resume_chat"]

Some files were not shown because too many files have changed in this diff Show more