feat: enhance chat event streaming by tracking active tool depth to prevent inner-tool LLM token leakage

This commit is contained in:
Anish Sarkar 2026-02-19 13:11:18 +05:30
parent ce110faa5a
commit f3ec48fb00

View file

@ -226,6 +226,7 @@ async def _stream_agent_events(
last_active_step_title: str = initial_step_title last_active_step_title: str = initial_step_title
last_active_step_items: list[str] = initial_step_items or [] last_active_step_items: list[str] = initial_step_items or []
just_finished_tool: bool = False just_finished_tool: bool = False
active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool
def next_thinking_step_id() -> str: def next_thinking_step_id() -> str:
nonlocal thinking_step_counter nonlocal thinking_step_counter
@ -250,6 +251,8 @@ async def _stream_agent_events(
event_type = event.get("event", "") event_type = event.get("event", "")
if event_type == "on_chat_model_stream": if event_type == "on_chat_model_stream":
if active_tool_depth > 0:
continue # Suppress inner-tool LLM tokens from leaking into chat
chunk = event.get("data", {}).get("chunk") chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content"): if chunk and hasattr(chunk, "content"):
content = chunk.content content = chunk.content
@ -269,6 +272,7 @@ async def _stream_agent_events(
accumulated_text += content accumulated_text += content
elif event_type == "on_tool_start": elif event_type == "on_tool_start":
active_tool_depth += 1
tool_name = event.get("name", "unknown_tool") tool_name = event.get("name", "unknown_tool")
run_id = event.get("run_id", "") run_id = event.get("run_id", "")
tool_input = event.get("data", {}).get("input", {}) tool_input = event.get("data", {}).get("input", {})
@ -428,6 +432,7 @@ async def _stream_agent_events(
) )
elif event_type == "on_tool_end": elif event_type == "on_tool_end":
active_tool_depth = max(0, active_tool_depth - 1)
run_id = event.get("run_id", "") run_id = event.get("run_id", "")
tool_name = event.get("name", "unknown_tool") tool_name = event.get("name", "unknown_tool")
raw_output = event.get("data", {}).get("output", "") raw_output = event.get("data", {}).get("output", "")
@ -1309,4 +1314,4 @@ async def stream_resume_chat(
yield streaming_service.format_done() yield streaming_service.format_done()
finally: finally:
await clear_ai_responding(session, chat_id) await clear_ai_responding(session, chat_id)