From f944cdacb753369af8e117e34cdaca39c3c48c73 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 8 May 2026 22:47:03 +0200 Subject: [PATCH] Add helpers to open and close task delegation span ids. --- .../tasks/chat/streaming/relay/task_span.py | 74 +++++++++++++++++++ .../tasks/chat/streaming/test_task_span.py | 69 +++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 surfsense_backend/app/tasks/chat/streaming/relay/task_span.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/streaming/test_task_span.py diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/task_span.py b/surfsense_backend/app/tasks/chat/streaming/relay/task_span.py new file mode 100644 index 000000000..c4cdf24ba --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/task_span.py @@ -0,0 +1,74 @@ +"""Open/close ``active_span_id`` around a delegating ``task`` tool run.""" + +from __future__ import annotations + +import uuid + +from app.tasks.chat.streaming.relay.state import AgentEventRelayState + + +def new_span_id() -> str: + """One delegation-episode id (shared by activity under an open ``task``).""" + return f"spn_{uuid.uuid4().hex}" + + +def _run_key(run_id: str) -> str: + return (run_id or "").strip() + + +def _lc_key(langchain_tool_call_id: str | None) -> str: + return (langchain_tool_call_id or "").strip() + + +def ensure_pending_task_span_for_lc(state: AgentEventRelayState, lc_id: str) -> str: + """Return span id for this LangChain tool call id, storing it in ``pending`` if new. + + Used from ``chat_model_stream`` when the first ``task`` chunk registers so + early ``tool-input-start`` can carry ``metadata.spanId`` before ``on_tool_start``. + """ + key = _lc_key(lc_id) + if not key: + return new_span_id() + existing = state.pending_task_span_by_lc.get(key) + if existing: + return existing + sid = new_span_id() + state.pending_task_span_by_lc[key] = sid + return sid + + +def open_task_span( + state: AgentEventRelayState, + *, + run_id: str, + langchain_tool_call_id: str | None = None, +) -> str: + """Set ``active_span_id`` from pending (same lc) or mint; remember ``active_task_run_id``. + + Call when the ``task`` tool **starts**. Nested ``task`` is not supported: + a second call replaces the previous span without restoring it. + """ + key = _lc_key(langchain_tool_call_id) + sid: str | None = state.pending_task_span_by_lc.pop(key, None) if key else None + if not sid: + sid = new_span_id() + state.active_span_id = sid + state.active_task_run_id = _run_key(run_id) or None + return sid + + +def clear_task_span_if_delegating_task_ended( + state: AgentEventRelayState, + *, + tool_name: str, + run_id: str, +) -> None: + """Clear span state only when this event is the end of the opening ``task`` run.""" + if tool_name != "task": + return + if state.active_task_run_id is None: + return + if state.active_task_run_id != _run_key(run_id): + return + state.active_span_id = None + state.active_task_run_id = None diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_task_span.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_task_span.py new file mode 100644 index 000000000..349c9879c --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_task_span.py @@ -0,0 +1,69 @@ +"""Unit tests for ``task_span`` open/close helpers.""" + +from __future__ import annotations + +import pytest + +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.task_span import ( + clear_task_span_if_delegating_task_ended, + ensure_pending_task_span_for_lc, + open_task_span, +) + +pytestmark = pytest.mark.unit + + +def test_open_task_span_sets_span_and_run_id() -> None: + state = AgentEventRelayState.for_invocation() + sid = open_task_span(state, run_id="run-abc") + assert sid.startswith("spn_") + assert state.active_span_id == sid + assert state.active_task_run_id == "run-abc" + assert state.span_metadata_if_active() == {"spanId": sid} + + +def test_clear_ignored_for_non_task_tool() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-1") + sid = state.active_span_id + clear_task_span_if_delegating_task_ended( + state, tool_name="web_search", run_id="run-1" + ) + assert state.active_span_id == sid + assert state.active_task_run_id == "run-1" + + +def test_clear_ignored_when_task_run_id_mismatches() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-open") + clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-other") + assert state.active_span_id is not None + assert state.active_task_run_id == "run-open" + + +def test_clear_on_matching_task_end() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-x") + clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-x") + assert state.active_span_id is None + assert state.active_task_run_id is None + assert state.span_metadata_if_active() is None + + +def test_clear_noop_when_no_open_span() -> None: + state = AgentEventRelayState.for_invocation() + clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-x") + assert state.active_span_id is None + + +def test_pending_then_open_reuses_same_span_id() -> None: + state = AgentEventRelayState.for_invocation() + sid_pending = ensure_pending_task_span_for_lc(state, "lc-task-1") + assert state.pending_task_span_by_lc["lc-task-1"] == sid_pending + sid_active = open_task_span( + state, run_id="run-1", langchain_tool_call_id="lc-task-1" + ) + assert sid_active == sid_pending + assert state.active_span_id == sid_pending + assert "lc-task-1" not in state.pending_task_span_by_lc