diff --git a/surfsense_backend/app/agents/new_chat/middleware/file_intent.py b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py new file mode 100644 index 000000000..e264a939c --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py @@ -0,0 +1,253 @@ +"""Semantic file-intent routing middleware for new chat turns. + +This middleware classifies the latest human turn into a small intent set: +- chat_only +- file_write +- file_read + +For ``file_write`` turns it injects a strict system contract so the model +uses filesystem tools before claiming success, and provides a deterministic +fallback path when no filename is specified by the user. +""" + +from __future__ import annotations + +import json +import logging +import re +from datetime import UTC, datetime +from enum import StrEnum +from typing import Any + +from langchain.agents.middleware import AgentMiddleware, AgentState +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langgraph.runtime import Runtime +from pydantic import BaseModel, Field, ValidationError + +logger = logging.getLogger(__name__) + + +class FileOperationIntent(StrEnum): + CHAT_ONLY = "chat_only" + FILE_WRITE = "file_write" + FILE_READ = "file_read" + + +class FileIntentPlan(BaseModel): + intent: FileOperationIntent = Field( + description="Primary user intent for this turn." + ) + confidence: float = Field( + ge=0.0, + le=1.0, + default=0.5, + description="Model confidence in the selected intent.", + ) + suggested_filename: str | None = Field( + default=None, + description="Optional filename (e.g. notes.md) inferred from user request.", + ) + + +def _extract_text_from_message(message: BaseMessage) -> str: + content = getattr(message, "content", "") + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict) and item.get("type") == "text": + parts.append(str(item.get("text", ""))) + return "\n".join(part for part in parts if part) + return str(content) + + +def _extract_json_payload(text: str) -> str: + stripped = text.strip() + fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL) + if fenced: + return fenced.group(1) + start = stripped.find("{") + end = stripped.rfind("}") + if start != -1 and end != -1 and end > start: + return stripped[start : end + 1] + return stripped + + +def _sanitize_filename(value: str) -> str: + name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip() + name = re.sub(r"\s+", "-", name) + name = name.strip("._-") + if not name: + name = "note" + if len(name) > 80: + name = name[:80].rstrip("-_.") + return name + + +def _infer_text_file_extension(user_text: str) -> str: + lowered = user_text.lower() + if any(token in lowered for token in ("json", ".json")): + return ".json" + if any(token in lowered for token in ("yaml", "yml", ".yaml", ".yml")): + return ".yaml" + if any(token in lowered for token in ("csv", ".csv")): + return ".csv" + if any(token in lowered for token in ("python", ".py")): + return ".py" + if any(token in lowered for token in ("typescript", ".ts", ".tsx")): + return ".ts" + if any(token in lowered for token in ("javascript", ".js", ".mjs", ".cjs")): + return ".js" + if any(token in lowered for token in ("html", ".html")): + return ".html" + if any(token in lowered for token in ("css", ".css")): + return ".css" + if any(token in lowered for token in ("sql", ".sql")): + return ".sql" + if any(token in lowered for token in ("toml", ".toml")): + return ".toml" + if any(token in lowered for token in ("ini", ".ini")): + return ".ini" + if any(token in lowered for token in ("xml", ".xml")): + return ".xml" + if any(token in lowered for token in ("markdown", ".md", "readme")): + return ".md" + return ".md" + + +def _fallback_path(suggested_filename: str | None, *, user_text: str) -> str: + default_extension = _infer_text_file_extension(user_text) + if suggested_filename: + sanitized = _sanitize_filename(suggested_filename) + if sanitized.lower().endswith(".txt"): + sanitized = f"{sanitized[:-4]}.md" + if "." not in sanitized: + sanitized = f"{sanitized}{default_extension}" + return f"/{sanitized}" + return f"/notes{default_extension}" + + +def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str: + return ( + "Classify the latest user request into a filesystem intent for an AI agent.\n" + "Return JSON only with this exact schema:\n" + '{"intent":"chat_only|file_write|file_read","confidence":0.0,"suggested_filename":"string or null"}\n\n' + "Rules:\n" + "- Use semantic intent, not literal keywords.\n" + "- file_write: user asks to create/save/write/update/edit content as a file.\n" + "- file_read: user asks to open/read/list/search existing files.\n" + "- chat_only: conversational/analysis responses without required file operations.\n" + "- For file_write, choose a concise semantic suggested_filename and match the requested format.\n" + "- Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql).\n" + "- Do not use .txt; prefer .md for generic text notes.\n" + "- Do not include dates or timestamps in suggested_filename unless explicitly requested.\n" + "- Never include markdown or explanation.\n\n" + f"Recent conversation:\n{recent_conversation or '(none)'}\n\n" + f"Latest user message:\n{user_text}" + ) + + +def _build_recent_conversation(messages: list[BaseMessage], *, max_messages: int = 6) -> str: + rows: list[str] = [] + for msg in messages[-max_messages:]: + role = "user" if isinstance(msg, HumanMessage) else "assistant" + text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip() + if text: + rows.append(f"{role}: {text[:280]}") + return "\n".join(rows) + + +class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg] + """Classify file intent and inject a strict file-write contract.""" + + tools = () + + def __init__(self, *, llm: BaseChatModel | None = None) -> None: + self.llm = llm + + async def _classify_intent( + self, *, messages: list[BaseMessage], user_text: str + ) -> FileIntentPlan: + if self.llm is None: + return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0) + + prompt = _build_classifier_prompt( + recent_conversation=_build_recent_conversation(messages), + user_text=user_text, + ) + try: + response = await self.llm.ainvoke( + [HumanMessage(content=prompt)], + config={"tags": ["surfsense:internal"]}, + ) + payload = json.loads(_extract_json_payload(_extract_text_from_message(response))) + plan = FileIntentPlan.model_validate(payload) + return plan + except (json.JSONDecodeError, ValidationError, ValueError) as exc: + logger.warning("File intent classifier returned invalid output: %s", exc) + except Exception as exc: # pragma: no cover - defensive fallback + logger.warning("File intent classifier failed: %s", exc) + + return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0) + + async def abefore_agent( # type: ignore[override] + self, + state: AgentState, + runtime: Runtime[Any], + ) -> dict[str, Any] | None: + del runtime + messages = state.get("messages") or [] + if not messages: + return None + + last_human: HumanMessage | None = None + for msg in reversed(messages): + if isinstance(msg, HumanMessage): + last_human = msg + break + if last_human is None: + return None + + user_text = _extract_text_from_message(last_human).strip() + if not user_text: + return None + + plan = await self._classify_intent(messages=messages, user_text=user_text) + suggested_path = _fallback_path(plan.suggested_filename, user_text=user_text) + contract = { + "intent": plan.intent.value, + "confidence": plan.confidence, + "suggested_path": suggested_path, + "timestamp": datetime.now(UTC).isoformat(), + "turn_id": state.get("turn_id", ""), + } + + if plan.intent != FileOperationIntent.FILE_WRITE: + return {"file_operation_contract": contract} + + contract_msg = SystemMessage( + content=( + "\n" + "This turn intent is file_write.\n" + f"Suggested default path: {suggested_path}\n" + "Rules:\n" + "- You MUST call write_file or edit_file before claiming success.\n" + "- If no path is provided by the user, use the suggested default path.\n" + "- Do not claim a file was created/updated unless tool output confirms it.\n" + "- If the write/edit fails, clearly report failure instead of success.\n" + "- Do not include timestamps or dates in generated file content unless the user explicitly asks for them.\n" + "- For open-ended requests (e.g., random note), generate useful concrete content, not placeholders.\n" + "" + ) + ) + + # Insert just before the latest human turn so it applies to this request. + new_messages = list(messages) + insert_at = max(len(new_messages) - 1, 0) + new_messages.insert(insert_at, contract_msg) + return {"messages": new_messages, "file_operation_contract": contract} + diff --git a/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py b/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py new file mode 100644 index 000000000..68876dfeb --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py @@ -0,0 +1,116 @@ +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from app.agents.new_chat.middleware.file_intent import ( + FileIntentMiddleware, + FileOperationIntent, +) + +pytestmark = pytest.mark.unit + + +class _FakeLLM: + def __init__(self, response_text: str): + self._response_text = response_text + + async def ainvoke(self, *_args, **_kwargs): + return AIMessage(content=self._response_text) + + +@pytest.mark.asyncio +async def test_file_write_intent_injects_contract_message(): + llm = _FakeLLM( + '{"intent":"file_write","confidence":0.93,"suggested_filename":"ideas.md"}' + ) + middleware = FileIntentMiddleware(llm=llm) + state = { + "messages": [HumanMessage(content="Create another random note for me")], + "turn_id": "123:456", + } + + result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] + + assert result is not None + contract = result["file_operation_contract"] + assert contract["intent"] == FileOperationIntent.FILE_WRITE.value + assert contract["suggested_path"] == "/ideas.md" + assert contract["turn_id"] == "123:456" + assert any( + "file_operation_contract" in str(msg.content) + for msg in result["messages"] + if hasattr(msg, "content") + ) + + +@pytest.mark.asyncio +async def test_non_write_intent_does_not_inject_contract_message(): + llm = _FakeLLM( + '{"intent":"file_read","confidence":0.88,"suggested_filename":null}' + ) + middleware = FileIntentMiddleware(llm=llm) + original_messages = [HumanMessage(content="Read /notes.md")] + state = {"messages": original_messages, "turn_id": "abc:def"} + + result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] + + assert result is not None + assert result["file_operation_contract"]["intent"] == FileOperationIntent.FILE_READ.value + assert "messages" not in result + + +@pytest.mark.asyncio +async def test_file_write_null_filename_uses_semantic_default_path(): + llm = _FakeLLM( + '{"intent":"file_write","confidence":0.74,"suggested_filename":null}' + ) + middleware = FileIntentMiddleware(llm=llm) + state = { + "messages": [HumanMessage(content="create a random markdown file")], + "turn_id": "turn:1", + } + + result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] + + assert result is not None + contract = result["file_operation_contract"] + assert contract["intent"] == FileOperationIntent.FILE_WRITE.value + assert contract["suggested_path"] == "/notes.md" + + +@pytest.mark.asyncio +async def test_file_write_null_filename_infers_json_extension(): + llm = _FakeLLM( + '{"intent":"file_write","confidence":0.71,"suggested_filename":null}' + ) + middleware = FileIntentMiddleware(llm=llm) + state = { + "messages": [HumanMessage(content="create a sample json config file")], + "turn_id": "turn:2", + } + + result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] + + assert result is not None + contract = result["file_operation_contract"] + assert contract["intent"] == FileOperationIntent.FILE_WRITE.value + assert contract["suggested_path"] == "/notes.json" + + +@pytest.mark.asyncio +async def test_file_write_txt_suggestion_is_normalized_to_markdown(): + llm = _FakeLLM( + '{"intent":"file_write","confidence":0.82,"suggested_filename":"random.txt"}' + ) + middleware = FileIntentMiddleware(llm=llm) + state = { + "messages": [HumanMessage(content="create a random file")], + "turn_id": "turn:3", + } + + result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] + + assert result is not None + contract = result["file_operation_contract"] + assert contract["intent"] == FileOperationIntent.FILE_WRITE.value + assert contract["suggested_path"] == "/random.md" +