mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
agent: retire eager KB priority/planner path and its dead flags
The pull-based KB design (on-demand search_knowledge_base tool + pre-injected workspace tree) fully replaced the old eager retrieval path. Remove its last remnants: - Delete KnowledgePriorityMiddleware (knowledge_search.py) and its tests. - Drop the kb_priority state field + reducer default; trim KbContextProjectionMiddleware to project only workspace_tree_text. - Remove the now-dead feature flags enable_kb_priority_preinjection and enable_kb_planner_runnable across backend (flags, route schema, tests, env examples) and frontend (settings toggle, zod schema). - Scrub <priority_documents> and stale KnowledgePriorityMiddleware references from prompts, docstrings, and the ADR. No functional change: nothing wrote kb_priority and neither flag gated live behavior after the cutover. Full backend suite green (pre-existing unrelated failures aside).
This commit is contained in:
parent
0148647b98
commit
2beafbdec8
34 changed files with 62 additions and 1890 deletions
|
|
@ -38,7 +38,7 @@ class TestIsProtectedSystemMessage:
|
|||
)
|
||||
|
||||
def test_tolerates_leading_whitespace(self) -> None:
|
||||
msg = SystemMessage(content=" \n<priority_documents>\n...")
|
||||
msg = SystemMessage(content=" \n<workspace_tree>\n...")
|
||||
assert _is_protected_system_message(msg) is True
|
||||
|
||||
|
||||
|
|
@ -89,7 +89,7 @@ class TestPartitionMessages:
|
|||
|
||||
def test_protected_system_message_preserved_even_in_summarize_half(self) -> None:
|
||||
partitioner = self._build_partitioner()
|
||||
protected = SystemMessage(content="<priority_documents>\n...")
|
||||
protected = SystemMessage(content="<workspace_tree>\n...")
|
||||
msgs = [
|
||||
HumanMessage(content="old human"),
|
||||
AIMessage(content="old ai"),
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
|
||||
"SURFSENSE_ENABLE_SKILLS",
|
||||
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
|
||||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
|
|
@ -57,7 +56,6 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) ->
|
|||
assert flags.enable_llm_tool_selector is False
|
||||
assert flags.enable_skills is True
|
||||
assert flags.enable_specialized_subagents is True
|
||||
assert flags.enable_kb_planner_runnable is True
|
||||
assert flags.enable_action_log is True
|
||||
assert flags.enable_revert_route is True
|
||||
assert flags.enable_plugin_loader is False
|
||||
|
|
@ -122,7 +120,6 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
|
|||
"enable_llm_tool_selector": "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
|
||||
"enable_skills": "SURFSENSE_ENABLE_SKILLS",
|
||||
"enable_specialized_subagents": "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
|
||||
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
|
|
|
|||
|
|
@ -90,8 +90,8 @@ class TestSubstituteInText:
|
|||
|
||||
class TestResolveMentions:
|
||||
"""``resolve_mentions`` resolves chip ids → virtual paths and emits
|
||||
a ``ResolvedMentionSet`` whose id partitions feed
|
||||
``KnowledgePriorityMiddleware``."""
|
||||
a ``ResolvedMentionSet`` whose id partitions feed the
|
||||
``search_knowledge_base`` retrieval scope."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_no_mentions(self):
|
||||
|
|
|
|||
|
|
@ -161,7 +161,6 @@ class TestInitialFilesystemState:
|
|||
assert state["doc_id_by_path"] == {}
|
||||
assert state["dirty_paths"] == []
|
||||
assert state["dirty_path_tool_calls"] == {}
|
||||
assert state["kb_priority"] == []
|
||||
assert state["kb_anon_doc"] is None
|
||||
assert state["tree_version"] == 0
|
||||
|
||||
|
|
|
|||
|
|
@ -1,604 +0,0 @@
|
|||
"""Unit tests for knowledge_search middleware helpers."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware import knowledge_search as ks
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
|
||||
KBSearchPlan,
|
||||
KnowledgePriorityMiddleware,
|
||||
_normalize_optional_date_range,
|
||||
_parse_kb_search_plan_response,
|
||||
_render_recent_conversation,
|
||||
_resolve_search_types,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ── _resolve_search_types ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveSearchTypes:
|
||||
def test_returns_none_when_no_inputs(self):
|
||||
assert _resolve_search_types(None, None) is None
|
||||
|
||||
def test_returns_none_when_both_empty(self):
|
||||
assert _resolve_search_types([], []) is None
|
||||
|
||||
def test_includes_legacy_type_for_google_gmail(self):
|
||||
result = _resolve_search_types(["GOOGLE_GMAIL_CONNECTOR"], None)
|
||||
assert "GOOGLE_GMAIL_CONNECTOR" in result
|
||||
assert "COMPOSIO_GMAIL_CONNECTOR" in result
|
||||
|
||||
def test_includes_legacy_type_for_google_drive(self):
|
||||
result = _resolve_search_types(None, ["GOOGLE_DRIVE_FILE"])
|
||||
assert "GOOGLE_DRIVE_FILE" in result
|
||||
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in result
|
||||
|
||||
def test_includes_legacy_type_for_google_calendar(self):
|
||||
result = _resolve_search_types(["GOOGLE_CALENDAR_CONNECTOR"], None)
|
||||
assert "GOOGLE_CALENDAR_CONNECTOR" in result
|
||||
assert "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" in result
|
||||
|
||||
def test_no_legacy_expansion_for_unrelated_types(self):
|
||||
result = _resolve_search_types(["FILE", "NOTE"], None)
|
||||
assert set(result) == {"FILE", "NOTE"}
|
||||
|
||||
def test_combines_connectors_and_document_types(self):
|
||||
result = _resolve_search_types(["FILE"], ["NOTE", "CRAWLED_URL"])
|
||||
assert {"FILE", "NOTE", "CRAWLED_URL"}.issubset(set(result))
|
||||
|
||||
def test_deduplicates(self):
|
||||
result = _resolve_search_types(["FILE", "FILE"], ["FILE"])
|
||||
assert result.count("FILE") == 1
|
||||
|
||||
|
||||
# ── planner parsing / date normalization ───────────────────────────────
|
||||
|
||||
|
||||
class TestPlannerHelpers:
|
||||
def test_parse_kb_search_plan_response_accepts_plain_json(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "ocv meeting decisions summary",
|
||||
"start_date": "2026-03-01",
|
||||
"end_date": "2026-03-31",
|
||||
}
|
||||
)
|
||||
)
|
||||
assert plan.optimized_query == "ocv meeting decisions summary"
|
||||
assert plan.start_date == "2026-03-01"
|
||||
assert plan.end_date == "2026-03-31"
|
||||
|
||||
def test_parse_kb_search_plan_response_accepts_fenced_json(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
"""```json
|
||||
{"optimized_query":"deel founders guide","start_date":null,"end_date":null}
|
||||
```"""
|
||||
)
|
||||
assert plan.optimized_query == "deel founders guide"
|
||||
assert plan.start_date is None
|
||||
assert plan.end_date is None
|
||||
|
||||
def test_normalize_optional_date_range_returns_none_when_absent(self):
|
||||
start_date, end_date = _normalize_optional_date_range(None, None)
|
||||
assert start_date is None
|
||||
assert end_date is None
|
||||
|
||||
def test_normalize_optional_date_range_resolves_single_bound(self):
|
||||
start_date, end_date = _normalize_optional_date_range("2026-03-01", None)
|
||||
assert start_date is not None
|
||||
assert end_date is not None
|
||||
assert start_date.date().isoformat() == "2026-03-01"
|
||||
assert end_date >= start_date
|
||||
|
||||
|
||||
class FakeLLM:
|
||||
def __init__(self, response_text: str):
|
||||
self.response_text = response_text
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def ainvoke(self, messages, config=None):
|
||||
self.calls.append({"messages": messages, "config": config})
|
||||
return AIMessage(content=self.response_text)
|
||||
|
||||
|
||||
class FakeBudgetLLM:
|
||||
def __init__(self, *, max_input_tokens: int):
|
||||
self._max_input_tokens_value = max_input_tokens
|
||||
|
||||
def _get_max_input_tokens(self) -> int:
|
||||
return self._max_input_tokens_value
|
||||
|
||||
def _count_tokens(self, messages) -> int:
|
||||
# Deterministic, simple proxy for tests: count characters as tokens.
|
||||
return sum(len(msg.get("content", "")) for msg in messages)
|
||||
|
||||
|
||||
class TestKnowledgePriorityMiddlewarePlanner:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _disable_planner_runnable(self, monkeypatch):
|
||||
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
|
||||
# planner Runnable path is enabled) calls ``.bind()`` on the LLM,
|
||||
# which the mock does not implement. Pin the flag off so the
|
||||
# planner falls through to the legacy ``self.llm.ainvoke`` path
|
||||
# these tests assert against (``llm.calls[0]["config"]``).
|
||||
monkeypatch.setenv("SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "false")
|
||||
|
||||
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
|
||||
messages = [
|
||||
HumanMessage(content="old user context " * 40),
|
||||
AIMessage(content="old assistant answer " * 35),
|
||||
HumanMessage(content="recent user context " * 20),
|
||||
AIMessage(content="recent assistant answer " * 18),
|
||||
HumanMessage(content="latest question"),
|
||||
]
|
||||
|
||||
rendered = _render_recent_conversation(
|
||||
messages,
|
||||
llm=FakeBudgetLLM(max_input_tokens=900),
|
||||
user_text="latest question",
|
||||
)
|
||||
|
||||
assert "recent user context" in rendered
|
||||
assert "recent assistant answer" in rendered
|
||||
assert "latest question" not in rendered
|
||||
assert rendered.index("recent user context") < rendered.index(
|
||||
"recent assistant answer"
|
||||
)
|
||||
|
||||
def test_render_recent_conversation_falls_back_to_legacy_without_budgeting(self):
|
||||
messages = [
|
||||
HumanMessage(content="message one"),
|
||||
AIMessage(content="message two"),
|
||||
HumanMessage(content="latest question"),
|
||||
]
|
||||
|
||||
rendered = _render_recent_conversation(
|
||||
messages,
|
||||
llm=None,
|
||||
user_text="latest question",
|
||||
)
|
||||
|
||||
assert "user: message one" in rendered
|
||||
assert "assistant: message two" in rendered
|
||||
assert "latest question" not in rendered
|
||||
|
||||
async def test_middleware_uses_optimized_query_and_dates(self, monkeypatch):
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "ocv meeting decisions action items",
|
||||
"start_date": "2026-03-01",
|
||||
"end_date": "2026-03-31",
|
||||
}
|
||||
)
|
||||
)
|
||||
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=37)
|
||||
|
||||
result = await middleware.abefore_agent(
|
||||
{
|
||||
"messages": [
|
||||
HumanMessage(content="what happened in our OCV meeting last month?")
|
||||
]
|
||||
},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert captured["query"] == "ocv meeting decisions action items"
|
||||
assert captured["start_date"] is not None
|
||||
assert captured["end_date"] is not None
|
||||
assert captured["start_date"].date().isoformat() == "2026-03-01"
|
||||
assert captured["end_date"].date().isoformat() == "2026-03-31"
|
||||
assert llm.calls[0]["config"] == {"tags": ["surfsense:internal"]}
|
||||
|
||||
async def test_middleware_falls_back_when_planner_returns_invalid_json(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgePriorityMiddleware(
|
||||
llm=FakeLLM("not json"),
|
||||
search_space_id=37,
|
||||
)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert captured["query"] == "summarize founders guide by deel"
|
||||
assert captured["start_date"] is None
|
||||
assert captured["end_date"] is None
|
||||
|
||||
async def test_middleware_passes_none_dates_when_planner_returns_nulls(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgePriorityMiddleware(
|
||||
llm=FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "deel founders guide summary",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
}
|
||||
)
|
||||
),
|
||||
search_space_id=37,
|
||||
)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert captured["query"] == "deel founders guide summary"
|
||||
assert captured["start_date"] is None
|
||||
assert captured["end_date"] is None
|
||||
|
||||
async def test_middleware_routes_to_recency_browse_when_flagged(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""When the planner sets is_recency_query=true, browse_recent_documents
|
||||
is called instead of search_knowledge_base."""
|
||||
browse_captured: dict = {}
|
||||
search_called = False
|
||||
|
||||
async def fake_browse_recent_documents(**kwargs):
|
||||
browse_captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
nonlocal search_called
|
||||
search_called = True
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"browse_recent_documents",
|
||||
fake_browse_recent_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "latest uploaded file",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": True,
|
||||
}
|
||||
)
|
||||
)
|
||||
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42)
|
||||
|
||||
result = await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what's my latest file?")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert browse_captured["search_space_id"] == 42
|
||||
assert not search_called
|
||||
|
||||
async def test_middleware_uses_hybrid_search_when_not_recency(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""When is_recency_query is false (default), hybrid search is used."""
|
||||
search_captured: dict = {}
|
||||
browse_called = False
|
||||
|
||||
async def fake_browse_recent_documents(**kwargs):
|
||||
nonlocal browse_called
|
||||
browse_called = True
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
search_captured.update(kwargs)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"browse_recent_documents",
|
||||
fake_browse_recent_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "quarterly revenue report analysis",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="find the quarterly revenue report")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert search_captured["query"] == "quarterly revenue report analysis"
|
||||
assert not browse_called
|
||||
|
||||
|
||||
# ── KBSearchPlan schema ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestKBSearchPlanSchema:
|
||||
def test_is_recency_query_defaults_to_false(self):
|
||||
plan = KBSearchPlan(optimized_query="test query")
|
||||
assert plan.is_recency_query is False
|
||||
|
||||
def test_is_recency_query_parses_true(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "latest uploaded file",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": True,
|
||||
}
|
||||
)
|
||||
)
|
||||
assert plan.is_recency_query is True
|
||||
assert plan.optimized_query == "latest uploaded file"
|
||||
|
||||
def test_missing_is_recency_query_defaults_to_false(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "meeting notes",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
}
|
||||
)
|
||||
)
|
||||
assert plan.is_recency_query is False
|
||||
|
||||
|
||||
# ── mentioned_document_ids cross-turn drain ────────────────────────────
|
||||
|
||||
|
||||
class TestKnowledgePriorityMentionDrain:
|
||||
"""Regression tests for the cross-turn ``mentioned_document_ids`` drain.
|
||||
|
||||
The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware`
|
||||
instance across turns of the same thread. ``mentioned_document_ids``
|
||||
can therefore enter the middleware via two paths:
|
||||
|
||||
1. The constructor closure (``__init__(mentioned_document_ids=...)``) —
|
||||
seeded by the cache-miss build on turn 1.
|
||||
2. ``runtime.context.mentioned_document_ids`` — supplied freshly per
|
||||
turn by the streaming task.
|
||||
|
||||
Without the drain fix, an empty ``runtime.context.mentioned_document_ids``
|
||||
on turn 2 would fall through to the closure (because ``[]`` is falsy in
|
||||
Python) and replay turn 1's mentions. This class pins down the
|
||||
correct behaviour: the runtime path is authoritative even when empty,
|
||||
and the closure is drained the first time the runtime path fires so
|
||||
no later turn can ever resurrect stale state.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_runtime(mention_ids: list[int]):
|
||||
"""Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
context=SimpleNamespace(mentioned_document_ids=mention_ids),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _planner_llm() -> "FakeLLM":
|
||||
# Planner returns a stable, non-recency plan so we always land in
|
||||
# the hybrid-search branch (where ``fetch_mentioned_documents`` is
|
||||
# invoked alongside the main search).
|
||||
return FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "follow up question",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch):
|
||||
"""Turn 1 with mentions in BOTH closure and runtime context: the
|
||||
runtime path wins AND the closure is drained so a future turn
|
||||
cannot replay it.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgePriorityMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[1, 2, 3],
|
||||
)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what is in those docs?")]},
|
||||
runtime=self._make_runtime([1, 2, 3]),
|
||||
)
|
||||
|
||||
assert fetched_ids == [[1, 2, 3]], (
|
||||
"runtime.context mentions must be the source of truth on turn 1"
|
||||
)
|
||||
assert middleware.mentioned_document_ids == [], (
|
||||
"closure must be drained the first time the runtime path fires "
|
||||
"so no later turn can replay stale mentions"
|
||||
)
|
||||
|
||||
async def test_empty_runtime_context_does_not_replay_closure_mentions(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Regression: turn 2 with NO mentions must not surface turn 1's
|
||||
mentions from the constructor closure.
|
||||
|
||||
Before the fix, ``if ctx_mentions:`` treated an empty list as
|
||||
absent and fell through to ``elif self.mentioned_document_ids:``,
|
||||
replaying turn 1's mentions. This test pins down the corrected
|
||||
behaviour.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
# Simulate a cached middleware instance whose closure was seeded
|
||||
# by a previous turn's cache-miss build (mentions=[1,2,3]).
|
||||
middleware = KnowledgePriorityMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[1, 2, 3],
|
||||
)
|
||||
|
||||
# Turn 2: streaming task supplies an EMPTY mention list (no
|
||||
# mentions on this follow-up turn).
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what about the next steps?")]},
|
||||
runtime=self._make_runtime([]),
|
||||
)
|
||||
|
||||
assert fetched_ids == [], (
|
||||
"fetch_mentioned_documents must NOT be called when the runtime "
|
||||
"context says there are no mentions for this turn"
|
||||
)
|
||||
|
||||
async def test_legacy_path_fires_only_when_runtime_context_absent(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Backward-compat: if a caller doesn't supply runtime.context (old
|
||||
non-streaming code path), the closure-injected mentions are still
|
||||
honoured exactly once and then drained.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgePriorityMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[7, 8],
|
||||
)
|
||||
|
||||
# First call: no runtime → legacy path uses the closure.
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="initial question")]},
|
||||
runtime=None,
|
||||
)
|
||||
# Second call: still no runtime — closure already drained, so no replay.
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="follow up")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert fetched_ids == [[7, 8]], (
|
||||
"legacy path must honour the closure exactly once and then drain it"
|
||||
)
|
||||
assert middleware.mentioned_document_ids == []
|
||||
Loading…
Add table
Add a link
Reference in a new issue