mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 09:12:40 +02:00
Add helpers to open and close task delegation span ids.
This commit is contained in:
parent
f0f87107f2
commit
f944cdacb7
2 changed files with 143 additions and 0 deletions
|
|
@ -0,0 +1,74 @@
|
||||||
|
"""Open/close ``active_span_id`` around a delegating ``task`` tool run."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||||
|
|
||||||
|
|
||||||
|
def new_span_id() -> str:
|
||||||
|
"""One delegation-episode id (shared by activity under an open ``task``)."""
|
||||||
|
return f"spn_{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
|
||||||
|
def _run_key(run_id: str) -> str:
|
||||||
|
return (run_id or "").strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _lc_key(langchain_tool_call_id: str | None) -> str:
|
||||||
|
return (langchain_tool_call_id or "").strip()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_pending_task_span_for_lc(state: AgentEventRelayState, lc_id: str) -> str:
|
||||||
|
"""Return span id for this LangChain tool call id, storing it in ``pending`` if new.
|
||||||
|
|
||||||
|
Used from ``chat_model_stream`` when the first ``task`` chunk registers so
|
||||||
|
early ``tool-input-start`` can carry ``metadata.spanId`` before ``on_tool_start``.
|
||||||
|
"""
|
||||||
|
key = _lc_key(lc_id)
|
||||||
|
if not key:
|
||||||
|
return new_span_id()
|
||||||
|
existing = state.pending_task_span_by_lc.get(key)
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
sid = new_span_id()
|
||||||
|
state.pending_task_span_by_lc[key] = sid
|
||||||
|
return sid
|
||||||
|
|
||||||
|
|
||||||
|
def open_task_span(
|
||||||
|
state: AgentEventRelayState,
|
||||||
|
*,
|
||||||
|
run_id: str,
|
||||||
|
langchain_tool_call_id: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Set ``active_span_id`` from pending (same lc) or mint; remember ``active_task_run_id``.
|
||||||
|
|
||||||
|
Call when the ``task`` tool **starts**. Nested ``task`` is not supported:
|
||||||
|
a second call replaces the previous span without restoring it.
|
||||||
|
"""
|
||||||
|
key = _lc_key(langchain_tool_call_id)
|
||||||
|
sid: str | None = state.pending_task_span_by_lc.pop(key, None) if key else None
|
||||||
|
if not sid:
|
||||||
|
sid = new_span_id()
|
||||||
|
state.active_span_id = sid
|
||||||
|
state.active_task_run_id = _run_key(run_id) or None
|
||||||
|
return sid
|
||||||
|
|
||||||
|
|
||||||
|
def clear_task_span_if_delegating_task_ended(
|
||||||
|
state: AgentEventRelayState,
|
||||||
|
*,
|
||||||
|
tool_name: str,
|
||||||
|
run_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Clear span state only when this event is the end of the opening ``task`` run."""
|
||||||
|
if tool_name != "task":
|
||||||
|
return
|
||||||
|
if state.active_task_run_id is None:
|
||||||
|
return
|
||||||
|
if state.active_task_run_id != _run_key(run_id):
|
||||||
|
return
|
||||||
|
state.active_span_id = None
|
||||||
|
state.active_task_run_id = None
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""Unit tests for ``task_span`` open/close helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||||
|
from app.tasks.chat.streaming.relay.task_span import (
|
||||||
|
clear_task_span_if_delegating_task_ended,
|
||||||
|
ensure_pending_task_span_for_lc,
|
||||||
|
open_task_span,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_open_task_span_sets_span_and_run_id() -> None:
|
||||||
|
state = AgentEventRelayState.for_invocation()
|
||||||
|
sid = open_task_span(state, run_id="run-abc")
|
||||||
|
assert sid.startswith("spn_")
|
||||||
|
assert state.active_span_id == sid
|
||||||
|
assert state.active_task_run_id == "run-abc"
|
||||||
|
assert state.span_metadata_if_active() == {"spanId": sid}
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_ignored_for_non_task_tool() -> None:
|
||||||
|
state = AgentEventRelayState.for_invocation()
|
||||||
|
open_task_span(state, run_id="run-1")
|
||||||
|
sid = state.active_span_id
|
||||||
|
clear_task_span_if_delegating_task_ended(
|
||||||
|
state, tool_name="web_search", run_id="run-1"
|
||||||
|
)
|
||||||
|
assert state.active_span_id == sid
|
||||||
|
assert state.active_task_run_id == "run-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_ignored_when_task_run_id_mismatches() -> None:
|
||||||
|
state = AgentEventRelayState.for_invocation()
|
||||||
|
open_task_span(state, run_id="run-open")
|
||||||
|
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-other")
|
||||||
|
assert state.active_span_id is not None
|
||||||
|
assert state.active_task_run_id == "run-open"
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_on_matching_task_end() -> None:
|
||||||
|
state = AgentEventRelayState.for_invocation()
|
||||||
|
open_task_span(state, run_id="run-x")
|
||||||
|
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-x")
|
||||||
|
assert state.active_span_id is None
|
||||||
|
assert state.active_task_run_id is None
|
||||||
|
assert state.span_metadata_if_active() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_noop_when_no_open_span() -> None:
|
||||||
|
state = AgentEventRelayState.for_invocation()
|
||||||
|
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-x")
|
||||||
|
assert state.active_span_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_pending_then_open_reuses_same_span_id() -> None:
|
||||||
|
state = AgentEventRelayState.for_invocation()
|
||||||
|
sid_pending = ensure_pending_task_span_for_lc(state, "lc-task-1")
|
||||||
|
assert state.pending_task_span_by_lc["lc-task-1"] == sid_pending
|
||||||
|
sid_active = open_task_span(
|
||||||
|
state, run_id="run-1", langchain_tool_call_id="lc-task-1"
|
||||||
|
)
|
||||||
|
assert sid_active == sid_pending
|
||||||
|
assert state.active_span_id == sid_pending
|
||||||
|
assert "lc-task-1" not in state.pending_task_span_by_lc
|
||||||
Loading…
Add table
Add a link
Reference in a new issue