mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
Merge pull request #1443 from CREDO23/feature-automations
[Feat] Automation V1 — Scheduled Agent Tasks, Created via Chat (HITL) or JSON
This commit is contained in:
commit
4dda02c06c
219 changed files with 13821 additions and 55 deletions
179
surfsense_backend/alembic/versions/144_add_automation_tables.py
Normal file
179
surfsense_backend/alembic/versions/144_add_automation_tables.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
"""Add automation tables (automations, automation_triggers, automation_runs)
|
||||
|
||||
Revision ID: 144
|
||||
Revises: 143
|
||||
Create Date: 2026-05-26
|
||||
|
||||
Adds the three tables that back the v1 automation engine, plus the
|
||||
three PostgreSQL ENUM types they reference. Matches the SQLAlchemy
|
||||
models under ``app.automations.persistence.models`` and the v1 data
|
||||
model in ``automation-design-plan.md`` §9.
|
||||
|
||||
v1 ships these three tables only. ``domain_events`` is deferred to
|
||||
Phase 3 with the event trigger; ``mcp_connections`` / ``mcp_tools``
|
||||
are deferred to Phase 4 with the MCP integration.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "144"
|
||||
down_revision: str | None = "143"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ENUM types (PostgreSQL requires types created before tables that use them)
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TYPE automation_status AS ENUM (
|
||||
'active', 'paused', 'archived'
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TYPE automation_trigger_type AS ENUM (
|
||||
'schedule', 'manual'
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TYPE automation_run_status AS ENUM (
|
||||
'pending', 'running', 'succeeded', 'failed',
|
||||
'cancelled', 'timed_out'
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# automations — the editable, versioned automation definition
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE automations (
|
||||
id SERIAL PRIMARY KEY,
|
||||
search_space_id INTEGER NOT NULL
|
||||
REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||
created_by_user_id UUID
|
||||
REFERENCES "user"(id) ON DELETE SET NULL,
|
||||
name VARCHAR(200) NOT NULL,
|
||||
description TEXT,
|
||||
status automation_status NOT NULL DEFAULT 'active',
|
||||
definition JSONB NOT NULL,
|
||||
version INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_automations_search_space_id ON automations(search_space_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_automations_created_by_user_id ON automations(created_by_user_id);"
|
||||
)
|
||||
op.execute("CREATE INDEX ix_automations_status ON automations(status);")
|
||||
op.execute("CREATE INDEX ix_automations_created_at ON automations(created_at);")
|
||||
op.execute("CREATE INDEX ix_automations_updated_at ON automations(updated_at);")
|
||||
|
||||
# automation_triggers — one row per (automation, trigger-instance) pair
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE automation_triggers (
|
||||
id SERIAL PRIMARY KEY,
|
||||
automation_id INTEGER NOT NULL
|
||||
REFERENCES automations(id) ON DELETE CASCADE,
|
||||
type automation_trigger_type NOT NULL,
|
||||
params JSONB NOT NULL,
|
||||
static_inputs JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
last_fired_at TIMESTAMP WITH TIME ZONE,
|
||||
next_fire_at TIMESTAMP WITH TIME ZONE,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_automation_triggers_automation_id ON automation_triggers(automation_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_automation_triggers_type ON automation_triggers(type);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_automation_triggers_enabled ON automation_triggers(enabled);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_automation_triggers_created_at ON automation_triggers(created_at);"
|
||||
)
|
||||
# Partial index for the schedule tick: only enabled schedule triggers
|
||||
# with a scheduled next fire are ever scanned for due rows.
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX ix_automation_triggers_due
|
||||
ON automation_triggers (next_fire_at)
|
||||
WHERE enabled = true
|
||||
AND type = 'schedule'
|
||||
AND next_fire_at IS NOT NULL;
|
||||
"""
|
||||
)
|
||||
|
||||
# automation_runs — the immutable per-fire execution record
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE automation_runs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
automation_id INTEGER NOT NULL
|
||||
REFERENCES automations(id) ON DELETE CASCADE,
|
||||
trigger_id INTEGER
|
||||
REFERENCES automation_triggers(id) ON DELETE SET NULL,
|
||||
status automation_run_status NOT NULL DEFAULT 'pending',
|
||||
definition_snapshot JSONB NOT NULL,
|
||||
inputs JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
step_results JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
output JSONB,
|
||||
artifacts JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
error JSONB,
|
||||
started_at TIMESTAMP WITH TIME ZONE,
|
||||
finished_at TIMESTAMP WITH TIME ZONE,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_automation_runs_automation_id ON automation_runs(automation_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_automation_runs_trigger_id ON automation_runs(trigger_id);"
|
||||
)
|
||||
op.execute("CREATE INDEX ix_automation_runs_status ON automation_runs(status);")
|
||||
op.execute(
|
||||
"CREATE INDEX ix_automation_runs_created_at ON automation_runs(created_at);"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ix_automation_runs_created_at;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_automation_runs_status;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_automation_runs_trigger_id;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_automation_runs_automation_id;")
|
||||
op.execute("DROP TABLE IF EXISTS automation_runs;")
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS ix_automation_triggers_due;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_automation_triggers_created_at;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_automation_triggers_enabled;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_automation_triggers_type;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_automation_triggers_automation_id;")
|
||||
op.execute("DROP TABLE IF EXISTS automation_triggers;")
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS ix_automations_updated_at;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_automations_created_at;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_automations_status;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_automations_created_by_user_id;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_automations_search_space_id;")
|
||||
op.execute("DROP TABLE IF EXISTS automations;")
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS automation_run_status;")
|
||||
op.execute("DROP TYPE IF EXISTS automation_trigger_type;")
|
||||
op.execute("DROP TYPE IF EXISTS automation_status;")
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
"""Add automations permissions to existing Editor/Viewer roles
|
||||
|
||||
Revision ID: 145
|
||||
Revises: 144
|
||||
Create Date: 2026-05-27
|
||||
|
||||
Owners already have ``*`` and need no backfill. Custom (non-system) roles
|
||||
are left untouched on purpose: workspace admins manage those explicitly.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "145"
|
||||
down_revision: str | None = "144"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
_EDITOR_PERMISSIONS = (
|
||||
"automations:create",
|
||||
"automations:read",
|
||||
"automations:update",
|
||||
"automations:execute",
|
||||
)
|
||||
_VIEWER_PERMISSIONS = ("automations:read",)
|
||||
|
||||
|
||||
def upgrade():
|
||||
connection = op.get_bind()
|
||||
|
||||
for permission in _EDITOR_PERMISSIONS:
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_append(permissions, :permission)
|
||||
WHERE name = 'Editor'
|
||||
AND NOT (:permission = ANY(permissions))
|
||||
"""
|
||||
),
|
||||
{"permission": permission},
|
||||
)
|
||||
|
||||
for permission in _VIEWER_PERMISSIONS:
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_append(permissions, :permission)
|
||||
WHERE name = 'Viewer'
|
||||
AND NOT (:permission = ANY(permissions))
|
||||
"""
|
||||
),
|
||||
{"permission": permission},
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
connection = op.get_bind()
|
||||
|
||||
for permission in _EDITOR_PERMISSIONS:
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_remove(permissions, :permission)
|
||||
WHERE name = 'Editor'
|
||||
"""
|
||||
),
|
||||
{"permission": permission},
|
||||
)
|
||||
|
||||
for permission in _VIEWER_PERMISSIONS:
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_remove(permissions, :permission)
|
||||
WHERE name = 'Viewer'
|
||||
"""
|
||||
),
|
||||
{"permission": permission},
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""``create_automation`` — description + few-shot examples."""
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
- `create_automation` — Draft and author a new automation. You describe the
|
||||
user's intent; a focused drafter inside the tool turns it into the full
|
||||
automation JSON; the user sees a preview on an approval card and chooses
|
||||
approve or reject. All three phases happen in a single tool call.
|
||||
- Call when the user wants SurfSense to do something on its own: anything
|
||||
recurring or scheduled ("every morning…", "each Monday…", "weekly
|
||||
recap…").
|
||||
- Args:
|
||||
- `intent` (string): restate the user's request **concretely**, in one
|
||||
paragraph. Cover three things:
|
||||
- **What** should run (the action: summarize, recap, post, draft, …).
|
||||
- **When** it should run (schedule + timezone if the user mentioned one;
|
||||
otherwise leave the timezone for the drafter to default to UTC).
|
||||
- **Static values** the automation needs (folder ids, channel names,
|
||||
project keys, parent page ids, …) — list them with their values.
|
||||
If the user did NOT supply one the automation needs, say so
|
||||
explicitly ("the Notion parent page id was not specified") so the
|
||||
drafter leaves a placeholder.
|
||||
- Do NOT prompt the user to confirm before calling — the approval card
|
||||
IS the confirmation. The card shows a structured preview plus the raw
|
||||
JSON; it offers approve/reject only. If the user wants changes after
|
||||
seeing the draft, they reply in chat and you call this tool again with
|
||||
a refined `intent` — that's the edit path.
|
||||
- Returns:
|
||||
- `{status: "saved", automation_id, name}` — confirm briefly to the
|
||||
user ("Saved as automation #N — runs <when>."). Don't dump JSON back.
|
||||
- `{status: "rejected", message}` — the user declined on the card.
|
||||
Acknowledge once ("Understood, I didn't create it.") and stop. Do
|
||||
NOT retry or pitch variants without a fresh user request.
|
||||
- `{status: "invalid", issues, raw?}` — drafting/validation failed
|
||||
before the card was shown. Read the issues, refine your `intent`
|
||||
with the missing details, call again.
|
||||
- `{status: "error", message}` — surface the message verbatim and
|
||||
offer to retry.
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
<example>
|
||||
user: "Every weekday at 9am, summarize new documents in folder 12 and post the summary to Slack channel #daily-digest."
|
||||
→ create_automation(intent="Every weekday at 09:00 UTC, summarize documents added to folder_id=12 since the last run, then post the summary to Slack channel '#daily-digest'. Static inputs: folder_id=12, slack_channel='#daily-digest'.")
|
||||
tool returns: {"status": "saved", "automation_id": 42, "name": "Daily folder 12 digest"}
|
||||
(Reply briefly: "Saved as automation #42 — runs weekdays at 9am UTC.")
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: "Once a week on Mondays at 7am Paris time, draft a Notion page recapping last week's Jira tickets in project CORE."
|
||||
→ create_automation(intent="Every Monday at 07:00 Europe/Paris, read last week's Jira issues in project CORE, then draft a Notion page recapping them. Static inputs: jira_project_key='CORE'. The user did NOT specify which Notion page the recap should sit under — leave notion_parent_page_id as a placeholder.")
|
||||
tool returns: {"status": "saved", "automation_id": 51, "name": "Weekly CORE Jira recap"}
|
||||
(Reply: "Saved as automation #51. I left the Notion parent page id as a placeholder — set it on the automation before next Monday.")
|
||||
</example>
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""``create_automation`` — author + persist an automation via a HITL card."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .create import create_create_automation_tool
|
||||
|
||||
__all__ = ["create_create_automation_tool"]
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
"""``create_automation`` — NL intent → drafted JSON → HITL approval card → persisted.
|
||||
|
||||
Single tool that:
|
||||
|
||||
1. Drafts a structured automation from the user's intent via a focused sub-LLM
|
||||
(system prompt in :mod:`.prompt`).
|
||||
2. Surfaces the validated draft in a HITL approval card
|
||||
(``action_type="automation_create"``).
|
||||
3. On approval, validates the (possibly edited) payload again and persists
|
||||
it via :class:`AutomationService`.
|
||||
|
||||
The main agent only restates the user's request as a single ``intent`` string.
|
||||
The drafting sub-LLM owns the JSON shape; the HITL card is the user's review.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
|
||||
request_approval,
|
||||
)
|
||||
from app.automations.schemas.api import AutomationCreate
|
||||
from app.automations.services.automation import AutomationService
|
||||
from app.db import User, async_session_maker
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
from .prompt import build_draft_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_JSON_FENCE = re.compile(r"```(?:json)?\s*(.*?)\s*```", re.DOTALL)
|
||||
|
||||
|
||||
def create_create_automation_tool(
|
||||
*,
|
||||
search_space_id: int,
|
||||
user_id: str | UUID,
|
||||
llm: Any,
|
||||
):
|
||||
"""Factory for the ``create_automation`` tool.
|
||||
|
||||
``search_space_id`` is injected from the chat session (the model never
|
||||
has to guess it). ``llm`` is the drafting sub-model — we reuse the main
|
||||
agent's LLM and tag the call so it's identifiable in traces. A fresh
|
||||
``AsyncSession`` is opened per call to avoid stale sessions on
|
||||
compiled-agent cache hits (same pattern as the Notion / memory tools).
|
||||
"""
|
||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
@tool
|
||||
async def create_automation(intent: str, runtime: ToolRuntime) -> dict[str, Any]:
|
||||
"""Draft + save an automation from a natural-language intent.
|
||||
|
||||
Use this when the user wants SurfSense to do something on its own
|
||||
on a schedule (e.g. "every morning summarize folder 12 to Slack").
|
||||
Restate the user's request as ONE concrete ``intent`` string: what
|
||||
should run, when, and which static values (folder ids, channel
|
||||
names, …) it needs.
|
||||
|
||||
The tool drafts the full automation JSON internally, shows the user
|
||||
a structured preview on an approval card, and persists on approval.
|
||||
The card supports approve/reject only — if the user wants edits
|
||||
after seeing the draft, they say so in chat and you call this tool
|
||||
again with a refined intent. Do NOT prompt the user to confirm
|
||||
before calling — the card IS the confirmation.
|
||||
|
||||
Args:
|
||||
intent: Concrete restatement of the user's request. Include
|
||||
the schedule (with timezone if mentioned), the action to
|
||||
take, and any static values. Example: "Every weekday at
|
||||
09:00 UTC, summarize new docs added to folder_id=12 since
|
||||
the last run, then post the summary to Slack channel
|
||||
'#daily-digest'."
|
||||
|
||||
Returns:
|
||||
``{"status": "saved", "automation_id": int, "name": str}`` on
|
||||
approval + save.
|
||||
``{"status": "rejected", "message": "..."}`` when the user
|
||||
declines on the card.
|
||||
``{"status": "invalid", "issues": [...], "raw": ...}`` when
|
||||
the drafter produced output that did not validate (call again
|
||||
with a more precise intent).
|
||||
``{"status": "error", "message": "..."}`` on drafter or
|
||||
persistence failure.
|
||||
|
||||
IMPORTANT: when status is ``"rejected"`` the user explicitly
|
||||
declined. Acknowledge once and stop — do NOT retry or pitch
|
||||
variants without a fresh user request.
|
||||
"""
|
||||
# --- 1. Draft via sub-LLM ---
|
||||
prompt = build_draft_prompt(search_space_id=search_space_id, intent=intent)
|
||||
try:
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "automation-draft"]},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("create_automation drafting LLM call failed")
|
||||
return {"status": "error", "message": f"drafting failed: {exc}"}
|
||||
|
||||
raw_text = extract_text_content(response.content).strip()
|
||||
draft = _extract_json(raw_text)
|
||||
if draft is None:
|
||||
return {
|
||||
"status": "invalid",
|
||||
"issues": ["model output was not parseable JSON"],
|
||||
"raw": raw_text,
|
||||
}
|
||||
|
||||
# search_space_id is injected here so the sub-LLM never has to guess.
|
||||
draft["search_space_id"] = search_space_id
|
||||
try:
|
||||
validated_draft = AutomationCreate.model_validate(draft)
|
||||
except ValidationError as exc:
|
||||
return {
|
||||
"status": "invalid",
|
||||
"issues": _format_validation_issues(exc),
|
||||
"raw": draft,
|
||||
}
|
||||
|
||||
# --- 2. HITL approval card ---
|
||||
try:
|
||||
card_params = validated_draft.model_dump(mode="json", by_alias=True)
|
||||
# search_space_id is session-scoped, not user-editable.
|
||||
card_params.pop("search_space_id", None)
|
||||
|
||||
result = request_approval(
|
||||
action_type="automation_create",
|
||||
tool_name="create_automation",
|
||||
params=card_params,
|
||||
context={"search_space_id": search_space_id},
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
# --- 3. Persist (re-validate in case the user edited) ---
|
||||
final_payload = {**result.params, "search_space_id": search_space_id}
|
||||
try:
|
||||
final_validated = AutomationCreate.model_validate(final_payload)
|
||||
except ValidationError as exc:
|
||||
return {
|
||||
"status": "invalid",
|
||||
"issues": _format_validation_issues(exc),
|
||||
}
|
||||
|
||||
async with async_session_maker() as session:
|
||||
user = await session.get(User, uid)
|
||||
if user is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "user not found in this session",
|
||||
}
|
||||
service = AutomationService(session=session, user=user)
|
||||
created = await service.create(final_validated)
|
||||
return {
|
||||
"status": "saved",
|
||||
"automation_id": created.id,
|
||||
"name": created.name,
|
||||
}
|
||||
|
||||
except HTTPException as exc:
|
||||
return {"status": "error", "message": exc.detail}
|
||||
except Exception as exc:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(exc, GraphInterrupt):
|
||||
raise
|
||||
logger.exception("create_automation failed")
|
||||
return {"status": "error", "message": f"persistence failed: {exc}"}
|
||||
|
||||
return create_automation
|
||||
|
||||
|
||||
def _extract_json(text: str) -> dict[str, Any] | None:
|
||||
"""Pull a JSON object out of the model response, tolerating ``` fences."""
|
||||
if not text:
|
||||
return None
|
||||
candidate = text
|
||||
fence_match = _JSON_FENCE.search(text)
|
||||
if fence_match:
|
||||
candidate = fence_match.group(1)
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
|
||||
|
||||
def _format_validation_issues(exc: ValidationError) -> list[str]:
|
||||
return [
|
||||
f"{'.'.join(str(p) for p in err['loc'])}: {err['msg']}" for err in exc.errors()
|
||||
]
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
"""System prompt for the drafting sub-LLM inside ``create_automation``.
|
||||
|
||||
Converts a natural-language ``intent`` into a structured ``AutomationCreate``
|
||||
JSON object. That object becomes the payload the HITL approval card surfaces.
|
||||
|
||||
Scope split:
|
||||
Real automation JSONs live here — this is the graph that *generates*
|
||||
the JSON. The main agent's prompt fragments (``description.md`` /
|
||||
``example.md``) only carry intent-string examples; the main agent
|
||||
never sees the schema.
|
||||
|
||||
Layout:
|
||||
The prompt is concatenated from four format-safe pieces. ``_HEADER`` /
|
||||
``_FOOTER`` carry the only ``str.format`` placeholders; ``_SCHEMA`` and
|
||||
``_FEW_SHOTS`` are plain strings so their JSON literals (and the
|
||||
``{{ inputs.X }}`` Jinja references in queries) can stay readable
|
||||
without doubled-brace escaping.
|
||||
|
||||
Catalog handling:
|
||||
v1 hard-codes the action/trigger catalog (one action, one trigger).
|
||||
When new types ship, swap the inline lines for a render-time pull
|
||||
from ``app.automations.actions`` / ``app.automations.triggers`` via
|
||||
lazy imports inside :func:`build_draft_prompt` so this module never
|
||||
participates in the ``multi_agent_chat`` import cycle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
_HEADER = """\
|
||||
You are the SurfSense automation drafter. Convert the user intent below
|
||||
into a SINGLE JSON object matching the AutomationCreate schema. Output
|
||||
ONLY that JSON object — no prose, no markdown fence, no commentary.
|
||||
|
||||
Current UTC time (for cron context): {now}
|
||||
Target search_space_id: {search_space_id}
|
||||
"""
|
||||
|
||||
|
||||
_SCHEMA = """
|
||||
Required JSON shape:
|
||||
{
|
||||
"name": "<1-200 char identifier>",
|
||||
"description": "<one-liner or null>",
|
||||
"definition": {
|
||||
"schema_version": "1.0",
|
||||
"name": "<same as outer name>",
|
||||
"goal": "<one sentence>",
|
||||
"plan": [
|
||||
{
|
||||
"step_id": "<slug>",
|
||||
"action": "agent_task",
|
||||
"params": {
|
||||
"query": "<Jinja string referencing {{ inputs.X }}>",
|
||||
"auto_approve_all": true
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {"tags": ["..."]}
|
||||
},
|
||||
"triggers": [
|
||||
{
|
||||
"type": "schedule",
|
||||
"params": {"cron": "<5-field cron>", "timezone": "<IANA tz, default UTC>"},
|
||||
"static_inputs": {"<key>": <value>, ...},
|
||||
"enabled": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
v1 catalog (only these are valid):
|
||||
- Actions: agent_task — params: query (string, Jinja), auto_approve_all (bool).
|
||||
- Triggers: schedule — params: cron (5-field), timezone (IANA, e.g. "UTC",
|
||||
"Europe/Paris"). Has static_inputs (object).
|
||||
|
||||
Conventions:
|
||||
- Whatever the plan references via {{ inputs.X }} MUST appear either in a
|
||||
trigger's static_inputs OR in definition.inputs.schema_.properties so the
|
||||
executor can resolve it at fire time.
|
||||
- static_inputs carries values that stay the same across every fire
|
||||
(folder ids, channel names, project keys, parent page ids). Put them on
|
||||
the trigger that supplies them, not in the plan.
|
||||
- If the user did NOT supply a value the plan needs, put "REPLACE_ME" in
|
||||
static_inputs. Do NOT invent ids, channels, or paths.
|
||||
- Cron is 5-field (minute hour day-of-month month day-of-week). Use the
|
||||
timezone the user mentioned; default "UTC" when unspecified.
|
||||
- Templating variables available at fire time: inputs.* (merged
|
||||
static_inputs + runtime), inputs.fired_at, inputs.last_fired_at.
|
||||
"""
|
||||
|
||||
|
||||
_FEW_SHOTS = """
|
||||
Few-shot examples (intent → JSON output):
|
||||
|
||||
### Example 1 — schedule with all static values supplied
|
||||
intent: "Every weekday at 09:00 UTC, summarize documents added to folder_id=12 since the last run, then post the summary to Slack channel '#daily-digest'. Static inputs: folder_id=12, slack_channel='#daily-digest'."
|
||||
output:
|
||||
{
|
||||
"name": "Daily folder 12 digest",
|
||||
"description": "Weekday 09:00 UTC summary of folder 12 documents posted to #daily-digest",
|
||||
"definition": {
|
||||
"schema_version": "1.0",
|
||||
"name": "Daily folder 12 digest",
|
||||
"goal": "Summarize new docs in folder 12 since the last run and post to #daily-digest",
|
||||
"plan": [
|
||||
{
|
||||
"step_id": "summarize_and_post",
|
||||
"action": "agent_task",
|
||||
"params": {
|
||||
"query": "Summarize documents added to folder {{ inputs.folder_id }} since {{ inputs.last_fired_at or 'yesterday' }}, then send the summary to Slack channel {{ inputs.slack_channel }}.",
|
||||
"auto_approve_all": true
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {"tags": ["daily", "digest", "slack"]}
|
||||
},
|
||||
"triggers": [
|
||||
{
|
||||
"type": "schedule",
|
||||
"params": {"cron": "0 9 * * 1-5", "timezone": "UTC"},
|
||||
"static_inputs": {"folder_id": 12, "slack_channel": "#daily-digest"},
|
||||
"enabled": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
### Example 2 — schedule with a missing value (REPLACE_ME placeholder)
|
||||
intent: "Every Monday at 07:00 Europe/Paris, read last week's Jira issues in project CORE, then draft a Notion page recapping them. Static inputs: jira_project_key='CORE'. The user did NOT specify the Notion parent page id — leave it as a placeholder."
|
||||
output:
|
||||
{
|
||||
"name": "Weekly CORE Jira recap",
|
||||
"description": "Monday 07:00 Europe/Paris recap of last week's CORE Jira issues, drafted to Notion",
|
||||
"definition": {
|
||||
"schema_version": "1.0",
|
||||
"name": "Weekly CORE Jira recap",
|
||||
"goal": "Recap last week's CORE Jira issues into a Notion page",
|
||||
"plan": [
|
||||
{
|
||||
"step_id": "recap",
|
||||
"action": "agent_task",
|
||||
"params": {
|
||||
"query": "List Jira issues in project {{ inputs.jira_project_key }} updated in the 7 days before {{ inputs.fired_at }}. Draft a Notion page under parent id {{ inputs.notion_parent_page_id }} titled 'CORE recap — week of {{ inputs.fired_at }}'.",
|
||||
"auto_approve_all": true
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {"tags": ["weekly", "recap", "jira", "notion"]}
|
||||
},
|
||||
"triggers": [
|
||||
{
|
||||
"type": "schedule",
|
||||
"params": {"cron": "0 7 * * 1", "timezone": "Europe/Paris"},
|
||||
"static_inputs": {"jira_project_key": "CORE", "notion_parent_page_id": "REPLACE_ME"},
|
||||
"enabled": true
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
_FOOTER = """
|
||||
User intent:
|
||||
{intent}
|
||||
"""
|
||||
|
||||
|
||||
def build_draft_prompt(*, search_space_id: int, intent: str) -> str:
|
||||
"""Render the drafting sub-LLM system prompt for the given intent."""
|
||||
return (
|
||||
_HEADER.format(
|
||||
now=datetime.now(UTC).isoformat(timespec="seconds"),
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
+ _SCHEMA
|
||||
+ _FEW_SHOTS
|
||||
+ _FOOTER.format(intent=intent.strip())
|
||||
)
|
||||
|
|
@ -10,6 +10,7 @@ MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED: tuple[str, ...] = (
|
|||
"web_search",
|
||||
"scrape_webpage",
|
||||
"update_memory",
|
||||
"create_automation",
|
||||
)
|
||||
|
||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ def request_approval(
|
|||
params: dict[str, Any],
|
||||
context: dict[str, Any] | None = None,
|
||||
trusted_tools: list[str] | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
) -> HITLResult:
|
||||
"""Pause the graph for user approval and return the user's decision.
|
||||
|
||||
|
|
@ -64,6 +65,10 @@ def request_approval(
|
|||
forwarded verbatim to the FE for richer card chrome.
|
||||
trusted_tools: Per-session allowlist; when ``tool_name`` is in it the
|
||||
interrupt is skipped and the tool runs immediately.
|
||||
tool_call_id: Caller's LangChain tool-call id. Required for tools
|
||||
running directly on the main agent; subagent-mounted tools omit
|
||||
it (the ``task`` chokepoint stamps it on re-raise — see
|
||||
:mod:`...checkpointed_subagent_middleware.propagation`).
|
||||
|
||||
Returns:
|
||||
:class:`HITLResult` with ``rejected=True`` if the user declined or
|
||||
|
|
@ -90,6 +95,8 @@ def request_approval(
|
|||
interrupt_type=action_type,
|
||||
context=context,
|
||||
)
|
||||
if tool_call_id:
|
||||
payload["tool_call_id"] = tool_call_id
|
||||
approval = interrupt(payload)
|
||||
|
||||
parsed = parse_lc_envelope(approval)
|
||||
|
|
|
|||
|
|
@ -150,6 +150,28 @@ class ToolDefinition:
|
|||
reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Deferred-import factories
|
||||
# =============================================================================
|
||||
# Used for tools whose impls live under ``multi_agent_chat``. Importing those
|
||||
# at module-load time would cycle (``multi_agent_chat`` middleware imports
|
||||
# this registry). The import inside the factory runs only when
|
||||
# ``build_tools`` is called, by which point ``multi_agent_chat`` is fully
|
||||
# initialised.
|
||||
|
||||
|
||||
def _build_create_automation_tool(deps: dict[str, Any]) -> BaseTool:
|
||||
from app.agents.multi_agent_chat.main_agent.tools.automation import (
|
||||
create_create_automation_tool,
|
||||
)
|
||||
|
||||
return create_create_automation_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
llm=deps["llm"],
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Built-in Tools Registry
|
||||
# =============================================================================
|
||||
|
|
@ -261,6 +283,21 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# AUTOMATION AUTHORING - single HITL tool. The tool takes an NL ``intent``
|
||||
# from the main agent, drafts the full AutomationCreate JSON via a focused
|
||||
# sub-LLM, surfaces it on an approval card, and persists on approval. The
|
||||
# factory defers its import because the impl lives under ``multi_agent_chat``
|
||||
# and that package transitively pulls this registry via middleware;
|
||||
# deferring to ``build_tools`` call-time breaks the cycle without a
|
||||
# parallel registry.
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_automation",
|
||||
description="Draft an automation from an NL intent; user approves the card; tool saves",
|
||||
factory=_build_create_automation_tool,
|
||||
requires=["search_space_id", "user_id", "llm"],
|
||||
),
|
||||
# =========================================================================
|
||||
# MEMORY TOOL - single update_memory, private or team by thread_visibility
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
|
|
|
|||
5
surfsense_backend/app/automations/__init__.py
Normal file
5
surfsense_backend/app/automations/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""Automations engine — see automation-design-plan.md."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__: list[str] = []
|
||||
24
surfsense_backend/app/automations/actions/__init__.py
Normal file
24
surfsense_backend/app/automations/actions/__init__.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
"""Actions domain: registry surface + built-in action packages.
|
||||
|
||||
Each action lives in its own subpackage (``agent_task/``, ...) and self-registers
|
||||
at import time via its ``definition`` module. Side-effect imports below ensure
|
||||
the registry is populated whenever anyone touches the actions package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .store import all_actions, get_action, register_action
|
||||
from .types import ActionContext, ActionDefinition, ActionHandler, ActionHandlerFactory
|
||||
|
||||
__all__ = [
|
||||
"ActionContext",
|
||||
"ActionDefinition",
|
||||
"ActionHandler",
|
||||
"ActionHandlerFactory",
|
||||
"all_actions",
|
||||
"get_action",
|
||||
"register_action",
|
||||
]
|
||||
|
||||
# Built-in actions self-register at import time.
|
||||
from . import agent_task # noqa: E402, F401
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""``agent_task`` action: spin up multi_agent_chat for one rendered query.
|
||||
|
||||
Imports ``definition`` for its side-effect (self-registration on the actions
|
||||
registry) and re-exports ``build_handler`` for direct consumers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .factory import build_handler
|
||||
from .params import AgentTaskActionParams
|
||||
|
||||
__all__ = ["AgentTaskActionParams", "build_handler"]
|
||||
|
||||
# Side-effect: register on the actions store.
|
||||
from . import definition # noqa: E402, F401
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Synthesize HITL decisions for every pending interrupt (approve-all or reject-all)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def build_auto_decisions(
|
||||
state: Any, decision: str
|
||||
) -> tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]]]:
|
||||
"""Return ``(lg_resume_map, surfsense_resume_value)`` covering every pending interrupt.
|
||||
|
||||
``lg_resume_map`` is keyed by ``Interrupt.id`` for ``Command(resume=...)``;
|
||||
``surfsense_resume_value`` is keyed by ``tool_call_id`` for the subagent
|
||||
middleware bridge. Action count is read from ``value.action_requests`` when
|
||||
present and falls back to ``1`` for wrapped scalar interrupts.
|
||||
"""
|
||||
lg_resume_map: dict[str, dict[str, Any]] = {}
|
||||
routed: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for interrupt_obj in getattr(state, "interrupts", ()) or ():
|
||||
value = getattr(interrupt_obj, "value", None)
|
||||
if not isinstance(value, dict):
|
||||
continue
|
||||
interrupt_id = getattr(interrupt_obj, "id", None)
|
||||
if not isinstance(interrupt_id, str):
|
||||
continue
|
||||
|
||||
action_requests = value.get("action_requests")
|
||||
count = len(action_requests) if isinstance(action_requests, list) else 1
|
||||
decisions = [{"type": decision} for _ in range(count)]
|
||||
|
||||
lg_resume_map[interrupt_id] = {"decisions": decisions}
|
||||
|
||||
tool_call_id = value.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
routed[tool_call_id] = {"decisions": decisions}
|
||||
|
||||
return lg_resume_map, routed
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
"""``agent_task`` ``ActionDefinition`` registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..store import register_action
|
||||
from ..types import ActionDefinition
|
||||
from .factory import build_handler
|
||||
from .params import AgentTaskActionParams
|
||||
|
||||
AGENT_TASK_ACTION = ActionDefinition(
|
||||
type="agent_task",
|
||||
name="Agent task",
|
||||
description="Run a multi_agent_chat turn from an automation step.",
|
||||
params_model=AgentTaskActionParams,
|
||||
build_handler=build_handler,
|
||||
)
|
||||
|
||||
register_action(AGENT_TASK_ACTION)
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
"""Build the per-invocation dependencies the multi_agent_chat factory needs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
|
||||
from app.tasks.chat.streaming.flows.shared.pre_stream_setup import (
|
||||
setup_connector_and_firecrawl,
|
||||
)
|
||||
|
||||
|
||||
class DependencyError(Exception):
|
||||
"""An external dependency (LLM config, connector service, ...) refused to load."""
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AgentDependencies:
|
||||
"""Everything ``create_multi_agent_chat_deep_agent`` needs from the environment."""
|
||||
|
||||
llm: Any
|
||||
agent_config: Any
|
||||
connector_service: Any
|
||||
firecrawl_api_key: str | None
|
||||
checkpointer: Any
|
||||
|
||||
|
||||
async def build_dependencies(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> AgentDependencies:
|
||||
"""Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer.
|
||||
|
||||
Uses the search space's default LLM config (``config_id=-1``). Per-step
|
||||
model overrides land in a future iteration alongside the ``model`` param.
|
||||
"""
|
||||
llm, agent_config, err = await load_llm_bundle(
|
||||
session, config_id=-1, search_space_id=search_space_id
|
||||
)
|
||||
if err is not None or llm is None:
|
||||
raise DependencyError(err or "failed to load default LLM config")
|
||||
|
||||
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
|
||||
session, search_space_id=search_space_id
|
||||
)
|
||||
# Quick fix: use an in-memory checkpointer for automation runs.
|
||||
#
|
||||
# The shared Postgres checkpointer caches DB connections in a
|
||||
# module-level pool. Each cached connection is bound to the asyncio
|
||||
# loop that opened it. Celery throws away the loop after every task,
|
||||
# so the pool ends up full of connections pointing to a dead loop,
|
||||
# and the next Celery task (running on a fresh loop) can't use any
|
||||
# of them — it hangs 30s and fails with
|
||||
# `PoolTimeout: couldn't get a connection after 30.00 sec`.
|
||||
#
|
||||
# InMemorySaver has no cached connections, no loop binding — each
|
||||
# Celery task creates one and drops it on exit.
|
||||
#
|
||||
# TODO(checkpointer): proper fix is to dispose the checkpointer
|
||||
# pool around each Celery task in `run_async_celery_task`, the same
|
||||
# way `_dispose_shared_db_engine` already does for the SQLAlchemy
|
||||
# pool. Then this site can switch back to the shared checkpointer.
|
||||
checkpointer = InMemorySaver()
|
||||
return AgentDependencies(
|
||||
llm=llm,
|
||||
agent_config=agent_config,
|
||||
connector_service=connector_service,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Bind ``ActionContext`` to a callable that runs one ``agent_task`` step."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..types import ActionContext, ActionHandler
|
||||
from .invoke import run_agent_task
|
||||
from .params import AgentTaskActionParams
|
||||
|
||||
|
||||
def build_handler(ctx: ActionContext) -> ActionHandler:
|
||||
"""Return a handler closure that validates params and runs the agent task."""
|
||||
|
||||
async def handle(params: dict[str, Any]) -> dict[str, Any]:
|
||||
validated = AgentTaskActionParams.model_validate(params)
|
||||
return await run_agent_task(
|
||||
ctx=ctx,
|
||||
query=validated.query,
|
||||
auto_approve_all=validated.auto_approve_all,
|
||||
)
|
||||
|
||||
return handle
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
"""Extract the agent's final assistant text from the terminal invoke result."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
def extract_final_assistant_message(result: Any) -> str | None:
|
||||
"""Return the last ``AIMessage`` text content, or ``None`` if there isn't one.
|
||||
|
||||
Multi-part messages (content lists) are flattened by concatenating ``text``
|
||||
parts in order. Non-string content (tool calls, images) is skipped.
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
messages = result.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return None
|
||||
|
||||
for msg in reversed(messages):
|
||||
if not isinstance(msg, AIMessage):
|
||||
continue
|
||||
return _content_to_text(msg.content)
|
||||
return None
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str | None:
|
||||
if isinstance(content, str):
|
||||
text = content.strip()
|
||||
return text or None
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
parts.append(part)
|
||||
elif isinstance(part, dict) and part.get("type") == "text":
|
||||
text = part.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
joined = "".join(parts).strip()
|
||||
return joined or None
|
||||
return None
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
"""Run one ``agent_task`` invocation: ainvoke + auto-decision resume loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
|
||||
from app.db import ChatVisibility, async_session_maker
|
||||
|
||||
from ..types import ActionContext
|
||||
|
||||
from .auto_decide import build_auto_decisions
|
||||
from .dependencies import build_dependencies
|
||||
from .finalize import extract_final_assistant_message
|
||||
|
||||
# Cap on HITL resume iterations. The agent should not need this many turns in one
|
||||
# step; treat overshoot as a runaway and fail the step.
|
||||
_MAX_RESUMES = 50
|
||||
|
||||
|
||||
async def run_agent_task(
|
||||
*,
|
||||
ctx: ActionContext,
|
||||
query: str,
|
||||
auto_approve_all: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Invoke multi_agent_chat for one rendered query and return its outcome.
|
||||
|
||||
Opens its own DB session so the executor's bookkeeping session isn't tied
|
||||
up for the entire invocation. The LangGraph ``thread_id`` (a fresh UUID)
|
||||
is returned as ``agent_session_id`` for later inspection.
|
||||
"""
|
||||
agent_session_id = str(uuid.uuid4())
|
||||
user_id = str(ctx.creator_user_id) if ctx.creator_user_id else None
|
||||
decision = "approve" if auto_approve_all else "reject"
|
||||
|
||||
async with async_session_maker() as agent_session:
|
||||
deps = await build_dependencies(
|
||||
session=agent_session,
|
||||
search_space_id=ctx.search_space_id,
|
||||
)
|
||||
|
||||
agent = await create_multi_agent_chat_deep_agent(
|
||||
llm=deps.llm,
|
||||
search_space_id=ctx.search_space_id,
|
||||
db_session=agent_session,
|
||||
connector_service=deps.connector_service,
|
||||
checkpointer=deps.checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=None,
|
||||
agent_config=deps.agent_config,
|
||||
firecrawl_api_key=deps.firecrawl_api_key,
|
||||
thread_visibility=ChatVisibility.PRIVATE,
|
||||
)
|
||||
|
||||
request_id = f"automation:{ctx.run_id}:{ctx.step_id}"
|
||||
turn_id = f"{request_id}:{int(time.time() * 1000)}"
|
||||
input_state: dict[str, Any] = {
|
||||
"messages": [HumanMessage(content=query)],
|
||||
"search_space_id": ctx.search_space_id,
|
||||
"request_id": request_id,
|
||||
"turn_id": turn_id,
|
||||
}
|
||||
config: dict[str, Any] = {
|
||||
"configurable": {
|
||||
"thread_id": agent_session_id,
|
||||
"request_id": request_id,
|
||||
"turn_id": turn_id,
|
||||
},
|
||||
"recursion_limit": 10_000,
|
||||
}
|
||||
|
||||
result = await agent.ainvoke(input_state, config=config)
|
||||
|
||||
resumes = 0
|
||||
while True:
|
||||
state = await agent.aget_state(config)
|
||||
if not getattr(state, "interrupts", None):
|
||||
break
|
||||
if resumes >= _MAX_RESUMES:
|
||||
raise RuntimeError(
|
||||
f"agent_task exceeded {_MAX_RESUMES} HITL resume iterations"
|
||||
)
|
||||
lg_resume_map, routed = build_auto_decisions(state, decision)
|
||||
config["configurable"]["surfsense_resume_value"] = routed
|
||||
result = await agent.ainvoke(Command(resume=lg_resume_map), config=config)
|
||||
resumes += 1
|
||||
|
||||
return {
|
||||
"agent_session_id": agent_session_id,
|
||||
"final_message": extract_final_assistant_message(result),
|
||||
"resumes": resumes,
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""``AgentTaskActionParams`` — params for the ``agent_task`` action type."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class AgentTaskActionParams(BaseModel):
|
||||
"""Run a multi_agent_chat turn from an automation step."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="User query for the agent; rendered at execute time.",
|
||||
)
|
||||
auto_approve_all: bool = Field(
|
||||
default=False,
|
||||
description="If true, every HITL approval is auto-approved; otherwise rejected.",
|
||||
)
|
||||
23
surfsense_backend/app/automations/actions/store.py
Normal file
23
surfsense_backend/app/automations/actions/store.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
"""In-memory action registry. Populated once at process startup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .types import ActionDefinition
|
||||
|
||||
_REGISTRY: dict[str, ActionDefinition] = {}
|
||||
|
||||
|
||||
def register_action(action: ActionDefinition) -> None:
|
||||
"""Register an action. Raises on duplicate type."""
|
||||
if action.type in _REGISTRY:
|
||||
raise ValueError(f"Action already registered: {action.type!r}")
|
||||
_REGISTRY[action.type] = action
|
||||
|
||||
|
||||
def get_action(action_type: str) -> ActionDefinition | None:
|
||||
return _REGISTRY.get(action_type)
|
||||
|
||||
|
||||
def all_actions() -> dict[str, ActionDefinition]:
|
||||
"""Defensive snapshot of the registry."""
|
||||
return dict(_REGISTRY)
|
||||
40
surfsense_backend/app/automations/actions/types.py
Normal file
40
surfsense_backend/app/automations/actions/types.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""``ActionDefinition``, ``ActionContext``, and handler/factory signatures."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ActionContext:
|
||||
"""Per-invocation dependencies bound to an action handler at execute time."""
|
||||
|
||||
session: AsyncSession
|
||||
run_id: int
|
||||
step_id: str
|
||||
search_space_id: int
|
||||
creator_user_id: UUID | None
|
||||
|
||||
|
||||
ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]]
|
||||
ActionHandlerFactory = Callable[[ActionContext], ActionHandler]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ActionDefinition:
|
||||
type: str
|
||||
name: str
|
||||
description: str
|
||||
params_model: type[BaseModel]
|
||||
build_handler: ActionHandlerFactory
|
||||
|
||||
@property
|
||||
def params_schema(self) -> dict[str, Any]:
|
||||
"""JSON Schema (draft 2020-12) derived from ``params_model``."""
|
||||
return self.params_model.model_json_schema()
|
||||
16
surfsense_backend/app/automations/api/__init__.py
Normal file
16
surfsense_backend/app/automations/api/__init__.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""HTTP layer for the automations feature."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .automation import router as automation_router
|
||||
from .run import router as run_router
|
||||
from .trigger import router as trigger_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(automation_router)
|
||||
router.include_router(trigger_router)
|
||||
router.include_router(run_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
80
surfsense_backend/app/automations/api/automation.py
Normal file
80
surfsense_backend/app/automations/api/automation.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
"""HTTP routes for the ``Automation`` resource."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
|
||||
from app.automations.schemas.api import (
|
||||
AutomationCreate,
|
||||
AutomationDetail,
|
||||
AutomationList,
|
||||
AutomationSummary,
|
||||
AutomationUpdate,
|
||||
)
|
||||
from app.automations.services import AutomationService, get_automation_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/automations",
|
||||
response_model=AutomationDetail,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_automation(
|
||||
payload: AutomationCreate,
|
||||
service: AutomationService = Depends(get_automation_service),
|
||||
) -> AutomationDetail:
|
||||
"""Create an automation, optionally with initial triggers (atomic)."""
|
||||
automation = await service.create(payload)
|
||||
return AutomationDetail.model_validate(automation)
|
||||
|
||||
|
||||
@router.get("/automations", response_model=AutomationList)
|
||||
async def list_automations(
|
||||
search_space_id: int = Query(...),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
service: AutomationService = Depends(get_automation_service),
|
||||
) -> AutomationList:
|
||||
"""List automations in a search space."""
|
||||
items, total = await service.list(
|
||||
search_space_id=search_space_id, limit=limit, offset=offset
|
||||
)
|
||||
return AutomationList(
|
||||
items=[AutomationSummary.model_validate(a) for a in items],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/automations/{automation_id}", response_model=AutomationDetail)
|
||||
async def get_automation(
|
||||
automation_id: int,
|
||||
service: AutomationService = Depends(get_automation_service),
|
||||
) -> AutomationDetail:
|
||||
"""Get one automation with its definition and triggers."""
|
||||
automation = await service.get(automation_id)
|
||||
return AutomationDetail.model_validate(automation)
|
||||
|
||||
|
||||
@router.patch("/automations/{automation_id}", response_model=AutomationDetail)
|
||||
async def update_automation(
|
||||
automation_id: int,
|
||||
patch: AutomationUpdate,
|
||||
service: AutomationService = Depends(get_automation_service),
|
||||
) -> AutomationDetail:
|
||||
"""Partially update an automation. Triggers are managed separately."""
|
||||
automation = await service.update(automation_id, patch)
|
||||
return AutomationDetail.model_validate(automation)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/automations/{automation_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_automation(
|
||||
automation_id: int,
|
||||
service: AutomationService = Depends(get_automation_service),
|
||||
) -> None:
|
||||
"""Delete an automation; triggers and runs are removed by FK cascade."""
|
||||
await service.delete(automation_id)
|
||||
44
surfsense_backend/app/automations/api/run.py
Normal file
44
surfsense_backend/app/automations/api/run.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""HTTP routes for automation run history."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from app.automations.schemas.api import RunDetail, RunList, RunSummary
|
||||
from app.automations.services import RunService, get_run_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/automations/{automation_id}/runs",
|
||||
response_model=RunList,
|
||||
)
|
||||
async def list_runs(
|
||||
automation_id: int,
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
service: RunService = Depends(get_run_service),
|
||||
) -> RunList:
|
||||
"""List run history for an automation, newest first."""
|
||||
items, total = await service.list(
|
||||
automation_id=automation_id, limit=limit, offset=offset
|
||||
)
|
||||
return RunList(
|
||||
items=[RunSummary.model_validate(r) for r in items],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/automations/{automation_id}/runs/{run_id}",
|
||||
response_model=RunDetail,
|
||||
)
|
||||
async def get_run(
|
||||
automation_id: int,
|
||||
run_id: int,
|
||||
service: RunService = Depends(get_run_service),
|
||||
) -> RunDetail:
|
||||
"""Get the full record of a single run, including step results and artifacts."""
|
||||
run = await service.get(automation_id=automation_id, run_id=run_id)
|
||||
return RunDetail.model_validate(run)
|
||||
55
surfsense_backend/app/automations/api/trigger.py
Normal file
55
surfsense_backend/app/automations/api/trigger.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
"""HTTP routes for triggers attached to an automation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
|
||||
from app.automations.schemas.api import TriggerCreate, TriggerDetail, TriggerUpdate
|
||||
from app.automations.services import TriggerService, get_trigger_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/automations/{automation_id}/triggers",
|
||||
response_model=TriggerDetail,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def add_trigger(
|
||||
automation_id: int,
|
||||
payload: TriggerCreate,
|
||||
service: TriggerService = Depends(get_trigger_service),
|
||||
) -> TriggerDetail:
|
||||
"""Attach a new trigger to an automation."""
|
||||
trigger = await service.add(automation_id=automation_id, payload=payload)
|
||||
return TriggerDetail.model_validate(trigger)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/automations/{automation_id}/triggers/{trigger_id}",
|
||||
response_model=TriggerDetail,
|
||||
)
|
||||
async def update_trigger(
|
||||
automation_id: int,
|
||||
trigger_id: int,
|
||||
patch: TriggerUpdate,
|
||||
service: TriggerService = Depends(get_trigger_service),
|
||||
) -> TriggerDetail:
|
||||
"""Toggle ``enabled`` or replace ``params``. Trigger type is immutable."""
|
||||
trigger = await service.update(
|
||||
automation_id=automation_id, trigger_id=trigger_id, patch=patch
|
||||
)
|
||||
return TriggerDetail.model_validate(trigger)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/automations/{automation_id}/triggers/{trigger_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def remove_trigger(
|
||||
automation_id: int,
|
||||
trigger_id: int,
|
||||
service: TriggerService = Depends(get_trigger_service),
|
||||
) -> None:
|
||||
"""Detach a trigger from an automation."""
|
||||
await service.remove(automation_id=automation_id, trigger_id=trigger_id)
|
||||
8
surfsense_backend/app/automations/dispatch/__init__.py
Normal file
8
surfsense_backend/app/automations/dispatch/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
"""Generic dispatch primitives shared across trigger types."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .errors import DispatchError
|
||||
from .run import dispatch_run
|
||||
|
||||
__all__ = ["DispatchError", "dispatch_run"]
|
||||
7
surfsense_backend/app/automations/dispatch/errors.py
Normal file
7
surfsense_backend/app/automations/dispatch/errors.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""Dispatch errors raised when a fire request cannot be turned into a run."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class DispatchError(Exception):
|
||||
"""A dispatch could not proceed (missing trigger, invalid inputs, ...)."""
|
||||
83
surfsense_backend/app/automations/dispatch/run.py
Normal file
83
surfsense_backend/app/automations/dispatch/run.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
"""Generic run dispatch: validate, snapshot, persist, enqueue. Shared by every trigger."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jsonschema
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.persistence.enums.run_status import RunStatus
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.schemas.definition.envelope import AutomationDefinition
|
||||
from app.automations.tasks.execute_run import automation_run_execute
|
||||
|
||||
from .errors import DispatchError
|
||||
|
||||
|
||||
async def dispatch_run(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
automation: Automation,
|
||||
trigger: AutomationTrigger,
|
||||
runtime_inputs: dict[str, Any] | None = None,
|
||||
) -> AutomationRun:
|
||||
"""Validate, snapshot the definition, persist an ``AutomationRun``, enqueue execution.
|
||||
|
||||
Final inputs = ``trigger.static_inputs`` merged with ``runtime_inputs``,
|
||||
static winning on key collision. The merged dict is validated against
|
||||
``automation.definition.inputs.schema_`` and stored on the run.
|
||||
|
||||
Callers (trigger-specific adapters) are responsible for resolving
|
||||
``automation`` and ``trigger`` and for the trigger-side ``ACTIVE`` /
|
||||
``enabled`` guards. This function only handles what's identical across
|
||||
every trigger type.
|
||||
"""
|
||||
try:
|
||||
definition = AutomationDefinition.model_validate(automation.definition)
|
||||
except Exception as exc:
|
||||
raise DispatchError(f"invalid automation definition: {exc}") from exc
|
||||
|
||||
merged_inputs = {**(runtime_inputs or {}), **(trigger.static_inputs or {})}
|
||||
validated_inputs = _validate_inputs(definition, merged_inputs)
|
||||
snapshot = definition.model_dump(mode="json", by_alias=True)
|
||||
|
||||
run = AutomationRun(
|
||||
automation_id=automation.id,
|
||||
trigger_id=trigger.id,
|
||||
status=RunStatus.PENDING,
|
||||
definition_snapshot=snapshot,
|
||||
inputs=validated_inputs,
|
||||
step_results=[],
|
||||
artifacts=[],
|
||||
)
|
||||
session.add(run)
|
||||
await session.commit()
|
||||
await session.refresh(run)
|
||||
|
||||
automation_run_execute.apply_async(
|
||||
args=[run.id],
|
||||
time_limit=definition.execution.timeout_seconds,
|
||||
)
|
||||
return run
|
||||
|
||||
|
||||
def _validate_inputs(
|
||||
definition: AutomationDefinition, inputs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Validate merged inputs against the optional declared schema.
|
||||
|
||||
No declared schema → pass through (runtime inputs like ``fired_at`` /
|
||||
``last_fired_at`` and trigger ``static_inputs`` must still reach the
|
||||
template context). Returning ``{}`` here strips them and makes Jinja
|
||||
blow up on any ``{{ inputs.* }}`` reference.
|
||||
"""
|
||||
if definition.inputs is None or not definition.inputs.schema_:
|
||||
return inputs
|
||||
try:
|
||||
jsonschema.validate(instance=inputs, schema=definition.inputs.schema_)
|
||||
except jsonschema.ValidationError as exc:
|
||||
raise DispatchError(f"inputs: {exc.message}") from exc
|
||||
return inputs
|
||||
15
surfsense_backend/app/automations/persistence/__init__.py
Normal file
15
surfsense_backend/app/automations/persistence/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""Models and enums for the automation tables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .enums import AutomationStatus, RunStatus, TriggerType
|
||||
from .models import Automation, AutomationRun, AutomationTrigger
|
||||
|
||||
__all__ = [
|
||||
"Automation",
|
||||
"AutomationRun",
|
||||
"AutomationStatus",
|
||||
"AutomationTrigger",
|
||||
"RunStatus",
|
||||
"TriggerType",
|
||||
]
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
"""Enums for the automation tables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .automation_status import AutomationStatus
|
||||
from .run_status import RunStatus
|
||||
from .trigger_type import TriggerType
|
||||
|
||||
__all__ = [
|
||||
"AutomationStatus",
|
||||
"RunStatus",
|
||||
"TriggerType",
|
||||
]
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
"""Automation lifecycle status."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class AutomationStatus(StrEnum):
|
||||
ACTIVE = "active" # eligible to fire
|
||||
PAUSED = "paused" # kept, but triggers don't fire
|
||||
ARCHIVED = "archived" # read-only history
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
"""AutomationRun state machine: pending → running → (succeeded|failed|cancelled|timed_out)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RunStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
TIMED_OUT = "timed_out"
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""Trigger-kind discriminator.
|
||||
|
||||
v1 only registers ``schedule``. ``manual`` is reserved in the enum (mirrors the
|
||||
postgres enum) but is intentionally unregistered pending a redesign of the
|
||||
"Run now" UX.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class TriggerType(StrEnum):
|
||||
SCHEDULE = "schedule"
|
||||
MANUAL = "manual"
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
"""Models, one per table."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .automation import Automation
|
||||
from .run import AutomationRun
|
||||
from .trigger import AutomationTrigger
|
||||
|
||||
__all__ = [
|
||||
"Automation",
|
||||
"AutomationRun",
|
||||
"AutomationTrigger",
|
||||
]
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
"""``automations`` table — editable, versioned automation definition."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
TIMESTAMP,
|
||||
Column,
|
||||
Enum as SQLAlchemyEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db import BaseModel, TimestampMixin
|
||||
|
||||
from ..enums.automation_status import AutomationStatus
|
||||
|
||||
|
||||
class Automation(BaseModel, TimestampMixin):
|
||||
__tablename__ = "automations"
|
||||
|
||||
search_space_id = Column(
|
||||
Integer,
|
||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
created_by_user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
name = Column(String(200), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
status = Column(
|
||||
SQLAlchemyEnum(
|
||||
AutomationStatus,
|
||||
name="automation_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
default=AutomationStatus.ACTIVE,
|
||||
server_default=AutomationStatus.ACTIVE.value,
|
||||
index=True,
|
||||
)
|
||||
|
||||
definition = Column(JSONB, nullable=False)
|
||||
|
||||
version = Column(Integer, nullable=False, default=1, server_default="1")
|
||||
|
||||
updated_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
|
||||
search_space = relationship("SearchSpace", back_populates="automations")
|
||||
created_by = relationship("User", back_populates="automations")
|
||||
triggers = relationship(
|
||||
"AutomationTrigger",
|
||||
back_populates="automation",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
runs = relationship(
|
||||
"AutomationRun",
|
||||
back_populates="automation",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
66
surfsense_backend/app/automations/persistence/models/run.py
Normal file
66
surfsense_backend/app/automations/persistence/models/run.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""``automation_runs`` table — immutable per-fire execution record."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import (
|
||||
TIMESTAMP,
|
||||
Column,
|
||||
Enum as SQLAlchemyEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db import BaseModel, TimestampMixin
|
||||
|
||||
from ..enums.run_status import RunStatus
|
||||
|
||||
|
||||
class AutomationRun(BaseModel, TimestampMixin):
|
||||
__tablename__ = "automation_runs"
|
||||
|
||||
automation_id = Column(
|
||||
Integer,
|
||||
ForeignKey("automations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
trigger_id = Column(
|
||||
Integer,
|
||||
ForeignKey("automation_triggers.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
status = Column(
|
||||
SQLAlchemyEnum(
|
||||
RunStatus,
|
||||
name="automation_run_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
default=RunStatus.PENDING,
|
||||
server_default=RunStatus.PENDING.value,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# locked at fire time so historical runs always show the exact code path
|
||||
definition_snapshot = Column(JSONB, nullable=False)
|
||||
|
||||
# merged & validated inputs the run was dispatched with
|
||||
# (trigger.static_inputs ∪ producer runtime data, static wins on collision)
|
||||
inputs = Column(JSONB, nullable=False, server_default="{}")
|
||||
# one entry per executed step; agent_task entries carry their own
|
||||
# `agent_session_id` inside their entry
|
||||
step_results = Column(JSONB, nullable=False, server_default="[]")
|
||||
output = Column(JSONB, nullable=True)
|
||||
artifacts = Column(JSONB, nullable=False, server_default="[]")
|
||||
error = Column(JSONB, nullable=True)
|
||||
|
||||
started_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
finished_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
automation = relationship("Automation", back_populates="runs")
|
||||
trigger = relationship("AutomationTrigger", back_populates="runs")
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""``automation_triggers`` table — one row per (automation, trigger-instance) pair."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import (
|
||||
TIMESTAMP,
|
||||
Boolean,
|
||||
Column,
|
||||
Enum as SQLAlchemyEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db import BaseModel, TimestampMixin
|
||||
|
||||
from ..enums.trigger_type import TriggerType
|
||||
|
||||
|
||||
class AutomationTrigger(BaseModel, TimestampMixin):
|
||||
__tablename__ = "automation_triggers"
|
||||
|
||||
automation_id = Column(
|
||||
Integer,
|
||||
ForeignKey("automations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
type = Column(
|
||||
SQLAlchemyEnum(
|
||||
TriggerType,
|
||||
name="automation_trigger_type",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
params = Column(JSONB, nullable=False)
|
||||
|
||||
# Per-attachment domain values merged into every dispatched run's inputs.
|
||||
# Static wins over runtime data on key collision.
|
||||
static_inputs = Column(JSONB, nullable=False, server_default="{}")
|
||||
|
||||
enabled = Column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
server_default="true",
|
||||
index=True,
|
||||
)
|
||||
|
||||
last_fired_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
# Precomputed next fire moment in UTC; advanced after each fire by the
|
||||
# schedule tick. NULL means the trigger has never been scheduled (the
|
||||
# tick self-heals on first sight).
|
||||
next_fire_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
automation = relationship("Automation", back_populates="triggers")
|
||||
runs = relationship(
|
||||
"AutomationRun",
|
||||
back_populates="trigger",
|
||||
passive_deletes=True,
|
||||
)
|
||||
7
surfsense_backend/app/automations/runtime/__init__.py
Normal file
7
surfsense_backend/app/automations/runtime/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""Automation run executor: plan walker, step dispatch, retries, persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .executor import execute_run
|
||||
|
||||
__all__ = ["execute_run"]
|
||||
124
surfsense_backend/app/automations/runtime/executor.py
Normal file
124
surfsense_backend/app/automations/runtime/executor.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Walk an ``AutomationRun``'s snapshot plan to terminal state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.persistence.enums.run_status import RunStatus
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.automations.actions.types import ActionContext
|
||||
from app.automations.schemas.definition.envelope import AutomationDefinition
|
||||
from app.automations.schemas.definition.plan_step import PlanStep
|
||||
from app.automations.templating import build_run_context
|
||||
|
||||
from . import repository
|
||||
from .step import execute_step
|
||||
|
||||
|
||||
async def execute_run(session: AsyncSession, run_id: int) -> None:
|
||||
"""Load run ``run_id`` and execute its snapshot plan to a terminal state."""
|
||||
run = await repository.load_run(session, run_id)
|
||||
if run is None:
|
||||
raise ValueError(f"automation_run {run_id} not found")
|
||||
|
||||
if run.status != RunStatus.PENDING:
|
||||
return
|
||||
|
||||
try:
|
||||
definition = AutomationDefinition.model_validate(run.definition_snapshot)
|
||||
except Exception as exc:
|
||||
await repository.mark_failed(
|
||||
session,
|
||||
run,
|
||||
{"message": f"definition_snapshot invalid: {exc}", "type": type(exc).__name__},
|
||||
)
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
await repository.mark_running(session, run)
|
||||
await session.commit()
|
||||
|
||||
step_outputs: dict[str, Any] = {}
|
||||
|
||||
for step in definition.plan:
|
||||
template_ctx = _build_template_ctx(run, step_outputs)
|
||||
action_ctx = _build_action_ctx(session, run, step)
|
||||
result = await execute_step(
|
||||
step=step,
|
||||
template_context=template_ctx,
|
||||
action_context=action_ctx,
|
||||
default_max_retries=definition.execution.max_retries,
|
||||
default_retry_backoff=definition.execution.retry_backoff,
|
||||
default_timeout_seconds=definition.execution.timeout_seconds,
|
||||
)
|
||||
await repository.append_step_result(session, run, result)
|
||||
await session.commit()
|
||||
|
||||
if result["status"] == "failed":
|
||||
await _run_on_failure(session, run, definition)
|
||||
await repository.mark_failed(session, run, result.get("error"))
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
if result["status"] == "succeeded":
|
||||
step_outputs[step.output_as or step.step_id] = result.get("result")
|
||||
|
||||
await repository.mark_succeeded(session, run)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _run_on_failure(
|
||||
session: AsyncSession,
|
||||
run: AutomationRun,
|
||||
definition: AutomationDefinition,
|
||||
) -> None:
|
||||
"""Run the on_failure steps. Their failures don't recurse into more on_failure."""
|
||||
if not definition.execution.on_failure:
|
||||
return
|
||||
template_ctx = _build_template_ctx(run, step_outputs={})
|
||||
for step in definition.execution.on_failure:
|
||||
action_ctx = _build_action_ctx(session, run, step)
|
||||
result = await execute_step(
|
||||
step=step,
|
||||
template_context=template_ctx,
|
||||
action_context=action_ctx,
|
||||
default_max_retries=definition.execution.max_retries,
|
||||
default_retry_backoff=definition.execution.retry_backoff,
|
||||
default_timeout_seconds=definition.execution.timeout_seconds,
|
||||
)
|
||||
await repository.append_step_result(session, run, result)
|
||||
await session.commit()
|
||||
|
||||
|
||||
def _build_template_ctx(run: AutomationRun, step_outputs: dict[str, Any]) -> dict[str, Any]:
|
||||
automation = run.automation
|
||||
trigger = run.trigger
|
||||
return build_run_context(
|
||||
run_id=run.id,
|
||||
automation_id=run.automation_id,
|
||||
automation_name=automation.name if automation else None,
|
||||
automation_version=automation.version if automation else None,
|
||||
search_space_id=automation.search_space_id if automation else None,
|
||||
creator_id=automation.created_by_user_id if automation else None,
|
||||
trigger_id=run.trigger_id,
|
||||
trigger_type=trigger.type.value if trigger else None,
|
||||
started_at=run.started_at,
|
||||
attempt=1,
|
||||
inputs=run.inputs or {},
|
||||
step_outputs=step_outputs,
|
||||
)
|
||||
|
||||
|
||||
def _build_action_ctx(
|
||||
session: AsyncSession, run: AutomationRun, step: PlanStep
|
||||
) -> ActionContext:
|
||||
automation = run.automation
|
||||
return ActionContext(
|
||||
session=session,
|
||||
run_id=run.id,
|
||||
step_id=step.step_id,
|
||||
search_space_id=automation.search_space_id,
|
||||
creator_user_id=automation.created_by_user_id,
|
||||
)
|
||||
62
surfsense_backend/app/automations/runtime/repository.py
Normal file
62
surfsense_backend/app/automations/runtime/repository.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""Persistence operations on ``AutomationRun``. Pure SQL, no business logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.automations.persistence.enums.run_status import RunStatus
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
|
||||
|
||||
async def load_run(session: AsyncSession, run_id: int) -> AutomationRun | None:
|
||||
"""Load a run with its automation and trigger eagerly loaded."""
|
||||
stmt = (
|
||||
select(AutomationRun)
|
||||
.where(AutomationRun.id == run_id)
|
||||
.options(
|
||||
selectinload(AutomationRun.automation),
|
||||
selectinload(AutomationRun.trigger),
|
||||
)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def mark_running(session: AsyncSession, run: AutomationRun) -> None:
|
||||
run.status = RunStatus.RUNNING
|
||||
run.started_at = datetime.now(UTC)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def mark_succeeded(session: AsyncSession, run: AutomationRun) -> None:
|
||||
run.status = RunStatus.SUCCEEDED
|
||||
run.finished_at = datetime.now(UTC)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def mark_failed(
|
||||
session: AsyncSession,
|
||||
run: AutomationRun,
|
||||
error: dict[str, Any] | None,
|
||||
) -> None:
|
||||
run.status = RunStatus.FAILED
|
||||
run.finished_at = datetime.now(UTC)
|
||||
run.error = error
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def append_step_result(
|
||||
session: AsyncSession,
|
||||
run: AutomationRun,
|
||||
step_result: dict[str, Any],
|
||||
) -> None:
|
||||
"""Append one step result. Reassigns the list so SQLAlchemy detects the change."""
|
||||
current = list(run.step_results or [])
|
||||
current.append(step_result)
|
||||
run.step_results = current
|
||||
await session.flush()
|
||||
36
surfsense_backend/app/automations/runtime/retries.py
Normal file
36
surfsense_backend/app/automations/runtime/retries.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""Retry policy enforcement for action handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
|
||||
async def with_retries[T](
|
||||
coro_factory: Callable[[], Awaitable[T]],
|
||||
*,
|
||||
max_retries: int,
|
||||
backoff: str,
|
||||
timeout: int | None,
|
||||
) -> tuple[T, int]:
|
||||
"""Call ``coro_factory`` up to ``1 + max_retries`` times. Return ``(result, attempts)``."""
|
||||
total = 1 + max(0, max_retries)
|
||||
for attempt in range(1, total + 1):
|
||||
try:
|
||||
coro = coro_factory()
|
||||
if timeout is not None and timeout > 0:
|
||||
return await asyncio.wait_for(coro, timeout=timeout), attempt
|
||||
return await coro, attempt
|
||||
except Exception:
|
||||
if attempt >= total:
|
||||
raise
|
||||
await asyncio.sleep(_backoff_seconds(backoff, attempt))
|
||||
raise RuntimeError("with_retries exhausted without raising or returning")
|
||||
|
||||
|
||||
def _backoff_seconds(strategy: str, attempt: int) -> float:
|
||||
if strategy == "exponential":
|
||||
return float(2 ** (attempt - 1))
|
||||
if strategy == "linear":
|
||||
return float(attempt)
|
||||
return 0.0
|
||||
96
surfsense_backend/app/automations/runtime/step.py
Normal file
96
surfsense_backend/app/automations/runtime/step.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"""Execute one plan step: when-predicate, params render, handler dispatch, retries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from app.automations.actions import get_action
|
||||
from app.automations.actions.types import ActionContext
|
||||
from app.automations.schemas.definition.plan_step import PlanStep
|
||||
from app.automations.templating import evaluate_predicate, render_value
|
||||
|
||||
from .retries import with_retries
|
||||
|
||||
|
||||
async def execute_step(
|
||||
*,
|
||||
step: PlanStep,
|
||||
template_context: Mapping[str, Any],
|
||||
action_context: ActionContext,
|
||||
default_max_retries: int,
|
||||
default_retry_backoff: str,
|
||||
default_timeout_seconds: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Run one step and return its structured result entry."""
|
||||
started_at = datetime.now(UTC)
|
||||
|
||||
if step.when is not None:
|
||||
try:
|
||||
should_run = evaluate_predicate(step.when, template_context)
|
||||
except Exception as exc:
|
||||
return _result(step, "failed", started_at, attempts=0, error=_error(exc, "when"))
|
||||
if not should_run:
|
||||
return _result(step, "skipped", started_at, attempts=0)
|
||||
|
||||
try:
|
||||
resolved_params = render_value(step.params, template_context)
|
||||
except Exception as exc:
|
||||
return _result(step, "failed", started_at, attempts=0, error=_error(exc, "render"))
|
||||
|
||||
action = get_action(step.action)
|
||||
if action is None:
|
||||
return _result(
|
||||
step,
|
||||
"failed",
|
||||
started_at,
|
||||
attempts=0,
|
||||
error={"message": f"action not registered: {step.action}", "type": "ActionNotFound"},
|
||||
)
|
||||
|
||||
handler = action.build_handler(action_context)
|
||||
|
||||
max_retries = step.max_retries if step.max_retries is not None else default_max_retries
|
||||
timeout = step.timeout_seconds or default_timeout_seconds
|
||||
|
||||
try:
|
||||
result, attempts = await with_retries(
|
||||
lambda: handler(resolved_params),
|
||||
max_retries=max_retries,
|
||||
backoff=default_retry_backoff,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as exc:
|
||||
return _result(step, "failed", started_at, attempts=max_retries + 1, error=_error(exc))
|
||||
|
||||
return _result(step, "succeeded", started_at, attempts=attempts, result=result)
|
||||
|
||||
|
||||
def _result(
|
||||
step: PlanStep,
|
||||
status: str,
|
||||
started_at: datetime,
|
||||
*,
|
||||
attempts: int,
|
||||
result: Any = None,
|
||||
error: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
entry: dict[str, Any] = {
|
||||
"step_id": step.step_id,
|
||||
"action": step.action,
|
||||
"status": status,
|
||||
"started_at": started_at.isoformat(),
|
||||
"finished_at": datetime.now(UTC).isoformat(),
|
||||
"attempts": attempts,
|
||||
}
|
||||
if result is not None:
|
||||
entry["result"] = result
|
||||
if error is not None:
|
||||
entry["error"] = error
|
||||
return entry
|
||||
|
||||
|
||||
def _error(exc: Exception, phase: str | None = None) -> dict[str, Any]:
|
||||
msg = f"{phase}: {exc}" if phase else str(exc)
|
||||
return {"message": msg, "type": type(exc).__name__}
|
||||
27
surfsense_backend/app/automations/schemas/__init__.py
Normal file
27
surfsense_backend/app/automations/schemas/__init__.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Schemas for the automation definition envelope.
|
||||
|
||||
Per-action and per-trigger params schemas live with the action/trigger
|
||||
implementations (``app.automations.actions.<name>.params`` /
|
||||
``app.automations.triggers.<name>.params``); only the cross-cutting envelope
|
||||
lives here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .definition import (
|
||||
AutomationDefinition,
|
||||
Execution,
|
||||
Inputs,
|
||||
Metadata,
|
||||
PlanStep,
|
||||
TriggerSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AutomationDefinition",
|
||||
"Execution",
|
||||
"Inputs",
|
||||
"Metadata",
|
||||
"PlanStep",
|
||||
"TriggerSpec",
|
||||
]
|
||||
27
surfsense_backend/app/automations/schemas/api/__init__.py
Normal file
27
surfsense_backend/app/automations/schemas/api/__init__.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Request/response schemas for the automations HTTP layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .automation import (
|
||||
AutomationCreate,
|
||||
AutomationDetail,
|
||||
AutomationList,
|
||||
AutomationSummary,
|
||||
AutomationUpdate,
|
||||
)
|
||||
from .run import RunDetail, RunList, RunSummary
|
||||
from .trigger import TriggerCreate, TriggerDetail, TriggerUpdate
|
||||
|
||||
__all__ = [
|
||||
"AutomationCreate",
|
||||
"AutomationDetail",
|
||||
"AutomationList",
|
||||
"AutomationSummary",
|
||||
"AutomationUpdate",
|
||||
"RunDetail",
|
||||
"RunList",
|
||||
"RunSummary",
|
||||
"TriggerCreate",
|
||||
"TriggerDetail",
|
||||
"TriggerUpdate",
|
||||
]
|
||||
64
surfsense_backend/app/automations/schemas/api/automation.py
Normal file
64
surfsense_backend/app/automations/schemas/api/automation.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""Request/response schemas for the ``Automation`` resource."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.automations.persistence.enums.automation_status import AutomationStatus
|
||||
from app.automations.schemas.definition import AutomationDefinition
|
||||
|
||||
from .trigger import TriggerCreate, TriggerDetail
|
||||
|
||||
|
||||
class AutomationCreate(BaseModel):
|
||||
"""Create an automation, optionally with initial triggers (atomic)."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
search_space_id: int
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
description: str | None = None
|
||||
definition: AutomationDefinition
|
||||
triggers: list[TriggerCreate] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AutomationUpdate(BaseModel):
|
||||
"""Partial update of an automation. Triggers are managed separately."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=200)
|
||||
description: str | None = None
|
||||
status: AutomationStatus | None = None
|
||||
definition: AutomationDefinition | None = None
|
||||
|
||||
|
||||
class AutomationSummary(BaseModel):
|
||||
"""Lightweight automation view for list endpoints."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
search_space_id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
status: AutomationStatus
|
||||
version: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class AutomationDetail(AutomationSummary):
|
||||
"""Full automation view including definition and attached triggers."""
|
||||
|
||||
definition: AutomationDefinition
|
||||
triggers: list[TriggerDetail] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AutomationList(BaseModel):
|
||||
"""Paginated list of automations."""
|
||||
|
||||
items: list[AutomationSummary]
|
||||
total: int
|
||||
42
surfsense_backend/app/automations/schemas/api/run.py
Normal file
42
surfsense_backend/app/automations/schemas/api/run.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Response schemas for run sub-resources."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from app.automations.persistence.enums.run_status import RunStatus
|
||||
|
||||
|
||||
class RunSummary(BaseModel):
|
||||
"""Lightweight run view for list endpoints."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
automation_id: int
|
||||
trigger_id: int | None = None
|
||||
status: RunStatus
|
||||
started_at: datetime | None = None
|
||||
finished_at: datetime | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class RunDetail(RunSummary):
|
||||
"""Full run view including snapshot, results and artifacts."""
|
||||
|
||||
definition_snapshot: dict[str, Any]
|
||||
inputs: dict[str, Any]
|
||||
step_results: list[dict[str, Any]]
|
||||
output: dict[str, Any] | None = None
|
||||
artifacts: list[dict[str, Any]]
|
||||
error: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class RunList(BaseModel):
|
||||
"""Paginated list of runs."""
|
||||
|
||||
items: list[RunSummary]
|
||||
total: int
|
||||
46
surfsense_backend/app/automations/schemas/api/trigger.py
Normal file
46
surfsense_backend/app/automations/schemas/api/trigger.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
"""Request/response schemas for trigger sub-resources."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
|
||||
|
||||
class TriggerCreate(BaseModel):
|
||||
"""Attach a trigger to an automation."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: TriggerType
|
||||
params: dict[str, Any] = Field(default_factory=dict)
|
||||
static_inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class TriggerUpdate(BaseModel):
|
||||
"""Partial update of an existing trigger."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
enabled: bool | None = None
|
||||
params: dict[str, Any] | None = None
|
||||
static_inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TriggerDetail(BaseModel):
|
||||
"""Trigger as returned to clients."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
type: TriggerType
|
||||
params: dict[str, Any]
|
||||
static_inputs: dict[str, Any]
|
||||
enabled: bool
|
||||
last_fired_at: datetime | None = None
|
||||
next_fire_at: datetime | None = None
|
||||
created_at: datetime
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""Automation definition envelope and its components."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .envelope import AutomationDefinition
|
||||
from .execution import Execution
|
||||
from .inputs import Inputs
|
||||
from .metadata import Metadata
|
||||
from .plan_step import PlanStep
|
||||
from .trigger_spec import TriggerSpec
|
||||
|
||||
__all__ = [
|
||||
"AutomationDefinition",
|
||||
"Execution",
|
||||
"Inputs",
|
||||
"Metadata",
|
||||
"PlanStep",
|
||||
"TriggerSpec",
|
||||
]
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
"""``AutomationDefinition`` — top-level envelope persisted in ``automations.definition``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from .execution import Execution
|
||||
from .inputs import Inputs
|
||||
from .metadata import Metadata
|
||||
from .plan_step import PlanStep
|
||||
from .trigger_spec import TriggerSpec
|
||||
|
||||
|
||||
class AutomationDefinition(BaseModel):
|
||||
"""Top-level shape of an automation."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
schema_version: str = "1.0"
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
goal: str | None = None
|
||||
inputs: Inputs | None = None
|
||||
triggers: list[TriggerSpec] = Field(default_factory=list)
|
||||
plan: list[PlanStep] = Field(..., min_length=1)
|
||||
execution: Execution = Field(default_factory=Execution)
|
||||
metadata: Metadata = Field(default_factory=Metadata)
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
"""``Execution`` — automation-wide execution defaults (overridable per step)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from .plan_step import PlanStep
|
||||
|
||||
|
||||
class Execution(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
timeout_seconds: int = Field(default=600, gt=0, description="Wall-clock cap for the run.")
|
||||
max_retries: int = Field(default=2, ge=0, description="Per-step retry budget.")
|
||||
retry_backoff: Literal["exponential", "linear", "none"] = "exponential"
|
||||
concurrency: Literal["drop_if_running", "queue", "always"] = "drop_if_running"
|
||||
on_failure: list[PlanStep] = Field(
|
||||
default_factory=list,
|
||||
description="Steps run when the main plan fails after retries.",
|
||||
)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""``Inputs`` — JSON Schema for inputs an automation accepts at fire time."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class Inputs(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
)
|
||||
|
||||
schema_: dict[str, Any] = Field(
|
||||
...,
|
||||
alias="schema",
|
||||
description="JSON Schema (draft 2020-12) for accepted inputs.",
|
||||
)
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
"""``Metadata`` — free-form metadata on a definition. Extra keys allowed."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class Metadata(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""``PlanStep`` — one step in the sequential plan."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class PlanStep(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
step_id: str = Field(..., min_length=1, description="Unique within the plan.")
|
||||
action: str = Field(..., min_length=1, description="Action type; resolved via registry.")
|
||||
when: str | None = Field(
|
||||
default=None,
|
||||
description="Optional predicate; step is skipped when falsy.",
|
||||
)
|
||||
params: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Action-type-specific params; rendered at execute time.",
|
||||
)
|
||||
output_as: str | None = Field(
|
||||
default=None,
|
||||
description="Bind step output under this name. Defaults to step_id.",
|
||||
)
|
||||
max_retries: int | None = Field(default=None, ge=0)
|
||||
timeout_seconds: int | None = Field(default=None, gt=0)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
"""``TriggerSpec`` — one entry in the definition's ``triggers[]`` array."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class TriggerSpec(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: str = Field(..., min_length=1, description="Trigger type; resolved via registry.")
|
||||
params: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Type-specific params; validated against the trigger's schema.",
|
||||
)
|
||||
16
surfsense_backend/app/automations/services/__init__.py
Normal file
16
surfsense_backend/app/automations/services/__init__.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""Services for the automations HTTP layer (one service per resource)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .automation import AutomationService, get_automation_service
|
||||
from .run import RunService, get_run_service
|
||||
from .trigger import TriggerService, get_trigger_service
|
||||
|
||||
__all__ = [
|
||||
"AutomationService",
|
||||
"RunService",
|
||||
"TriggerService",
|
||||
"get_automation_service",
|
||||
"get_run_service",
|
||||
"get_trigger_service",
|
||||
]
|
||||
172
surfsense_backend/app/automations/services/automation.py
Normal file
172
surfsense_backend/app/automations/services/automation.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
"""``AutomationService`` — orchestration for the ``Automation`` resource."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.automations.schemas.api import (
|
||||
AutomationCreate,
|
||||
AutomationUpdate,
|
||||
TriggerCreate,
|
||||
)
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.schedule import compute_next_fire_at
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
||||
class AutomationService:
|
||||
"""Lifecycle of the ``Automation`` resource."""
|
||||
|
||||
def __init__(self, *, session: AsyncSession, user: User) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
|
||||
async def create(self, payload: AutomationCreate) -> Automation:
|
||||
"""Create an automation and its initial triggers in one transaction."""
|
||||
await self._authorize(payload.search_space_id, Permission.AUTOMATIONS_CREATE.value)
|
||||
|
||||
automation = Automation(
|
||||
search_space_id=payload.search_space_id,
|
||||
created_by_user_id=self.user.id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
definition=payload.definition.model_dump(mode="json", by_alias=True),
|
||||
version=1,
|
||||
)
|
||||
for spec in payload.triggers:
|
||||
automation.triggers.append(_build_trigger(spec))
|
||||
|
||||
self.session.add(automation)
|
||||
await self.session.commit()
|
||||
return await self._get_with_triggers_or_raise(automation.id)
|
||||
|
||||
async def list(
|
||||
self,
|
||||
*,
|
||||
search_space_id: int,
|
||||
limit: int,
|
||||
offset: int,
|
||||
) -> tuple[list[Automation], int]:
|
||||
"""Return a page of automations and the total count."""
|
||||
await self._authorize(search_space_id, Permission.AUTOMATIONS_READ.value)
|
||||
|
||||
base = select(Automation).where(Automation.search_space_id == search_space_id)
|
||||
total = await self.session.scalar(
|
||||
select(func.count()).select_from(base.subquery())
|
||||
)
|
||||
|
||||
rows = (
|
||||
await self.session.execute(
|
||||
base.order_by(Automation.created_at.desc()).limit(limit).offset(offset)
|
||||
)
|
||||
).scalars().all()
|
||||
return list(rows), int(total or 0)
|
||||
|
||||
async def get(self, automation_id: int) -> Automation:
|
||||
"""Get an automation with its triggers loaded."""
|
||||
automation = await self._get_with_triggers_or_raise(automation_id)
|
||||
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_READ.value)
|
||||
return automation
|
||||
|
||||
async def update(self, automation_id: int, patch: AutomationUpdate) -> Automation:
|
||||
"""Patch fields. Bumps ``version`` when ``definition`` changes."""
|
||||
automation = await self._get_with_triggers_or_raise(automation_id)
|
||||
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_UPDATE.value)
|
||||
|
||||
data = patch.model_dump(exclude_unset=True)
|
||||
|
||||
if "name" in data:
|
||||
automation.name = data["name"]
|
||||
if "description" in data:
|
||||
automation.description = data["description"]
|
||||
if "status" in data:
|
||||
automation.status = data["status"]
|
||||
if "definition" in data:
|
||||
automation.definition = patch.definition.model_dump(mode="json", by_alias=True)
|
||||
automation.version += 1
|
||||
|
||||
await self.session.commit()
|
||||
return await self._get_with_triggers_or_raise(automation_id)
|
||||
|
||||
async def delete(self, automation_id: int) -> None:
|
||||
"""Delete an automation; FK cascades remove triggers and runs."""
|
||||
automation = await self._get_or_raise(automation_id)
|
||||
await self._authorize(automation.search_space_id, Permission.AUTOMATIONS_DELETE.value)
|
||||
await self.session.delete(automation)
|
||||
await self.session.commit()
|
||||
|
||||
async def _get_or_raise(self, automation_id: int) -> Automation:
|
||||
automation = await self.session.get(Automation, automation_id)
|
||||
if automation is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"automation {automation_id} not found"
|
||||
)
|
||||
return automation
|
||||
|
||||
async def _get_with_triggers_or_raise(self, automation_id: int) -> Automation:
|
||||
stmt = (
|
||||
select(Automation)
|
||||
.where(Automation.id == automation_id)
|
||||
.options(selectinload(Automation.triggers))
|
||||
)
|
||||
automation = (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
if automation is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"automation {automation_id} not found"
|
||||
)
|
||||
return automation
|
||||
|
||||
async def _authorize(self, search_space_id: int, permission: str) -> None:
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
search_space_id,
|
||||
permission,
|
||||
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
|
||||
)
|
||||
|
||||
|
||||
def _build_trigger(spec: TriggerCreate) -> AutomationTrigger:
|
||||
"""Validate trigger params via its registered Pydantic model and build the ORM row."""
|
||||
definition = get_trigger(spec.type.value)
|
||||
if definition is None:
|
||||
raise HTTPException(status_code=422, detail=f"unknown trigger type {spec.type.value!r}")
|
||||
|
||||
try:
|
||||
validated = definition.params_model.model_validate(spec.params)
|
||||
except ValidationError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
params = validated.model_dump(mode="json")
|
||||
|
||||
next_fire_at = None
|
||||
if spec.type == TriggerType.SCHEDULE and spec.enabled:
|
||||
next_fire_at = compute_next_fire_at(
|
||||
params["cron"], params["timezone"], after=datetime.now(UTC)
|
||||
)
|
||||
|
||||
return AutomationTrigger(
|
||||
type=spec.type,
|
||||
params=params,
|
||||
static_inputs=spec.static_inputs,
|
||||
enabled=spec.enabled,
|
||||
next_fire_at=next_fire_at,
|
||||
)
|
||||
|
||||
|
||||
def get_automation_service(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> AutomationService:
|
||||
return AutomationService(session=session, user=user)
|
||||
72
surfsense_backend/app/automations/services/run.py
Normal file
72
surfsense_backend/app/automations/services/run.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""``RunService`` — read-only access to automation run history."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
||||
class RunService:
|
||||
"""Read-only access to ``AutomationRun`` history."""
|
||||
|
||||
def __init__(self, *, session: AsyncSession, user: User) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
|
||||
async def list(
|
||||
self,
|
||||
*,
|
||||
automation_id: int,
|
||||
limit: int,
|
||||
offset: int,
|
||||
) -> tuple[list[AutomationRun], int]:
|
||||
"""Return a page of runs for an automation, newest first."""
|
||||
await self._authorize(automation_id, Permission.AUTOMATIONS_READ.value)
|
||||
|
||||
base = select(AutomationRun).where(AutomationRun.automation_id == automation_id)
|
||||
total = await self.session.scalar(
|
||||
select(func.count()).select_from(base.subquery())
|
||||
)
|
||||
|
||||
rows = (
|
||||
await self.session.execute(
|
||||
base.order_by(AutomationRun.created_at.desc()).limit(limit).offset(offset)
|
||||
)
|
||||
).scalars().all()
|
||||
return list(rows), int(total or 0)
|
||||
|
||||
async def get(self, *, automation_id: int, run_id: int) -> AutomationRun:
|
||||
await self._authorize(automation_id, Permission.AUTOMATIONS_READ.value)
|
||||
run = await self.session.get(AutomationRun, run_id)
|
||||
if run is None or run.automation_id != automation_id:
|
||||
raise HTTPException(status_code=404, detail=f"run {run_id} not found")
|
||||
return run
|
||||
|
||||
async def _authorize(self, automation_id: int, permission: str) -> Automation:
|
||||
automation = await self.session.get(Automation, automation_id)
|
||||
if automation is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"automation {automation_id} not found"
|
||||
)
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
automation.search_space_id,
|
||||
permission,
|
||||
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
|
||||
)
|
||||
return automation
|
||||
|
||||
|
||||
def get_run_service(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> RunService:
|
||||
return RunService(session=session, user=user)
|
||||
143
surfsense_backend/app/automations/services/trigger.py
Normal file
143
surfsense_backend/app/automations/services/trigger.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""``TriggerService`` — lifecycle of triggers attached to an automation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.schedule import compute_next_fire_at
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
||||
class TriggerService:
|
||||
"""Lifecycle of the ``AutomationTrigger`` sub-resource."""
|
||||
|
||||
def __init__(self, *, session: AsyncSession, user: User) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
|
||||
async def add(
|
||||
self, *, automation_id: int, payload: TriggerCreate
|
||||
) -> AutomationTrigger:
|
||||
automation = await self._authorize_automation(
|
||||
automation_id, Permission.AUTOMATIONS_UPDATE.value
|
||||
)
|
||||
|
||||
validated_params = _validate_params(payload.type, payload.params)
|
||||
trigger = AutomationTrigger(
|
||||
automation_id=automation.id,
|
||||
type=payload.type,
|
||||
params=validated_params,
|
||||
static_inputs=payload.static_inputs,
|
||||
enabled=payload.enabled,
|
||||
next_fire_at=_initial_next_fire(payload.type, validated_params, payload.enabled),
|
||||
)
|
||||
self.session.add(trigger)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(trigger)
|
||||
return trigger
|
||||
|
||||
async def update(
|
||||
self,
|
||||
*,
|
||||
automation_id: int,
|
||||
trigger_id: int,
|
||||
patch: TriggerUpdate,
|
||||
) -> AutomationTrigger:
|
||||
await self._authorize_automation(automation_id, Permission.AUTOMATIONS_UPDATE.value)
|
||||
trigger = await self._get_trigger_or_raise(automation_id, trigger_id)
|
||||
|
||||
data = patch.model_dump(exclude_unset=True)
|
||||
|
||||
if "params" in data:
|
||||
trigger.params = _validate_params(trigger.type, data["params"])
|
||||
|
||||
if "static_inputs" in data:
|
||||
trigger.static_inputs = data["static_inputs"]
|
||||
|
||||
if "enabled" in data:
|
||||
trigger.enabled = data["enabled"]
|
||||
|
||||
# Recompute next_fire_at when schedule timing changed or the trigger was
|
||||
# toggled back on.
|
||||
if trigger.type == TriggerType.SCHEDULE:
|
||||
trigger.next_fire_at = _initial_next_fire(
|
||||
trigger.type, trigger.params, trigger.enabled
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(trigger)
|
||||
return trigger
|
||||
|
||||
async def remove(self, *, automation_id: int, trigger_id: int) -> None:
|
||||
await self._authorize_automation(automation_id, Permission.AUTOMATIONS_UPDATE.value)
|
||||
trigger = await self._get_trigger_or_raise(automation_id, trigger_id)
|
||||
await self.session.delete(trigger)
|
||||
await self.session.commit()
|
||||
|
||||
async def _authorize_automation(
|
||||
self, automation_id: int, permission: str
|
||||
) -> Automation:
|
||||
automation = await self.session.get(Automation, automation_id)
|
||||
if automation is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"automation {automation_id} not found"
|
||||
)
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
automation.search_space_id,
|
||||
permission,
|
||||
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
|
||||
)
|
||||
return automation
|
||||
|
||||
async def _get_trigger_or_raise(
|
||||
self, automation_id: int, trigger_id: int
|
||||
) -> AutomationTrigger:
|
||||
trigger = await self.session.get(AutomationTrigger, trigger_id)
|
||||
if trigger is None or trigger.automation_id != automation_id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"trigger {trigger_id} not found"
|
||||
)
|
||||
return trigger
|
||||
|
||||
|
||||
def _validate_params(trigger_type: TriggerType, raw: dict) -> dict:
|
||||
definition = get_trigger(trigger_type.value)
|
||||
if definition is None:
|
||||
raise HTTPException(
|
||||
status_code=422, detail=f"unknown trigger type {trigger_type.value!r}"
|
||||
)
|
||||
try:
|
||||
validated = definition.params_model.model_validate(raw)
|
||||
except ValidationError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
return validated.model_dump(mode="json")
|
||||
|
||||
|
||||
def _initial_next_fire(
|
||||
trigger_type: TriggerType, params: dict, enabled: bool
|
||||
) -> datetime | None:
|
||||
if trigger_type != TriggerType.SCHEDULE or not enabled:
|
||||
return None
|
||||
return compute_next_fire_at(
|
||||
params["cron"], params["timezone"], after=datetime.now(UTC)
|
||||
)
|
||||
|
||||
|
||||
def get_trigger_service(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> TriggerService:
|
||||
return TriggerService(session=session, user=user)
|
||||
3
surfsense_backend/app/automations/tasks/__init__.py
Normal file
3
surfsense_backend/app/automations/tasks/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""Celery task wrappers for the automation runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
33
surfsense_backend/app/automations/tasks/execute_run.py
Normal file
33
surfsense_backend/app/automations/tasks/execute_run.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""Celery task that runs one automation. Thin wrapper over ``runtime.executor``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.automations.runtime import execute_run
|
||||
from app.celery_app import celery_app
|
||||
from app.tasks.celery_tasks import (
|
||||
get_celery_session_maker,
|
||||
run_async_celery_task,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TASK_NAME = "automation_run_execute"
|
||||
|
||||
|
||||
@celery_app.task(name=TASK_NAME, bind=True)
|
||||
def automation_run_execute(self, run_id: int) -> None: # noqa: ARG001 — Celery bind
|
||||
"""Execute one ``AutomationRun``. Idempotent: terminal runs no-op."""
|
||||
return run_async_celery_task(lambda: _impl(run_id))
|
||||
|
||||
|
||||
async def _impl(run_id: int) -> None:
|
||||
session_maker = get_celery_session_maker()
|
||||
async with session_maker() as session:
|
||||
try:
|
||||
await execute_run(session, run_id)
|
||||
except Exception:
|
||||
logger.exception("automation_run %d failed unexpectedly", run_id)
|
||||
await session.rollback()
|
||||
raise
|
||||
187
surfsense_backend/app/automations/tasks/schedule_tick.py
Normal file
187
surfsense_backend/app/automations/tasks/schedule_tick.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""Celery Beat tick that fires due ``schedule`` triggers.
|
||||
|
||||
Runs every minute. Each tick performs two passes:
|
||||
|
||||
1. **Self-heal**: enabled schedule triggers with NULL ``next_fire_at`` get
|
||||
it computed from their ``cron`` + ``timezone`` (e.g. fresh inserts or
|
||||
rows restored from backup).
|
||||
2. **Claim & fire**: due rows are locked with ``FOR UPDATE SKIP LOCKED``,
|
||||
their ``next_fire_at`` is advanced and ``last_fired_at`` is set, and
|
||||
``dispatch_schedule_run`` is invoked for each. Dispatch errors are
|
||||
logged; a missed fire stays missed (matches K8s CronJob / Airflow
|
||||
``catchup=False`` semantics).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.triggers.schedule import (
|
||||
InvalidCronError,
|
||||
compute_next_fire_at,
|
||||
dispatch_schedule_run,
|
||||
)
|
||||
from app.celery_app import celery_app
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TASK_NAME = "automation_schedule_tick"
|
||||
|
||||
# Cap rows touched per tick so a backlog of due triggers can't starve the
|
||||
# worker; remaining rows fire on the next tick.
|
||||
_TICK_BATCH = 200
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _Claim:
|
||||
"""Per-trigger fire context captured before row state is mutated."""
|
||||
|
||||
trigger_id: int
|
||||
scheduled_for: datetime
|
||||
previous_last_fired_at: datetime | None
|
||||
|
||||
|
||||
@celery_app.task(name=TASK_NAME)
|
||||
def automation_schedule_tick() -> None:
|
||||
"""Tick once: self-heal NULL next_fire_at, claim due rows, fire each."""
|
||||
return run_async_celery_task(_tick)
|
||||
|
||||
|
||||
async def _tick() -> None:
|
||||
session_maker = get_celery_session_maker()
|
||||
async with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
await _self_heal_null_next_fire(session, now=now)
|
||||
|
||||
claims = await _claim_due_triggers(session, now=now)
|
||||
if not claims:
|
||||
return
|
||||
|
||||
for claim in claims:
|
||||
await _fire_one(session, claim=claim, fired_at=now)
|
||||
|
||||
|
||||
async def _self_heal_null_next_fire(session: AsyncSession, *, now: datetime) -> None:
|
||||
"""Backfill ``next_fire_at`` for enabled schedule triggers missing it."""
|
||||
stmt = (
|
||||
select(AutomationTrigger)
|
||||
.where(
|
||||
AutomationTrigger.type == TriggerType.SCHEDULE,
|
||||
AutomationTrigger.enabled.is_(True),
|
||||
AutomationTrigger.next_fire_at.is_(None),
|
||||
)
|
||||
.limit(_TICK_BATCH)
|
||||
)
|
||||
triggers = (await session.execute(stmt)).scalars().all()
|
||||
if not triggers:
|
||||
return
|
||||
|
||||
for trigger in triggers:
|
||||
try:
|
||||
trigger.next_fire_at = compute_next_fire_at(
|
||||
trigger.params["cron"],
|
||||
trigger.params["timezone"],
|
||||
after=now,
|
||||
)
|
||||
except (InvalidCronError, KeyError, TypeError) as exc:
|
||||
logger.warning(
|
||||
"automation_trigger %d has invalid schedule params, disabling: %s",
|
||||
trigger.id,
|
||||
exc,
|
||||
)
|
||||
trigger.enabled = False
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _claim_due_triggers(
|
||||
session: AsyncSession, *, now: datetime
|
||||
) -> list[_Claim]:
|
||||
"""Lock and advance due rows; return per-trigger fire context."""
|
||||
stmt = (
|
||||
select(AutomationTrigger)
|
||||
.where(
|
||||
AutomationTrigger.type == TriggerType.SCHEDULE,
|
||||
AutomationTrigger.enabled.is_(True),
|
||||
AutomationTrigger.next_fire_at.isnot(None),
|
||||
AutomationTrigger.next_fire_at <= now,
|
||||
)
|
||||
.order_by(AutomationTrigger.next_fire_at)
|
||||
.limit(_TICK_BATCH)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
triggers = (await session.execute(stmt)).scalars().all()
|
||||
if not triggers:
|
||||
return []
|
||||
|
||||
claims: list[_Claim] = []
|
||||
for trigger in triggers:
|
||||
# Snapshot fire-context BEFORE we advance the row.
|
||||
scheduled_for = trigger.next_fire_at
|
||||
previous_last_fired_at = trigger.last_fired_at
|
||||
|
||||
try:
|
||||
trigger.next_fire_at = compute_next_fire_at(
|
||||
trigger.params["cron"],
|
||||
trigger.params["timezone"],
|
||||
after=now,
|
||||
)
|
||||
except (InvalidCronError, KeyError, TypeError) as exc:
|
||||
logger.warning(
|
||||
"automation_trigger %d has invalid schedule params, disabling: %s",
|
||||
trigger.id,
|
||||
exc,
|
||||
)
|
||||
trigger.enabled = False
|
||||
continue
|
||||
|
||||
trigger.last_fired_at = now
|
||||
claims.append(
|
||||
_Claim(
|
||||
trigger_id=trigger.id,
|
||||
scheduled_for=scheduled_for,
|
||||
previous_last_fired_at=previous_last_fired_at,
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
return claims
|
||||
|
||||
|
||||
async def _fire_one(
|
||||
session: AsyncSession, *, claim: _Claim, fired_at: datetime
|
||||
) -> None:
|
||||
"""Reload the trigger post-commit and dispatch a run for it."""
|
||||
trigger = await session.get(AutomationTrigger, claim.trigger_id)
|
||||
if trigger is None:
|
||||
return
|
||||
|
||||
try:
|
||||
run = await dispatch_schedule_run(
|
||||
session=session,
|
||||
trigger=trigger,
|
||||
fired_at=fired_at,
|
||||
scheduled_for=claim.scheduled_for,
|
||||
previous_last_fired_at=claim.previous_last_fired_at,
|
||||
)
|
||||
logger.info(
|
||||
"scheduled fire: trigger=%d automation=%d run=%d",
|
||||
claim.trigger_id,
|
||||
trigger.automation_id,
|
||||
run.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"scheduled fire failed for trigger %d (next attempt at next match)",
|
||||
claim.trigger_id,
|
||||
)
|
||||
await session.rollback()
|
||||
13
surfsense_backend/app/automations/templating/__init__.py
Normal file
13
surfsense_backend/app/automations/templating/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
"""Sandboxed template engine for automation definitions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .context import build_run_context
|
||||
from .render import evaluate_predicate, render_template, render_value
|
||||
|
||||
__all__ = [
|
||||
"build_run_context",
|
||||
"evaluate_predicate",
|
||||
"render_template",
|
||||
"render_value",
|
||||
]
|
||||
31
surfsense_backend/app/automations/templating/allowlist.py
Normal file
31
surfsense_backend/app/automations/templating/allowlist.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""Filter and test names admitted into the sandboxed environment."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
ALLOWED_FILTERS: tuple[str, ...] = (
|
||||
"default",
|
||||
"first",
|
||||
"join",
|
||||
"last",
|
||||
"length",
|
||||
"lower",
|
||||
"replace",
|
||||
"reverse",
|
||||
"sort",
|
||||
"tojson",
|
||||
"trim",
|
||||
"truncate",
|
||||
"upper",
|
||||
"date",
|
||||
"slugify",
|
||||
)
|
||||
|
||||
ALLOWED_TESTS: tuple[str, ...] = (
|
||||
"defined",
|
||||
"none",
|
||||
"number",
|
||||
"string",
|
||||
"mapping",
|
||||
"sequence",
|
||||
"boolean",
|
||||
)
|
||||
41
surfsense_backend/app/automations/templating/context.py
Normal file
41
surfsense_backend/app/automations/templating/context.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Builder for the ``{run, inputs, steps}`` namespace exposed to every template."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
def build_run_context(
|
||||
*,
|
||||
run_id: int,
|
||||
automation_id: int,
|
||||
automation_name: str | None,
|
||||
automation_version: int | None,
|
||||
search_space_id: int | None,
|
||||
creator_id: Any,
|
||||
trigger_id: int | None,
|
||||
trigger_type: str | None,
|
||||
started_at: datetime | None,
|
||||
attempt: int,
|
||||
inputs: Mapping[str, Any],
|
||||
step_outputs: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Build the ``{run, inputs, steps}`` namespace exposed to every template."""
|
||||
return {
|
||||
"run": {
|
||||
"id": run_id,
|
||||
"automation_id": automation_id,
|
||||
"automation_name": automation_name,
|
||||
"automation_version": automation_version,
|
||||
"search_space_id": search_space_id,
|
||||
"creator_id": creator_id,
|
||||
"trigger_id": trigger_id,
|
||||
"trigger_type": trigger_type,
|
||||
"started_at": started_at,
|
||||
"attempt": attempt,
|
||||
},
|
||||
"inputs": dict(inputs),
|
||||
"steps": dict(step_outputs),
|
||||
}
|
||||
43
surfsense_backend/app/automations/templating/environment.py
Normal file
43
surfsense_backend/app/automations/templating/environment.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
"""SandboxedEnvironment construction with the audited filter/test allowlist."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import StrictUndefined
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
from .allowlist import ALLOWED_FILTERS, ALLOWED_TESTS
|
||||
from .filters import filter_date, filter_slugify
|
||||
|
||||
|
||||
def _finalize(value: Any) -> Any:
|
||||
"""Stringify common non-string values at output sites."""
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
if isinstance(value, list | dict):
|
||||
return json.dumps(value, ensure_ascii=False, default=str)
|
||||
return value
|
||||
|
||||
|
||||
def _build_env() -> SandboxedEnvironment:
|
||||
env = SandboxedEnvironment(
|
||||
autoescape=False,
|
||||
undefined=StrictUndefined,
|
||||
finalize=_finalize,
|
||||
)
|
||||
env.globals.clear()
|
||||
env.filters = {k: v for k, v in env.filters.items() if k in ALLOWED_FILTERS}
|
||||
env.filters["date"] = filter_date
|
||||
env.filters["slugify"] = filter_slugify
|
||||
env.tests = {k: v for k, v in env.tests.items() if k in ALLOWED_TESTS}
|
||||
return env
|
||||
|
||||
|
||||
ENV: SandboxedEnvironment = _build_env()
|
||||
29
surfsense_backend/app/automations/templating/filters.py
Normal file
29
surfsense_backend/app/automations/templating/filters.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""Custom Jinja filters registered into the sandboxed environment."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
def filter_date(value: Any, fmt: str = "%Y-%m-%d") -> str:
|
||||
"""Format a datetime-like value with ``strftime``. Strings pass through."""
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if hasattr(value, "strftime"):
|
||||
return value.strftime(fmt)
|
||||
raise ValueError(f"date filter requires datetime-like, got {type(value).__name__}")
|
||||
|
||||
|
||||
_SLUG_NONALNUM = re.compile(r"[^a-z0-9]+")
|
||||
_SLUG_DASHES = re.compile(r"-+")
|
||||
|
||||
|
||||
def filter_slugify(value: Any) -> str:
|
||||
"""Lowercase, replace non-alphanumerics with hyphens, collapse and trim."""
|
||||
s = str(value).lower()
|
||||
s = _SLUG_NONALNUM.sub("-", s)
|
||||
s = _SLUG_DASHES.sub("-", s)
|
||||
return s.strip("-")
|
||||
29
surfsense_backend/app/automations/templating/render.py
Normal file
29
surfsense_backend/app/automations/templating/render.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""Render templates and evaluate predicates against the sandboxed environment."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from .environment import ENV
|
||||
|
||||
|
||||
def render_template(template: str, context: Mapping[str, Any]) -> str:
|
||||
"""Render ``template`` with ``context``."""
|
||||
return ENV.from_string(template).render(**context)
|
||||
|
||||
|
||||
def evaluate_predicate(expression: str, context: Mapping[str, Any]) -> bool:
|
||||
"""Evaluate a Jinja expression (not a template body) and coerce to bool."""
|
||||
return bool(ENV.compile_expression(expression)(**context))
|
||||
|
||||
|
||||
def render_value(value: Any, context: Mapping[str, Any]) -> Any:
|
||||
"""Recursively render every string in a JSON-like value structure."""
|
||||
if isinstance(value, str):
|
||||
return render_template(value, context)
|
||||
if isinstance(value, dict):
|
||||
return {k: render_value(v, context) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [render_value(v, context) for v in value]
|
||||
return value
|
||||
20
surfsense_backend/app/automations/triggers/__init__.py
Normal file
20
surfsense_backend/app/automations/triggers/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""Triggers domain: registry surface + built-in trigger packages.
|
||||
|
||||
Each trigger lives in its own subpackage (``schedule/``, ...) and
|
||||
self-registers at import time via its ``definition`` module.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .store import all_triggers, get_trigger, register_trigger
|
||||
from .types import TriggerDefinition
|
||||
|
||||
__all__ = [
|
||||
"TriggerDefinition",
|
||||
"all_triggers",
|
||||
"get_trigger",
|
||||
"register_trigger",
|
||||
]
|
||||
|
||||
# Built-in triggers self-register at import time.
|
||||
from . import schedule # noqa: E402, F401
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
"""``schedule`` trigger: fired on a cron schedule in a given timezone."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .cron import InvalidCronError, compute_next_fire_at, validate_cron
|
||||
from .dispatch import dispatch_schedule_run
|
||||
from .params import ScheduleTriggerParams
|
||||
|
||||
__all__ = [
|
||||
"InvalidCronError",
|
||||
"ScheduleTriggerParams",
|
||||
"compute_next_fire_at",
|
||||
"dispatch_schedule_run",
|
||||
"validate_cron",
|
||||
]
|
||||
|
||||
# Side-effect: register on the triggers store.
|
||||
from . import definition # noqa: E402, F401
|
||||
37
surfsense_backend/app/automations/triggers/schedule/cron.py
Normal file
37
surfsense_backend/app/automations/triggers/schedule/cron.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""Cron math for the ``schedule`` trigger: validate + advance ``next_fire_at``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from croniter import CroniterBadCronError, croniter
|
||||
|
||||
|
||||
class InvalidCronError(ValueError):
|
||||
"""Raised when a cron expression or timezone fails validation."""
|
||||
|
||||
|
||||
def validate_cron(cron: str, timezone: str) -> None:
|
||||
"""Raise ``InvalidCronError`` if cron or timezone are unusable."""
|
||||
try:
|
||||
ZoneInfo(timezone)
|
||||
except ZoneInfoNotFoundError as exc:
|
||||
raise InvalidCronError(f"unknown timezone {timezone!r}") from exc
|
||||
|
||||
try:
|
||||
croniter(cron)
|
||||
except (CroniterBadCronError, ValueError) as exc:
|
||||
raise InvalidCronError(f"invalid cron {cron!r}: {exc}") from exc
|
||||
|
||||
|
||||
def compute_next_fire_at(cron: str, timezone: str, *, after: datetime) -> datetime:
|
||||
"""Return the next moment matching ``cron`` in ``timezone`` strictly after ``after``.
|
||||
|
||||
The result is normalized to UTC for storage. ``after`` is converted into the
|
||||
given timezone before evaluation so DST and IANA rules apply correctly.
|
||||
"""
|
||||
tz = ZoneInfo(timezone)
|
||||
base = after.astimezone(tz) if after.tzinfo else after.replace(tzinfo=UTC).astimezone(tz)
|
||||
nxt: datetime = croniter(cron, base).get_next(datetime)
|
||||
return nxt.astimezone(UTC)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""``schedule`` ``TriggerDefinition`` registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..store import register_trigger
|
||||
from ..types import TriggerDefinition
|
||||
from .params import ScheduleTriggerParams
|
||||
|
||||
SCHEDULE_TRIGGER = TriggerDefinition(
|
||||
type="schedule",
|
||||
description="Fire on a cron schedule in a given timezone.",
|
||||
params_model=ScheduleTriggerParams,
|
||||
)
|
||||
|
||||
register_trigger(SCHEDULE_TRIGGER)
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Schedule dispatch adapter: load + guard, then call generic dispatch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.automations.dispatch import DispatchError, dispatch_run
|
||||
from app.automations.persistence.enums.automation_status import AutomationStatus
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
|
||||
|
||||
async def dispatch_schedule_run(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
trigger: AutomationTrigger,
|
||||
fired_at: datetime,
|
||||
scheduled_for: datetime,
|
||||
previous_last_fired_at: datetime | None,
|
||||
) -> AutomationRun:
|
||||
"""Fire one scheduled run for ``trigger``.
|
||||
|
||||
Emits calendar context as runtime inputs:
|
||||
|
||||
- ``fired_at`` — actual fire time
|
||||
- ``scheduled_for`` — cron-derived target time for this fire
|
||||
- ``last_fired_at`` — fire time of the previous run, or null on first fire
|
||||
|
||||
The caller (the schedule tick) is responsible for selecting due triggers
|
||||
and advancing ``next_fire_at`` / ``last_fired_at`` before invoking this.
|
||||
"""
|
||||
automation = await _load_automation(session, trigger.automation_id)
|
||||
if automation is None:
|
||||
raise DispatchError(
|
||||
f"automation {trigger.automation_id} not found for trigger {trigger.id}"
|
||||
)
|
||||
|
||||
if automation.status != AutomationStatus.ACTIVE:
|
||||
raise DispatchError(
|
||||
f"automation {trigger.automation_id} is {automation.status.value}, not active"
|
||||
)
|
||||
|
||||
runtime_inputs = {
|
||||
"fired_at": fired_at.isoformat(),
|
||||
"scheduled_for": scheduled_for.isoformat(),
|
||||
"last_fired_at": (
|
||||
previous_last_fired_at.isoformat() if previous_last_fired_at else None
|
||||
),
|
||||
}
|
||||
|
||||
return await dispatch_run(
|
||||
session=session,
|
||||
automation=automation,
|
||||
trigger=trigger,
|
||||
runtime_inputs=runtime_inputs,
|
||||
)
|
||||
|
||||
|
||||
async def _load_automation(
|
||||
session: AsyncSession, automation_id: int
|
||||
) -> Automation | None:
|
||||
stmt = select(Automation).where(Automation.id == automation_id)
|
||||
return (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
"""``ScheduleTriggerParams`` — params for the ``schedule`` trigger type."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from .cron import InvalidCronError, validate_cron
|
||||
|
||||
|
||||
class ScheduleTriggerParams(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
cron: str = Field(..., description="Five-field cron expression.", examples=["0 9 * * 1-5"])
|
||||
timezone: str = Field(..., description="IANA timezone.", examples=["Africa/Kigali"])
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate(self) -> ScheduleTriggerParams:
|
||||
try:
|
||||
validate_cron(self.cron, self.timezone)
|
||||
except InvalidCronError as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
return self
|
||||
23
surfsense_backend/app/automations/triggers/store.py
Normal file
23
surfsense_backend/app/automations/triggers/store.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
"""In-memory trigger registry. Populated once at process startup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .types import TriggerDefinition
|
||||
|
||||
_REGISTRY: dict[str, TriggerDefinition] = {}
|
||||
|
||||
|
||||
def register_trigger(trigger: TriggerDefinition) -> None:
|
||||
"""Register a trigger. Raises on duplicate type."""
|
||||
if trigger.type in _REGISTRY:
|
||||
raise ValueError(f"Trigger already registered: {trigger.type!r}")
|
||||
_REGISTRY[trigger.type] = trigger
|
||||
|
||||
|
||||
def get_trigger(trigger_type: str) -> TriggerDefinition | None:
|
||||
return _REGISTRY.get(trigger_type)
|
||||
|
||||
|
||||
def all_triggers() -> dict[str, TriggerDefinition]:
|
||||
"""Defensive snapshot of the registry."""
|
||||
return dict(_REGISTRY)
|
||||
20
surfsense_backend/app/automations/triggers/types.py
Normal file
20
surfsense_backend/app/automations/triggers/types.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""``TriggerDefinition`` dataclass. Declarative; firing is the dispatcher's job."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TriggerDefinition:
|
||||
type: str
|
||||
description: str
|
||||
params_model: type[BaseModel]
|
||||
|
||||
@property
|
||||
def params_schema(self) -> dict[str, Any]:
|
||||
"""JSON Schema (draft 2020-12) derived from ``params_model``."""
|
||||
return self.params_model.model_json_schema()
|
||||
|
|
@ -188,6 +188,8 @@ celery_app = Celery(
|
|||
"app.tasks.celery_tasks.document_reindex_tasks",
|
||||
"app.tasks.celery_tasks.stale_notification_cleanup_task",
|
||||
"app.tasks.celery_tasks.stripe_reconciliation_task",
|
||||
"app.automations.tasks.execute_run",
|
||||
"app.automations.tasks.schedule_tick",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -282,4 +284,14 @@ celery_app.conf.beat_schedule = {
|
|||
"expires": 60,
|
||||
},
|
||||
},
|
||||
# Fire due automation schedule triggers. Ticks every minute; per-row cron
|
||||
# math is precomputed (next_fire_at column) so the tick is an indexed
|
||||
# lookup, not N cron evaluations.
|
||||
"automation-schedule-tick": {
|
||||
"task": "automation_schedule_tick",
|
||||
"schedule": crontab(minute="*"),
|
||||
"options": {
|
||||
"expires": 50,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -439,6 +439,13 @@ class Permission(StrEnum):
|
|||
PUBLIC_SHARING_CREATE = "public_sharing:create"
|
||||
PUBLIC_SHARING_DELETE = "public_sharing:delete"
|
||||
|
||||
# Automations
|
||||
AUTOMATIONS_CREATE = "automations:create"
|
||||
AUTOMATIONS_READ = "automations:read"
|
||||
AUTOMATIONS_UPDATE = "automations:update"
|
||||
AUTOMATIONS_DELETE = "automations:delete"
|
||||
AUTOMATIONS_EXECUTE = "automations:execute"
|
||||
|
||||
# Full access wildcard
|
||||
FULL_ACCESS = "*"
|
||||
|
||||
|
|
@ -494,6 +501,11 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
# Public Sharing (can create and view, no delete)
|
||||
Permission.PUBLIC_SHARING_VIEW.value,
|
||||
Permission.PUBLIC_SHARING_CREATE.value,
|
||||
# Automations (no delete)
|
||||
Permission.AUTOMATIONS_CREATE.value,
|
||||
Permission.AUTOMATIONS_READ.value,
|
||||
Permission.AUTOMATIONS_UPDATE.value,
|
||||
Permission.AUTOMATIONS_EXECUTE.value,
|
||||
],
|
||||
"Viewer": [
|
||||
# Documents (read only)
|
||||
|
|
@ -525,6 +537,8 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
Permission.SETTINGS_VIEW.value,
|
||||
# Public Sharing (view only)
|
||||
Permission.PUBLIC_SHARING_VIEW.value,
|
||||
# Automations (read only)
|
||||
Permission.AUTOMATIONS_READ.value,
|
||||
],
|
||||
}
|
||||
|
||||
|
|
@ -1533,6 +1547,14 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
automations = relationship(
|
||||
"Automation",
|
||||
back_populates="search_space",
|
||||
order_by="Automation.id",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# RBAC relationships
|
||||
roles = relationship(
|
||||
"SearchSpaceRole",
|
||||
|
|
@ -2125,6 +2147,13 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Automations created by this user
|
||||
automations = relationship(
|
||||
"Automation",
|
||||
back_populates="created_by",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Incentive tasks completed by this user
|
||||
incentive_tasks = relationship(
|
||||
"UserIncentiveTask",
|
||||
|
|
@ -2257,6 +2286,13 @@ else:
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Automations created by this user
|
||||
automations = relationship(
|
||||
"Automation",
|
||||
back_populates="created_by",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Incentive tasks completed by this user
|
||||
incentive_tasks = relationship(
|
||||
"UserIncentiveTask",
|
||||
|
|
@ -2560,6 +2596,16 @@ class RefreshToken(Base, TimestampMixin):
|
|||
return not self.is_expired and not self.is_revoked
|
||||
|
||||
|
||||
# Register model packages that live outside this file so their classes
|
||||
# are present in Base.metadata before configure_mappers() resolves any
|
||||
# string-based relationship() references.
|
||||
from app.automations.persistence import ( # noqa: E402, F401
|
||||
Automation,
|
||||
AutomationRun,
|
||||
AutomationTrigger,
|
||||
)
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
DATABASE_URL,
|
||||
pool_size=30,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from .agent_revert_route import router as agent_revert_router
|
|||
from .airtable_add_connector_route import (
|
||||
router as airtable_add_connector_router,
|
||||
)
|
||||
from app.automations.api import router as automations_router
|
||||
from .chat_comments_routes import router as chat_comments_router
|
||||
from .circleback_webhook_route import router as circleback_webhook_router
|
||||
from .clickup_add_connector_route import router as clickup_add_connector_router
|
||||
|
|
@ -119,3 +120,4 @@ router.include_router(youtube_router) # YouTube playlist resolution
|
|||
router.include_router(prompts_router)
|
||||
router.include_router(memory_router) # User personal memory (memory.md style)
|
||||
router.include_router(team_memory_router) # Search-space team memory
|
||||
router.include_router(automations_router) # Automations CRUD + run history
|
||||
|
|
|
|||
|
|
@ -107,6 +107,12 @@ PERMISSION_DESCRIPTIONS = {
|
|||
"settings:view": "View search space settings",
|
||||
"settings:update": "Modify search space settings",
|
||||
"settings:delete": "Delete the entire search space",
|
||||
# Automations
|
||||
"automations:create": "Create automations from chat or JSON",
|
||||
"automations:read": "View automations, their triggers, and run history",
|
||||
"automations:update": "Edit automations and manage their triggers",
|
||||
"automations:delete": "Remove automations from the search space",
|
||||
"automations:execute": "Manually fire automations",
|
||||
# Full access
|
||||
"*": "Full access to all features and settings",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
"""Agent construction and per-turn event-loop drivers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread
|
||||
from app.tasks.chat.streaming.agent.event_loop import stream_agent_events
|
||||
|
||||
__all__ = ["build_main_agent_for_thread", "stream_agent_events"]
|
||||
49
surfsense_backend/app/tasks/chat/streaming/agent/builder.py
Normal file
49
surfsense_backend/app/tasks/chat/streaming/agent/builder.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
"""Single per-thread agent (re)build path.
|
||||
|
||||
A graph swap mid-turn would corrupt checkpointer state for the same
|
||||
``thread_id``, so both the initial build and any mid-stream 429 recovery rebuild
|
||||
must funnel through this single function.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
|
||||
async def build_main_agent_for_thread(
|
||||
agent_factory: Any,
|
||||
*,
|
||||
llm: Any,
|
||||
search_space_id: int,
|
||||
db_session: Any,
|
||||
connector_service: ConnectorService,
|
||||
checkpointer: Any,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
agent_config: AgentConfig | None,
|
||||
firecrawl_api_key: str | None,
|
||||
thread_visibility: ChatVisibility | None,
|
||||
filesystem_selection: FilesystemSelection | None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
) -> Any:
|
||||
return await agent_factory(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=db_session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=thread_visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
175
surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py
Normal file
175
surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""Per-turn agent event-loop driver.
|
||||
|
||||
Drives ``stream_output`` (graph_stream relay) for one agent turn, then runs the
|
||||
post-stream agent-state inspection: safety-net commit of any staged filesystem
|
||||
state (in case ``aafter_agent`` was skipped), file-operation contract scoring,
|
||||
intent classification, and interrupt detection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware.kb_persistence import (
|
||||
commit_staged_filesystem_state,
|
||||
)
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.streaming.contract.file_contract import (
|
||||
contract_enforcement_active,
|
||||
evaluate_file_contract_outcome,
|
||||
log_file_contract,
|
||||
)
|
||||
from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
|
||||
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
||||
all_interrupt_values,
|
||||
)
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.tasks.chat.streaming.shared.utils import safe_float
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
async def stream_agent_events(
|
||||
agent: Any,
|
||||
config: dict[str, Any],
|
||||
input_data: Any,
|
||||
streaming_service: VercelStreamingService,
|
||||
result: StreamResult,
|
||||
step_prefix: str = "thinking",
|
||||
initial_step_id: str | None = None,
|
||||
initial_step_title: str = "",
|
||||
initial_step_items: list[str] | None = None,
|
||||
*,
|
||||
fallback_commit_search_space_id: int | None = None,
|
||||
fallback_commit_created_by_id: str | None = None,
|
||||
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||
fallback_commit_thread_id: int | None = None,
|
||||
runtime_context: Any = None,
|
||||
content_builder: Any | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream and format ``astream_events`` from the agent.
|
||||
|
||||
Yields SSE-formatted strings; after exhausting, ``result`` carries
|
||||
``accumulated_text`` and interrupt state. See ``StreamResult`` for the
|
||||
side-channel surface populated by the underlying relay.
|
||||
"""
|
||||
async for sse in stream_output(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=input_data,
|
||||
streaming_service=streaming_service,
|
||||
result=result,
|
||||
step_prefix=step_prefix,
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_step_title,
|
||||
initial_step_items=initial_step_items,
|
||||
content_builder=content_builder,
|
||||
runtime_context=runtime_context,
|
||||
):
|
||||
yield sse
|
||||
|
||||
accumulated_text = result.accumulated_text
|
||||
|
||||
state = await agent.aget_state(config)
|
||||
state_values = getattr(state, "values", {}) or {}
|
||||
|
||||
# Safety net: if astream_events was cancelled before
|
||||
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
||||
# (dirty_paths / staged_dirs / pending_moves / pending_deletes /
|
||||
# pending_dir_deletes) is still in the checkpointed state. Run the SAME
|
||||
# shared commit helper so the turn's writes don't get lost on client
|
||||
# disconnect, then push the delta back into the graph using ``as_node=...``
|
||||
# so reducers fire as if the after_agent hook produced it.
|
||||
if (
|
||||
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
|
||||
and fallback_commit_search_space_id is not None
|
||||
and (
|
||||
(state_values.get("dirty_paths") or [])
|
||||
or (state_values.get("staged_dirs") or [])
|
||||
or (state_values.get("pending_moves") or [])
|
||||
or (state_values.get("pending_deletes") or [])
|
||||
or (state_values.get("pending_dir_deletes") or [])
|
||||
)
|
||||
):
|
||||
try:
|
||||
delta = await commit_staged_filesystem_state(
|
||||
state_values,
|
||||
search_space_id=fallback_commit_search_space_id,
|
||||
created_by_id=fallback_commit_created_by_id,
|
||||
filesystem_mode=fallback_commit_filesystem_mode,
|
||||
thread_id=fallback_commit_thread_id,
|
||||
dispatch_events=False,
|
||||
)
|
||||
if delta:
|
||||
await agent.aupdate_state(
|
||||
config,
|
||||
delta,
|
||||
as_node="KnowledgeBasePersistenceMiddleware.after_agent",
|
||||
)
|
||||
except Exception as exc:
|
||||
_perf_log.warning("[stream_agent_events] safety-net commit failed: %s", exc)
|
||||
|
||||
contract_state = state_values.get("file_operation_contract") or {}
|
||||
contract_turn_id = contract_state.get("turn_id")
|
||||
current_turn_id = config.get("configurable", {}).get("turn_id", "")
|
||||
intent_value = contract_state.get("intent")
|
||||
if (
|
||||
isinstance(intent_value, str)
|
||||
and intent_value in ("chat_only", "file_write", "file_read")
|
||||
and contract_turn_id == current_turn_id
|
||||
):
|
||||
result.intent_detected = intent_value
|
||||
if (
|
||||
isinstance(intent_value, str)
|
||||
and intent_value in ("chat_only", "file_write", "file_read")
|
||||
and contract_turn_id != current_turn_id
|
||||
):
|
||||
# Ignore stale intent contracts from previous turns/checkpoints.
|
||||
result.intent_detected = "chat_only"
|
||||
result.intent_confidence = (
|
||||
safe_float(contract_state.get("confidence"), default=0.0)
|
||||
if contract_turn_id == current_turn_id
|
||||
else 0.0
|
||||
)
|
||||
|
||||
if result.intent_detected == "file_write":
|
||||
result.commit_gate_passed, result.commit_gate_reason = (
|
||||
evaluate_file_contract_outcome(result)
|
||||
)
|
||||
if not result.commit_gate_passed and contract_enforcement_active(result):
|
||||
gate_notice = (
|
||||
"I could not complete the requested file write because no successful "
|
||||
"write_file/edit_file operation was confirmed."
|
||||
)
|
||||
gate_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(gate_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_start(gate_text_id)
|
||||
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_delta(gate_text_id, gate_notice)
|
||||
yield streaming_service.format_text_end(gate_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_end(gate_text_id)
|
||||
yield streaming_service.format_terminal_info(gate_notice, "error")
|
||||
accumulated_text = gate_notice
|
||||
else:
|
||||
result.commit_gate_passed = True
|
||||
result.commit_gate_reason = ""
|
||||
|
||||
result.accumulated_text = accumulated_text
|
||||
log_file_contract("turn_outcome", result)
|
||||
|
||||
pending_values = all_interrupt_values(state)
|
||||
if pending_values:
|
||||
result.is_interrupted = True
|
||||
# One frame per paused subagent so each parallel HITL renders its own
|
||||
# approval card on the wire. Order matches ``state.interrupts``, which
|
||||
# the resume slicer in
|
||||
# ``checkpointed_subagent_middleware.resume_routing`` consumes in the
|
||||
# same order — keeping emit and resume in lock-step.
|
||||
for interrupt_value in pending_values:
|
||||
yield streaming_service.format_interrupt_request(interrupt_value)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""Pre-agent context shaping: mentioned-doc rendering and todos extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.context.deepagents_todos import (
|
||||
extract_todos_from_deepagents,
|
||||
)
|
||||
from app.tasks.chat.streaming.context.mentioned_docs import (
|
||||
format_mentioned_surfsense_docs_as_context,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"extract_todos_from_deepagents",
|
||||
"format_mentioned_surfsense_docs_as_context",
|
||||
]
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""Extract todos from a deepagents ``TodoListMiddleware`` ``Command`` output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_todos_from_deepagents(command_output: Any) -> dict:
|
||||
"""Normalize todos out of a deepagents ``Command`` or dict payload.
|
||||
|
||||
deepagents returns a ``Command`` whose ``update['todos']`` is a list of
|
||||
``{'content': str, 'status': str}`` dicts. The UI expects the same shape,
|
||||
so no transformation is required — only extraction.
|
||||
"""
|
||||
todos_data: list = []
|
||||
if hasattr(command_output, "update"):
|
||||
update = command_output.update
|
||||
todos_data = update.get("todos", [])
|
||||
elif isinstance(command_output, dict):
|
||||
if "todos" in command_output:
|
||||
todos_data = command_output.get("todos", [])
|
||||
elif "update" in command_output and isinstance(
|
||||
command_output["update"], dict
|
||||
):
|
||||
todos_data = command_output["update"].get("todos", [])
|
||||
|
||||
return {"todos": todos_data}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
"""Render user-mentioned SurfSense docs as XML context for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from app.db import SurfsenseDocsDocument
|
||||
from app.utils.surfsense_docs import surfsense_docs_public_url
|
||||
|
||||
|
||||
def format_mentioned_surfsense_docs_as_context(
|
||||
documents: list[SurfsenseDocsDocument],
|
||||
) -> str:
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
context_parts = ["<mentioned_surfsense_docs>"]
|
||||
context_parts.append(
|
||||
"The user has explicitly mentioned the following SurfSense documentation pages. "
|
||||
"These are official documentation about how to use SurfSense and should be used to answer questions about the application. "
|
||||
"Use [citation:CHUNK_ID] format for citations (e.g., [citation:doc-123])."
|
||||
)
|
||||
|
||||
for doc in documents:
|
||||
public_url = surfsense_docs_public_url(doc.source)
|
||||
metadata_json = json.dumps(
|
||||
{"source": doc.source, "public_url": public_url}, ensure_ascii=False
|
||||
)
|
||||
|
||||
context_parts.append("<document>")
|
||||
context_parts.append("<document_metadata>")
|
||||
context_parts.append(f" <document_id>doc-{doc.id}</document_id>")
|
||||
context_parts.append(" <document_type>SURFSENSE_DOCS</document_type>")
|
||||
context_parts.append(f" <title><![CDATA[{doc.title}]]></title>")
|
||||
context_parts.append(f" <url><![CDATA[{public_url}]]></url>")
|
||||
context_parts.append(
|
||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>"
|
||||
)
|
||||
context_parts.append("</document_metadata>")
|
||||
context_parts.append("")
|
||||
context_parts.append("<document_content>")
|
||||
|
||||
if hasattr(doc, "chunks") and doc.chunks:
|
||||
for chunk in doc.chunks:
|
||||
context_parts.append(
|
||||
f" <chunk id='doc-{chunk.id}'><![CDATA[{chunk.content}]]></chunk>"
|
||||
)
|
||||
else:
|
||||
context_parts.append(
|
||||
f" <chunk id='doc-0'><![CDATA[{doc.content}]]></chunk>"
|
||||
)
|
||||
|
||||
context_parts.append("</document_content>")
|
||||
context_parts.append("</document>")
|
||||
context_parts.append("")
|
||||
|
||||
context_parts.append("</mentioned_surfsense_docs>")
|
||||
return "\n".join(context_parts)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""File-operation contract evaluation and logging."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.contract.file_contract import (
|
||||
contract_enforcement_active,
|
||||
evaluate_file_contract_outcome,
|
||||
log_file_contract,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"contract_enforcement_active",
|
||||
"evaluate_file_contract_outcome",
|
||||
"log_file_contract",
|
||||
]
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
"""File-operation contract: when to enforce, how to score, how to log."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
def contract_enforcement_active(result: StreamResult) -> bool:
|
||||
# Enforce only in desktop local-folder mode. Kept deterministic, no
|
||||
# env-driven progression modes.
|
||||
return result.filesystem_mode == "desktop_local_folder"
|
||||
|
||||
|
||||
def evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]:
|
||||
if result.intent_detected != "file_write":
|
||||
return True, ""
|
||||
if not result.write_attempted:
|
||||
return False, "no_write_attempt"
|
||||
if not result.write_succeeded:
|
||||
return False, "write_failed"
|
||||
if not result.verification_succeeded:
|
||||
return False, "verification_failed"
|
||||
return True, ""
|
||||
|
||||
|
||||
def log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
|
||||
payload: dict[str, Any] = {
|
||||
"stage": stage,
|
||||
"request_id": result.request_id or "unknown",
|
||||
"turn_id": result.turn_id or "unknown",
|
||||
"chat_id": (
|
||||
result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown"
|
||||
),
|
||||
"filesystem_mode": result.filesystem_mode,
|
||||
"client_platform": result.client_platform,
|
||||
"intent_detected": result.intent_detected,
|
||||
"intent_confidence": result.intent_confidence,
|
||||
"write_attempted": result.write_attempted,
|
||||
"write_succeeded": result.write_succeeded,
|
||||
"verification_succeeded": result.verification_succeeded,
|
||||
"commit_gate_passed": result.commit_gate_passed,
|
||||
"commit_gate_reason": result.commit_gate_reason or None,
|
||||
}
|
||||
payload.update(extra)
|
||||
_perf_log.info(
|
||||
"[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False)
|
||||
)
|
||||
17
surfsense_backend/app/tasks/chat/streaming/flows/__init__.py
Normal file
17
surfsense_backend/app/tasks/chat/streaming/flows/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Top-level streaming flows: ``new_chat`` and ``resume_chat`` orchestrators.
|
||||
|
||||
Re-exports the public entry points so callers can write::
|
||||
|
||||
from app.tasks.chat.streaming.flows import stream_new_chat, stream_resume_chat
|
||||
|
||||
The orchestrators themselves live under ``new_chat/orchestrator.py`` and
|
||||
``resume_chat/orchestrator.py`` (slim composition of the per-concern modules in
|
||||
each flow folder and the building blocks in ``shared/``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.flows.new_chat import stream_new_chat
|
||||
from app.tasks.chat.streaming.flows.resume_chat import stream_resume_chat
|
||||
|
||||
__all__ = ["stream_new_chat", "stream_resume_chat"]
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""New-chat streaming flow.
|
||||
|
||||
The public entry point ``stream_new_chat`` is the slim coroutine in
|
||||
``orchestrator.py`` that composes the per-concern modules in this folder and
|
||||
the building blocks under ``flows/shared/``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.flows.new_chat.orchestrator import stream_new_chat
|
||||
|
||||
__all__ = ["stream_new_chat"]
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""Resolve the auto-pin for the *initial* turn config.
|
||||
|
||||
Auto-pin (``selected_llm_config_id=0``) picks the best eligible LLM config for
|
||||
this thread / search space / user, optionally filtered to vision-capable
|
||||
configs when the turn carries images.
|
||||
|
||||
Errors classified here:
|
||||
|
||||
* ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` — the auto-pin pool has no
|
||||
vision-capable cfg for an image-bearing turn. The same gate fires later
|
||||
in ``llm_capability`` for explicit selections; mapping both to the same
|
||||
code keeps the FE error UI consistent.
|
||||
* ``SERVER_ERROR`` — any other ``ValueError`` from the resolver.
|
||||
|
||||
This module owns *initial* pin resolution; the rate-limit recovery loop has
|
||||
its own narrower auto-pin call (with ``exclude_config_ids``) in
|
||||
``flows/shared/rate_limit_recovery``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.observability import otel as ot
|
||||
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoPinResult:
|
||||
"""Outcome of ``resolve_initial_auto_pin``.
|
||||
|
||||
``llm_config_id`` is set when ``error`` is ``None``; ``error`` carries the
|
||||
classified user-facing message plus error code/kind so the orchestrator can
|
||||
emit one terminal-error SSE frame.
|
||||
"""
|
||||
|
||||
llm_config_id: int | None
|
||||
error: tuple[str, str, Literal["user_error", "server_error"]] | None
|
||||
|
||||
|
||||
async def resolve_initial_auto_pin(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
selected_llm_config_id: int,
|
||||
requires_image_input: bool,
|
||||
requested_llm_config_id: int,
|
||||
) -> AutoPinResult:
|
||||
"""Run the resolver and classify any ``ValueError`` for the SSE error path."""
|
||||
try:
|
||||
pinned = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=selected_llm_config_id,
|
||||
requires_image_input=requires_image_input,
|
||||
)
|
||||
ot.add_event(
|
||||
"model.pin.resolved",
|
||||
{
|
||||
"pin.requested_id": requested_llm_config_id,
|
||||
"pin.resolved_id": pinned.resolved_llm_config_id,
|
||||
"pin.requires_image_input": requires_image_input,
|
||||
},
|
||||
)
|
||||
return AutoPinResult(
|
||||
llm_config_id=pinned.resolved_llm_config_id, error=None
|
||||
)
|
||||
except ValueError as pin_error:
|
||||
# The "no vision-capable cfg" path raises a ValueError whose message
|
||||
# we map to the friendly image-input SSE error so the user sees the
|
||||
# same message regardless of whether the gate fired in the resolver or
|
||||
# in ``llm_capability.assert_vision_capability_for_image_turn``.
|
||||
is_vision_failure = (
|
||||
requires_image_input and "vision-capable" in str(pin_error)
|
||||
)
|
||||
error_code = (
|
||||
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
|
||||
if is_vision_failure
|
||||
else "SERVER_ERROR"
|
||||
)
|
||||
error_kind: Literal["user_error", "server_error"] = (
|
||||
"user_error" if is_vision_failure else "server_error"
|
||||
)
|
||||
if is_vision_failure:
|
||||
ot.add_event("quota.denied", {"quota.code": error_code})
|
||||
return AutoPinResult(
|
||||
llm_config_id=None, error=(str(pin_error), error_code, error_kind)
|
||||
)
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""Build and emit the first ``thinking-1`` step for a new-chat turn.
|
||||
|
||||
The step title and "Processing X" items are derived from what the user sent
|
||||
(text snippet, image count, mentioned doc titles) so the FE can render a
|
||||
meaningful placeholder while the agent stream warms up.
|
||||
|
||||
``thinking-1`` is the canonical id for this step — every subsequent
|
||||
``thinking-N`` produced by ``stream_agent_events`` folds into the same
|
||||
singleton ``data-thinking-steps`` part on the FE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.db import SurfsenseDocsDocument
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitialThinkingStep:
|
||||
"""Resolved fields passed both into the SSE frame and the builder hook.
|
||||
|
||||
``items`` is the bullet list under the step title; ``title`` is the
|
||||
one-line step header. ``step_id`` is hard-coded ``thinking-1`` so the FE
|
||||
Timeline can de-duplicate against the prior assistant message on resume.
|
||||
"""
|
||||
|
||||
step_id: str
|
||||
title: str
|
||||
items: list[str]
|
||||
|
||||
|
||||
def build_initial_thinking_step(
|
||||
*,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument],
|
||||
) -> InitialThinkingStep:
|
||||
if mentioned_surfsense_docs:
|
||||
title = "Analyzing referenced content"
|
||||
action_verb = "Analyzing"
|
||||
else:
|
||||
title = "Understanding your request"
|
||||
action_verb = "Processing"
|
||||
|
||||
processing_parts: list[str] = []
|
||||
if user_query.strip():
|
||||
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
|
||||
processing_parts.append(query_text)
|
||||
elif user_image_data_urls:
|
||||
processing_parts.append(f"[{len(user_image_data_urls)} image(s)]")
|
||||
else:
|
||||
processing_parts.append("(message)")
|
||||
|
||||
if mentioned_surfsense_docs:
|
||||
doc_names: list[str] = []
|
||||
for doc in mentioned_surfsense_docs:
|
||||
t = doc.title
|
||||
if len(t) > 30:
|
||||
t = t[:27] + "..."
|
||||
doc_names.append(t)
|
||||
if len(doc_names) == 1:
|
||||
processing_parts.append(f"[{doc_names[0]}]")
|
||||
else:
|
||||
processing_parts.append(f"[{len(doc_names)} docs]")
|
||||
|
||||
items = [f"{action_verb}: {' '.join(processing_parts)}"]
|
||||
return InitialThinkingStep(step_id="thinking-1", title=title, items=items)
|
||||
|
||||
|
||||
def iter_initial_thinking_step_frame(
|
||||
step: InitialThinkingStep,
|
||||
*,
|
||||
streaming_service: VercelStreamingService,
|
||||
content_builder: Any | None,
|
||||
) -> Iterator[str]:
|
||||
"""Drive both the SSE emission and the builder hook for the initial step.
|
||||
|
||||
The FE folds this step into the same singleton ``data-thinking-steps`` part
|
||||
as everything the agent stream emits later, so we mirror that fold
|
||||
server-side by driving the builder lifecycle ourselves.
|
||||
"""
|
||||
if content_builder is not None:
|
||||
content_builder.on_thinking_step(
|
||||
step.step_id, step.title, "in_progress", step.items
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=step.step_id,
|
||||
title=step.title,
|
||||
status="in_progress",
|
||||
items=step.items,
|
||||
)
|
||||
|
|
@ -0,0 +1,264 @@
|
|||
r"""Assemble the LangGraph ``input_state`` for the new-chat turn.
|
||||
|
||||
Pipeline:
|
||||
|
||||
1. **History bootstrap** — only for cloned chats with no LangGraph checkpoint
|
||||
yet; flips the per-thread ``needs_history_bootstrap`` flag back to False
|
||||
once the rows are loaded.
|
||||
2. **Mentioned SurfSense docs** — eager-load chunks so the formatter has the
|
||||
full content without a second roundtrip.
|
||||
3. **Recent reports** — top 3 by id desc with non-null content, so the LLM
|
||||
can resolve ``report_id`` for versioning without spelunking history.
|
||||
4. **@-mention resolve** (cloud mode) — substitute ``@title`` tokens in the
|
||||
query with canonical ``\`/documents/...\``` paths the LLM expects.
|
||||
5. **Context block render** — XML-wrap surfsense docs + reports, prepend to
|
||||
the rewritten query, optionally prefix with display name for SEARCH_SPACE
|
||||
visibility.
|
||||
6. **HumanMessage** — multimodal content if images are attached.
|
||||
|
||||
Returns the assembled ``input_state`` dict plus side-channel data the
|
||||
orchestrator needs downstream (``accepted_folder_ids`` for runtime context;
|
||||
``mentioned_surfsense_docs`` for the initial thinking step).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatThread,
|
||||
Report,
|
||||
SurfsenseDocsDocument,
|
||||
)
|
||||
from app.tasks.chat.streaming.context.mentioned_docs import (
|
||||
format_mentioned_surfsense_docs_as_context,
|
||||
)
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
from app.utils.user_message_multimodal import build_human_message_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewChatInputState:
|
||||
"""Everything ``build_new_chat_input_state`` produces.
|
||||
|
||||
``input_state`` is fed straight to the agent. ``accepted_folder_ids``
|
||||
feeds the runtime context (the resolver may have dropped some chips).
|
||||
``mentioned_surfsense_docs`` is consumed by the initial thinking-step
|
||||
builder for the FE placeholder before the agent stream starts.
|
||||
"""
|
||||
|
||||
input_state: dict[str, Any]
|
||||
accepted_folder_ids: list[int]
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument]
|
||||
|
||||
|
||||
async def build_new_chat_input_state(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None,
|
||||
mentioned_folder_ids: list[int] | None,
|
||||
mentioned_documents: list[dict[str, Any]] | None,
|
||||
needs_history_bootstrap: bool,
|
||||
thread_visibility: ChatVisibility,
|
||||
current_user_display_name: str | None,
|
||||
filesystem_mode: str,
|
||||
request_id: str | None,
|
||||
turn_id: str,
|
||||
) -> NewChatInputState:
|
||||
langchain_messages: list[Any] = []
|
||||
|
||||
if needs_history_bootstrap:
|
||||
langchain_messages = await bootstrap_history_from_db(
|
||||
session, chat_id, thread_visibility=thread_visibility
|
||||
)
|
||||
thread_result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
thread = thread_result.scalars().first()
|
||||
if thread:
|
||||
thread.needs_history_bootstrap = False
|
||||
await session.commit()
|
||||
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
|
||||
if mentioned_surfsense_doc_ids:
|
||||
result = await session.execute(
|
||||
select(SurfsenseDocsDocument)
|
||||
.options(selectinload(SurfsenseDocsDocument.chunks))
|
||||
.filter(SurfsenseDocsDocument.id.in_(mentioned_surfsense_doc_ids))
|
||||
)
|
||||
mentioned_surfsense_docs = list(result.scalars().all())
|
||||
|
||||
# Top 3 reports keyed by id desc (newest first) with content present,
|
||||
# surfaced inline so the LLM resolves ``report_id`` for versioning without
|
||||
# digging through conversation history.
|
||||
recent_reports_result = await session.execute(
|
||||
select(Report)
|
||||
.filter(
|
||||
Report.thread_id == chat_id,
|
||||
Report.content.isnot(None),
|
||||
)
|
||||
.order_by(Report.id.desc())
|
||||
.limit(3)
|
||||
)
|
||||
recent_reports = list(recent_reports_result.scalars().all())
|
||||
|
||||
agent_user_query, accepted_folder_ids = await _resolve_mentions_for_query(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_query=user_query,
|
||||
filesystem_mode=filesystem_mode,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
mentioned_documents=mentioned_documents,
|
||||
)
|
||||
|
||||
final_query = _render_query_with_context(
|
||||
agent_user_query=agent_user_query,
|
||||
mentioned_surfsense_docs=mentioned_surfsense_docs,
|
||||
recent_reports=recent_reports,
|
||||
)
|
||||
|
||||
if thread_visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name:
|
||||
final_query = f"**[{current_user_display_name}]:** {final_query}"
|
||||
|
||||
human_content = build_human_message_content(
|
||||
final_query, list(user_image_data_urls or ())
|
||||
)
|
||||
langchain_messages.append(HumanMessage(content=human_content))
|
||||
|
||||
input_state = {
|
||||
"messages": langchain_messages,
|
||||
"search_space_id": search_space_id,
|
||||
"request_id": request_id or "unknown",
|
||||
"turn_id": turn_id,
|
||||
}
|
||||
|
||||
return NewChatInputState(
|
||||
input_state=input_state,
|
||||
accepted_folder_ids=accepted_folder_ids,
|
||||
mentioned_surfsense_docs=mentioned_surfsense_docs,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_mentions_for_query(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
user_query: str,
|
||||
filesystem_mode: str,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None,
|
||||
mentioned_folder_ids: list[int] | None,
|
||||
mentioned_documents: list[dict[str, Any]] | None,
|
||||
) -> tuple[str, list[int]]:
|
||||
r"""Resolve @-mention chips and rewrite the user query to canonical paths.
|
||||
|
||||
Cloud mode only: local-folder mode keeps the legacy ``@title`` text path
|
||||
(mention support there is a follow-up task — the path scheme is
|
||||
mount-rooted and the picker UI both need separate work).
|
||||
|
||||
The substitution lands in the returned ``agent_user_query`` ONLY — the
|
||||
original ``user_query`` (with ``@title`` tokens) flows untouched into
|
||||
``persist_user_turn`` so chip rendering on reload still works
|
||||
(``UserTextPart`` → ``parseMentionSegments`` matches ``@title``, not
|
||||
``\`/documents/...\```). It also feeds the human-readable surfaces — SSE
|
||||
"Processing X" status, auto thread title, memory seed — which all want
|
||||
what the user typed.
|
||||
"""
|
||||
agent_user_query = user_query
|
||||
accepted_folder_ids: list[int] = []
|
||||
|
||||
has_any_mention = bool(
|
||||
mentioned_document_ids
|
||||
or mentioned_surfsense_doc_ids
|
||||
or mentioned_folder_ids
|
||||
or mentioned_documents
|
||||
)
|
||||
if filesystem_mode != FilesystemMode.CLOUD.value or not has_any_mention:
|
||||
return agent_user_query, accepted_folder_ids
|
||||
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
chip_objs: list[MentionedDocumentInfo] | None = None
|
||||
if mentioned_documents:
|
||||
chip_objs = []
|
||||
for raw in mentioned_documents:
|
||||
if isinstance(raw, MentionedDocumentInfo):
|
||||
chip_objs.append(raw)
|
||||
continue
|
||||
try:
|
||||
chip_objs.append(MentionedDocumentInfo.model_validate(raw))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"stream_new_chat: dropping malformed mention chip %r", raw
|
||||
)
|
||||
|
||||
resolved = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
mentioned_documents=chip_objs,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
)
|
||||
agent_user_query = substitute_in_text(user_query, resolved.token_to_path)
|
||||
accepted_folder_ids = resolved.mentioned_folder_ids
|
||||
return agent_user_query, accepted_folder_ids
|
||||
|
||||
|
||||
def _render_query_with_context(
|
||||
*,
|
||||
agent_user_query: str,
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument],
|
||||
recent_reports: list[Report],
|
||||
) -> str:
|
||||
"""Prepend surfsense-docs + recent-reports XML blocks to the user query."""
|
||||
context_parts: list[str] = []
|
||||
|
||||
if mentioned_surfsense_docs:
|
||||
context_parts.append(
|
||||
format_mentioned_surfsense_docs_as_context(mentioned_surfsense_docs)
|
||||
)
|
||||
|
||||
if recent_reports:
|
||||
report_lines: list[str] = []
|
||||
for r in recent_reports:
|
||||
report_lines.append(
|
||||
f' - report_id={r.id}, title="{r.title}", '
|
||||
f'style="{r.report_style or "detailed"}"'
|
||||
)
|
||||
reports_listing = "\n".join(report_lines)
|
||||
context_parts.append(
|
||||
"<report_context>\n"
|
||||
"Previously generated reports in this conversation:\n"
|
||||
f"{reports_listing}\n\n"
|
||||
"If the user wants to MODIFY, REVISE, UPDATE, or ADD to one of "
|
||||
"these reports, set parent_report_id to the relevant report_id above.\n"
|
||||
"If the user wants a completely NEW report on a different topic, "
|
||||
"leave parent_report_id unset.\n"
|
||||
"</report_context>"
|
||||
)
|
||||
|
||||
if context_parts:
|
||||
context = "\n\n".join(context_parts)
|
||||
return f"{context}\n\n<user_query>{agent_user_query}</user_query>"
|
||||
|
||||
return agent_user_query
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
"""Vision-capability gate for image-bearing turns.
|
||||
|
||||
Capability safety net for explicit (non-auto-pin) selections: a turn carrying
|
||||
user-uploaded images cannot be routed to a chat config that LiteLLM's
|
||||
authoritative model map *explicitly* marks as text-only (``supports_vision``
|
||||
set to False). The check is intentionally narrow — it only fires when LiteLLM
|
||||
is *certain* the model can't accept image input; unknown / unmapped /
|
||||
vision-capable models pass through.
|
||||
|
||||
Without this guard a known-text-only model would 404 at the provider with
|
||||
``"No endpoints found that support image input"``, surfacing as an opaque
|
||||
``SERVER_ERROR`` SSE chunk; failing here lets us return a friendly message that
|
||||
tells the user what to change.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.observability import otel as ot
|
||||
|
||||
|
||||
def check_image_input_capability(
|
||||
*,
|
||||
user_image_data_urls: list[str] | None,
|
||||
agent_config: AgentConfig | None,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Return ``(user_message, error_code)`` when the gate trips, else ``None``.
|
||||
|
||||
The caller emits one terminal-error SSE frame on a non-``None`` return.
|
||||
"""
|
||||
if not (user_image_data_urls and agent_config is not None):
|
||||
return None
|
||||
|
||||
from app.services.provider_capabilities import is_known_text_only_chat_model
|
||||
|
||||
agent_litellm_params = agent_config.litellm_params or {}
|
||||
agent_base_model = (
|
||||
agent_litellm_params.get("base_model")
|
||||
if isinstance(agent_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
if not is_known_text_only_chat_model(
|
||||
provider=agent_config.provider,
|
||||
model_name=agent_config.model_name,
|
||||
base_model=agent_base_model,
|
||||
custom_provider=agent_config.custom_provider,
|
||||
):
|
||||
return None
|
||||
|
||||
model_label = agent_config.config_name or agent_config.model_name or "model"
|
||||
ot.add_event(
|
||||
"quota.denied", {"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"}
|
||||
)
|
||||
return (
|
||||
(
|
||||
f"The selected model ({model_label}) does not support "
|
||||
"image input. Switch to a vision-capable model "
|
||||
"(e.g. GPT-4o, Claude, Gemini) or remove the image "
|
||||
"attachment and try again."
|
||||
),
|
||||
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
|
||||
)
|
||||
|
|
@ -0,0 +1,868 @@
|
|||
"""``stream_new_chat`` — public entry point for a fresh chat turn.
|
||||
|
||||
Slim composition layer over the per-concern modules in this folder and the
|
||||
building blocks under ``flows/shared/``. Each phase corresponds to a numbered
|
||||
block in the surrounding code so the on-the-wire ordering stays explicit:
|
||||
|
||||
1. Validation / config — auto-pin, LLM bundle, capability, premium reserve.
|
||||
2. Concurrent persistence + pre-stream setup — spawn DB writes, build the
|
||||
connector, fetch the checkpointer, build the agent.
|
||||
3. Input assembly — history bootstrap, mentions, surfsense docs, reports.
|
||||
4. First SSE frames — message_start, start_step, turn-info, turn-status.
|
||||
5. Persistence join + message-id frames (ghost-thread protection).
|
||||
6. Initial thinking step + title task + runtime context.
|
||||
7. Stream loop with in-stream rate-limit recovery + mid-stream title emit.
|
||||
8. Finalize — premium debit, token-usage SSE, finish frames.
|
||||
9. Exception branch — classify, emit terminal error, finish frames.
|
||||
10. Finally — premium release, session close, assistant finalize, GC, span.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from functools import partial
|
||||
from typing import Any, Literal
|
||||
|
||||
import anyio
|
||||
|
||||
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.middleware.busy_mutex import end_turn
|
||||
from app.config import config as _app_config
|
||||
from app.db import ChatVisibility, async_session_maker
|
||||
from app.observability import otel as ot
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||
from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread
|
||||
from app.tasks.chat.streaming.contract.file_contract import log_file_contract
|
||||
from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error
|
||||
from app.tasks.chat.streaming.flows.new_chat.auto_pin import resolve_initial_auto_pin
|
||||
from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import (
|
||||
build_initial_thinking_step,
|
||||
iter_initial_thinking_step_frame,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.input_state import (
|
||||
build_new_chat_input_state,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.llm_capability import (
|
||||
check_image_input_capability,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import (
|
||||
await_persist_task,
|
||||
spawn_persist_assistant_shell_task,
|
||||
spawn_persist_user_task,
|
||||
spawn_set_ai_responding_bg,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.runtime_context import (
|
||||
build_new_chat_runtime_context,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.title_gen import (
|
||||
await_pending_title_update,
|
||||
maybe_emit_title_update,
|
||||
spawn_title_task,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.assistant_finalize import (
|
||||
finalize_assistant_message,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame
|
||||
from app.tasks.chat.streaming.flows.shared.finally_cleanup import (
|
||||
close_session_and_clear_ai_responding,
|
||||
run_gc_pass,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.first_frames import (
|
||||
iter_final_frames,
|
||||
iter_initial_frames,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
|
||||
from app.tasks.chat.streaming.flows.shared.pre_stream_setup import (
|
||||
get_chat_checkpointer,
|
||||
setup_connector_and_firecrawl,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.premium_quota import (
|
||||
PremiumReservation,
|
||||
finalize_premium,
|
||||
needs_premium_quota,
|
||||
release_premium,
|
||||
reserve_premium,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import (
|
||||
can_recover_provider_rate_limit,
|
||||
log_rate_limit_recovered,
|
||||
reroute_to_next_auto_pin,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.span import (
|
||||
close_chat_request_span,
|
||||
open_chat_request_span,
|
||||
set_agent_mode,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.stream_loop import run_stream_loop
|
||||
from app.tasks.chat.streaming.flows.shared.terminal_error import (
|
||||
handle_terminal_exception,
|
||||
)
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
# Holds spawned background tasks (set_ai_responding, persist_user, persist_asst)
|
||||
# so the GC doesn't drop them before they finish. Kept at module level so it
|
||||
# survives across turns within one process.
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
async def stream_new_chat(
|
||||
user_query: str,
|
||||
search_space_id: int,
|
||||
chat_id: int,
|
||||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None,
|
||||
mentioned_folder_ids: list[int] | None = None,
|
||||
mentioned_documents: list[dict[str, Any]] | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
needs_history_bootstrap: bool = False,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
current_user_display_name: str | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
request_id: str | None = None,
|
||||
user_image_data_urls: list[str] | None = None,
|
||||
flow: Literal["new", "regenerate"] = "new",
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream a new chat turn using the SurfSense deep agent.
|
||||
|
||||
Uses the Vercel AI SDK Data Stream Protocol (SSE). ``chat_id`` is the
|
||||
LangGraph thread id (durable conversation memory via the checkpointer).
|
||||
Manages its own database session so cleanup runs even when Starlette
|
||||
cancels the task on client disconnect.
|
||||
"""
|
||||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
|
||||
fs_platform = (
|
||||
filesystem_selection.client_platform.value if filesystem_selection else "web"
|
||||
)
|
||||
stream_result.request_id = request_id
|
||||
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
|
||||
stream_result.filesystem_mode = fs_mode
|
||||
stream_result.client_platform = fs_platform
|
||||
|
||||
chat_agent_mode = "unknown"
|
||||
chat_outcome = "success"
|
||||
chat_error_category: str | None = None
|
||||
chat_span_cm, chat_span = open_chat_request_span(
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
flow=flow,
|
||||
request_id=request_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
filesystem_mode=fs_mode,
|
||||
client_platform=fs_platform,
|
||||
agent_mode=chat_agent_mode,
|
||||
)
|
||||
log_file_contract("turn_start", stream_result)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] filesystem_mode=%s client_platform=%s",
|
||||
fs_mode,
|
||||
fs_platform,
|
||||
)
|
||||
log_system_snapshot("stream_new_chat_START")
|
||||
|
||||
from app.services.token_tracking_service import start_turn
|
||||
|
||||
accumulator = start_turn()
|
||||
|
||||
premium_reservation: PremiumReservation | None = None
|
||||
busy_error_raised = False
|
||||
|
||||
emit_stream_error = partial(
|
||||
emit_stream_terminal_error,
|
||||
streaming_service=streaming_service,
|
||||
flow=flow,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
session = async_session_maker()
|
||||
# Declared at function scope so SSE-yield join points and the finally
|
||||
# clause see them on every exit path.
|
||||
persist_user_task: asyncio.Task[int | None] | None = None
|
||||
persist_asst_task: asyncio.Task[int | None] | None = None
|
||||
try:
|
||||
spawn_set_ai_responding_bg(
|
||||
chat_id=chat_id, user_id=user_id, background_tasks=_background_tasks
|
||||
)
|
||||
|
||||
# --- Block 1: LLM config + capability ---
|
||||
|
||||
requested_llm_config_id = llm_config_id
|
||||
requires_image_input = bool(user_image_data_urls)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
pin_result = await resolve_initial_auto_pin(
|
||||
session,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=llm_config_id,
|
||||
requires_image_input=requires_image_input,
|
||||
requested_llm_config_id=requested_llm_config_id,
|
||||
)
|
||||
if pin_result.error is not None:
|
||||
message, error_code, error_kind = pin_result.error
|
||||
yield emit_stream_error(
|
||||
message=message, error_kind=error_kind, error_code=error_code
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
llm_config_id = pin_result.llm_config_id # type: ignore[assignment]
|
||||
|
||||
llm, agent_config, llm_load_error = await load_llm_bundle(
|
||||
session, config_id=llm_config_id, search_space_id=search_space_id
|
||||
)
|
||||
if llm_load_error:
|
||||
yield emit_stream_error(
|
||||
message=llm_load_error,
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
|
||||
time.perf_counter() - _t0,
|
||||
llm_config_id,
|
||||
)
|
||||
|
||||
capability_error = check_image_input_capability(
|
||||
user_image_data_urls=user_image_data_urls, agent_config=agent_config
|
||||
)
|
||||
if capability_error is not None:
|
||||
message, error_code = capability_error
|
||||
yield emit_stream_error(
|
||||
message=message,
|
||||
error_kind="user_error",
|
||||
error_code=error_code,
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if needs_premium_quota(agent_config, user_id):
|
||||
premium_reservation = await reserve_premium(
|
||||
agent_config=agent_config, user_id=user_id # type: ignore[arg-type]
|
||||
)
|
||||
if not premium_reservation.allowed:
|
||||
ot.add_event("quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"})
|
||||
if requested_llm_config_id == 0:
|
||||
pin_fallback = await resolve_initial_auto_pin(
|
||||
session,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=requires_image_input,
|
||||
requested_llm_config_id=requested_llm_config_id,
|
||||
)
|
||||
if pin_fallback.error is not None:
|
||||
message, error_code, error_kind = pin_fallback.error
|
||||
yield emit_stream_error(
|
||||
message=message,
|
||||
error_kind=error_kind,
|
||||
error_code=error_code,
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
llm_config_id = pin_fallback.llm_config_id # type: ignore[assignment]
|
||||
ot.add_event(
|
||||
"model.repin",
|
||||
{
|
||||
"repin.reason": "premium_quota_exhausted",
|
||||
"repin.to_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
llm, agent_config, llm_load_error = await load_llm_bundle(
|
||||
session,
|
||||
config_id=llm_config_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if llm_load_error:
|
||||
yield emit_stream_error(
|
||||
message=llm_load_error,
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
premium_reservation = None
|
||||
# Re-route to free fallback logged via the structured
|
||||
# stream-error logger so cost/analytics see the auto-switch.
|
||||
from app.tasks.chat.streaming.errors.classifier import (
|
||||
log_chat_stream_error,
|
||||
)
|
||||
|
||||
log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind="premium_quota_exhausted",
|
||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Premium quota exhausted on pinned model; "
|
||||
"auto-fallback switched to a free model"
|
||||
),
|
||||
extra={
|
||||
"fallback_config_id": llm_config_id,
|
||||
"auto_fallback": True,
|
||||
},
|
||||
)
|
||||
else:
|
||||
yield emit_stream_error(
|
||||
message=(
|
||||
"Buy more tokens to continue with this model, or "
|
||||
"switch to a free model"
|
||||
),
|
||||
error_kind="premium_quota_exhausted",
|
||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
extra={
|
||||
"resolved_config_id": llm_config_id,
|
||||
"auto_fallback": False,
|
||||
},
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if not llm:
|
||||
yield emit_stream_error(
|
||||
message="Failed to create LLM instance",
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# --- Block 2: Spawn concurrent persistence; build pre-stream setup ---
|
||||
|
||||
persist_user_task = spawn_persist_user_task(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
mentioned_documents=mentioned_documents,
|
||||
background_tasks=_background_tasks,
|
||||
)
|
||||
persist_asst_task = spawn_persist_assistant_shell_task(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
background_tasks=_background_tasks,
|
||||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
|
||||
session, search_space_id=search_space_id
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
checkpointer = await get_chat_checkpointer()
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED)
|
||||
chat_agent_mode = "multi" if use_multi_agent else "single"
|
||||
set_agent_mode(chat_span, chat_agent_mode)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent_factory = (
|
||||
create_multi_agent_chat_deep_agent
|
||||
if use_multi_agent
|
||||
else create_surfsense_deep_agent
|
||||
)
|
||||
# Build the agent inline. Provider 429s surface through the in-stream
|
||||
# recovery loop below, which repins the thread to an eligible
|
||||
# alternative config and rebuilds the agent before the user sees any
|
||||
# output.
|
||||
agent = await build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# --- Block 3: Input assembly ---
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
assembled = await build_new_chat_input_state(
|
||||
session,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
mentioned_documents=mentioned_documents,
|
||||
needs_history_bootstrap=needs_history_bootstrap,
|
||||
thread_visibility=visibility,
|
||||
current_user_display_name=current_user_display_name,
|
||||
filesystem_mode=fs_mode,
|
||||
request_id=request_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
)
|
||||
input_state = assembled.input_state
|
||||
accepted_folder_ids = assembled.accepted_folder_ids
|
||||
mentioned_surfsense_docs = assembled.mentioned_surfsense_docs
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] History bootstrap + doc/report queries in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
# All pre-streaming DB reads done. Commit to release the transaction
|
||||
# and its ACCESS SHARE locks so we don't block DDL (e.g. migrations)
|
||||
# for the entire LLM streaming duration. Tools that need DB access
|
||||
# during streaming start their own short-lived transactions (or use
|
||||
# isolated sessions).
|
||||
await session.commit()
|
||||
# Detach heavy ORM objects (documents with chunks, reports, etc.)
|
||||
# from the session identity map now that we've extracted what we
|
||||
# need. Without this they accumulate in memory for the entire
|
||||
# streaming duration (which can be several minutes).
|
||||
session.expunge_all()
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
configurable: dict[str, Any] = {
|
||||
"thread_id": str(chat_id),
|
||||
"request_id": request_id or "unknown",
|
||||
"turn_id": stream_result.turn_id,
|
||||
}
|
||||
if checkpoint_id:
|
||||
configurable["checkpoint_id"] = checkpoint_id
|
||||
|
||||
config = {
|
||||
"configurable": configurable,
|
||||
# Effectively uncapped, matching the agent-level ``with_config``
|
||||
# default in ``chat_deepagent.create_agent`` and the unbounded
|
||||
# ``while(true)`` in OpenCode's ``session/processor.ts``. Real
|
||||
# circuit-breakers live in middleware (``DoomLoopMiddleware``,
|
||||
# plus ``enable_tool_call_limit`` / ``enable_model_call_limit``).
|
||||
# The original 25 (and our previous 80 bump) hit users on
|
||||
# legitimate multi-tool plans.
|
||||
"recursion_limit": 10_000,
|
||||
}
|
||||
|
||||
# --- Block 4: First SSE frames ---
|
||||
|
||||
for sse in iter_initial_frames(streaming_service, turn_id=stream_result.turn_id):
|
||||
yield sse
|
||||
|
||||
# --- Block 5: Persistence join + message-id frames ---
|
||||
|
||||
user_message_id = await await_persist_task(
|
||||
persist_user_task,
|
||||
chat_id=chat_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
log_label="persist_user_task",
|
||||
)
|
||||
if user_message_id is None:
|
||||
yield emit_stream_error(
|
||||
message="We couldn't save your message. Please try again in a moment.",
|
||||
error_kind="server_error",
|
||||
error_code="MESSAGE_PERSIST_FAILED",
|
||||
)
|
||||
for sse in iter_final_frames(streaming_service):
|
||||
yield sse
|
||||
return
|
||||
|
||||
# Emit canonical user message id BEFORE any LLM streaming so the FE
|
||||
# can rename its optimistic ``msg-user-XXX`` placeholder to
|
||||
# ``msg-{user_message_id}`` and unlock features gated on a real DB id
|
||||
# (comments, edit-from-this-message). See B4 in the
|
||||
# ``sse-based_message_id_handshake`` plan.
|
||||
yield streaming_service.format_data(
|
||||
"user-message-id",
|
||||
{"message_id": user_message_id, "turn_id": stream_result.turn_id},
|
||||
)
|
||||
|
||||
assistant_message_id = await await_persist_task(
|
||||
persist_asst_task,
|
||||
chat_id=chat_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
log_label="persist_asst_task",
|
||||
)
|
||||
if assistant_message_id is None:
|
||||
# Genuine DB failure — abort the turn rather than stream into a
|
||||
# void. The user row is already persisted so the legacy
|
||||
# ghost-thread gate isn't reopened.
|
||||
yield emit_stream_error(
|
||||
message=(
|
||||
"We couldn't initialize the assistant message. Please try again."
|
||||
),
|
||||
error_kind="server_error",
|
||||
error_code="MESSAGE_PERSIST_FAILED",
|
||||
)
|
||||
for sse in iter_final_frames(streaming_service):
|
||||
yield sse
|
||||
return
|
||||
|
||||
yield streaming_service.format_data(
|
||||
"assistant-message-id",
|
||||
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
|
||||
)
|
||||
|
||||
stream_result.assistant_message_id = assistant_message_id
|
||||
stream_result.content_builder = AssistantContentBuilder()
|
||||
|
||||
# --- Block 6: Initial thinking step + title task + runtime context ---
|
||||
|
||||
initial_step = build_initial_thinking_step(
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
mentioned_surfsense_docs=mentioned_surfsense_docs,
|
||||
)
|
||||
for sse in iter_initial_thinking_step_frame(
|
||||
initial_step,
|
||||
streaming_service=streaming_service,
|
||||
content_builder=stream_result.content_builder,
|
||||
):
|
||||
yield sse
|
||||
|
||||
initial_step_id = initial_step.step_id
|
||||
initial_step_title = initial_step.title
|
||||
initial_step_items = initial_step.items
|
||||
# Drop the heavy ORM objects + the container that holds them so they
|
||||
# aren't retained for the entire streaming duration. ``input_state``
|
||||
# already carries the langchain_messages list independently.
|
||||
del assembled, mentioned_surfsense_docs
|
||||
|
||||
title_task = spawn_title_task(
|
||||
chat_id=chat_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
assistant_message_id=assistant_message_id,
|
||||
llm=llm,
|
||||
agent_config=agent_config,
|
||||
)
|
||||
title_emitted = False
|
||||
|
||||
runtime_context = build_new_chat_runtime_context(
|
||||
search_space_id=search_space_id,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
accepted_folder_ids=accepted_folder_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
request_id=request_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
)
|
||||
|
||||
# --- Block 7: Stream loop ---
|
||||
|
||||
_t_stream_start = time.perf_counter()
|
||||
runtime_rate_limit_recovered = False
|
||||
|
||||
def _on_first_event() -> None:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
||||
"%.3fs (total since request start) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
async def _recover(exc: BaseException, first_event_seen: bool):
|
||||
nonlocal llm_config_id, llm, agent_config, runtime_rate_limit_recovered
|
||||
nonlocal title_task
|
||||
if not can_recover_provider_rate_limit(
|
||||
exc,
|
||||
first_event_seen=first_event_seen,
|
||||
runtime_rate_limit_recovered=runtime_rate_limit_recovered,
|
||||
requested_llm_config_id=requested_llm_config_id,
|
||||
current_llm_config_id=llm_config_id,
|
||||
):
|
||||
return None
|
||||
runtime_rate_limit_recovered = True
|
||||
previous_config_id = llm_config_id
|
||||
llm_config_id = await reroute_to_next_auto_pin(
|
||||
session,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
current_llm_config_id=llm_config_id,
|
||||
requires_image_input=requires_image_input,
|
||||
)
|
||||
new_llm, new_agent_config, llm_load_err = await load_llm_bundle(
|
||||
session, config_id=llm_config_id, search_space_id=search_space_id
|
||||
)
|
||||
if llm_load_err:
|
||||
# Re-raise the original so the terminal-error path classifies
|
||||
# it correctly (don't swallow as "config load error").
|
||||
return None
|
||||
llm = new_llm
|
||||
agent_config = new_agent_config
|
||||
|
||||
# Title gen used the initial llm object. After a runtime repin we
|
||||
# keep the stream focused on response recovery and skip title gen
|
||||
# for this turn.
|
||||
if title_task is not None and not title_task.done():
|
||||
title_task.cancel()
|
||||
title_task = None
|
||||
|
||||
_t_rebuild = time.perf_counter()
|
||||
new_agent = await build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Runtime rate-limit recovery repinned "
|
||||
"config_id=%s -> %s and rebuilt agent in %.3fs",
|
||||
previous_config_id,
|
||||
llm_config_id,
|
||||
time.perf_counter() - _t_rebuild,
|
||||
)
|
||||
log_rate_limit_recovered(
|
||||
flow=flow,
|
||||
request_id=request_id,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
previous_config_id=previous_config_id,
|
||||
new_config_id=llm_config_id,
|
||||
)
|
||||
return new_agent
|
||||
|
||||
async for sse in run_stream_loop(
|
||||
agent=agent,
|
||||
streaming_service=streaming_service,
|
||||
config=config,
|
||||
input_data=input_state,
|
||||
stream_result=stream_result,
|
||||
step_prefix="thinking",
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_step_title,
|
||||
initial_step_items=initial_step_items,
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode if filesystem_selection else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
runtime_context=runtime_context,
|
||||
content_builder=stream_result.content_builder,
|
||||
recover=_recover,
|
||||
on_first_event=_on_first_event,
|
||||
):
|
||||
yield sse
|
||||
# Inject the title update mid-stream as soon as the background
|
||||
# task finishes; gated so we emit at most once.
|
||||
async for title_sse in maybe_emit_title_update(
|
||||
title_task=title_task,
|
||||
title_emitted=title_emitted,
|
||||
chat_id=chat_id,
|
||||
accumulator=accumulator,
|
||||
streaming_service=streaming_service,
|
||||
):
|
||||
yield title_sse
|
||||
title_emitted = True
|
||||
# Account for the case where the task completed but produced no
|
||||
# title — flip the flag anyway so we don't keep checking it.
|
||||
if (
|
||||
title_task is not None
|
||||
and title_task.done()
|
||||
and not title_emitted
|
||||
):
|
||||
title_emitted = True
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
chat_id,
|
||||
)
|
||||
log_system_snapshot("stream_new_chat_END")
|
||||
|
||||
# --- Block 8: Finalize ---
|
||||
|
||||
if stream_result.is_interrupted:
|
||||
ot.add_event("chat.interrupted", {"chat.flow": flow})
|
||||
if title_task is not None and not title_task.done():
|
||||
title_task.cancel()
|
||||
for sse in iter_token_usage_frame(
|
||||
streaming_service,
|
||||
accumulator=accumulator,
|
||||
log_label="interrupted new_chat",
|
||||
):
|
||||
yield sse
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
async for title_sse in await_pending_title_update(
|
||||
title_task=title_task,
|
||||
title_emitted=title_emitted,
|
||||
chat_id=chat_id,
|
||||
accumulator=accumulator,
|
||||
streaming_service=streaming_service,
|
||||
):
|
||||
yield title_sse
|
||||
|
||||
# Finalize premium credit debit with the actual provider cost reported
|
||||
# by LiteLLM, summed across every call in the turn. Mirrors the
|
||||
# pre-cost behaviour of "premium turn → all calls count" so free
|
||||
# sub-agent calls during a premium turn still contribute to the bill
|
||||
# (they're $0 in practice anyway).
|
||||
if premium_reservation is not None and user_id:
|
||||
await finalize_premium(
|
||||
reservation=premium_reservation,
|
||||
user_id=user_id,
|
||||
accumulator=accumulator,
|
||||
)
|
||||
premium_reservation = None
|
||||
|
||||
for sse in iter_token_usage_frame(
|
||||
streaming_service, accumulator=accumulator, log_label="normal new_chat"
|
||||
):
|
||||
yield sse
|
||||
|
||||
for sse in iter_final_frames(streaming_service):
|
||||
yield sse
|
||||
|
||||
except Exception as exc:
|
||||
frames, summary = handle_terminal_exception(
|
||||
exc,
|
||||
flow=flow,
|
||||
flow_label="chat",
|
||||
log_prefix="stream_new_chat",
|
||||
streaming_service=streaming_service,
|
||||
request_id=request_id,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
chat_span=chat_span,
|
||||
)
|
||||
if summary["busy_error_raised"]:
|
||||
busy_error_raised = True
|
||||
chat_outcome = summary["chat_outcome"]
|
||||
chat_error_category = summary["chat_error_category"]
|
||||
for sse in frames:
|
||||
yield sse
|
||||
|
||||
finally:
|
||||
# Shield the ENTIRE async cleanup from anyio cancel-scope cancellation.
|
||||
# Starlette's BaseHTTPMiddleware uses anyio task groups; on client
|
||||
# disconnect, it cancels the scope with level-triggered cancellation
|
||||
# — every unshielded ``await`` would raise CancelledError immediately.
|
||||
# Without this the very first ``await`` (session.rollback) would
|
||||
# raise, ``except Exception`` wouldn't catch it (CancelledError is a
|
||||
# BaseException), and the rest of cleanup — including session.close()
|
||||
# — would never run.
|
||||
with anyio.CancelScope(shield=True):
|
||||
# Authoritative fallback cleanup for lock/cancel state. Middleware
|
||||
# teardown can be skipped on some client-abort paths.
|
||||
end_turn(str(chat_id))
|
||||
|
||||
if premium_reservation is not None and user_id:
|
||||
await release_premium(
|
||||
reservation=premium_reservation, user_id=user_id
|
||||
)
|
||||
|
||||
await close_session_and_clear_ai_responding(session, chat_id)
|
||||
|
||||
await finalize_assistant_message(
|
||||
stream_result=stream_result,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
accumulator=accumulator,
|
||||
log_prefix="stream_new_chat",
|
||||
)
|
||||
|
||||
# Persist any sandbox-produced files to local storage so they remain
|
||||
# downloadable after the Daytona sandbox auto-deletes.
|
||||
if stream_result and stream_result.sandbox_files:
|
||||
with contextlib.suppress(Exception):
|
||||
from app.agents.new_chat.sandbox import (
|
||||
is_sandbox_enabled,
|
||||
persist_and_delete_sandbox,
|
||||
)
|
||||
|
||||
if is_sandbox_enabled():
|
||||
with anyio.CancelScope(shield=True):
|
||||
await persist_and_delete_sandbox(
|
||||
chat_id, stream_result.sandbox_files
|
||||
)
|
||||
|
||||
# ``aafter_agent`` doesn't fire on ``interrupt()`` or early bailout.
|
||||
# Skip on ``BusyError`` (caller never acquired the lock).
|
||||
if not busy_error_raised:
|
||||
with contextlib.suppress(Exception):
|
||||
end_turn(str(chat_id))
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] end_turn cleanup (chat_id=%s)", chat_id
|
||||
)
|
||||
|
||||
# Break circular refs held by the agent graph, tools, and LLM
|
||||
# wrappers so the GC can reclaim them in a single pass.
|
||||
agent = llm = connector_service = None # noqa: F841
|
||||
input_state = stream_result = None # noqa: F841
|
||||
session = None # noqa: F841
|
||||
|
||||
run_gc_pass(log_prefix="stream_new_chat", chat_id=chat_id)
|
||||
close_chat_request_span(
|
||||
span_cm=chat_span_cm,
|
||||
span=chat_span,
|
||||
chat_outcome=chat_outcome,
|
||||
chat_agent_mode=chat_agent_mode,
|
||||
flow=flow,
|
||||
chat_error_category=chat_error_category,
|
||||
duration_seconds=time.perf_counter() - _t_total,
|
||||
)
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
"""Concurrent persistence tasks spawned right after the initial validation gate.
|
||||
|
||||
These run *during* the rest of the pre-stream setup so we don't serialize
|
||||
their latency against agent construction. Awaiting them at the SSE message-id
|
||||
yield sites preserves the ghost-thread protection (the user-row INSERT must
|
||||
succeed before any LLM streaming begins).
|
||||
|
||||
The ``set_ai_responding`` flag flip runs fully fire-and-forget on its own
|
||||
shielded session — failures only delay the "AI is responding…" UI flag, not
|
||||
the response itself.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.db import shielded_async_session
|
||||
from app.services.chat_session_state_service import set_ai_responding
|
||||
from app.tasks.chat.persistence import (
|
||||
persist_assistant_shell,
|
||||
persist_user_turn,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def spawn_set_ai_responding_bg(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
background_tasks: set[asyncio.Task[Any]],
|
||||
) -> None:
|
||||
"""Fire-and-forget: flip the per-thread AI-responding flag on its own session.
|
||||
|
||||
Errors are swallowed and logged — the worst case is a stale UI flag, which
|
||||
is preferable to delaying the SSE stream behind a flag write.
|
||||
"""
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
async def _bg_set_ai_responding() -> None:
|
||||
try:
|
||||
async with shielded_async_session() as s:
|
||||
await set_ai_responding(s, chat_id, UUID(user_id))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"set_ai_responding failed (chat_id=%s)",
|
||||
chat_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
t = asyncio.create_task(_bg_set_ai_responding())
|
||||
background_tasks.add(t)
|
||||
t.add_done_callback(background_tasks.discard)
|
||||
|
||||
|
||||
def spawn_persist_user_task(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
mentioned_documents: list[dict[str, Any]] | None,
|
||||
background_tasks: set[asyncio.Task[Any]],
|
||||
) -> asyncio.Task[int | None]:
|
||||
"""Spawn the user-row INSERT; await at the user-message-id yield site."""
|
||||
task = asyncio.create_task(
|
||||
persist_user_turn(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=turn_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
mentioned_documents=mentioned_documents,
|
||||
)
|
||||
)
|
||||
background_tasks.add(task)
|
||||
task.add_done_callback(background_tasks.discard)
|
||||
return task
|
||||
|
||||
|
||||
def spawn_persist_assistant_shell_task(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
background_tasks: set[asyncio.Task[Any]],
|
||||
) -> asyncio.Task[int | None]:
|
||||
"""Spawn the assistant-shell INSERT; await at the assistant-message-id yield site."""
|
||||
task = asyncio.create_task(
|
||||
persist_assistant_shell(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
)
|
||||
background_tasks.add(task)
|
||||
task.add_done_callback(background_tasks.discard)
|
||||
return task
|
||||
|
||||
|
||||
async def await_persist_task(
|
||||
task: asyncio.Task[int | None] | None,
|
||||
*,
|
||||
chat_id: int,
|
||||
turn_id: str,
|
||||
log_label: str,
|
||||
) -> int | None:
|
||||
"""Join a spawned persistence task with ``shield`` + uniform error handling.
|
||||
|
||||
``shield`` keeps the DB write alive if the SSE generator is cancelled by
|
||||
client disconnect mid-await. Returns ``None`` on failure; the caller
|
||||
abort-paths the turn with a friendly error SSE.
|
||||
"""
|
||||
if task is None:
|
||||
return None
|
||||
try:
|
||||
return await asyncio.shield(task)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s failed (chat_id=%s, turn_id=%s)", log_label, chat_id, turn_id
|
||||
)
|
||||
return None
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
"""Build the per-invocation ``SurfSenseContextSchema`` for a new-chat turn.
|
||||
|
||||
Carries the per-turn read inputs that middlewares read via
|
||||
``runtime.context.*`` instead of from their ``__init__`` closures, so the same
|
||||
compiled-agent instance can serve multiple turns with different
|
||||
mention lists / request ids / turn ids without rebuilding the graph.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
|
||||
|
||||
def build_new_chat_runtime_context(
|
||||
*,
|
||||
search_space_id: int,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
accepted_folder_ids: list[int],
|
||||
mentioned_folder_ids: list[int] | None,
|
||||
request_id: str | None,
|
||||
turn_id: str,
|
||||
) -> SurfSenseContextSchema:
|
||||
"""``mentioned_document_ids`` is consumed by ``KnowledgePriorityMiddleware``.
|
||||
|
||||
``accepted_folder_ids`` (post-resolve) wins over the raw
|
||||
``mentioned_folder_ids`` from the request: the resolver drops chips that
|
||||
pointed at deleted folders or folders the caller can't see, so middlewares
|
||||
only get authorized ids.
|
||||
"""
|
||||
return SurfSenseContextSchema(
|
||||
search_space_id=search_space_id,
|
||||
mentioned_document_ids=list(mentioned_document_ids or []),
|
||||
mentioned_folder_ids=list(
|
||||
accepted_folder_ids or mentioned_folder_ids or []
|
||||
),
|
||||
request_id=request_id,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
"""Background thread-title generation (first-response only).
|
||||
|
||||
The first assistant response in a thread gets a short auto-generated title
|
||||
inserted into ``new_chat_threads.title``. We:
|
||||
|
||||
1. Spawn the generation as an ``asyncio.Task`` so it runs in parallel with
|
||||
the agent stream (no extra TTFT).
|
||||
2. Probe inside the task (on its own shielded session) whether this is
|
||||
actually the first response — newer turns short-circuit to ``None``.
|
||||
3. Inject the resulting ``thread-title-update`` SSE frame on the first agent
|
||||
event after the task completes (mid-stream interlock), or right before
|
||||
the finish frames (post-stream join) if the task hadn't finished yet.
|
||||
|
||||
Usage tokens come directly off the response (LiteLLM's async callback fires
|
||||
via fire-and-forget ``create_task``, so the ``TokenTrackingCallback`` would
|
||||
run too late). We also blank the per-task accumulator so the late callback
|
||||
doesn't double-count.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import NewChatMessage, NewChatThread, shielded_async_session
|
||||
from app.prompts import TITLE_GENERATION_PROMPT
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.services.token_tracking_service import TokenAccumulator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def spawn_title_task(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
assistant_message_id: int | None,
|
||||
llm: Any,
|
||||
agent_config: AgentConfig | None,
|
||||
) -> asyncio.Task[tuple[str | None, dict | None]] | None:
|
||||
"""Spawn ``_generate_title``; returns ``None`` when prerequisites aren't met.
|
||||
|
||||
Title gen is gated on a real ``assistant_message_id`` so a stream that
|
||||
aborts before persistence can never leave a thread with a title and no
|
||||
anchoring rows.
|
||||
"""
|
||||
if assistant_message_id is None:
|
||||
return None
|
||||
return asyncio.create_task(
|
||||
_generate_title(
|
||||
chat_id=chat_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
assistant_message_id=assistant_message_id,
|
||||
llm=llm,
|
||||
agent_config=agent_config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _generate_title(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
assistant_message_id: int,
|
||||
llm: Any,
|
||||
agent_config: AgentConfig | None,
|
||||
) -> tuple[str | None, dict | None]:
|
||||
"""Probe is-first-response, then call ``acompletion``. Returns ``(title, usage)``."""
|
||||
try:
|
||||
from litellm import acompletion
|
||||
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.token_tracking_service import _turn_accumulator
|
||||
|
||||
# Excludes this turn's own assistant row (pre-written by
|
||||
# ``persist_assistant_shell``) — without the ``!=`` filter the gate
|
||||
# would false-negative on every turn after the first.
|
||||
try:
|
||||
async with shielded_async_session() as probe_session:
|
||||
probe_result = await probe_session.execute(
|
||||
select(NewChatMessage.id)
|
||||
.filter(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
NewChatMessage.role == "assistant",
|
||||
NewChatMessage.id != assistant_message_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
is_first_response = probe_result.scalars().first() is None
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[TitleGen] first-response probe failed (chat_id=%s)",
|
||||
chat_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None, None
|
||||
|
||||
if not is_first_response:
|
||||
return None, None
|
||||
|
||||
_turn_accumulator.set(None)
|
||||
|
||||
title_seed = user_query.strip() or (
|
||||
f"[{len(user_image_data_urls or [])} image(s)]"
|
||||
if user_image_data_urls
|
||||
else ""
|
||||
)
|
||||
prompt = TITLE_GENERATION_PROMPT.replace(
|
||||
"{user_query}", title_seed[:500] or "(message)"
|
||||
)
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
if getattr(llm, "model", None) == "auto":
|
||||
router = LLMRouterService.get_router()
|
||||
response = await router.acompletion(model="auto", messages=messages)
|
||||
else:
|
||||
# Apply the same ``api_base`` cascade chat / vision / image-gen
|
||||
# call sites use so we never inherit ``litellm.api_base``
|
||||
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
|
||||
# config itself ships an empty ``api_base``. Without this the
|
||||
# title-gen on an OpenRouter chat config would 404 against the
|
||||
# inherited Azure endpoint — see ``provider_api_base`` for the
|
||||
# same bug repro on the image-gen / vision paths.
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
provider_prefix = (
|
||||
raw_model.split("/", 1)[0] if "/" in raw_model else None
|
||||
)
|
||||
provider_value = (
|
||||
agent_config.provider if agent_config is not None else None
|
||||
)
|
||||
title_api_base = resolve_api_base(
|
||||
provider=provider_value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=getattr(llm, "api_base", None),
|
||||
)
|
||||
response = await acompletion(
|
||||
model=raw_model,
|
||||
messages=messages,
|
||||
api_key=getattr(llm, "api_key", None),
|
||||
api_base=title_api_base,
|
||||
)
|
||||
|
||||
usage_info = None
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage:
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
model_name = (
|
||||
raw_model.split("/", 1)[-1]
|
||||
if "/" in raw_model
|
||||
else (raw_model or response.model or "unknown")
|
||||
)
|
||||
usage_info = {
|
||||
"model": model_name,
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
|
||||
}
|
||||
|
||||
raw_title = response.choices[0].message.content.strip()
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
return raw_title.strip("\"'"), usage_info
|
||||
return None, usage_info
|
||||
except Exception:
|
||||
logger.exception("[TitleGen] _generate_title failed")
|
||||
return None, None
|
||||
|
||||
|
||||
async def maybe_emit_title_update(
|
||||
*,
|
||||
title_task: asyncio.Task[tuple[str | None, dict | None]] | None,
|
||||
title_emitted: bool,
|
||||
chat_id: int,
|
||||
accumulator: TokenAccumulator,
|
||||
streaming_service: VercelStreamingService,
|
||||
):
|
||||
"""Inject one ``thread-title-update`` SSE if the task completed.
|
||||
|
||||
Yields the SSE frame (when applicable). Returns nothing; the orchestrator
|
||||
flips ``title_emitted`` itself after iterating so we don't fight Python's
|
||||
nonlocal-in-generator semantics.
|
||||
"""
|
||||
if title_task is None or title_emitted or not title_task.done():
|
||||
return
|
||||
generated_title, title_usage = title_task.result()
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(chat_id, generated_title)
|
||||
|
||||
|
||||
async def await_pending_title_update(
|
||||
*,
|
||||
title_task: asyncio.Task[tuple[str | None, dict | None]] | None,
|
||||
title_emitted: bool,
|
||||
chat_id: int,
|
||||
accumulator: TokenAccumulator,
|
||||
streaming_service: VercelStreamingService,
|
||||
):
|
||||
"""If the task hadn't completed during the stream, await it now and emit.
|
||||
|
||||
Used right before the finish frames in the success path. Mirror of
|
||||
``maybe_emit_title_update`` but unconditionally awaits.
|
||||
"""
|
||||
if title_task is None or title_emitted:
|
||||
return
|
||||
generated_title, title_usage = await title_task
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(chat_id, generated_title)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Resume-chat streaming flow.
|
||||
|
||||
Public entry point ``stream_resume_chat`` is the slim coroutine in
|
||||
``orchestrator.py`` that composes the per-concern modules in this folder and
|
||||
the building blocks under ``flows/shared/``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.flows.resume_chat.orchestrator import stream_resume_chat
|
||||
|
||||
__all__ = ["stream_resume_chat"]
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue