fix: fixed composio issues

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-02 21:16:03 -07:00
parent 47b2994ec7
commit cea8618aed
25 changed files with 1756 additions and 461 deletions

View file

@ -51,22 +51,34 @@ class _FakeToolMessage:
tool_call_id: str | None = None
@dataclass
class _FakeInterrupt:
value: dict[str, Any]
@dataclass
class _FakeTask:
interrupts: tuple[_FakeInterrupt, ...] = ()
class _FakeAgentState:
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
def __init__(self) -> None:
def __init__(self, tasks: list[Any] | None = None) -> None:
# Empty values keeps the cloud-fallback safety-net branch a no-op,
# and an empty ``tasks`` list keeps the post-stream interrupt
# check a no-op too.
# and empty ``tasks`` keep the post-stream interrupt check a no-op too.
self.values: dict[str, Any] = {}
self.tasks: list[Any] = []
self.tasks: list[Any] = tasks or []
class _FakeAgent:
"""Replays a list of ``astream_events`` events."""
def __init__(self, events: list[dict[str, Any]]) -> None:
def __init__(
self, events: list[dict[str, Any]], state: _FakeAgentState | None = None
) -> None:
self._events = events
self._state = state or _FakeAgentState()
async def astream_events( # type: ignore[no-untyped-def]
self, _input_data: Any, *, config: dict[str, Any], version: str
@ -79,7 +91,7 @@ class _FakeAgent:
# Called once after astream_events drains so the cloud-fallback
# safety net can inspect staged filesystem work. The fake stays
# empty so the safety net is a no-op.
return _FakeAgentState()
return self._state
def _model_stream(
@ -170,11 +182,13 @@ def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
)
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
async def _drain(
events: list[dict[str, Any]], state: _FakeAgentState | None = None
) -> list[dict[str, Any]]:
"""Run ``_stream_agent_events`` against a fake agent and return the
SSE payloads (parsed JSON) it yielded.
"""
agent = _FakeAgent(events)
agent = _FakeAgent(events, state=state)
service = VercelStreamingService()
result = StreamResult()
config = {"configurable": {"thread_id": "test-thread"}}
@ -525,3 +539,29 @@ async def test_unmatched_fallback_still_attaches_lc_id(
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-1")
assert starts[0]["langchainToolCallId"] == "lc-orphan"
@pytest.mark.asyncio
async def test_interrupt_request_uses_task_that_contains_interrupt(
parity_v2_on: None,
) -> None:
interrupt_payload = {
"type": "calendar_event_create",
"action": {
"tool": "create_calendar_event",
"params": {"summary": "mom bday"},
},
"context": {},
}
state = _FakeAgentState(
tasks=[
_FakeTask(interrupts=()),
_FakeTask(interrupts=(_FakeInterrupt(value=interrupt_payload),)),
]
)
payloads = await _drain([], state=state)
interrupts = _of_type(payloads, "data-interrupt-request")
assert len(interrupts) == 1
assert interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event"