mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
Merge remote-tracking branch 'upstream/dev' into fix/zero-cache-stale-replica-1355
This commit is contained in:
commit
af1d2fa430
601 changed files with 45027 additions and 4681 deletions
|
|
@ -76,6 +76,9 @@ from app.services.chat_session_state_service import (
|
|||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
|
||||
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
||||
all_interrupt_values,
|
||||
)
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
|
||||
from app.utils.user_message_multimodal import build_human_message_content
|
||||
|
|
@ -89,6 +92,21 @@ TURN_CANCELLING_BACKOFF_FACTOR = 2
|
|||
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||
|
||||
|
||||
def _resume_step_prefix(turn_id: str) -> str:
|
||||
"""Build the per-turn ``step_prefix`` for a resume invocation.
|
||||
|
||||
Each ``_stream_agent_events`` call constructs a fresh
|
||||
:class:`AgentEventRelayState` with ``thinking_step_counter=0``, so two
|
||||
consecutive resume turns would otherwise both emit ``thinking-resume-1``,
|
||||
``-2`` etc. The frontend rehydrates ``currentThinkingSteps`` from the
|
||||
immediate prior assistant message at the start of every resume — if the
|
||||
new stream's IDs collide with the seeded ones, React renders sibling
|
||||
Timeline rows with the same key. Salting with ``turn_id`` guarantees
|
||||
disjoint IDs across resumes within one thread.
|
||||
"""
|
||||
return f"thinking-resume-{turn_id}"
|
||||
|
||||
|
||||
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||
if attempt < 1:
|
||||
attempt = 1
|
||||
|
|
@ -98,47 +116,6 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
|||
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||
|
||||
|
||||
def _first_interrupt_value(state: Any) -> dict[str, Any] | None:
|
||||
"""Return the first LangGraph interrupt payload across all snapshot tasks."""
|
||||
|
||||
def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None:
|
||||
if isinstance(candidate, dict):
|
||||
value = candidate.get("value", candidate)
|
||||
return value if isinstance(value, dict) else None
|
||||
value = getattr(candidate, "value", None)
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(candidate, (list, tuple)):
|
||||
for item in candidate:
|
||||
extracted = _extract_interrupt_value(item)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
return None
|
||||
|
||||
for task in getattr(state, "tasks", ()) or ():
|
||||
try:
|
||||
interrupts = getattr(task, "interrupts", ()) or ()
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
interrupts = ()
|
||||
if not interrupts:
|
||||
extracted = _extract_interrupt_value(task)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
continue
|
||||
for interrupt_item in interrupts:
|
||||
extracted = _extract_interrupt_value(interrupt_item)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
try:
|
||||
state_interrupts = getattr(state, "interrupts", ()) or ()
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
state_interrupts = ()
|
||||
extracted = _extract_interrupt_value(state_interrupts)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
return None
|
||||
|
||||
|
||||
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
|
||||
|
||||
|
|
@ -301,7 +278,6 @@ def extract_todos_from_deepagents(command_output) -> dict:
|
|||
class StreamResult:
|
||||
accumulated_text: str = ""
|
||||
is_interrupted: bool = False
|
||||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list)
|
||||
agent_called_update_memory: bool = False
|
||||
request_id: str | None = None
|
||||
|
|
@ -915,11 +891,15 @@ async def _stream_agent_events(
|
|||
result.accumulated_text = accumulated_text
|
||||
_log_file_contract("turn_outcome", result)
|
||||
|
||||
interrupt_value = _first_interrupt_value(state)
|
||||
if interrupt_value is not None:
|
||||
pending_values = all_interrupt_values(state)
|
||||
if pending_values:
|
||||
result.is_interrupted = True
|
||||
result.interrupt_value = interrupt_value
|
||||
yield streaming_service.format_interrupt_request(result.interrupt_value)
|
||||
# One frame per paused subagent so each parallel HITL renders its own
|
||||
# approval card on the wire. Order matches ``state.interrupts``, which
|
||||
# the resume slicer in ``checkpointed_subagent_middleware.resume_routing``
|
||||
# consumes in the same order — keeping emit and resume in lock-step.
|
||||
for interrupt_value in pending_values:
|
||||
yield streaming_service.format_interrupt_request(interrupt_value)
|
||||
|
||||
|
||||
async def stream_new_chat(
|
||||
|
|
@ -2871,14 +2851,40 @@ async def stream_resume_chat(
|
|||
|
||||
from langgraph.types import Command
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
build_lg_resume_map,
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
|
||||
# Each pending interrupt is stamped with its originating ``tool_call_id``
|
||||
# (see ``checkpointed_subagent_middleware.propagation``) so we can route
|
||||
# a flat ``decisions`` list back to the right paused subagent.
|
||||
parent_state = await agent.aget_state(
|
||||
{"configurable": {"thread_id": str(chat_id)}}
|
||||
)
|
||||
pending = collect_pending_tool_calls(parent_state)
|
||||
_perf_log.info(
|
||||
"[hitl_route] resume_entry chat_id=%s decisions=%d pending_subagents=%d",
|
||||
chat_id,
|
||||
len(decisions),
|
||||
len(pending),
|
||||
)
|
||||
routed_resume_value = slice_decisions_by_tool_call(decisions, pending)
|
||||
# Langgraph rejects scalar ``Command(resume=...)`` when multiple
|
||||
# interrupts are pending (parallel HITL); the mapped form works
|
||||
# for the single-pause case too, so we always use it.
|
||||
lg_resume_map = build_lg_resume_map(parent_state, routed_resume_value)
|
||||
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": str(chat_id),
|
||||
"request_id": request_id or "unknown",
|
||||
"turn_id": stream_result.turn_id,
|
||||
# Side-channel consumed by ``SurfSenseCheckpointedSubAgentMiddleware``
|
||||
# to forward the resume into a subagent's pending ``interrupt()``.
|
||||
"surfsense_resume_value": {"decisions": decisions},
|
||||
# Per-``tool_call_id`` resume slices read by
|
||||
# ``SurfSenseCheckpointedSubAgentMiddleware``. Parallel
|
||||
# siblings each pop their own entry, so they never race.
|
||||
"surfsense_resume_value": routed_resume_value,
|
||||
},
|
||||
# See ``stream_new_chat`` above for rationale: effectively
|
||||
# uncapped to mirror the agent default and OpenCode's
|
||||
|
|
@ -2960,10 +2966,10 @@ async def stream_resume_chat(
|
|||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=Command(resume={"decisions": decisions}),
|
||||
input_data=Command(resume=lg_resume_map),
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking-resume",
|
||||
step_prefix=_resume_step_prefix(stream_result.turn_id),
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from typing import Any
|
|||
class StreamingResult:
|
||||
accumulated_text: str = ""
|
||||
is_interrupted: bool = False
|
||||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list)
|
||||
agent_called_update_memory: bool = False
|
||||
request_id: str | None = None
|
||||
|
|
|
|||
|
|
@ -1,12 +1,30 @@
|
|||
"""Read the first interrupt payload from a LangGraph state snapshot."""
|
||||
"""Read every pending interrupt payload from a LangGraph state snapshot.
|
||||
|
||||
The chat-stream emit loop yields one ``data-interrupt-request`` SSE frame per
|
||||
pending interrupt so parallel HITL across siblings stays addressable on the
|
||||
wire (the resume slicer in ``checkpointed_subagent_middleware.resume_routing``
|
||||
correlates each frame back to the right paused subagent via the stamped
|
||||
``tool_call_id``). This helper produces that flat, ordered list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def first_interrupt_value(state: Any) -> dict[str, Any] | None:
|
||||
"""Return the first interrupt payload across all snapshot tasks."""
|
||||
def all_interrupt_values(state: Any) -> list[dict[str, Any]]:
|
||||
"""Return every interrupt payload across the snapshot, in traversal order.
|
||||
|
||||
Walks ``state.tasks[*].interrupts`` first (langgraph's per-task buckets,
|
||||
which carry one interrupt per paused subagent) and falls back to
|
||||
``state.interrupts`` when the per-task lists are empty. Order matches the
|
||||
snapshot's iteration order so the emit-time order on the SSE stream agrees
|
||||
with ``collect_pending_tool_calls`` consumption order on resume.
|
||||
|
||||
Defensive against malformed snapshots: tasks/interrupts that raise on
|
||||
attribute access are skipped silently. Non-dict values are skipped — the
|
||||
chat-stream contract requires structured interrupt payloads.
|
||||
"""
|
||||
|
||||
def _extract(candidate: Any) -> dict[str, Any] | None:
|
||||
if isinstance(candidate, dict):
|
||||
|
|
@ -15,33 +33,32 @@ def first_interrupt_value(state: Any) -> dict[str, Any] | None:
|
|||
value = getattr(candidate, "value", None)
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(candidate, list | tuple):
|
||||
for item in candidate:
|
||||
extracted = _extract(item)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
return None
|
||||
|
||||
values: list[dict[str, Any]] = []
|
||||
saw_task_interrupt = False
|
||||
|
||||
for task in getattr(state, "tasks", ()) or ():
|
||||
try:
|
||||
interrupts = getattr(task, "interrupts", ()) or ()
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
interrupts = ()
|
||||
if not interrupts:
|
||||
extracted = _extract(task)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
continue
|
||||
for interrupt_item in interrupts:
|
||||
extracted = _extract(interrupt_item)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
if interrupts:
|
||||
saw_task_interrupt = True
|
||||
for interrupt_item in interrupts:
|
||||
extracted = _extract(interrupt_item)
|
||||
if extracted is not None:
|
||||
values.append(extracted)
|
||||
|
||||
if saw_task_interrupt:
|
||||
return values
|
||||
|
||||
try:
|
||||
state_interrupts = getattr(state, "interrupts", ()) or ()
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
state_interrupts = ()
|
||||
extracted = _extract(state_interrupts)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
return None
|
||||
for interrupt_item in state_interrupts:
|
||||
extracted = _extract(interrupt_item)
|
||||
if extracted is not None:
|
||||
values.append(extracted)
|
||||
return values
|
||||
|
|
|
|||
|
|
@ -123,10 +123,6 @@ async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | No
|
|||
"""Extract content from a non-document file (plaintext/direct_convert/audio/image) via the unified ETL pipeline."""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
from app.etl_pipeline.file_classifier import (
|
||||
FileCategory,
|
||||
classify_file as etl_classify,
|
||||
)
|
||||
|
||||
await _notify(ctx, "parsing", "Processing file")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
|
|
@ -135,8 +131,12 @@ async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | No
|
|||
{"processing_stage": "extracting"},
|
||||
)
|
||||
|
||||
# Fetch the vision LLM whenever the operator opts in. The ETL
|
||||
# pipeline decides what to do with it: image files run through the
|
||||
# vision LLM directly; document files (PDFs) get per-image
|
||||
# descriptions appended via picture_describer.
|
||||
vision_llm = None
|
||||
if ctx.use_vision_llm and etl_classify(ctx.filename) == FileCategory.IMAGE:
|
||||
if ctx.use_vision_llm:
|
||||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(ctx.session, ctx.search_space_id)
|
||||
|
|
@ -230,7 +230,16 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
|
|||
|
||||
await _notify(ctx, "parsing", "Extracting content")
|
||||
|
||||
etl_result = await EtlPipelineService().extract(
|
||||
# Document files (PDF, docx, etc.) get vision LLM treatment too:
|
||||
# the ETL pipeline appends a per-image description section when
|
||||
# vision_llm is provided. See picture_describer.describe_pictures.
|
||||
vision_llm = None
|
||||
if ctx.use_vision_llm:
|
||||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(ctx.session, ctx.search_space_id)
|
||||
|
||||
etl_result = await EtlPipelineService(vision_llm=vision_llm).extract(
|
||||
EtlRequest(
|
||||
file_path=ctx.file_path,
|
||||
filename=ctx.filename,
|
||||
|
|
@ -418,8 +427,12 @@ async def _extract_file_content(
|
|||
billable_pages = estimated_pages * mode.page_multiplier
|
||||
await page_limit_service.check_page_limit(user_id, billable_pages)
|
||||
|
||||
# Vision LLM is provided to the ETL pipeline for any file category
|
||||
# when the operator opts in. Image files run through it directly;
|
||||
# document files (PDFs) get per-image descriptions appended via
|
||||
# picture_describer.
|
||||
vision_llm = None
|
||||
if use_vision_llm and category == FileCategory.IMAGE:
|
||||
if use_vision_llm:
|
||||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(session, search_space_id)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue