From 772150eb66b10ab04db64affca842607927a2093 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 21 Mar 2026 13:19:58 +0530 Subject: [PATCH] feat: add unit tests for DedupHITLToolCallsMiddleware --- .../tests/unit/middleware/__init__.py | 0 .../middleware/test_dedup_hitl_tool_calls.py | 41 +++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 surfsense_backend/tests/unit/middleware/__init__.py create mode 100644 surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py diff --git a/surfsense_backend/tests/unit/middleware/__init__.py b/surfsense_backend/tests/unit/middleware/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py b/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py new file mode 100644 index 000000000..add0105e4 --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py @@ -0,0 +1,41 @@ +import pytest +from langchain_core.messages import AIMessage + +from app.agents.new_chat.middleware.dedup_tool_calls import ( + DedupHITLToolCallsMiddleware, +) + +pytestmark = pytest.mark.unit + + +def _make_state(tool_calls: list[dict]) -> dict: + """Build a minimal agent state with one AIMessage carrying *tool_calls*.""" + msg = AIMessage(content="", tool_calls=tool_calls) + return {"messages": [msg]} + + +def test_duplicate_hitl_calls_reduced_to_first(): + """When the LLM emits the same HITL tool call twice, only the first is kept.""" + mw = DedupHITLToolCallsMiddleware() + + state = _make_state( + [ + { + "id": "call_1", + "name": "delete_calendar_event", + "args": {"event_title_or_id": "Doctor Appointment"}, + }, + { + "id": "call_2", + "name": "delete_calendar_event", + "args": {"event_title_or_id": "Doctor Appointment"}, + }, + ] + ) + + result = mw.after_model(state, runtime=None) # type: ignore[arg-type] + + assert result is not None, "Expected middleware to return updated state" + updated_calls = result["messages"][0].tool_calls + assert len(updated_calls) == 1 + assert updated_calls[0]["id"] == "call_1"