Add LangGraph handlers for chat model, chain, tool, and custom events.

This commit is contained in:
CREDO23 2026-05-06 20:08:48 +02:00
parent 7581a7c9c3
commit ee16e1d5f9
8 changed files with 586 additions and 0 deletions

View file

@ -0,0 +1,3 @@
"""LangGraph stream handlers by event kind."""
from __future__ import annotations

View file

@ -0,0 +1,23 @@
"""Close open text when a LangGraph chain or agent node finishes."""
from __future__ import annotations
from collections.abc import Iterator
from typing import Any
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
def iter_chain_end_frames(
_event: dict[str, Any],
*,
state: AgentEventRelayState,
streaming_service: Any,
content_builder: Any | None,
) -> Iterator[str]:
"""Close the open text stream if one is open."""
if state.current_text_id is not None:
yield streaming_service.format_text_end(state.current_text_id)
if content_builder is not None:
content_builder.on_text_end(state.current_text_id)
state.current_text_id = None

View file

@ -0,0 +1,149 @@
"""Chat model stream: text, reasoning, and tool-call chunk SSE."""
from __future__ import annotations
from collections.abc import Iterator
from typing import Any
from app.tasks.chat.streaming.helpers.chunk_parts import extract_chunk_parts
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.thinking_step_completion import (
complete_active_thinking_step,
)
def iter_chat_model_stream_frames(
event: dict[str, Any],
*,
state: AgentEventRelayState,
streaming_service: Any,
content_builder: Any | None,
step_prefix: str,
) -> Iterator[str]:
"""SSE frames for one chat-model chunk."""
if state.active_tool_depth > 0:
return
if "surfsense:internal" in event.get("tags", []):
return
chunk = event.get("data", {}).get("chunk")
if not chunk:
return
parts = extract_chunk_parts(chunk)
reasoning_delta = parts["reasoning"]
text_delta = parts["text"]
if state.parity_v2 and reasoning_delta:
if state.current_text_id is not None:
yield streaming_service.format_text_end(state.current_text_id)
if content_builder is not None:
content_builder.on_text_end(state.current_text_id)
state.current_text_id = None
if state.current_reasoning_id is None:
comp, new_active = complete_active_thinking_step(
streaming_service=streaming_service,
content_builder=content_builder,
last_active_step_id=state.last_active_step_id,
last_active_step_title=state.last_active_step_title,
last_active_step_items=state.last_active_step_items,
completed_step_ids=state.completed_step_ids,
)
if comp:
yield comp
state.last_active_step_id = new_active
if state.just_finished_tool:
state.last_active_step_id = None
state.last_active_step_title = ""
state.last_active_step_items = []
state.just_finished_tool = False
state.current_reasoning_id = streaming_service.generate_reasoning_id()
yield streaming_service.format_reasoning_start(state.current_reasoning_id)
if content_builder is not None:
content_builder.on_reasoning_start(state.current_reasoning_id)
yield streaming_service.format_reasoning_delta(
state.current_reasoning_id, reasoning_delta
)
if content_builder is not None:
content_builder.on_reasoning_delta(
state.current_reasoning_id, reasoning_delta
)
if text_delta:
if state.current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(state.current_reasoning_id)
if content_builder is not None:
content_builder.on_reasoning_end(state.current_reasoning_id)
state.current_reasoning_id = None
if state.current_text_id is None:
comp, new_active = complete_active_thinking_step(
streaming_service=streaming_service,
content_builder=content_builder,
last_active_step_id=state.last_active_step_id,
last_active_step_title=state.last_active_step_title,
last_active_step_items=state.last_active_step_items,
completed_step_ids=state.completed_step_ids,
)
if comp:
yield comp
state.last_active_step_id = new_active
if state.just_finished_tool:
state.last_active_step_id = None
state.last_active_step_title = ""
state.last_active_step_items = []
state.just_finished_tool = False
state.current_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(state.current_text_id)
if content_builder is not None:
content_builder.on_text_start(state.current_text_id)
yield streaming_service.format_text_delta(state.current_text_id, text_delta)
state.accumulated_text += text_delta
if content_builder is not None:
content_builder.on_text_delta(state.current_text_id, text_delta)
if state.parity_v2 and parts["tool_call_chunks"]:
for tcc in parts["tool_call_chunks"]:
idx = tcc.get("index")
if idx is not None and idx not in state.index_to_meta:
lc_id = tcc.get("id")
name = tcc.get("name")
if lc_id and name:
ui_id = lc_id
if state.current_text_id is not None:
yield streaming_service.format_text_end(state.current_text_id)
if content_builder is not None:
content_builder.on_text_end(state.current_text_id)
state.current_text_id = None
if state.current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(
state.current_reasoning_id
)
if content_builder is not None:
content_builder.on_reasoning_end(state.current_reasoning_id)
state.current_reasoning_id = None
state.index_to_meta[idx] = {
"ui_id": ui_id,
"lc_id": lc_id,
"name": name,
}
yield streaming_service.format_tool_input_start(
ui_id,
name,
langchain_tool_call_id=lc_id,
)
if content_builder is not None:
content_builder.on_tool_input_start(ui_id, name, lc_id)
meta = state.index_to_meta.get(idx) if idx is not None else None
if meta:
args_chunk = tcc.get("args") or ""
if args_chunk:
yield streaming_service.format_tool_input_delta(
meta["ui_id"], args_chunk
)
if content_builder is not None:
content_builder.on_tool_input_delta(meta["ui_id"], args_chunk)
else:
state.pending_tool_call_chunks.append(tcc)

View file

@ -0,0 +1,56 @@
"""Custom graph events routed to SSE (documents, action logs, report progress)."""
from __future__ import annotations
from collections.abc import Iterator
from typing import Any
from app.tasks.chat.streaming.handlers.custom_events import (
handle_action_log,
handle_action_log_updated,
handle_document_created,
handle_report_progress,
)
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
def iter_custom_event_frames(
event: dict[str, Any],
*,
state: AgentEventRelayState,
streaming_service: Any,
content_builder: Any | None,
) -> Iterator[str]:
"""Yield any SSE produced by ad-hoc graph events (documents, action logs, report progress)."""
name = event.get("name")
data = event.get("data", {})
if name == "report_progress":
frame, state.last_active_step_items = handle_report_progress(
data,
last_active_step_id=state.last_active_step_id,
last_active_step_title=state.last_active_step_title,
last_active_step_items=state.last_active_step_items,
streaming_service=streaming_service,
content_builder=content_builder,
)
if frame:
yield frame
return
if name == "document_created":
frame = handle_document_created(data, streaming_service=streaming_service)
if frame:
yield frame
return
if name == "action_log":
frame = handle_action_log(data, streaming_service=streaming_service)
if frame:
yield frame
return
if name == "action_log_updated":
frame = handle_action_log_updated(data, streaming_service=streaming_service)
if frame:
yield frame

View file

@ -0,0 +1,77 @@
"""Custom-event payloads turned into SSE (no model/tool stream handling)."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
def handle_report_progress(
data: dict[str, Any],
*,
last_active_step_id: str | None,
last_active_step_title: str,
last_active_step_items: list[str],
streaming_service: Any,
content_builder: Any | None,
) -> tuple[str | None, list[str]]:
"""Update report step items; may emit one thinking SSE frame.
Returns (frame or None, items list after update).
"""
message = data.get("message", "")
if not message or not last_active_step_id:
return None, last_active_step_items
phase = data.get("phase", "")
topic_items = [
item for item in last_active_step_items if item.startswith("Topic:")
]
if phase in ("revising_section", "adding_section"):
plan_items = [
item
for item in last_active_step_items
if item.startswith("Topic:")
or item.startswith("Modifying ")
or item.startswith("Adding ")
or item.startswith("Removing ")
]
plan_items = [item for item in plan_items if not item.endswith("...")]
new_items = [*plan_items, message]
else:
new_items = [*topic_items, message]
frame = emit_thinking_step_frame(
streaming_service=streaming_service,
content_builder=content_builder,
step_id=last_active_step_id,
title=last_active_step_title,
status="in_progress",
items=new_items,
)
return frame, new_items
def handle_document_created(data: dict[str, Any], *, streaming_service: Any) -> str | None:
if not data.get("id"):
return None
return streaming_service.format_data(
"documents-updated",
{"action": "created", "document": data},
)
def handle_action_log(data: dict[str, Any], *, streaming_service: Any) -> str | None:
if data.get("id") is None:
return None
return streaming_service.format_data("action-log", data)
def handle_action_log_updated(
data: dict[str, Any], *, streaming_service: Any
) -> str | None:
if data.get("id") is None:
return None
return streaming_service.format_data("action-log-updated", data)

View file

@ -0,0 +1,112 @@
"""Tool end: thinking completion, tool output, and terminal SSE."""
from __future__ import annotations
import json
from collections.abc import Iterator
from typing import Any
from app.tasks.chat.streaming.handlers.tools import (
ToolCompletionEmissionContext,
iter_tool_completion_emission_frames,
resolve_tool_completed_thinking_step,
)
from app.tasks.chat.streaming.helpers.tool_output import tool_output_has_error
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
def iter_tool_end_frames(
event: dict[str, Any],
*,
state: AgentEventRelayState,
streaming_service: Any,
content_builder: Any | None,
result: Any,
step_prefix: str,
config: dict[str, Any],
) -> Iterator[str]:
"""SSE frames when one tool run finishes."""
state.active_tool_depth = max(0, state.active_tool_depth - 1)
run_id = event.get("run_id", "")
tool_name = event.get("name", "unknown_tool")
raw_output = event.get("data", {}).get("output", "")
staged_file_path = (
state.file_path_by_run.pop(run_id, None) if run_id else None
)
if tool_name == "update_memory":
state.called_update_memory = True
if hasattr(raw_output, "content"):
content = raw_output.content
if isinstance(content, str):
try:
tool_output = json.loads(content)
except (json.JSONDecodeError, TypeError):
tool_output = {"result": content}
elif isinstance(content, dict):
tool_output = content
else:
tool_output = {"result": str(content)}
elif isinstance(raw_output, dict):
tool_output = raw_output
else:
tool_output = {"result": str(raw_output) if raw_output else "completed"}
if tool_name in ("write_file", "edit_file"):
if tool_output_has_error(tool_output):
pass
else:
result.write_succeeded = True
result.verification_succeeded = True
tool_call_id = state.ui_tool_call_id_by_run.get(
run_id,
f"call_{run_id[:32]}" if run_id else "call_unknown",
)
original_step_id = state.tool_step_ids.get(
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
)
state.completed_step_ids.add(original_step_id)
holder = state.current_lc_tool_call_id
holder["value"] = None
authoritative = getattr(raw_output, "tool_call_id", None)
if isinstance(authoritative, str) and authoritative:
holder["value"] = authoritative
if run_id:
state.lc_tool_call_id_by_run[run_id] = authoritative
elif run_id and run_id in state.lc_tool_call_id_by_run:
holder["value"] = state.lc_tool_call_id_by_run[run_id]
items = state.last_active_step_items
title, completed_items = resolve_tool_completed_thinking_step(
tool_name, tool_output, items
)
yield emit_thinking_step_frame(
streaming_service=streaming_service,
content_builder=content_builder,
step_id=original_step_id,
title=title,
status="completed",
items=completed_items,
)
state.just_finished_tool = True
state.last_active_step_id = None
state.last_active_step_title = ""
state.last_active_step_items = []
emission_ctx = ToolCompletionEmissionContext(
tool_name=tool_name,
tool_call_id=tool_call_id,
tool_output=tool_output,
streaming_service=streaming_service,
content_builder=content_builder,
langchain_tool_call_id_holder=holder,
stream_result=result,
langgraph_config=config,
staged_workspace_file_path=staged_file_path,
)
yield from iter_tool_completion_emission_frames(emission_ctx)

View file

@ -0,0 +1,24 @@
"""Emit tool-output SSE and optional assistant content updates."""
from __future__ import annotations
from typing import Any
def emit_tool_output_available_frame(
*,
streaming_service: Any,
content_builder: Any | None,
langchain_id_holder: dict[str, str | None],
call_id: str,
output: Any,
) -> str:
if content_builder is not None:
content_builder.on_tool_output_available(
call_id, output, langchain_id_holder["value"]
)
return streaming_service.format_tool_output_available(
call_id,
output,
langchain_tool_call_id=langchain_id_holder["value"],
)

View file

@ -0,0 +1,142 @@
"""Tool start: thinking-step and tool-input SSE."""
from __future__ import annotations
import json
from collections.abc import Iterator
from typing import Any
from app.tasks.chat.streaming.handlers.tools import resolve_tool_start_thinking
from app.tasks.chat.streaming.helpers.tool_call_matching import (
match_buffered_langchain_tool_call_id,
)
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.thinking_step_completion import (
complete_active_thinking_step,
)
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
def iter_tool_start_frames(
event: dict[str, Any],
*,
state: AgentEventRelayState,
streaming_service: Any,
content_builder: Any | None,
result: Any,
step_prefix: str,
) -> Iterator[str]:
"""SSE frames for the start of one tool run."""
state.active_tool_depth += 1
tool_name = event.get("name", "unknown_tool")
run_id = event.get("run_id", "")
tool_input = event.get("data", {}).get("input", {})
if tool_name in ("write_file", "edit_file"):
result.write_attempted = True
if isinstance(tool_input, dict):
file_path = tool_input.get("file_path")
if isinstance(file_path, str) and file_path.strip() and run_id:
state.file_path_by_run[run_id] = file_path.strip()
if state.current_text_id is not None:
yield streaming_service.format_text_end(state.current_text_id)
if content_builder is not None:
content_builder.on_text_end(state.current_text_id)
state.current_text_id = None
if state.last_active_step_title != "Synthesizing response":
comp, new_active = complete_active_thinking_step(
streaming_service=streaming_service,
content_builder=content_builder,
last_active_step_id=state.last_active_step_id,
last_active_step_title=state.last_active_step_title,
last_active_step_items=state.last_active_step_items,
completed_step_ids=state.completed_step_ids,
)
if comp:
yield comp
state.last_active_step_id = new_active
state.just_finished_tool = False
tool_step_id = state.next_thinking_step_id(step_prefix)
state.tool_step_ids[run_id] = tool_step_id
state.last_active_step_id = tool_step_id
thinking = resolve_tool_start_thinking(tool_name, tool_input)
state.last_active_step_title = thinking.title
state.last_active_step_items = thinking.items
frame_kw: dict[str, Any] = {
"streaming_service": streaming_service,
"content_builder": content_builder,
"step_id": tool_step_id,
"title": thinking.title,
"status": "in_progress",
}
if thinking.include_items_on_frame:
frame_kw["items"] = thinking.items
yield emit_thinking_step_frame(**frame_kw)
matched_meta: dict[str, str] | None = None
if state.parity_v2:
taken_ui_ids = set(state.ui_tool_call_id_by_run.values())
for meta in state.index_to_meta.values():
if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
matched_meta = meta
break
tool_call_id: str
langchain_tool_call_id: str | None = None
if matched_meta is not None:
tool_call_id = matched_meta["ui_id"]
langchain_tool_call_id = matched_meta["lc_id"]
if run_id:
state.lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"]
else:
tool_call_id = (
f"call_{run_id[:32]}"
if run_id
else streaming_service.generate_tool_call_id()
)
if state.parity_v2:
langchain_tool_call_id = match_buffered_langchain_tool_call_id(
state.pending_tool_call_chunks,
tool_name,
run_id,
state.lc_tool_call_id_by_run,
)
yield streaming_service.format_tool_input_start(
tool_call_id,
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
)
if content_builder is not None:
content_builder.on_tool_input_start(
tool_call_id, tool_name, langchain_tool_call_id
)
if run_id:
state.ui_tool_call_id_by_run[run_id] = tool_call_id
if isinstance(tool_input, dict):
_safe_input: dict[str, Any] = {}
for _k, _v in tool_input.items():
try:
json.dumps(_v)
_safe_input[_k] = _v
except (TypeError, ValueError, OverflowError):
pass
else:
_safe_input = {"input": tool_input}
yield streaming_service.format_tool_input_available(
tool_call_id,
tool_name,
_safe_input,
langchain_tool_call_id=langchain_tool_call_id,
)
if content_builder is not None:
content_builder.on_tool_input_available(
tool_call_id,
tool_name,
_safe_input,
langchain_tool_call_id,
)