mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
fix: fixed composio issues
This commit is contained in:
parent
47b2994ec7
commit
cea8618aed
25 changed files with 1756 additions and 461 deletions
|
|
@ -51,22 +51,34 @@ class _FakeToolMessage:
|
|||
tool_call_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeInterrupt:
|
||||
value: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeTask:
|
||||
interrupts: tuple[_FakeInterrupt, ...] = ()
|
||||
|
||||
|
||||
class _FakeAgentState:
|
||||
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, tasks: list[Any] | None = None) -> None:
|
||||
# Empty values keeps the cloud-fallback safety-net branch a no-op,
|
||||
# and an empty ``tasks`` list keeps the post-stream interrupt
|
||||
# check a no-op too.
|
||||
# and empty ``tasks`` keep the post-stream interrupt check a no-op too.
|
||||
self.values: dict[str, Any] = {}
|
||||
self.tasks: list[Any] = []
|
||||
self.tasks: list[Any] = tasks or []
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
"""Replays a list of ``astream_events`` events."""
|
||||
|
||||
def __init__(self, events: list[dict[str, Any]]) -> None:
|
||||
def __init__(
|
||||
self, events: list[dict[str, Any]], state: _FakeAgentState | None = None
|
||||
) -> None:
|
||||
self._events = events
|
||||
self._state = state or _FakeAgentState()
|
||||
|
||||
async def astream_events( # type: ignore[no-untyped-def]
|
||||
self, _input_data: Any, *, config: dict[str, Any], version: str
|
||||
|
|
@ -79,7 +91,7 @@ class _FakeAgent:
|
|||
# Called once after astream_events drains so the cloud-fallback
|
||||
# safety net can inspect staged filesystem work. The fake stays
|
||||
# empty so the safety net is a no-op.
|
||||
return _FakeAgentState()
|
||||
return self._state
|
||||
|
||||
|
||||
def _model_stream(
|
||||
|
|
@ -170,11 +182,13 @@ def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
)
|
||||
|
||||
|
||||
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
async def _drain(
|
||||
events: list[dict[str, Any]], state: _FakeAgentState | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run ``_stream_agent_events`` against a fake agent and return the
|
||||
SSE payloads (parsed JSON) it yielded.
|
||||
"""
|
||||
agent = _FakeAgent(events)
|
||||
agent = _FakeAgent(events, state=state)
|
||||
service = VercelStreamingService()
|
||||
result = StreamResult()
|
||||
config = {"configurable": {"thread_id": "test-thread"}}
|
||||
|
|
@ -525,3 +539,29 @@ async def test_unmatched_fallback_still_attaches_lc_id(
|
|||
assert len(starts) == 1
|
||||
assert starts[0]["toolCallId"].startswith("call_run-1")
|
||||
assert starts[0]["langchainToolCallId"] == "lc-orphan"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interrupt_request_uses_task_that_contains_interrupt(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
interrupt_payload = {
|
||||
"type": "calendar_event_create",
|
||||
"action": {
|
||||
"tool": "create_calendar_event",
|
||||
"params": {"summary": "mom bday"},
|
||||
},
|
||||
"context": {},
|
||||
}
|
||||
state = _FakeAgentState(
|
||||
tasks=[
|
||||
_FakeTask(interrupts=()),
|
||||
_FakeTask(interrupts=(_FakeInterrupt(value=interrupt_payload),)),
|
||||
]
|
||||
)
|
||||
|
||||
payloads = await _drain([], state=state)
|
||||
|
||||
interrupts = _of_type(payloads, "data-interrupt-request")
|
||||
assert len(interrupts) == 1
|
||||
assert interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue