2026-04-23 15:03:32 +05:30
""" Semantic file-intent routing middleware for new chat turns.
This middleware classifies the latest human turn into a small intent set :
- chat_only
- file_write
- file_read
For ` ` file_write ` ` turns it injects a strict system contract so the model
uses filesystem tools before claiming success , and provides a deterministic
fallback path when no filename is specified by the user .
"""
from __future__ import annotations
import json
import logging
import re
from datetime import UTC , datetime
from enum import StrEnum
from typing import Any
from langchain . agents . middleware import AgentMiddleware , AgentState
from langchain_core . language_models import BaseChatModel
2026-04-28 04:32:52 -07:00
from langchain_core . messages import AIMessage , BaseMessage , HumanMessage , SystemMessage
2026-04-23 15:03:32 +05:30
from langgraph . runtime import Runtime
from pydantic import BaseModel , Field , ValidationError
logger = logging . getLogger ( __name__ )
class FileOperationIntent ( StrEnum ) :
CHAT_ONLY = " chat_only "
FILE_WRITE = " file_write "
FILE_READ = " file_read "
class FileIntentPlan ( BaseModel ) :
intent : FileOperationIntent = Field (
description = " Primary user intent for this turn. "
)
confidence : float = Field (
ge = 0.0 ,
le = 1.0 ,
default = 0.5 ,
description = " Model confidence in the selected intent. " ,
)
suggested_filename : str | None = Field (
default = None ,
description = " Optional filename (e.g. notes.md) inferred from user request. " ,
)
2026-04-24 05:59:21 +05:30
suggested_directory : str | None = Field (
default = None ,
description = (
" Optional directory path (e.g. /reports/q2 or reports/q2) inferred from "
" user request. "
) ,
)
suggested_path : str | None = Field (
default = None ,
description = (
" Optional full file path (e.g. /reports/q2/summary.md). If present, this "
" takes precedence over suggested_directory + suggested_filename. "
) ,
)
2026-04-23 15:03:32 +05:30
def _extract_text_from_message ( message : BaseMessage ) - > str :
content = getattr ( message , " content " , " " )
if isinstance ( content , str ) :
return content
if isinstance ( content , list ) :
parts : list [ str ] = [ ]
for item in content :
if isinstance ( item , str ) :
parts . append ( item )
elif isinstance ( item , dict ) and item . get ( " type " ) == " text " :
parts . append ( str ( item . get ( " text " , " " ) ) )
return " \n " . join ( part for part in parts if part )
return str ( content )
def _extract_json_payload ( text : str ) - > str :
stripped = text . strip ( )
fenced = re . search ( r " ```(?:json)? \ s*( \ { .*? \ }) \ s*``` " , stripped , re . DOTALL )
if fenced :
return fenced . group ( 1 )
start = stripped . find ( " { " )
end = stripped . rfind ( " } " )
if start != - 1 and end != - 1 and end > start :
return stripped [ start : end + 1 ]
return stripped
def _sanitize_filename ( value : str ) - > str :
name = re . sub ( r " [ \\ /:*? \" <>|]+ " , " _ " , value ) . strip ( )
name = re . sub ( r " \ s+ " , " - " , name )
name = name . strip ( " ._- " )
if not name :
name = " note "
if len ( name ) > 80 :
name = name [ : 80 ] . rstrip ( " -_. " )
return name
2026-04-24 05:59:21 +05:30
def _sanitize_path_segment ( value : str ) - > str :
segment = re . sub ( r " [ \\ /:*? \" <>|]+ " , " _ " , value ) . strip ( )
segment = re . sub ( r " \ s+ " , " _ " , segment )
segment = segment . strip ( " ._- " )
return segment
def _normalize_directory ( value : str ) - > str :
raw = value . strip ( ) . replace ( " \\ " , " / " )
raw = raw . strip ( " / " )
if not raw :
return " "
parts = [ _sanitize_path_segment ( part ) for part in raw . split ( " / " ) if part . strip ( ) ]
parts = [ part for part in parts if part ]
return " / " . join ( parts )
def _normalize_file_path ( value : str ) - > str :
raw = value . strip ( ) . replace ( " \\ " , " / " ) . strip ( )
if not raw :
return " "
had_trailing_slash = raw . endswith ( " / " )
raw = raw . strip ( " / " )
if not raw :
return " "
parts = [ _sanitize_path_segment ( part ) for part in raw . split ( " / " ) if part . strip ( ) ]
parts = [ part for part in parts if part ]
if not parts :
return " "
if had_trailing_slash :
return f " / { ' / ' . join ( parts ) } / "
return f " / { ' / ' . join ( parts ) } "
def _infer_directory_from_user_text ( user_text : str ) - > str | None :
patterns = (
r " \ b(?:in|inside|under) \ s+(?:the \ s+)?([a-zA-Z0-9 _ \ -/]+?) \ s+folder \ b " ,
r " \ b(?:in|inside|under) \ s+([a-zA-Z0-9 _ \ -/]+?) \ b " ,
)
lowered = user_text . lower ( )
for pattern in patterns :
match = re . search ( pattern , lowered , flags = re . IGNORECASE )
if not match :
continue
candidate = match . group ( 1 ) . strip ( )
if candidate in { " the " , " a " , " an " } :
continue
normalized = _normalize_directory ( candidate )
if normalized :
return normalized
return None
def _fallback_path (
suggested_filename : str | None ,
* ,
suggested_directory : str | None = None ,
suggested_path : str | None = None ,
user_text : str ,
) - > str :
inferred_dir = _infer_directory_from_user_text ( user_text )
sanitized_filename = " "
2026-04-23 15:03:32 +05:30
if suggested_filename :
2026-04-24 05:59:21 +05:30
sanitized_filename = _sanitize_filename ( suggested_filename )
if sanitized_filename . lower ( ) . endswith ( " .txt " ) :
sanitized_filename = f " { sanitized_filename [ : - 4 ] } .md "
if not sanitized_filename :
2026-04-28 00:57:07 +05:30
sanitized_filename = " notes.md "
2026-04-24 05:59:21 +05:30
elif " . " not in sanitized_filename :
2026-04-28 00:57:07 +05:30
sanitized_filename = f " { sanitized_filename } .md "
2026-04-24 05:59:21 +05:30
normalized_suggested_path = (
_normalize_file_path ( suggested_path ) if suggested_path else " "
)
if normalized_suggested_path :
if normalized_suggested_path . endswith ( " / " ) :
return f " { normalized_suggested_path . rstrip ( ' / ' ) } / { sanitized_filename } "
return normalized_suggested_path
directory = _normalize_directory ( suggested_directory or " " )
if not directory and inferred_dir :
directory = inferred_dir
if directory :
return f " / { directory } / { sanitized_filename } "
return f " / { sanitized_filename } "
2026-04-23 15:03:32 +05:30
def _build_classifier_prompt ( * , recent_conversation : str , user_text : str ) - > str :
return (
" Classify the latest user request into a filesystem intent for an AI agent. \n "
" Return JSON only with this exact schema: \n "
2026-04-24 05:59:21 +05:30
' { " intent " : " chat_only|file_write|file_read " , " confidence " :0.0, " suggested_filename " : " string or null " , " suggested_directory " : " string or null " , " suggested_path " : " string or null " } \n \n '
2026-04-23 15:03:32 +05:30
" Rules: \n "
" - Use semantic intent, not literal keywords. \n "
" - file_write: user asks to create/save/write/update/edit content as a file. \n "
" - file_read: user asks to open/read/list/search existing files. \n "
" - chat_only: conversational/analysis responses without required file operations. \n "
" - For file_write, choose a concise semantic suggested_filename and match the requested format. \n "
2026-04-24 05:59:21 +05:30
" - If the user mentions a folder/directory, populate suggested_directory. \n "
" - If user specifies an explicit full path, populate suggested_path. \n "
2026-04-23 15:03:32 +05:30
" - Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql). \n "
" - Do not use .txt; prefer .md for generic text notes. \n "
" - Do not include dates or timestamps in suggested_filename unless explicitly requested. \n "
" - Never include markdown or explanation. \n \n "
f " Recent conversation: \n { recent_conversation or ' (none) ' } \n \n "
f " Latest user message: \n { user_text } "
)
2026-04-27 14:04:50 -07:00
def _build_recent_conversation (
messages : list [ BaseMessage ] , * , max_messages : int = 6
) - > str :
2026-04-23 15:03:32 +05:30
rows : list [ str ] = [ ]
2026-04-28 04:32:52 -07:00
filtered : list [ tuple [ str , BaseMessage ] ] = [ ]
for msg in messages :
role : str | None = None
if isinstance ( msg , HumanMessage ) :
role = " user "
elif isinstance ( msg , AIMessage ) :
if getattr ( msg , " tool_calls " , None ) :
continue
role = " assistant "
else :
continue
filtered . append ( ( role , msg ) )
for role , msg in filtered [ - max_messages : ] :
2026-04-23 15:03:32 +05:30
text = re . sub ( r " \ s+ " , " " , _extract_text_from_message ( msg ) ) . strip ( )
if text :
rows . append ( f " { role } : { text [ : 280 ] } " )
return " \n " . join ( rows )
class FileIntentMiddleware ( AgentMiddleware ) : # type: ignore[type-arg]
""" Classify file intent and inject a strict file-write contract. """
tools = ( )
def __init__ ( self , * , llm : BaseChatModel | None = None ) - > None :
self . llm = llm
async def _classify_intent (
self , * , messages : list [ BaseMessage ] , user_text : str
) - > FileIntentPlan :
if self . llm is None :
return FileIntentPlan ( intent = FileOperationIntent . CHAT_ONLY , confidence = 0.0 )
prompt = _build_classifier_prompt (
recent_conversation = _build_recent_conversation ( messages ) ,
user_text = user_text ,
)
try :
response = await self . llm . ainvoke (
[ HumanMessage ( content = prompt ) ] ,
config = { " tags " : [ " surfsense:internal " ] } ,
)
2026-04-27 14:04:50 -07:00
payload = json . loads (
_extract_json_payload ( _extract_text_from_message ( response ) )
)
2026-04-23 15:03:32 +05:30
plan = FileIntentPlan . model_validate ( payload )
return plan
except ( json . JSONDecodeError , ValidationError , ValueError ) as exc :
logger . warning ( " File intent classifier returned invalid output: %s " , exc )
except Exception as exc : # pragma: no cover - defensive fallback
logger . warning ( " File intent classifier failed: %s " , exc )
return FileIntentPlan ( intent = FileOperationIntent . CHAT_ONLY , confidence = 0.0 )
async def abefore_agent ( # type: ignore[override]
self ,
state : AgentState ,
runtime : Runtime [ Any ] ,
) - > dict [ str , Any ] | None :
del runtime
messages = state . get ( " messages " ) or [ ]
if not messages :
return None
last_human : HumanMessage | None = None
for msg in reversed ( messages ) :
if isinstance ( msg , HumanMessage ) :
last_human = msg
break
if last_human is None :
return None
user_text = _extract_text_from_message ( last_human ) . strip ( )
if not user_text :
return None
plan = await self . _classify_intent ( messages = messages , user_text = user_text )
2026-04-24 05:59:21 +05:30
suggested_path = _fallback_path (
plan . suggested_filename ,
suggested_directory = plan . suggested_directory ,
suggested_path = plan . suggested_path ,
user_text = user_text ,
)
2026-04-23 15:03:32 +05:30
contract = {
" intent " : plan . intent . value ,
" confidence " : plan . confidence ,
" suggested_path " : suggested_path ,
" timestamp " : datetime . now ( UTC ) . isoformat ( ) ,
" turn_id " : state . get ( " turn_id " , " " ) ,
}
if plan . intent != FileOperationIntent . FILE_WRITE :
return { " file_operation_contract " : contract }
contract_msg = SystemMessage (
content = (
" <file_operation_contract> \n "
" This turn intent is file_write. \n "
f " Suggested default path: { suggested_path } \n "
" Rules: \n "
" - You MUST call write_file or edit_file before claiming success. \n "
" - If no path is provided by the user, use the suggested default path. \n "
" - Do not claim a file was created/updated unless tool output confirms it. \n "
" - If the write/edit fails, clearly report failure instead of success. \n "
" - Do not include timestamps or dates in generated file content unless the user explicitly asks for them. \n "
" - For open-ended requests (e.g., random note), generate useful concrete content, not placeholders. \n "
" </file_operation_contract> "
)
)
# Insert just before the latest human turn so it applies to this request.
new_messages = list ( messages )
insert_at = max ( len ( new_messages ) - 1 , 0 )
new_messages . insert ( insert_at , contract_msg )
return { " messages " : new_messages , " file_operation_contract " : contract }