mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-13 17:52:38 +02:00
Add chat streaming error classification, helpers, and StreamResult.
This commit is contained in:
parent
366122da6e
commit
c25b78c304
10 changed files with 444 additions and 0 deletions
|
|
@ -0,0 +1,3 @@
|
|||
"""Pure helpers for chat streaming."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
"""Split a model chunk into text, reasoning, and tool-call fragment lists."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||
"""Return dict with keys text, reasoning, and tool_call_chunks (merged from chunk fields)."""
|
||||
out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []}
|
||||
if chunk is None:
|
||||
return out
|
||||
|
||||
content = getattr(chunk, "content", None)
|
||||
if isinstance(content, str):
|
||||
if content:
|
||||
out["text"] = content
|
||||
elif isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
reasoning_parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
value = block.get("text") or block.get("content") or ""
|
||||
if isinstance(value, str) and value:
|
||||
text_parts.append(value)
|
||||
elif block_type == "reasoning":
|
||||
value = (
|
||||
block.get("reasoning")
|
||||
or block.get("text")
|
||||
or block.get("content")
|
||||
or ""
|
||||
)
|
||||
if isinstance(value, str) and value:
|
||||
reasoning_parts.append(value)
|
||||
elif block_type in ("tool_call_chunk", "tool_use"):
|
||||
out["tool_call_chunks"].append(block)
|
||||
if text_parts:
|
||||
out["text"] = "".join(text_parts)
|
||||
if reasoning_parts:
|
||||
out["reasoning"] = "".join(reasoning_parts)
|
||||
|
||||
additional = getattr(chunk, "additional_kwargs", None) or {}
|
||||
if isinstance(additional, dict):
|
||||
extra_reasoning = additional.get("reasoning_content")
|
||||
if isinstance(extra_reasoning, str) and extra_reasoning:
|
||||
existing = out["reasoning"]
|
||||
out["reasoning"] = (
|
||||
(existing + extra_reasoning) if existing else extra_reasoning
|
||||
)
|
||||
|
||||
extra_tool_chunks = getattr(chunk, "tool_call_chunks", None)
|
||||
if isinstance(extra_tool_chunks, list):
|
||||
for tcc in extra_tool_chunks:
|
||||
if isinstance(tcc, dict):
|
||||
out["tool_call_chunks"].append(tcc)
|
||||
|
||||
return out
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
"""Read the first interrupt payload from a LangGraph state snapshot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def first_interrupt_value(state: Any) -> dict[str, Any] | None:
|
||||
"""Return the first interrupt payload across all snapshot tasks."""
|
||||
|
||||
def _extract(candidate: Any) -> dict[str, Any] | None:
|
||||
if isinstance(candidate, dict):
|
||||
value = candidate.get("value", candidate)
|
||||
return value if isinstance(value, dict) else None
|
||||
value = getattr(candidate, "value", None)
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(candidate, list | tuple):
|
||||
for item in candidate:
|
||||
extracted = _extract(item)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
return None
|
||||
|
||||
for task in getattr(state, "tasks", ()) or ():
|
||||
try:
|
||||
interrupts = getattr(task, "interrupts", ()) or ()
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
interrupts = ()
|
||||
if not interrupts:
|
||||
extracted = _extract(task)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
continue
|
||||
for interrupt_item in interrupts:
|
||||
extracted = _extract(interrupt_item)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
|
||||
try:
|
||||
state_interrupts = getattr(state, "interrupts", ()) or ()
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
state_interrupts = ()
|
||||
extracted = _extract(state_interrupts)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
return None
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
"""Match buffered model tool-call chunks to a tool start when ids were missing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def match_buffered_langchain_tool_call_id(
|
||||
pending_tool_call_chunks: list[dict[str, Any]],
|
||||
tool_name: str,
|
||||
run_id: str,
|
||||
lc_tool_call_id_by_run: dict[str, str],
|
||||
) -> str | None:
|
||||
matched_idx: int | None = None
|
||||
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||
if tcc.get("name") == tool_name and tcc.get("id"):
|
||||
matched_idx = idx
|
||||
break
|
||||
if matched_idx is None:
|
||||
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||
if tcc.get("id"):
|
||||
matched_idx = idx
|
||||
break
|
||||
if matched_idx is None:
|
||||
return None
|
||||
matched = pending_tool_call_chunks.pop(matched_idx)
|
||||
candidate = matched.get("id")
|
||||
if isinstance(candidate, str) and candidate:
|
||||
if run_id:
|
||||
lc_tool_call_id_by_run[run_id] = candidate
|
||||
return candidate
|
||||
return None
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
"""Normalize filesystem tool payloads for SSE cards and messages."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def tool_output_to_text(tool_output: Any) -> str:
|
||||
if isinstance(tool_output, dict):
|
||||
if isinstance(tool_output.get("result"), str):
|
||||
return tool_output["result"]
|
||||
if isinstance(tool_output.get("error"), str):
|
||||
return tool_output["error"]
|
||||
return json.dumps(tool_output, ensure_ascii=False)
|
||||
return str(tool_output)
|
||||
|
||||
|
||||
def tool_output_has_error(tool_output: Any) -> bool:
|
||||
if isinstance(tool_output, dict):
|
||||
if tool_output.get("error"):
|
||||
return True
|
||||
result = tool_output.get("result")
|
||||
return bool(
|
||||
isinstance(result, str) and result.strip().lower().startswith("error:")
|
||||
)
|
||||
if isinstance(tool_output, str):
|
||||
return tool_output.strip().lower().startswith("error:")
|
||||
return False
|
||||
|
||||
|
||||
def extract_resolved_file_path(
|
||||
*, tool_name: str, tool_output: Any, tool_input: Any | None = None
|
||||
) -> str | None:
|
||||
if isinstance(tool_output, dict):
|
||||
path_value = tool_output.get("path")
|
||||
if isinstance(path_value, str) and path_value.strip():
|
||||
return path_value.strip()
|
||||
if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict):
|
||||
file_path = tool_input.get("file_path")
|
||||
if isinstance(file_path, str) and file_path.strip():
|
||||
return file_path.strip()
|
||||
return None
|
||||
Loading…
Add table
Add a link
Reference in a new issue