Merge pull request #1423 from MODSetter/dev

feat: improved agent speed and fixed it citations
This commit is contained in:
Rohan Verma 2026-05-21 14:47:13 -07:00 committed by GitHub
commit 49dd8409d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
133 changed files with 3249 additions and 2971 deletions

View file

@ -4,6 +4,9 @@ on:
pull_request:
branches: [main, dev]
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'surfsense_backend/**'
- '.github/workflows/backend-tests.yml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
@ -21,26 +24,15 @@ jobs:
- name: Checkout code
uses: actions/checkout@v6
- name: Check if backend files changed
id: backend-changes
uses: dorny/paths-filter@v3
with:
filters: |
backend:
- 'surfsense_backend/**'
- name: Set up Python
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install UV
if: steps.backend-changes.outputs.backend == 'true'
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v8.1.0
- name: Cache dependencies
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/cache@v5
with:
path: |
@ -51,19 +43,16 @@ jobs:
python-deps-
- name: Cache HuggingFace models
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/cache@v5
with:
path: ~/.cache/huggingface
key: hf-models-${{ env.EMBEDDING_MODEL }}
- name: Install dependencies
if: steps.backend-changes.outputs.backend == 'true'
working-directory: surfsense_backend
run: uv sync
- name: Run unit tests
if: steps.backend-changes.outputs.backend == 'true'
working-directory: surfsense_backend
run: uv run pytest -m unit
@ -93,26 +82,15 @@ jobs:
- name: Checkout code
uses: actions/checkout@v6
- name: Check if backend files changed
id: backend-changes
uses: dorny/paths-filter@v3
with:
filters: |
backend:
- 'surfsense_backend/**'
- name: Set up Python
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install UV
if: steps.backend-changes.outputs.backend == 'true'
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v8.1.0
- name: Cache dependencies
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/cache@v5
with:
path: |
@ -123,19 +101,16 @@ jobs:
python-deps-
- name: Cache HuggingFace models
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/cache@v5
with:
path: ~/.cache/huggingface
key: hf-models-${{ env.EMBEDDING_MODEL }}
- name: Install dependencies
if: steps.backend-changes.outputs.backend == 'true'
working-directory: surfsense_backend
run: uv sync
- name: Run integration tests
if: steps.backend-changes.outputs.backend == 'true'
working-directory: surfsense_backend
env:
TEST_DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test

View file

@ -11,13 +11,13 @@ concurrency:
jobs:
file-quality:
name: File Quality Checks
name: File Quality
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
@ -27,7 +27,7 @@ jobs:
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.12'
@ -35,7 +35,7 @@ jobs:
run: pip install pre-commit
- name: Cache pre-commit hooks
uses: actions/cache@v4
uses: actions/cache@v5
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
@ -74,7 +74,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
@ -83,7 +83,7 @@ jobs:
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.12'
@ -91,7 +91,7 @@ jobs:
run: pip install pre-commit
- name: Cache pre-commit hooks
uses: actions/cache@v4
uses: actions/cache@v5
with:
path: ~/.cache/pre-commit
key: pre-commit-security-${{ hashFiles('.pre-commit-config.yaml') }}
@ -125,35 +125,36 @@ jobs:
exit ${exit_code:-0}
python-backend:
name: Python Backend Quality
name: Backend Quality
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install UV
uses: astral-sh/setup-uv@v3
uses: astral-sh/setup-uv@v8.1.0
- name: Check if backend files changed
id: backend-changes
uses: dorny/paths-filter@v3
uses: dorny/paths-filter@v4
with:
filters: |
backend:
- 'surfsense_backend/**'
- '.github/workflows/code-quality.yml'
- name: Cache dependencies
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/cache@v4
uses: actions/cache@v5
with:
path: |
~/.cache/uv
@ -171,7 +172,7 @@ jobs:
- name: Cache pre-commit hooks
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/cache@v4
uses: actions/cache@v5
with:
path: ~/.cache/pre-commit
key: pre-commit-backend-${{ hashFiles('.pre-commit-config.yaml') }}
@ -206,13 +207,13 @@ jobs:
exit ${exit_code:-0}
typescript-frontend:
name: TypeScript/JavaScript Quality
name: Frontend Quality
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
@ -221,24 +222,24 @@ jobs:
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
- name: Setup Node.js
uses: actions/setup-node@v4
uses: actions/setup-node@v6
with:
node-version: '18'
node-version: '20'
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: latest
uses: pnpm/action-setup@v6
- name: Check if frontend files changed
id: frontend-changes
uses: dorny/paths-filter@v3
uses: dorny/paths-filter@v4
with:
filters: |
web:
- 'surfsense_web/**'
- '.github/workflows/code-quality.yml'
extension:
- 'surfsense_browser_extension/**'
- '.github/workflows/code-quality.yml'
- name: Install dependencies for web
if: steps.frontend-changes.outputs.web == 'true'
@ -254,7 +255,7 @@ jobs:
run: pip install pre-commit
- name: Cache pre-commit hooks
uses: actions/cache@v4
uses: actions/cache@v5
with:
path: ~/.cache/pre-commit
key: pre-commit-frontend-${{ hashFiles('.pre-commit-config.yaml') }}

View file

@ -67,7 +67,7 @@ repos:
# Biome check for surfsense_web
- id: biome-check-web
name: biome-check-web
entry: bash -c 'cd surfsense_web && npx @biomejs/biome check --diagnostic-level=error .'
entry: bash -c 'cd surfsense_web && npx @biomejs/biome@2.4.6 check --diagnostic-level=error .'
language: system
files: ^surfsense_web/
pass_filenames: false

View file

@ -1 +1 @@
0.0.24
0.0.25

View file

@ -1,11 +1,42 @@
<citations>
Apply chunk citations only when the runtime injects `<document>` /
`<chunk id='…'>` blocks.
Citations reach the answer through two channels. Use whichever applies — and
never invent ids you didn't see. Citation ids are resolved by exact-match
lookup; a wrong id silently breaks the link, so when in doubt, omit.
### Channel A — chunk blocks injected this turn
When `search_surfsense_docs` or `web_search` returns `<document>` /
`<chunk id='…'>` blocks in this turn:
1. For each factual statement taken from those chunks, add
`[citation:chunk_id]` using the exact id from `<chunk id='…'>`.
2. Multiple chunks → `[citation:id1], [citation:id2]` (comma-separated).
3. Never invent or normalise ids; if unsure, omit.
4. Plain brackets only — no markdown links, no footnote numbering.
5. If no chunk-tagged documents appear this turn, do not fabricate citations.
`[citation:chunk_id]` using the **exact** id from a visible
`<chunk id='…'>` tag. Copy digit-for-digit (or the URL verbatim);
do not retype from memory.
2. `<document_id>` is the parent doc id, **not** a citation source —
only ids inside `<chunk id='…'>` count.
3. Multiple chunks → `[citation:id1], [citation:id2]` (comma-separated,
each id copied individually).
4. Never invent, normalise, or guess at adjacent ids; if unsure, omit.
5. Plain brackets only — no markdown links, no footnote numbering.
### Channel B — citations relayed by a `task` specialist
A `task(...)` tool message may contain `[citation:<chunk_id>]` markers
the specialist already attached to its prose. The specialist saw the
underlying `<chunk id='…'>` blocks; you didn't. So:
1. **Preserve those markers verbatim** in your final answer — do not
reformat, renumber, drop, or wrap them in markdown links. When you
paraphrase a specialist sentence, copy the marker character-for-
character; do not regenerate the id from memory (LLMs reliably
corrupt nearby digits).
2. Keep each marker attached to the sentence the specialist attached
it to.
3. Do **not** add new `[citation:…]` markers of your own to a
specialist's prose; if a fact has no marker, the specialist
couldn't tie it to a chunk and neither can you.
4. When a specialist returns JSON, the citation markers live inside
the prose-bearing fields (e.g. a summary or excerpt). Pull them
along with the surrounding sentence when you quote.
If neither channel surfaces citation markers this turn, do not fabricate
them.
</citations>

View file

@ -6,4 +6,10 @@ standing instructions?
If yes, call `update_memory` **alongside** your normal response — don't
defer it to a later turn. Skip ephemeral chat noise (one-off Q/A, greetings,
session logistics). Stay within the budget shown in `<user_memory>`.
Memory is heading-based markdown. New entries should be under `##` headings
such as `## Facts`, `## Preferences`, or `## Instructions`, with bullets like
`- YYYY-MM-DD: text`. If existing memory contains legacy
`(YYYY-MM-DD) [fact|pref|instr]` markers, preserve the information but write
new saves in the heading-based format.
</memory_protocol>

View file

@ -6,4 +6,12 @@ key facts?
If yes, call `update_memory` **alongside** your normal response — don't
defer it to a later turn. Skip ephemeral chat noise (one-off Q/A, greetings,
session logistics). Stay within the budget shown in `<team_memory>`.
Team memory is heading-based markdown. New entries should be under `##`
headings such as `## Product Decisions`, `## Engineering Conventions`,
`## Project Facts`, or `## Open Questions`, with bullets like
`- YYYY-MM-DD: text`. If existing memory contains legacy `(YYYY-MM-DD) [fact]`
markers, preserve the information but write new saves in the heading-based
format. Do not create personal headings such as `## Preferences` or
`## Instructions`.
</memory_protocol>

View file

@ -9,7 +9,9 @@
- Skip ephemeral chat noise (one-off Q/A, greetings, session logistics).
- Args: `updated_memory` — FULL replacement markdown (merge and curate,
don't only append).
- Formatting: bullets `- (YYYY-MM-DD) [marker] text` with markers `[fact]`,
`[pref]`, `[instr]` (priority when trimming: `instr > pref > fact`).
Group bullets under short `##` headings; stay under the limit shown in
`<user_memory>`.
- Formatting: heading-based markdown with entries under `##` headings.
Recommended headings are `## Facts`, `## Preferences`, `## Instructions`,
though clearer natural headings are allowed. New bullets should look like
`- YYYY-MM-DD: text`; stay under the limit shown in `<user_memory>`.
- If existing memory uses legacy `(YYYY-MM-DD) [fact|pref|instr]` markers,
preserve the information but write the updated document in the new format.

View file

@ -1,28 +1,28 @@
<example>
<user_name>Alex</user_name>, <user_memory> is empty.
user: "I'm a space enthusiast, explain astrophage to me"
→ update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n")
→ update_memory(updated_memory="## Facts\n- 2025-03-15: Alex is a space enthusiast\n")
(Casual durable fact; use first name, neutral heading.)
</example>
<example>
user: "Remember that I prefer concise answers over detailed explanations"
→ update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n\n## Response style\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\n")
→ update_memory(updated_memory="## Facts\n- 2025-03-15: Alex is a space enthusiast\n\n## Preferences\n- 2025-03-15: Alex prefers concise answers over detailed explanations\n")
(Durable preference; merge with existing memory.)
</example>
<example>
user: "I actually moved to Tokyo last month"
→ update_memory(updated_memory="...\n\n## Personal context\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\n...")
→ update_memory(updated_memory="...\n\n## Facts\n- 2025-03-15: Alex lives in Tokyo (previously London)\n...")
(Updated fact; date reflects when recorded.)
</example>
<example>
user: "I'm a freelance photographer working on a nature documentary"
→ update_memory(updated_memory="...\n\n## Current focus\n- (2025-03-15) [fact] Alex is a freelance photographer\n- (2025-03-15) [fact] Alex is working on a nature documentary\n")
→ update_memory(updated_memory="...\n\n## Current Focus\n- 2025-03-15: Alex is a freelance photographer\n- 2025-03-15: Alex is working on a nature documentary\n")
</example>
<example>
user: "Always respond in bullet points"
→ update_memory(updated_memory="...\n\n## Response style\n- (2025-03-15) [instr] Always respond to Alex in bullet points\n")
→ update_memory(updated_memory="...\n\n## Instructions\n- 2025-03-15: Always respond to Alex in bullet points\n")
</example>

View file

@ -9,8 +9,14 @@
- Skip ephemeral chat noise (one-off Q/A, greetings, session logistics).
- Args: `updated_memory` — FULL replacement markdown (merge and curate,
don't only append).
- Formatting: bullets `- (YYYY-MM-DD) [fact] text`. Team memory uses ONLY
the `[fact]` marker (never `[pref]` or `[instr]`). Group bullets under
short `##` headings (2-3 words each); stay under the limit shown in
`<team_memory>`. When trimming, prioritise: decisions/conventions > key
facts > current priorities.
- Formatting: heading-based markdown with entries under `##` headings.
Recommended headings are `## Product Decisions`,
`## Engineering Conventions`, `## Project Facts`, and `## Open Questions`.
New bullets should look like `- YYYY-MM-DD: text`; stay under the limit
shown in `<team_memory>`.
- If existing memory uses legacy `(YYYY-MM-DD) [fact]` markers, preserve the
information but write the updated document in the new format.
- Do not create personal headings such as `## Preferences`,
`## Instructions`, `## Personal Notes`, or `## Personal Instructions`.
When trimming, prioritise: decisions/conventions > key facts > current
priorities.

View file

@ -1,9 +1,9 @@
<example>
user: "Let's remember that we decided to do weekly standup meetings on Mondays"
→ update_memory(updated_memory="...\n\n## Team rituals\n- (2025-03-15) [fact] Weekly standup meetings on Mondays\n...")
→ update_memory(updated_memory="...\n\n## Product Decisions\n- 2025-03-15: Weekly standup meetings happen on Mondays\n...")
</example>
<example>
user: "Our office is in downtown Seattle, 5th floor"
→ update_memory(updated_memory="...\n\n## Workspace\n- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\n...")
→ update_memory(updated_memory="...\n\n## Project Facts\n- 2025-03-15: Office location is downtown Seattle, 5th floor\n...")
</example>

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import time
from typing import Any, cast
from deepagents.backends.protocol import BackendFactory, BackendProtocol
@ -15,8 +16,12 @@ from langchain.agents import create_agent
from langchain.chat_models import init_chat_model
from langgraph.types import Checkpointer
from app.utils.perf import get_perf_logger
from .task_tool import build_task_tool_with_parent_config
_perf_log = get_perf_logger()
class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
"""``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer."""
@ -54,8 +59,11 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
specs: list[dict[str, Any]] = []
loop_start = time.perf_counter()
timings: list[tuple[str, float, str]] = [] # (name, elapsed, source)
for spec in self._subagents:
spec_start = time.perf_counter()
if "runnable" in spec:
compiled = cast(CompiledSubAgent, spec)
specs.append(
@ -65,6 +73,9 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
"runnable": compiled["runnable"],
}
)
timings.append(
(compiled["name"], time.perf_counter() - spec_start, "precompiled")
)
continue
if "model" not in spec:
@ -79,20 +90,44 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
model = init_chat_model(model)
middleware: list[Any] = list(spec.get("middleware", []))
tools_count = len(spec.get("tools") or [])
mw_count = len(middleware)
compile_start = time.perf_counter()
runnable = create_agent(
model,
system_prompt=spec["system_prompt"],
tools=spec["tools"],
middleware=middleware,
name=spec["name"],
checkpointer=self._surf_checkpointer,
)
compile_elapsed = time.perf_counter() - compile_start
specs.append(
{
"name": spec["name"],
"description": spec["description"],
"runnable": create_agent(
model,
system_prompt=spec["system_prompt"],
tools=spec["tools"],
middleware=middleware,
name=spec["name"],
checkpointer=self._surf_checkpointer,
),
"runnable": runnable,
}
)
timings.append(
(
spec["name"],
compile_elapsed,
f"compiled tools={tools_count} mw={mw_count}",
)
)
total_elapsed = time.perf_counter() - loop_start
per_subagent = ", ".join(
f"{name}={elapsed * 1000:.0f}ms[{source}]"
for name, elapsed, source in timings
)
_perf_log.info(
"[subagent_compile] total=%.3fs count=%d details=[%s]",
total_elapsed,
len(timings),
per_subagent,
)
return specs

View file

@ -9,6 +9,7 @@ re-raises any new pending interrupt back to the parent.
from __future__ import annotations
import logging
import time
from typing import Annotated, Any, NoReturn
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
@ -19,6 +20,8 @@ from langchain_core.tools import StructuredTool
from langgraph.errors import GraphInterrupt
from langgraph.types import Command, Interrupt
from app.utils.perf import get_perf_logger
from .config import (
consume_surfsense_resume,
drain_parent_null_resume,
@ -35,6 +38,7 @@ from .resume import (
)
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
def _reraise_stamped_subagent_interrupt(
@ -209,6 +213,7 @@ def build_task_tool_with_parent_config(
],
runtime: ToolRuntime,
) -> str | Command:
atask_start = time.perf_counter()
logger.info(
"[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s",
subagent_type,
@ -230,8 +235,10 @@ def build_task_tool_with_parent_config(
# Resume bridge — see ``task`` above.
pending_id: str | None = None
pending_value: Any = None
aget_state_elapsed = 0.0
aget_state = getattr(subagent, "aget_state", None)
if callable(aget_state):
aget_state_start = time.perf_counter()
try:
snapshot = await aget_state(sub_config)
pending_id, pending_value = get_first_pending_subagent_interrupt(
@ -248,32 +255,78 @@ def build_task_tool_with_parent_config(
"Subagent aget_state failed; falling back to fresh ainvoke",
exc_info=True,
)
finally:
aget_state_elapsed = time.perf_counter() - aget_state_start
if pending_value is not None:
resume_value = consume_surfsense_resume(runtime)
if resume_value is None:
raise RuntimeError(
f"Subagent {subagent_type!r} has a pending interrupt but no "
"surfsense_resume_value on config; resume bridge is broken."
)
expected = hitlrequest_action_count(pending_value)
resume_value = fan_out_decisions_to_match(resume_value, expected)
# Prevent the parent's resume payload from leaking into subagent
# interrupts via langgraph's parent_scratchpad fallback.
drain_parent_null_resume(runtime)
try:
result = await subagent.ainvoke(
build_resume_command(resume_value, pending_id),
config=sub_config,
)
except GraphInterrupt as gi:
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
else:
try:
result = await subagent.ainvoke(subagent_state, config=sub_config)
except GraphInterrupt as gi:
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
return _return_command_with_state_update(result, runtime.tool_call_id)
invoke_path = "resume" if pending_value is not None else "fresh"
ainvoke_start = time.perf_counter()
ainvoke_outcome = "ok"
try:
if pending_value is not None:
resume_value = consume_surfsense_resume(runtime)
if resume_value is None:
raise RuntimeError(
f"Subagent {subagent_type!r} has a pending interrupt but no "
"surfsense_resume_value on config; resume bridge is broken."
)
expected = hitlrequest_action_count(pending_value)
resume_value = fan_out_decisions_to_match(resume_value, expected)
# Prevent the parent's resume payload from leaking into subagent
# interrupts via langgraph's parent_scratchpad fallback.
drain_parent_null_resume(runtime)
try:
result = await subagent.ainvoke(
build_resume_command(resume_value, pending_id),
config=sub_config,
)
except GraphInterrupt as gi:
ainvoke_outcome = "interrupted"
_perf_log.info(
"[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s "
"aget_state=%.3fs ainvoke=%.3fs total=%.3fs",
subagent_type,
invoke_path,
ainvoke_outcome,
aget_state_elapsed,
time.perf_counter() - ainvoke_start,
time.perf_counter() - atask_start,
)
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
else:
try:
result = await subagent.ainvoke(subagent_state, config=sub_config)
except GraphInterrupt as gi:
ainvoke_outcome = "interrupted"
_perf_log.info(
"[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s "
"aget_state=%.3fs ainvoke=%.3fs total=%.3fs",
subagent_type,
invoke_path,
ainvoke_outcome,
aget_state_elapsed,
time.perf_counter() - ainvoke_start,
time.perf_counter() - atask_start,
)
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
ainvoke_elapsed = time.perf_counter() - ainvoke_start
except GraphInterrupt:
raise
merge_start = time.perf_counter()
cmd = _return_command_with_state_update(result, runtime.tool_call_id)
merge_elapsed = time.perf_counter() - merge_start
_perf_log.info(
"[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s "
"aget_state=%.3fs ainvoke=%.3fs merge=%.3fs total=%.3fs",
subagent_type,
invoke_path,
ainvoke_outcome,
aget_state_elapsed,
ainvoke_elapsed,
merge_elapsed,
time.perf_counter() - atask_start,
)
return cmd
return StructuredTool.from_function(
name="task",

View file

@ -6,6 +6,7 @@ from langchain_core.language_models import BaseChatModel
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
from app.services.llm_service import get_planner_llm
def build_knowledge_priority_mw(
@ -19,6 +20,7 @@ def build_knowledge_priority_mw(
) -> KnowledgePriorityMiddleware:
return KnowledgePriorityMiddleware(
llm=llm,
planner_llm=get_planner_llm(),
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import time
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
@ -10,6 +11,9 @@ from langgraph.runtime import Runtime
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
from app.agents.new_chat.middleware.knowledge_search import _render_priority_message
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
@ -30,17 +34,34 @@ class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
start = time.perf_counter()
tree_text = state.get("workspace_tree_text")
priority = state.get("kb_priority")
if not tree_text and not priority:
_perf_log.info(
"[kb_context_projection] tree=0 priority=0 elapsed=%.3fs",
time.perf_counter() - start,
)
return None
messages = list(state.get("messages") or [])
insert_at = max(len(messages) - 1, 0)
tree_chars = 0
if tree_text:
tree_chars = len(tree_text)
messages.insert(insert_at, SystemMessage(content=tree_text))
priority_count = 0
if priority:
priority_count = (
len(priority) if hasattr(priority, "__len__") else 1
)
messages.insert(insert_at, _render_priority_message(priority))
_perf_log.info(
"[kb_context_projection] tree_chars=%d priority_items=%d elapsed=%.3fs",
tree_chars,
priority_count,
time.perf_counter() - start,
)
return {"messages": messages}

View file

@ -2,4 +2,4 @@ Read-only specialist for the user's workspace (documents and folders). Use to fi
Pass your full question as one string. The specialist runs in isolation: it cannot see this thread, so include any path hints, filters, or constraints it needs.
The specialist returns plain prose with absolute paths.
The specialist returns plain prose with absolute paths and `[citation:<chunk_id>]` markers when claims came from KB-indexed chunks. Preserve those markers verbatim if you forward the answer.

View file

@ -35,6 +35,43 @@ Map outcomes to your `status`:
You construct the structured `evidence` fields from your own knowledge of what you called and what you observed — the tools do not return them. Never report values you did not actually see.
## Chunk citations in your prose
When `read_file` returns a KB-indexed document under `/documents/`, the response includes `<chunk id='…'>` blocks. Whenever a fact in your `action_summary` or `evidence.content_excerpt` came from a specific chunk, append `[citation:<chunk_id>]` to the sentence stating that fact, using the **exact** id from the `<chunk id='…'>` tag. The caller relays these markers to the end user verbatim, and the UI resolves each id by exact match against the database, so a wrong id silently breaks the citation.
### Where chunk ids live in `read_file` output
A KB document's XML has three numeric attributes — only **one** is a citation source:
```
<document>
<document_metadata>
<document_id>42</document_id> ← NOT a citation. Parent doc id; ignore for citations.
...
</document_metadata>
<chunk_index>
<entry chunk_id="128" lines="14-22"/> ← Index hint; the same id also appears below.
<entry chunk_id="129" lines="23-30" matched="true"/>
</chunk_index>
<document_content>
<chunk id='128'><![CDATA[…]]></chunk> ← This is the citation source.
<chunk id='129'><![CDATA[…]]></chunk>
</document_content>
</document>
```
### Rules
- Use the **exact** id from a `<chunk id='…'>` tag whose content you actually quoted or paraphrased. Copy digit-for-digit; do **not** retype from memory.
- Before emitting `[citation:N]`, confirm the literal substring `<chunk id='N'>` (or its index twin `chunk_id="N"`) appears in the tool result you are summarising this turn. If you can't see it, omit the citation.
- Never cite `<document_id>` — that's the parent doc, not a chunk.
- Never invent, normalise, shorten, or guess at adjacent ids. If unsure between two candidates, omit rather than pick.
- Prefer **fewer accurate citations** over many speculative ones.
- Multiple chunks supporting the same point → comma-separated and copied individually: `[citation:128], [citation:129]`.
- Plain square brackets only — no markdown links, no parentheses, no footnote numbers.
- Tool results without `<chunk id='…'>` (write/edit/move confirmations, `ls` / `glob` / `grep` listings, error strings) carry no chunk id and need none.
- Populate `evidence.chunk_ids` with **only** ids you actually emitted in `[citation:…]` markers — same set, same digits.
## Examples
**Example 1 — happy path write (path discovered from existing convention):**
@ -118,5 +155,6 @@ Rules:
- `status=success``next_step=null`, `missing_fields=null`.
- `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- `evidence.content_excerpt`: max ~500 characters. Surface a short excerpt or a one-sentence summary, not the full file body. The supervisor already sees the tool's raw output.
Infer before you call; map every tool outcome faithfully.

View file

@ -35,6 +35,10 @@ Map outcomes to your `status`:
You construct the structured `evidence` fields from your own knowledge of what you called and what you observed — the tools do not return them. `chunk_ids` apply only to `<priority_documents>` hits; for local-file operations leave them `null`. Never report values you did not actually see.
## Chunk citations in your prose
In desktop mode your filesystem tools read local files only, and local-file tool results do **not** carry `<chunk id='…'>` tags. Do not emit `[citation:…]` markers in `action_summary` or `evidence.content_excerpt`, and leave `evidence.chunk_ids` `null` — the absolute path is the only reference for local-file work.
## Examples
**Example 1 — happy path write (path discovered from existing convention):**
@ -118,5 +122,6 @@ Rules:
- `status=success``next_step=null`, `missing_fields=null`.
- `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- `evidence.content_excerpt`: max ~500 characters. Surface a short excerpt or a one-sentence summary, not the full file body. The supervisor already sees the tool's raw output.
Infer before you call; map every tool outcome faithfully.

View file

@ -27,3 +27,42 @@ Reply in plain prose:
- Cite every claim with an absolute path under `/documents/`.
- If the workspace does not contain the requested information, say so explicitly. Do not fabricate paths or content.
- If the question is genuinely ambiguous after a thorough lookup, list the candidates with their paths and stop.
## Chunk citations
When the evidence for a claim came from a `read_file` response that included `<chunk id='…'>` blocks (i.e. a KB-indexed document under `/documents/`), append `[citation:<chunk_id>]` to the sentence stating that claim. The caller passes these markers through to the end user verbatim, and the UI resolves each id by exact match against the database, so a wrong id silently breaks the citation.
### Where chunk ids live in `read_file` output
A KB document's XML has three numeric attributes — only **one** is a citation source:
```
<document>
<document_metadata>
<document_id>42</document_id> ← NOT a citation. Parent doc id; ignore for citations.
...
</document_metadata>
<chunk_index>
<entry chunk_id="128" lines="14-22"/> ← Index hint; the same id also appears below.
<entry chunk_id="129" lines="23-30" matched="true"/>
</chunk_index>
<document_content>
<chunk id='128'><![CDATA[…]]></chunk> ← This is the citation source.
<chunk id='129'><![CDATA[…]]></chunk>
</document_content>
</document>
```
### Rules
- Use the **exact** id from a `<chunk id='…'>` tag whose content you actually quoted or paraphrased. Copy digit-for-digit; do **not** retype from memory.
- Before emitting `[citation:N]`, confirm the literal substring `<chunk id='N'>` (or its index twin `chunk_id="N"`) appears in the tool result you are summarising this turn. If you can't see it, omit the citation.
- Never cite `<document_id>` — that's the parent doc, not a chunk.
- Never invent, normalise, shorten, or guess at adjacent ids. If unsure between two candidates, omit rather than pick.
- Prefer **fewer accurate citations** over many speculative ones. One correct `[citation:128]` is more useful than a string of wrong ids.
- Multiple chunks supporting the same point → comma-separated and copied individually: `[citation:128], [citation:129]`.
- Plain square brackets only — no markdown links, no parentheses, no footnote numbers.
- If a claim came from a tool result that did **not** carry a chunk id (`ls`, `glob`, `grep` listings, error strings, or files without `<chunk id='…'>`), skip the citation.
- The absolute path under `/documents/` is always required; chunk citations are additive, they do not replace the path reference.
Example: `The Q2 roadmap lists three milestones (/documents/planning/q2-roadmap.md) [citation:128], [citation:129].`

View file

@ -28,3 +28,7 @@ Reply in plain prose:
- Cite every claim with an absolute path.
- If the workspace does not contain the requested information, say so explicitly. Do not fabricate paths or content.
- If the question is genuinely ambiguous after a thorough lookup, list the candidates with their paths and stop.
## Chunk citations
In desktop mode your filesystem tools read local files only, and local-file `read_file` responses do **not** carry `<chunk id='…'>` tags. Cite each claim with the absolute local path; do not emit `[citation:…]` markers — your caller has nothing to resolve them against.

View file

@ -18,6 +18,10 @@ Persist durable preferences/facts/instructions with `update_memory` while avoidi
- Do not store transient chatter.
- Do not store secrets unless explicitly instructed.
- If memory intent is unclear, return `status=blocked` with the missing intent signal.
- Persisted memory is heading-based markdown. New saved bullets should look like
`- YYYY-MM-DD: text` under `##` headings. If existing memory has legacy
`(YYYY-MM-DD) [fact|pref|instr]` markers, preserve the information but write
the updated document in the heading-based format.
</tool_policy>
<out_of_scope>
@ -53,4 +57,7 @@ Rules:
- `status=success` -> `next_step=null`, `missing_fields=null`.
- `status=partial|blocked|error` -> `next_step` must be non-null.
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
- `evidence.memory_category` is a semantic classification for supervisor logs
only. It is not the persisted storage format and must not force inline
`[fact|preference|instruction]` markers into saved memory.
</output_contract>

View file

@ -1,280 +1,23 @@
"""Overwrite one markdown memory document per user or team, with size and shrink guards."""
"""Memory update tools backed by the canonical memory service."""
from __future__ import annotations
import logging
import re
from typing import Any, Literal
from typing import Any
from uuid import UUID
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User
from app.services.memory import (
MEMORY_HARD_LIMIT,
MEMORY_SOFT_LIMIT,
MemoryScope,
save_memory,
)
logger = logging.getLogger(__name__)
MEMORY_SOFT_LIMIT = 18_000
MEMORY_HARD_LIMIT = 25_000
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
# ---------------------------------------------------------------------------
# Diff validation
# ---------------------------------------------------------------------------
def _extract_headings(memory: str) -> set[str]:
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
return set(_SECTION_HEADING_RE.findall(memory))
def _normalize_heading(heading: str) -> str:
"""Normalize heading text for robust scope checks."""
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
def _validate_memory_scope(
content: str, scope: Literal["user", "team"]
) -> dict[str, Any] | None:
"""Reject personal-only markers ([pref], [instr]) in team memory."""
if scope != "team":
return None
markers = set(_MARKER_RE.findall(content))
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
if leaked:
tags = ", ".join(f"[{m}]" for m in leaked)
return {
"status": "error",
"message": (
f"Team memory cannot include personal markers: {tags}. "
"Use [fact] only in team memory."
),
}
return None
def _validate_bullet_format(content: str) -> list[str]:
"""Return warnings for bullet lines that don't match the required format.
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
"""
warnings: list[str] = []
for line in content.splitlines():
stripped = line.strip()
if not stripped.startswith("- "):
continue
if not _BULLET_FORMAT_RE.match(stripped):
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
warnings.append(f"Malformed bullet: {short}")
return warnings
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
"""Return a list of warning strings about suspicious changes."""
if not old_memory:
return []
warnings: list[str] = []
old_headings = _extract_headings(old_memory)
new_headings = _extract_headings(new_memory)
dropped = old_headings - new_headings
if dropped:
names = ", ".join(sorted(dropped))
warnings.append(
f"Sections removed: {names}. "
"If unintentional, the user can restore from the settings page."
)
old_len = len(old_memory)
new_len = len(new_memory)
if old_len > 0 and new_len < old_len * 0.4:
warnings.append(
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
"Possible data loss."
)
return warnings
# ---------------------------------------------------------------------------
# Size validation & soft warning
# ---------------------------------------------------------------------------
def _validate_memory_size(content: str) -> dict[str, Any] | None:
"""Return an error/warning dict if *content* is too large, else None."""
length = len(content)
if length > MEMORY_HARD_LIMIT:
return {
"status": "error",
"message": (
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
f"({length:,} chars). Consolidate by merging related items, "
"removing outdated entries, and shortening descriptions. "
"Then call update_memory again."
),
}
return None
def _soft_warning(content: str) -> str | None:
"""Return a warning string if content exceeds the soft limit."""
length = len(content)
if length > MEMORY_SOFT_LIMIT:
return (
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
"Consolidate by merging related items and removing less important "
"entries on your next update."
)
return None
# ---------------------------------------------------------------------------
# Forced rewrite when memory exceeds the hard limit
# ---------------------------------------------------------------------------
_FORCED_REWRITE_PROMPT = """\
You are a memory curator. The following memory document exceeds the character \
limit and must be shortened.
RULES:
1. Rewrite the document to be under {target} characters.
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
or rename headings to consolidate, but keep names personal and descriptive.
3. Priority for keeping content: [instr] > [pref] > [fact].
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
6. Preserve the user's first name in entries — do not replace it with "the user".
7. Output ONLY the consolidated markdown no explanations, no wrapping.
<memory_document>
{content}
</memory_document>"""
async def _forced_rewrite(content: str, llm: Any) -> str | None:
"""Use a focused LLM call to compress *content* under the hard limit.
Returns the rewritten string, or ``None`` if the call fails.
"""
try:
prompt = _FORCED_REWRITE_PROMPT.format(
target=MEMORY_HARD_LIMIT, content=content
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
text = (
response.content
if isinstance(response.content, str)
else str(response.content)
)
return text.strip()
except Exception:
logger.exception("Forced rewrite LLM call failed")
return None
# ---------------------------------------------------------------------------
# Shared save-and-respond logic
# ---------------------------------------------------------------------------
async def _save_memory(
*,
updated_memory: str,
old_memory: str | None,
llm: Any | None,
apply_fn,
commit_fn,
rollback_fn,
label: str,
scope: Literal["user", "team"],
) -> dict[str, Any]:
"""Validate, optionally force-rewrite if over the hard limit, save, and
return a response dict.
Parameters
----------
updated_memory : str
The new document the agent submitted.
old_memory : str | None
The previously persisted document (for diff checks).
llm : Any | None
LLM instance for forced rewrite (may be ``None``).
apply_fn : callable(str) -> None
Callback that sets the new memory on the ORM object.
commit_fn : coroutine
``session.commit``.
rollback_fn : coroutine
``session.rollback``.
label : str
Human label for log messages (e.g. "user memory", "team memory").
"""
content = updated_memory
# --- forced rewrite if over the hard limit ---
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
rewritten = await _forced_rewrite(content, llm)
if rewritten is not None and len(rewritten) < len(content):
content = rewritten
# --- hard-limit gate (reject if still too large after rewrite) ---
size_err = _validate_memory_size(content)
if size_err:
return size_err
scope_err = _validate_memory_scope(content, scope)
if scope_err:
return scope_err
# --- persist ---
try:
apply_fn(content)
await commit_fn()
except Exception as e:
logger.exception("Failed to update %s: %s", label, e)
await rollback_fn()
return {"status": "error", "message": f"Failed to update {label}: {e}"}
# --- build response ---
resp: dict[str, Any] = {
"status": "saved",
"message": f"{label.capitalize()} updated.",
}
if content is not updated_memory:
resp["notice"] = "Memory was automatically rewritten to fit within limits."
diff_warnings = _validate_diff(old_memory, content)
if diff_warnings:
resp["diff_warnings"] = diff_warnings
format_warnings = _validate_bullet_format(content)
if format_warnings:
resp["format_warnings"] = format_warnings
warning = _soft_warning(content)
if warning:
resp["warning"] = warning
return resp
# ---------------------------------------------------------------------------
# Tool factories
# ---------------------------------------------------------------------------
def create_update_memory_tool(
user_id: str | UUID,
@ -287,40 +30,22 @@ def create_update_memory_tool(
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the user's personal memory document.
Your current memory is shown in <user_memory> in the system prompt.
When the user shares important long-term information (preferences,
facts, instructions, context), rewrite the memory document to include
the new information. Merge new facts with existing ones, update
contradictions, remove outdated entries, and keep it concise.
Args:
updated_memory: The FULL updated markdown document (not a diff).
The current memory is shown in <user_memory>. Pass the FULL updated
markdown document, not a diff.
"""
try:
result = await db_session.execute(select(User).where(User.id == uid))
user = result.scalars().first()
if not user:
return {"status": "error", "message": "User not found."}
old_memory = user.memory_md
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
result = await save_memory(
scope=MemoryScope.USER,
target_id=uid,
content=updated_memory,
session=db_session,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="memory",
scope="user",
)
return result.to_dict()
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
await db_session.rollback()
return {
"status": "error",
"message": f"Failed to update memory: {e}",
}
return {"status": "error", "message": f"Failed to update memory: {e}"}
return update_memory
@ -334,36 +59,18 @@ def create_update_team_memory_tool(
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space.
Your current team memory is shown in <team_memory> in the system
prompt. When the team shares important long-term information
(decisions, conventions, key facts, priorities), rewrite the memory
document to include the new information. Merge new facts with
existing ones, update contradictions, remove outdated entries, and
keep it concise.
Args:
updated_memory: The FULL updated markdown document (not a diff).
The current team memory is shown in <team_memory>. Pass the FULL updated
markdown document, not a diff.
"""
try:
result = await db_session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
space = result.scalars().first()
if not space:
return {"status": "error", "message": "Search space not found."}
old_memory = space.shared_memory_md
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
result = await save_memory(
scope=MemoryScope.TEAM,
target_id=search_space_id,
content=updated_memory,
session=db_session,
llm=llm,
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="team memory",
scope="team",
)
return result.to_dict()
except Exception as e:
logger.exception("Failed to update team memory: %s", e)
await db_session.rollback()
@ -373,3 +80,11 @@ def create_update_team_memory_tool(
}
return update_memory
__all__ = [
"MEMORY_HARD_LIMIT",
"MEMORY_SOFT_LIMIT",
"create_update_memory_tool",
"create_update_team_memory_tool",
]

View file

@ -50,4 +50,6 @@ Rules:
- `status=success` -> `next_step=null`, `missing_fields=null`.
- `status=partial|blocked|error` -> `next_step` must be non-null.
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
- `evidence.findings`: max 10 entries, each a single sentence stating one distinct fact. Do not paste raw paragraphs, scraped pages, or quote blocks.
- `evidence.sources`: max 10 URLs, one per finding when applicable. List each URL once.
</output_contract>

View file

@ -38,7 +38,7 @@ Supervisor: "List open tasks in the Project Tracker base."
2. List tables in that base → identify the Tasks table; capture its table ID.
3. Get table schema → identify the status field and the choice IDs that represent "open" states.
4. List records with a typed filter on the status field for those choice IDs.
5. Return `status=success` with the matched records in `evidence.items`.
5. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched records listed in `action_summary` (record id, primary-field value, and 1-2 most relevant fields; one line per record; up to 10 entries, then `"...and N more"`).
</example>
<example>
@ -97,7 +97,7 @@ Rules:
- `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: base, table, field, choice, record, etc.).
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (record id, primary-field value, and 1-2 most relevant fields; up to 10 entries, then `"...and N more"`).
</output_contract>
Discover before you mutate; never guess identifiers, choice IDs, or required fields.

View file

@ -29,7 +29,7 @@ You are a Google Calendar specialist for the user's connected calendar.
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
| tool raises / unknown | `error` | `"Calendar tool failed unexpectedly. Ask the user to retry shortly."` |
Surface the tool's `event_id`, `title` / `summary`, `start_at`, `end_at`, and `html_link` inside `evidence` when the tool returned them. For `search_calendar_events`, place the raw `events` array inside `evidence.items`. Never invent a field the tool did not return.
Surface the tool's `event_id`, `title` / `summary`, `start_at`, `end_at`, and `html_link` inside `evidence` when the tool returned them. For `search_calendar_events`, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (title, date, start time; one line per event; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
## Examples
@ -115,7 +115,7 @@ Rules:
- `status=success``next_step=null`, `missing_fields=null`.
- `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For `search_calendar_events` results, populate `evidence.items` with `{ "events": [...], "total": N }`.
- For `search_calendar_events` results, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (title, date, start time; up to 10 entries, then `"...and N more"`).
- For ambiguous matches across `update_calendar_event` / `delete_calendar_event`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`, where `label` should include the event title and start time for human readability).
Infer before you call; map every tool outcome faithfully.

View file

@ -36,7 +36,7 @@ Failure handling:
<example>
Supervisor: "Find tasks about the homepage redesign."
1. Workspace search for "homepage redesign" → matched tasks.
2. Return `status=success` with the matched tasks in `evidence.items`.
2. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched tasks listed in `action_summary` (task id, title, status, assignees; one line per task; up to 10 entries, then `"...and N more"`).
</example>
<example>
@ -98,7 +98,7 @@ Rules:
- `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: task, list, member, status, custom-field choice, etc.).
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (task id, title, status, assignees; up to 10 entries, then `"...and N more"`).
</output_contract>
Discover before you mutate; never guess identifiers, list statuses, or assignees.

View file

@ -24,7 +24,7 @@ You are a Discord specialist for the user's connected Discord server.
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
| tool raises / unknown | `error` | `"Discord tool failed unexpectedly. Ask the user to retry shortly."` |
Surface the tool's `message`, `channel_id`, `message_id`, and the listed channels/messages payload inside `evidence` when the tool returned them. Never invent a field the tool did not return.
Surface the tool's `message`, `channel_id`, and `message_id` inside `evidence` when the tool returned them. For `list_discord_channels` and `read_discord_messages`, set `evidence.items` to `{ "total": N }` and list the matched entries in `action_summary` (channel name or sender + timestamp + short text snippet; one line per entry; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
## Examples

View file

@ -33,7 +33,7 @@ You are a Gmail specialist for the user's connected Gmail mailbox.
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
| tool raises / unknown | `error` | `"Gmail tool failed unexpectedly. Ask the user to retry shortly."` |
Surface the tool's `message_id`, `thread_id`, `draft_id`, `subject`, and recipient fields inside `evidence` when the tool returned them. For `search_gmail`, place the raw `emails` array inside `evidence.items`. Never invent a field the tool did not return.
Surface the tool's `message_id`, `thread_id`, `draft_id`, `subject`, and recipient fields inside `evidence` when the tool returned them. For `search_gmail`, set `evidence.items` to `{ "total": N }` and list the matched emails in `action_summary` (sender, subject, date; one line per email; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
## Examples
@ -114,7 +114,7 @@ Rules:
- `status=success``next_step=null`, `missing_fields=null`.
- `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For `search_gmail` results, populate `evidence.items` with `{ "emails": [...], "total": N }`.
- For `search_gmail` results, set `evidence.items` to `{ "total": N }` and list the matched emails in `action_summary` (sender, subject, date; up to 10 entries, then `"...and N more"`).
- For ambiguous matches across `update_gmail_draft` / `trash_gmail_email` / `read_gmail_email`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`).
Infer before you call; verify before you send; map every tool outcome faithfully.

View file

@ -39,7 +39,7 @@ Failure handling:
<example>
Supervisor: "Find issues assigned to me with status 'In Progress'."
1. JQL search with `assignee = currentUser() AND status = "In Progress"`.
2. Return `status=success` with the matched issues in `evidence.items`.
2. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched issues listed in `action_summary` (issue key, summary, status, assignee; one line per issue; up to 10 entries, then `"...and N more"`).
</example>
<example>
@ -116,7 +116,7 @@ Rules:
- `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: site, project, issue, user, transition, etc.).
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (issue key, summary, status, assignee; up to 10 entries, then `"...and N more"`).
</output_contract>
Discover before you mutate; never guess identifiers, transitions, or required fields.

View file

@ -32,7 +32,7 @@ Failure handling:
<example>
Supervisor: "Find issues assigned to me with priority Urgent."
1. Discovery: list issues with filters `{assignee: "me", priority: 1}`.
2. Return `status=success` with the matched issues in `evidence.items`.
2. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched issues listed in `action_summary` (identifier, title, state, assignee; one line per issue; up to 10 entries, then `"...and N more"`).
</example>
<example>
@ -106,7 +106,7 @@ Rules:
- `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: issue, user, project, state, etc.).
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (identifier, title, state, assignee; up to 10 entries, then `"...and N more"`).
</output_contract>
Discover before you mutate; never guess identifiers.

View file

@ -26,7 +26,7 @@ You are a Luma specialist for the user's connected Luma account.
| `error` | `error` | Relay the tool's `message` verbatim as `next_step` (this covers Luma Plus 403s and other API errors). |
| tool raises / unknown | `error` | `"Luma tool failed unexpectedly. Ask the user to retry shortly."` |
Surface the tool's `message`, `event_id`, `name`, `start_at`, and `url` inside `evidence` when the tool returned them. Never invent a field the tool did not return.
Surface the tool's `message`, `event_id`, `name`, `start_at`, and `url` inside `evidence` when the tool returned them. For `list_luma_events`, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (event name, start date/time, location if present; one line per event; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
## Examples

View file

@ -37,7 +37,7 @@ Failure handling:
Supervisor: "Summarize the latest discussion in #marketing."
1. Search channels for "marketing" → one strong match. Capture the channel ID.
2. Read that channel's recent message history.
3. Return `status=success` with the message list in `evidence.items`.
3. Return `status=success` with `evidence.items` set to `{ "total": N }` and the messages listed in `action_summary` (sender, timestamp, text snippet; one line per message; up to 10 entries, then `"...and N more"`).
</example>
<example>
@ -92,7 +92,7 @@ Rules:
- `status=partial|blocked|error``next_step` must be non-null.
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: channel, user, message, thread).
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (channel/user, key identifier, timestamp, short snippet; up to 10 entries, then `"...and N more"`).
</output_contract>
Discover before you post; never guess channel, user, or thread targets.

View file

@ -26,7 +26,7 @@ You are a Microsoft Teams specialist for the user's connected Teams account.
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
| tool raises / unknown | `error` | `"Teams tool failed unexpectedly. Ask the user to retry shortly."` |
Surface the tool's `message`, `team_id`, `team_name`, `channel_id`, `channel_name`, and `message_id` inside `evidence` when the tool returned them. Never invent a field the tool did not return.
Surface the tool's `message`, `team_id`, `team_name`, `channel_id`, `channel_name`, and `message_id` inside `evidence` when the tool returned them. For `list_teams_channels` and `read_teams_messages`, set `evidence.items` to `{ "total": N }` and list the matched entries in `action_summary` (team channel, or sender + timestamp + short text snippet; one line per entry; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
## Examples

View file

@ -102,6 +102,7 @@ from app.agents.new_chat.tools.registry import (
)
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
from app.services.llm_service import get_planner_llm
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
@ -1077,6 +1078,7 @@ def _build_compiled_agent_blocking(
else None,
KnowledgePriorityMiddleware(
llm=llm,
planner_llm=get_planner_llm(),
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,

View file

@ -1,232 +0,0 @@
"""Background memory extraction for the SurfSense agent.
After each agent response, if the agent did not call ``update_memory`` during
the turn, this module can run a lightweight LLM call to decide whether the
latest message contains long-term information worth persisting.
"""
from __future__ import annotations
import logging
from typing import Any
from uuid import UUID
from langchain_core.messages import HumanMessage
from sqlalchemy import select
from app.agents.new_chat.tools.update_memory import _save_memory
from app.db import SearchSpace, User, shielded_async_session
from app.utils.content_utils import extract_text_content
logger = logging.getLogger(__name__)
_MEMORY_EXTRACT_PROMPT = """\
You are a memory extraction assistant. Analyze the user's message and decide \
if it contains any long-term information worth persisting to memory.
Worth remembering: preferences, background/identity, goals, projects, \
instructions, tools/languages they use, decisions, expertise, workplace \
durable facts that will matter in future conversations.
NOT worth remembering: greetings, one-off factual questions, session \
logistics, ephemeral requests, follow-up clarifications with no new personal \
info, things that only matter for the current task.
If the message contains memorizable information, output the FULL updated \
memory document with the new facts merged into the existing content. Follow \
these rules:
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
freely. Keep heading names short (2-3 words) and natural. Do NOT include the user's
name in headings.
- Keep entries as single bullet points. Be descriptive but concise include relevant
details and context rather than just a few words.
- Every bullet MUST use format: - (YYYY-MM-DD) [fact|pref|instr] text
[fact] = durable facts, [pref] = preferences, [instr] = standing instructions.
- Use the user's first name (from <user_name>) in entry text, not "the user".
- If a new fact contradicts an existing entry, update the existing entry.
- Do not duplicate information that is already present.
If nothing is worth remembering, output exactly: NO_UPDATE
<user_name>{user_name}</user_name>
<current_memory>
{current_memory}
</current_memory>
<user_message>
{user_message}
</user_message>"""
_TEAM_MEMORY_EXTRACT_PROMPT = """\
You are a team-memory extraction assistant. Analyze the latest message and \
decide if it contains durable TEAM-level information worth persisting.
Decision policy:
- Prioritize recall for durable team context, while avoiding personal-only facts.
- Do NOT require explicit consensus language. A direct team-level statement can
be stored if it is stable and broadly useful for future team chats.
- If evidence is weak or clearly tentative, output NO_UPDATE.
Worth remembering (team-level only):
- Decisions and defaults that guide future team work
- Team conventions/standards (naming, review policy, coding norms)
- Stable org/project facts (locations, ownership, constraints)
- Long-lived architecture/process facts
- Ongoing priorities that are likely relevant beyond this turn
NOT worth remembering:
- Personal preferences or biography of one person
- Questions, brainstorming, tentative ideas, or speculation
- One-off requests, status updates, TODOs, logistics for this session
- Information scoped only to a single ephemeral task
If the message contains memorizable team information, output the FULL updated \
team memory document with new facts merged into existing content. Follow rules:
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
freely. Keep heading names short (2-3 words) and natural.
- Keep entries as single bullet points. Be descriptive but concise include relevant
details and context rather than just a few words.
- Every bullet MUST use format: - (YYYY-MM-DD) [fact] text
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr].
- If a new fact contradicts an existing entry, update the existing entry.
- Do not duplicate existing information.
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
If nothing is worth remembering, output exactly: NO_UPDATE
<current_team_memory>
{current_memory}
</current_team_memory>
<latest_message_author>
{author}
</latest_message_author>
<latest_message>
{user_message}
</latest_message>"""
async def extract_and_save_memory(
*,
user_message: str,
user_id: str | None,
llm: Any,
) -> None:
"""Background task: extract memorizable info and persist it.
Designed to be fire-and-forget catches all exceptions internally.
"""
if not user_id:
return
try:
uid = UUID(user_id) if isinstance(user_id, str) else user_id
async with shielded_async_session() as session:
result = await session.execute(select(User).where(User.id == uid))
user = result.scalars().first()
if not user:
return
old_memory = user.memory_md
first_name = (
user.display_name.strip().split()[0]
if user.display_name and user.display_name.strip()
else "The user"
)
prompt = _MEMORY_EXTRACT_PROMPT.format(
current_memory=old_memory or "(empty)",
user_message=user_message,
user_name=first_name,
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-extraction"]},
)
text = extract_text_content(response.content).strip()
if text == "NO_UPDATE" or not text:
logger.debug("Memory extraction: no update needed (user %s)", uid)
return
save_result = await _save_memory(
updated_memory=text,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=session.commit,
rollback_fn=session.rollback,
label="memory",
scope="user",
)
logger.info(
"Background memory extraction for user %s: %s",
uid,
save_result.get("status"),
)
except Exception:
logger.exception("Background user memory extraction failed")
async def extract_and_save_team_memory(
*,
user_message: str,
search_space_id: int | None,
llm: Any,
author_display_name: str | None = None,
) -> None:
"""Background task: extract team-level memory and persist it.
Runs only for shared threads. Designed to be fire-and-forget and catches
exceptions internally.
"""
if not search_space_id:
return
try:
async with shielded_async_session() as session:
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
space = result.scalars().first()
if not space:
return
old_memory = space.shared_memory_md
prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
current_memory=old_memory or "(empty)",
author=author_display_name or "Unknown team member",
user_message=user_message,
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
)
text = extract_text_content(response.content).strip()
if text == "NO_UPDATE" or not text:
logger.debug(
"Team memory extraction: no update needed (space %s)",
search_space_id,
)
return
save_result = await _save_memory(
updated_memory=text,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
commit_fn=session.commit,
rollback_fn=session.rollback,
label="team memory",
scope="team",
)
logger.info(
"Background team memory extraction for space %s: %s",
search_space_id,
save_result.get("status"),
)
except Exception:
logger.exception("Background team memory extraction failed")

View file

@ -32,6 +32,7 @@ exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect).
from __future__ import annotations
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any
@ -249,11 +250,11 @@ async def _create_document(
session.add(doc)
await session.flush()
summary_embedding = embed_texts([content])[0]
summary_embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
doc.embedding = summary_embedding
chunks = chunk_text(content)
if chunks:
chunk_embeddings = embed_texts(chunks)
chunk_embeddings = await asyncio.to_thread(embed_texts, chunks)
session.add_all(
[
Chunk(document_id=doc.id, content=text, embedding=embedding)
@ -295,13 +296,13 @@ async def _update_document(
search_space_id,
)
summary_embedding = embed_texts([content])[0]
summary_embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
document.embedding = summary_embedding
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
chunks = chunk_text(content)
if chunks:
chunk_embeddings = embed_texts(chunks)
chunk_embeddings = await asyncio.to_thread(embed_texts, chunks)
session.add_all(
[
Chunk(document_id=document.id, content=text, embedding=embedding)

View file

@ -457,7 +457,7 @@ async def search_knowledge_base(
if not query:
return []
[embedding] = embed_texts([query])
[embedding] = await asyncio.to_thread(embed_texts, [query])
doc_types = _resolve_search_types(available_connectors, available_document_types)
retriever_top_k = min(top_k * 3, 30)
@ -579,6 +579,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
self,
*,
llm: BaseChatModel | None = None,
planner_llm: BaseChatModel | None = None,
search_space_id: int,
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
available_connectors: list[str] | None = None,
@ -588,6 +589,15 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
inject_system_message: bool = True, # For backwards compatibility
) -> None:
self.llm = llm
# The planner LLM handles short, structured internal tasks (query
# rewriting, date extraction, recency classification). When an
# operator marks a global config ``is_planner: true`` we route
# those calls to a cheap/fast model (e.g. gpt-4o-mini, Haiku, Azure
# gpt-5.x-nano) instead of the user's chat LLM — those classification
# tasks don't need frontier-tier capability. Falls back to the chat
# LLM when no planner config is wired up so deployments without one
# keep working unchanged.
self.planner_llm = planner_llm or llm
self.search_space_id = search_space_id
self.filesystem_mode = filesystem_mode
self.available_connectors = available_connectors
@ -598,7 +608,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
# Build the kb-planner private Runnable ONCE here so we don't pay
# the ``create_agent`` compile cost (50-200ms) on every turn.
# Disabled by default behind ``enable_kb_planner_runnable``; when
# off the planner falls back to the legacy ``self.llm.ainvoke``
# off the planner falls back to the legacy ``planner_llm.ainvoke``
# path.
self._planner: Runnable | None = None
self._planner_compile_failed = False
@ -608,7 +618,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
Returns ``None`` when the feature flag is disabled, when the LLM is
unavailable, or when ``create_agent`` raises (we fall back to the
legacy ``self.llm.ainvoke`` path in that case). Compilation happens
legacy ``planner_llm.ainvoke`` path in that case). Compilation happens
lazily on first call, then memoized via ``self._planner``.
The compiled agent is constructed without tools the planner's
@ -618,7 +628,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""
if self._planner is not None or self._planner_compile_failed:
return self._planner
if self.llm is None:
if self.planner_llm is None:
return None
flags = get_flags()
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
@ -628,13 +638,13 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
try:
self._planner = create_agent(
self.llm,
self.planner_llm,
tools=[],
middleware=[RetryAfterMiddleware(max_retries=2)],
)
except Exception as exc: # pragma: no cover - defensive
logger.warning(
"kb-planner Runnable compile failed; falling back to llm.ainvoke: %s",
"kb-planner Runnable compile failed; falling back to planner_llm.ainvoke: %s",
exc,
)
self._planner_compile_failed = True
@ -647,12 +657,12 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
messages: Sequence[BaseMessage],
user_text: str,
) -> tuple[str, datetime | None, datetime | None, bool]:
if self.llm is None:
if self.planner_llm is None:
return user_text, None, None, False
recent_conversation = _render_recent_conversation(
messages,
llm=self.llm,
llm=self.planner_llm,
user_text=user_text,
)
prompt = _build_kb_planner_prompt(
@ -663,8 +673,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
t0 = loop.time()
# Prefer the compiled-once planner Runnable when enabled; otherwise
# fall back to ``self.llm.ainvoke``. The ``surfsense:internal`` tag
# is preserved on both paths so ``_stream_agent_events`` still
# fall back to ``planner_llm.ainvoke``. The ``surfsense:internal``
# tag is preserved on both paths so ``_stream_agent_events`` still
# suppresses the planner's intermediate events from the UI.
planner = self._build_kb_planner_runnable()
try:
@ -684,7 +694,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
else AIMessage(content="")
)
else:
response = await self.llm.ainvoke(
response = await self.planner_llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)

View file

@ -24,6 +24,7 @@ from __future__ import annotations
import asyncio
import logging
import time
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
@ -41,6 +42,9 @@ from app.agents.new_chat.path_resolver import (
doc_to_virtual_path,
)
from app.db import Document, shielded_async_session
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
try:
from litellm import token_counter
@ -124,6 +128,7 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
if self.filesystem_mode != FilesystemMode.CLOUD:
return None
start = time.perf_counter()
update: dict[str, Any] = {}
if not state.get("cwd"):
update["cwd"] = DOCUMENTS_ROOT
@ -131,7 +136,11 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
anon_doc = state.get("kb_anon_doc")
if anon_doc:
tree_msg = self._render_anon_tree(anon_doc)
cache_outcome = "anon"
else:
version = int(state.get("tree_version") or 0)
cache_key = (self.search_space_id, version, False)
cache_outcome = "hit" if cache_key in self._cache else "miss"
tree_msg = await self._render_kb_tree(state)
update["workspace_tree_text"] = tree_msg
@ -141,6 +150,14 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
insert_at = max(len(messages) - 1, 0)
messages.insert(insert_at, SystemMessage(content=tree_msg))
update["messages"] = messages
_perf_log.info(
"[knowledge_tree] cache=%s chars=%d elapsed=%.3fs space=%d",
cache_outcome,
len(tree_msg),
time.perf_counter() - start,
self.search_space_id,
)
return update
def before_agent( # type: ignore[override]

View file

@ -8,6 +8,7 @@ Injects memory markdown into the system prompt on every turn:
from __future__ import annotations
import logging
import time
from typing import Any
from uuid import UUID
@ -17,10 +18,12 @@ from langgraph.runtime import Runtime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
from app.services.memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
@ -53,9 +56,13 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
if not isinstance(last_message, HumanMessage):
return None
start = time.perf_counter()
db_elapsed = 0.0
memory_blocks: list[str] = []
scope = "team" if self.visibility == ChatVisibility.SEARCH_SPACE else "user"
async with shielded_async_session() as session:
db_start = time.perf_counter()
if self.visibility == ChatVisibility.SEARCH_SPACE:
team_memory = await self._load_team_memory(session)
if team_memory:
@ -96,7 +103,15 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
f"</memory_warning>"
)
db_elapsed = time.perf_counter() - db_start
if not memory_blocks:
_perf_log.info(
"[memory_injection] scope=%s injected=0 db=%.3fs total=%.3fs",
scope,
db_elapsed,
time.perf_counter() - start,
)
return None
memory_text = "\n\n".join(memory_blocks)
@ -106,6 +121,13 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
insert_idx = 1 if len(new_messages) > 1 else 0
new_messages.insert(insert_idx, memory_msg)
_perf_log.info(
"[memory_injection] scope=%s injected=1 chars=%d db=%.3fs total=%.3fs",
scope,
len(memory_text),
db_elapsed,
time.perf_counter() - start,
)
return {"messages": new_messages}
async def _load_user_memory(

View file

@ -39,9 +39,19 @@ For OpenAI-family configs we additionally pass:
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` routing hint that
raises hit rate by sending requests with a shared prefix to the same
backend.
backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and
``azure/`` (added to LiteLLM's Azure transformer in
https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified
against ``AzureOpenAIConfig.get_supported_openai_params`` in our
installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``,
``azure/gpt-5.4``, ``azure/gpt-5.4-mini``).
- ``prompt_cache_retention="24h"`` extends cache TTL beyond the default
5-10 min in-memory cache.
5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's
server-side support landed in Microsoft's docs on 2026-05-13 but
LiteLLM 1.83.14's Azure transformer still omits it from its supported
params list, so it gets silently dropped by ``litellm.drop_params``.
Azure's default in-memory retention (5-10 min, max 1 h) already
bridges intra-conversation turns; revisit when LiteLLM bumps Azure.
Safety net: ``litellm.drop_params=True`` is set globally in
``app.services.llm_service`` at module-load time. Any kwarg the destination
@ -81,13 +91,31 @@ _DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
{"location": "message", "index": -1},
)
# Providers (uppercase ``AgentConfig.provider`` values) that natively expose
# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and
# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers
# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without
# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU,
# MINIMAX), so we can't infer family from the litellm prefix alone.
_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"})
# Providers (uppercase ``AgentConfig.provider`` values) that accept the
# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs
# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o
# or newer Azure deployment at ≥1024 tokens with no configuration needed,
# and that ``prompt_cache_key`` is combined with the prefix hash to
# improve routing affinity and therefore cache hit rate. LiteLLM's Azure
# transformer ships ``prompt_cache_key`` in its supported params as of
# https://github.com/BerriAI/litellm/pull/20989.
#
# Strict whitelist — many other providers in ``PROVIDER_MAP`` route
# through litellm's ``openai`` prefix without implementing the OpenAI
# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer
# family from the litellm prefix alone.
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
)
# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept
# ``prompt_cache_retention="24h"``. Azure is excluded: see module
# docstring — LiteLLM 1.83.14's Azure transformer omits the param so
# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM
# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``.
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
{"OPENAI", "DEEPSEEK", "XAI"}
)
def _is_router_llm(llm: BaseChatModel) -> bool:
@ -101,13 +129,13 @@ def _is_router_llm(llm: BaseChatModel) -> bool:
return type(llm).__name__ == "ChatLiteLLMRouter"
def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
"""Whether the config targets an OpenAI-style prompt-cache surface.
def _provider_supports_prompt_cache_key(agent_config: AgentConfig | None) -> bool:
"""Whether the config targets a provider that accepts ``prompt_cache_key``.
Strict only returns True when the user explicitly chose OPENAI,
DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` /
``YAMLConfig``. Auto-mode and custom providers return False because
we can't statically know the destination.
Strict only returns True for explicitly chosen OPENAI, DEEPSEEK,
XAI, AZURE, or AZURE_OPENAI providers. Auto-mode and custom
providers return False because we can't statically know the
destination and the router fans out across mixed providers.
"""
if agent_config is None or not agent_config.provider:
return False
@ -115,7 +143,25 @@ def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
return False
if agent_config.custom_provider:
return False
return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS
return agent_config.provider.upper() in _PROMPT_CACHE_KEY_PROVIDERS
def _provider_supports_prompt_cache_retention(
agent_config: AgentConfig | None,
) -> bool:
"""Whether the config targets a provider that accepts ``prompt_cache_retention``.
Tighter than :func:`_provider_supports_prompt_cache_key` Azure
deployments are excluded until LiteLLM ships the param in its Azure
transformer (see module docstring).
"""
if agent_config is None or not agent_config.provider:
return False
if agent_config.is_auto_mode:
return False
if agent_config.custom_provider:
return False
return agent_config.provider.upper() in _PROMPT_CACHE_RETENTION_PROVIDERS
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
@ -173,16 +219,23 @@ def apply_litellm_prompt_caching(
dict(point) for point in _DEFAULT_INJECTION_POINTS
]
# OpenAI-family extras only when we statically know the destination is
# OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers
# so we can't safely set OpenAI-only kwargs there (drop_params would
# strip them but it's wasteful to set them in the first place).
# OpenAI-style extras only when we statically know the destination
# accepts them. Auto-mode router fans out across mixed providers so
# we can't safely set destination-specific kwargs there (drop_params
# would strip them but it's wasteful to set them in the first
# place).
if _is_router_llm(llm):
return
if not _is_openai_family_config(agent_config):
return
if thread_id is not None and "prompt_cache_key" not in model_kwargs:
if (
thread_id is not None
and "prompt_cache_key" not in model_kwargs
and _provider_supports_prompt_cache_key(agent_config)
):
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
if "prompt_cache_retention" not in model_kwargs:
if (
"prompt_cache_retention" not in model_kwargs
and _provider_supports_prompt_cache_retention(agent_config)
):
model_kwargs["prompt_cache_retention"] = "24h"

View file

@ -3,4 +3,10 @@ IMPORTANT — After understanding each user message, ALWAYS check: does this mes
reveal durable facts about the user (role, interests, preferences, projects,
background, or standing instructions)? If yes, you MUST call update_memory
alongside your normal response — do not defer this to a later turn.
Memory is stored as a heading-based markdown document. New entries should be
under `##` headings such as `## Facts`, `## Preferences`, or `## Instructions`
with bullets like `- YYYY-MM-DD: text`. If existing memory contains legacy
`(YYYY-MM-DD) [fact|pref|instr]` markers, preserve the information but write
new saves in the heading-based format.
</memory_protocol>

View file

@ -3,4 +3,12 @@ IMPORTANT — After understanding each user message, ALWAYS check: does this mes
reveal durable facts about the team (decisions, conventions, architecture, processes,
or key facts)? If yes, you MUST call update_memory alongside your normal response —
do not defer this to a later turn.
Team memory is stored as a heading-based markdown document. New entries should
be under `##` headings such as `## Product Decisions`,
`## Engineering Conventions`, `## Project Facts`, or `## Open Questions` with
bullets like `- YYYY-MM-DD: text`. If existing memory contains legacy
`(YYYY-MM-DD) [fact]` markers, preserve the information but write new saves in
the heading-based format. Do not create personal headings such as
`## Preferences` or `## Instructions`.
</memory_protocol>

View file

@ -1,16 +1,16 @@
- <user_name>Alex</user_name>, <user_memory> is empty. User: "I'm a space enthusiast, explain astrophage to me"
- The user casually shared a durable fact. Use their first name in the entry, short neutral heading:
update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n")
- The user casually shared a durable fact:
update_memory(updated_memory="## Facts\n- 2025-03-15: Alex is a space enthusiast\n")
- User: "Remember that I prefer concise answers over detailed explanations"
- Durable preference. Merge with existing memory, add a new heading:
update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n\n## Response style\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\n")
- Durable preference. Merge with existing memory:
update_memory(updated_memory="## Facts\n- 2025-03-15: Alex is a space enthusiast\n\n## Preferences\n- 2025-03-15: Alex prefers concise answers over detailed explanations\n")
- User: "I actually moved to Tokyo last month"
- Updated fact, date prefix reflects when recorded:
update_memory(updated_memory="## Interests & background\n...\n\n## Personal context\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\n...")
update_memory(updated_memory="## Facts\n- 2025-03-15: Alex lives in Tokyo (previously London)\n...")
- User: "I'm a freelance photographer working on a nature documentary"
- Durable background info under a fitting heading:
update_memory(updated_memory="...\n\n## Current focus\n- (2025-03-15) [fact] Alex is a freelance photographer\n- (2025-03-15) [fact] Alex is working on a nature documentary\n")
update_memory(updated_memory="...\n\n## Current Focus\n- 2025-03-15: Alex is a freelance photographer\n- 2025-03-15: Alex is working on a nature documentary\n")
- User: "Always respond in bullet points"
- Standing instruction:
update_memory(updated_memory="...\n\n## Response style\n- (2025-03-15) [instr] Always respond to Alex in bullet points\n")
update_memory(updated_memory="...\n\n## Instructions\n- 2025-03-15: Always respond to Alex in bullet points\n")

View file

@ -1,7 +1,7 @@
- User: "Let's remember that we decided to do weekly standup meetings on Mondays"
- Durable team decision:
update_memory(updated_memory="- (2025-03-15) [fact] Weekly standup meetings on Mondays\n...")
update_memory(updated_memory="## Product Decisions\n- 2025-03-15: Weekly standup meetings happen on Mondays\n...")
- User: "Our office is in downtown Seattle, 5th floor"
- Durable team fact:
update_memory(updated_memory="- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\n...")
update_memory(updated_memory="## Project Facts\n- 2025-03-15: Office location is downtown Seattle, 5th floor\n...")

View file

@ -1,31 +1,26 @@
- update_memory: Update your personal memory document about the user.
- Your current memory is already in <user_memory> in your context. The `chars` and
`limit` attributes show your current usage and the maximum allowed size.
- This is your curated long-term memory — the distilled essence of what you know about
the user, not raw conversation logs.
- Call update_memory when:
* The user explicitly asks to remember or forget something
* The user shares durable facts or preferences that will matter in future conversations
- The user's first name is provided in <user_name>. Use it in memory entries
instead of "the user" (e.g. "{name} works at..." not "The user works at...").
Do not store the name itself as a separate memory entry.
- Do not store short-lived or ephemeral info: one-off questions, greetings,
session logistics, or things that only matter for the current task.
- Your current memory is already in <user_memory> in your context. The `chars`
and `limit` attributes show current usage and the maximum allowed size.
- This is curated long-term memory, not raw conversation logs.
- Call update_memory when the user explicitly asks to remember/forget
something or shares durable facts, preferences, or standing instructions.
- The user's first name is provided in <user_name>. Use it in entries instead
of "the user" when helpful. Do not store the name alone as a memory entry.
- Do not store short-lived info: one-off questions, greetings, session
logistics, or things that only matter for the current task.
- Args:
- updated_memory: The FULL updated markdown document (not a diff).
Merge new facts with existing ones, update contradictions, remove outdated entries.
Treat every update as a curation pass — consolidate, don't just append.
- Every bullet MUST use this format: - (YYYY-MM-DD) [marker] text
Markers:
[fact] — durable facts (role, background, projects, tools, expertise)
[pref] — preferences (response style, languages, formats, tools)
[instr] — standing instructions (always/never do, response rules)
- Keep it concise and well under the character limit shown in <user_memory>.
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
natural. Do NOT include the user's name in headings. Organize by context — e.g.
who they are, what they're focused on, how they prefer things. Create, split, or
merge headings freely as the memory grows.
- Each entry MUST be a single bullet point. Be descriptive but concise — include relevant
details and context rather than just a few words.
- During consolidation, prioritize keeping: [instr] > [pref] > [fact].
- updated_memory: The FULL updated markdown document, not a diff. Merge new
facts with existing ones, update contradictions, remove outdated entries,
and consolidate instead of only appending.
- Use heading-based Markdown:
* Every entry must be under a `##` heading.
* Recommended headings: `## Facts`, `## Preferences`, `## Instructions`.
Specific natural headings are allowed when clearer.
* New bullets should use `- YYYY-MM-DD: text`.
* Each entry should be one concise but descriptive bullet.
- If existing memory uses legacy `(YYYY-MM-DD) [fact|pref|instr]` markers,
preserve the information but write the updated document in the new
heading-based format.
- During consolidation, prioritize durable instructions and preferences before
generic facts.

View file

@ -1,26 +1,28 @@
- update_memory: Update the team's shared memory document for this search space.
- Your current team memory is already in <team_memory> in your context. The `chars`
and `limit` attributes show current usage and the maximum allowed size.
- This is the team's curated long-term memory — decisions, conventions, key facts.
- NEVER store personal memory in team memory (e.g. personal bio, individual
preferences, or user-only standing instructions).
- Call update_memory when:
* A team member explicitly asks to remember or forget something
* The conversation surfaces durable team decisions, conventions, or facts
that will matter in future conversations
- Do not store short-lived or ephemeral info: one-off questions, greetings,
session logistics, or things that only matter for the current task.
- Your current team memory is already in <team_memory> in your context. The
`chars` and `limit` attributes show current usage and the maximum allowed size.
- This is curated long-term team memory: decisions, conventions, architecture,
processes, and key shared facts.
- NEVER store personal memory in team memory: individual bios, personal
preferences, or user-only standing instructions.
- Call update_memory when a team member asks to remember/forget something, or
when the conversation surfaces durable team context that matters later.
- Do not store short-lived info: one-off questions, greetings, session
logistics, or things that only matter for the current task.
- Args:
- updated_memory: The FULL updated markdown document (not a diff).
Merge new facts with existing ones, update contradictions, remove outdated entries.
Treat every update as a curation pass — consolidate, don't just append.
- Every bullet MUST use this format: - (YYYY-MM-DD) [fact] text
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr] in team memory.
- Keep it concise and well under the character limit shown in <team_memory>.
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
natural. Organize by context — e.g. what the team decided, current architecture,
active processes. Create, split, or merge headings freely as the memory grows.
- Each entry MUST be a single bullet point. Be descriptive but concise — include relevant
details and context rather than just a few words.
- During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities.
- updated_memory: The FULL updated markdown document, not a diff. Merge new
facts with existing ones, update contradictions, remove outdated entries,
and consolidate instead of only appending.
- Use heading-based Markdown:
* Every entry must be under a `##` heading.
* Recommended headings: `## Product Decisions`, `## Engineering Conventions`,
`## Project Facts`, `## Open Questions`.
* New bullets should use `- YYYY-MM-DD: text`.
* Each entry should be one concise but descriptive bullet.
- If existing memory uses legacy `(YYYY-MM-DD) [fact]` markers, preserve the
information but write the updated document in the new heading-based format.
- Do not create personal headings such as `## Preferences`, `## Instructions`,
`## Personal Notes`, or `## Personal Instructions`.
- During consolidation, prioritize decisions/conventions, then key facts, then
current priorities.

View file

@ -36,8 +36,16 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
from app.agents.new_chat.tools.hitl import request_approval
from app.agents.new_chat.tools.mcp_client import MCPClient
from app.agents.new_chat.tools.mcp_tools_cache import (
CachedMCPTools,
read_cached_tools,
write_cached_tools,
)
from app.db import SearchSourceConnector
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
logger = logging.getLogger(__name__)
@ -293,15 +301,21 @@ async def _create_mcp_tool_from_definition_http(
timeout: float = 60.0,
) -> str:
"""Execute a single MCP HTTP call with the given headers."""
call_start = time.perf_counter()
async with (
streamablehttp_client(url, headers=call_headers) as (read, write, _),
ClientSession(read, write) as session,
):
init_start = time.perf_counter()
await session.initialize()
init_elapsed = time.perf_counter() - init_start
tool_start = time.perf_counter()
response = await asyncio.wait_for(
session.call_tool(original_tool_name, arguments=call_kwargs),
timeout=timeout,
)
tool_elapsed = time.perf_counter() - tool_start
result = []
for content in response.content:
@ -312,7 +326,18 @@ async def _create_mcp_tool_from_definition_http(
else:
result.append(str(content))
return "\n".join(result) if result else ""
payload = "\n".join(result) if result else ""
_perf_log.info(
"[mcp_http_call] connector=%s tool=%s init=%.3fs call=%.3fs total=%.3fs out_chars=%d",
connector_id,
original_tool_name,
init_elapsed,
tool_elapsed,
time.perf_counter() - call_start,
len(payload),
)
return payload
async def mcp_http_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via HTTP transport."""
@ -496,6 +521,7 @@ async def _load_http_mcp_tools(
is_generic_mcp: bool = False,
*,
bypass_internal_hitl: bool = False,
cached_tools: CachedMCPTools | None = None,
) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server.
@ -506,6 +532,8 @@ async def _load_http_mcp_tools(
readonly_tools: Tool names that skip HITL approval (read-only operations).
tool_name_prefix: If set, each tool name is prefixed for multi-account
disambiguation (e.g. ``linear_25``).
cached_tools: If provided, skip live discovery and rebuild wrappers
from the persisted definitions.
"""
tools: list[StructuredTool] = []
@ -529,15 +557,23 @@ async def _load_http_mcp_tools(
allowed_set = set(allowed_tools) if allowed_tools else None
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]:
"""Connect, initialize, and list tools from the MCP server."""
async def _discover(
disc_headers: dict[str, str],
) -> tuple[dict[str, str | None], list[dict[str, Any]]]:
"""Connect, initialize, and list tools — returns (serverInfo, tools)."""
async with (
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
init_result = await session.initialize()
server_info: dict[str, str | None] = {"name": None, "version": None}
si = getattr(init_result, "serverInfo", None)
if si is not None:
server_info["name"] = getattr(si, "name", None)
server_info["version"] = getattr(si, "version", None)
response = await session.list_tools()
return [
return server_info, [
{
"name": tool.name,
"description": tool.description or "",
@ -548,47 +584,65 @@ async def _load_http_mcp_tools(
for tool in response.tools
]
try:
tool_definitions = await _discover(headers)
except Exception as first_err:
if not _is_auth_error(first_err) or connector_id is None:
logger.exception(
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
url,
connector_id,
first_err,
)
return tools
logger.warning(
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
connector_id,
)
fresh_headers = await _force_refresh_and_get_headers(connector_id)
if fresh_headers is None:
await _mark_connector_auth_expired(connector_id)
logger.error(
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
connector_id,
)
return tools
if cached_tools is not None:
tool_definitions = [
{
"name": td.name,
"description": td.description,
"input_schema": td.input_schema,
}
for td in cached_tools.tools
]
else:
try:
tool_definitions = await _discover(fresh_headers)
headers = fresh_headers
logger.info(
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
server_info, tool_definitions = await _discover(headers)
except Exception as first_err:
if not _is_auth_error(first_err) or connector_id is None:
logger.exception(
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
url,
connector_id,
first_err,
)
return tools
logger.warning(
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
connector_id,
)
except Exception as retry_err:
logger.exception(
"HTTP MCP discovery for connector %d still failing after refresh: %s",
connector_id,
retry_err,
)
if _is_auth_error(retry_err):
fresh_headers = await _force_refresh_and_get_headers(connector_id)
if fresh_headers is None:
await _mark_connector_auth_expired(connector_id)
return tools
logger.error(
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
connector_id,
)
return tools
try:
server_info, tool_definitions = await _discover(fresh_headers)
headers = fresh_headers
logger.info(
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
connector_id,
)
except Exception as retry_err:
logger.exception(
"HTTP MCP discovery for connector %d still failing after refresh: %s",
connector_id,
retry_err,
)
if _is_auth_error(retry_err):
await _mark_connector_auth_expired(connector_id)
return tools
await write_cached_tools(
connector_id,
tool_definitions,
server_name=server_info.get("name"),
server_version=server_info.get("version"),
transport=server_config.get("transport", "streamable-http"),
)
total_discovered = len(tool_definitions)
@ -792,14 +846,25 @@ async def _maybe_refresh_mcp_oauth_token(
except (ValueError, TypeError):
return server_config
refresh_start = time.perf_counter()
try:
new_access = await _refresh_connector_token(session, connector)
if not new_access:
_perf_log.info(
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=no_token",
connector.id,
time.perf_counter() - refresh_start,
)
return server_config
logger.info(
"Proactively refreshed MCP OAuth token for connector %s", connector.id
)
_perf_log.info(
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=refreshed",
connector.id,
time.perf_counter() - refresh_start,
)
refreshed_config = dict(server_config)
refreshed_config["headers"] = {
@ -809,6 +874,11 @@ async def _maybe_refresh_mcp_oauth_token(
return refreshed_config
except Exception:
_perf_log.info(
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=failed",
connector.id,
time.perf_counter() - refresh_start,
)
logger.warning(
"Failed to refresh MCP OAuth token for connector %s",
connector.id,
@ -937,6 +1007,94 @@ def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
_mcp_tools_cache.clear()
async def discover_single_mcp_connector(connector_id: int) -> None:
"""Force live MCP discovery for one connector so its ``cached_tools`` row is fresh.
``_load_http_mcp_tools`` persists ``cached_tools`` as a side effect of any
live discovery; passing ``cached_tools=None`` here guarantees we go to the
network. The returned wrappers are discarded the in-process LRU is
rebuilt lazily on the next user query. Stdio connectors are not cached and
are skipped.
"""
from app.db import async_session_maker
started = time.perf_counter()
try:
async with async_session_maker() as session:
connector = await session.get(SearchSourceConnector, connector_id)
if connector is None:
logger.info(
"discover_single_mcp_connector: connector %d not found",
connector_id,
)
return
cfg = connector.config or {}
server_config = cfg.get("server_config", {})
if not server_config or not isinstance(server_config, dict):
return
transport = server_config.get("transport", "stdio")
if transport not in ("streamable-http", "http", "sse"):
return
if cfg.get("mcp_oauth"):
server_config = await _maybe_refresh_mcp_oauth_token(
session, connector, cfg, server_config
)
cfg = connector.config or {}
server_config = _inject_oauth_headers(cfg, server_config)
if server_config is None:
logger.info(
"discover_single_mcp_connector: OAuth token unavailable for connector %d",
connector_id,
)
return
ct = (
connector.connector_type.value
if hasattr(connector.connector_type, "value")
else str(connector.connector_type)
)
svc_cfg = get_service_by_connector_type(ct)
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
await asyncio.wait_for(
_load_http_mcp_tools(
connector.id,
connector.name,
server_config,
trusted_tools=cfg.get("trusted_tools", []),
allowed_tools=allowed_tools,
readonly_tools=readonly_tools,
tool_name_prefix=None,
is_generic_mcp=svc_cfg is None,
bypass_internal_hitl=True,
cached_tools=None,
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
_perf_log.info(
"[mcp_prefetch] connector=%s elapsed=%.3fs",
connector_id,
time.perf_counter() - started,
)
except TimeoutError:
logger.warning(
"discover_single_mcp_connector: connector %d timed out after %ds",
connector_id,
_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
except Exception:
logger.warning(
"discover_single_mcp_connector: failed for connector %d",
connector_id,
exc_info=True,
)
async def load_mcp_tools(
session: AsyncSession,
search_space_id: int,
@ -1063,6 +1221,7 @@ async def load_mcp_tools(
"tool_name_prefix": tool_name_prefix,
"transport": server_config.get("transport", "stdio"),
"is_generic_mcp": svc_cfg is None,
"cached_tools": read_cached_tools(connector),
}
)
@ -1074,9 +1233,12 @@ async def load_mcp_tools(
)
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
discover_start = time.perf_counter()
transport = task["transport"]
cached_tools = task.get("cached_tools")
try:
if task["transport"] in ("streamable-http", "http", "sse"):
return await asyncio.wait_for(
if transport in ("streamable-http", "http", "sse"):
result = await asyncio.wait_for(
_load_http_mcp_tools(
task["connector_id"],
task["connector_name"],
@ -1087,11 +1249,12 @@ async def load_mcp_tools(
tool_name_prefix=task["tool_name_prefix"],
is_generic_mcp=task.get("is_generic_mcp", False),
bypass_internal_hitl=bypass_internal_hitl,
cached_tools=cached_tools,
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
else:
return await asyncio.wait_for(
result = await asyncio.wait_for(
_load_stdio_mcp_tools(
task["connector_id"],
task["connector_name"],
@ -1101,7 +1264,24 @@ async def load_mcp_tools(
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
_perf_log.info(
"[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs cache=%s",
task["connector_id"],
task["connector_name"],
transport,
len(result),
time.perf_counter() - discover_start,
"hit" if cached_tools is not None else "miss",
)
return result
except TimeoutError:
_perf_log.info(
"[mcp_discover] connector=%s name=%r transport=%s elapsed=%.3fs outcome=timeout",
task["connector_id"],
task["connector_name"],
transport,
time.perf_counter() - discover_start,
)
logger.error(
"MCP connector %d timed out after %ds during discovery",
task["connector_id"],
@ -1109,6 +1289,13 @@ async def load_mcp_tools(
)
return []
except Exception as e:
_perf_log.info(
"[mcp_discover] connector=%s name=%r transport=%s elapsed=%.3fs outcome=error",
task["connector_id"],
task["connector_name"],
transport,
time.perf_counter() - discover_start,
)
logger.exception(
"Failed to load tools from MCP connector %d: %s",
task["connector_id"],
@ -1116,7 +1303,14 @@ async def load_mcp_tools(
)
return []
gather_start = time.perf_counter()
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
_perf_log.info(
"[mcp_discover] gather_wall=%.3fs connectors=%d total_tools=%d",
time.perf_counter() - gather_start,
len(discovery_tasks),
sum(len(r) for r in results),
)
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
_mcp_tools_cache[cache_key] = (now, tools)

View file

@ -0,0 +1,145 @@
"""Persist MCP ``list_tools`` results in ``SearchSourceConnector.config.cached_tools``."""
from __future__ import annotations
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import SearchSourceConnector, async_session_maker
logger = logging.getLogger(__name__)
_pending_prefetch_tasks: set[asyncio.Task[None]] = set()
class CachedMCPToolDef(BaseModel):
name: str
description: str = ""
input_schema: dict[str, Any] = Field(default_factory=dict)
class CachedMCPTools(BaseModel):
discovered_at: datetime
server_version: str | None = None
server_name: str | None = None
transport: str | None = None
tools: list[CachedMCPToolDef]
def read_cached_tools(connector: SearchSourceConnector) -> CachedMCPTools | None:
"""Return parsed cached tools or ``None`` if missing / corrupt (caller falls back to live discovery)."""
cfg = connector.config or {}
raw = cfg.get("cached_tools")
if not raw or not isinstance(raw, dict):
return None
try:
return CachedMCPTools.model_validate(raw)
except ValidationError as exc:
logger.warning(
"MCP connector %d has corrupt cached_tools — falling back to live discovery: %s",
connector.id,
exc,
)
return None
async def write_cached_tools(
connector_id: int,
tool_definitions: list[dict[str, Any]],
*,
server_name: str | None = None,
server_version: str | None = None,
transport: str | None = None,
) -> None:
"""Best-effort persist; uses its own session so a write failure cannot poison the caller's transaction."""
payload = CachedMCPTools(
discovered_at=datetime.now(UTC),
server_version=server_version,
server_name=server_name,
transport=transport,
tools=[CachedMCPToolDef.model_validate(td) for td in tool_definitions],
)
try:
async with async_session_maker() as session:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
)
)
connector = result.scalars().first()
if connector is None:
return
cfg = dict(connector.config or {})
cfg["cached_tools"] = payload.model_dump(mode="json")
connector.config = cfg
flag_modified(connector, "config")
await session.commit()
logger.info(
"Persisted cached_tools for MCP connector %d (%d tools)",
connector_id,
len(payload.tools),
)
except Exception:
logger.warning(
"Failed to persist cached_tools for MCP connector %d",
connector_id,
exc_info=True,
)
def refresh_mcp_tools_cache_for_connector(
connector_id: int,
search_space_id: int,
) -> None:
"""Maintain the MCP tool cache after a single-connector lifecycle event.
Synchronously evicts the in-process LRU for the connector's search space
(LRU keys are per-space, so eviction cannot be scoped finer), then schedules
a background live discovery for this connector alone so its persisted
``cached_tools`` row is refreshed before the next user query.
Idempotent. Eviction is best-effort; prefetch is best-effort and only runs
when an event loop is available. Neither path raises.
"""
try:
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(search_space_id)
except Exception:
logger.debug(
"MCP in-process cache eviction skipped for space %d",
search_space_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
task = loop.create_task(_run_connector_prefetch(connector_id))
_pending_prefetch_tasks.add(task)
task.add_done_callback(_pending_prefetch_tasks.discard)
async def _run_connector_prefetch(connector_id: int) -> None:
from app.agents.new_chat.tools.mcp_tool import discover_single_mcp_connector
try:
await discover_single_mcp_connector(connector_id)
except Exception:
logger.warning(
"MCP background prefetch failed for connector_id=%d",
connector_id,
exc_info=True,
)

View file

@ -1,369 +1,53 @@
"""Markdown-document memory tool for the SurfSense agent.
Replaces the old row-per-fact save_memory / recall_memory tools with a single
update_memory tool that overwrites a freeform markdown TEXT column. The LLM
always sees the current memory in <user_memory> / <team_memory> tags injected
by MemoryInjectionMiddleware, so it passes the FULL updated document each time.
Overflow handling:
- Soft limit (18K chars): a warning is returned telling the agent to
consolidate on the next update.
- Hard limit (25K chars): a forced LLM-driven rewrite compresses the document.
If it still exceeds the limit after rewriting, the save is rejected.
- Diff validation: warns when entire ``##`` sections are dropped or when the
document shrinks by more than 60%.
"""
"""Memory update tools backed by the canonical memory service."""
from __future__ import annotations
import logging
import re
from typing import Any, Literal
from typing import Any
from uuid import UUID
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User, async_session_maker
from app.utils.content_utils import extract_text_content
from app.db import async_session_maker
from app.services.memory import MemoryScope, save_memory
logger = logging.getLogger(__name__)
MEMORY_SOFT_LIMIT = 18_000
MEMORY_HARD_LIMIT = 25_000
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
# ---------------------------------------------------------------------------
# Diff validation
# ---------------------------------------------------------------------------
def _extract_headings(memory: str) -> set[str]:
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
return set(_SECTION_HEADING_RE.findall(memory))
def _normalize_heading(heading: str) -> str:
"""Normalize heading text for robust scope checks."""
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
def _validate_memory_scope(
content: str, scope: Literal["user", "team"]
) -> dict[str, Any] | None:
"""Reject personal-only markers ([pref], [instr]) in team memory."""
if scope != "team":
return None
markers = set(_MARKER_RE.findall(content))
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
if leaked:
tags = ", ".join(f"[{m}]" for m in leaked)
return {
"status": "error",
"message": (
f"Team memory cannot include personal markers: {tags}. "
"Use [fact] only in team memory."
),
}
return None
def _validate_bullet_format(content: str) -> list[str]:
"""Return warnings for bullet lines that don't match the required format.
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
"""
warnings: list[str] = []
for line in content.splitlines():
stripped = line.strip()
if not stripped.startswith("- "):
continue
if not _BULLET_FORMAT_RE.match(stripped):
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
warnings.append(f"Malformed bullet: {short}")
return warnings
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
"""Return a list of warning strings about suspicious changes."""
if not old_memory:
return []
warnings: list[str] = []
old_headings = _extract_headings(old_memory)
new_headings = _extract_headings(new_memory)
dropped = old_headings - new_headings
if dropped:
names = ", ".join(sorted(dropped))
warnings.append(
f"Sections removed: {names}. "
"If unintentional, the user can restore from the settings page."
)
old_len = len(old_memory)
new_len = len(new_memory)
if old_len > 0 and new_len < old_len * 0.4:
warnings.append(
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
"Possible data loss."
)
return warnings
# ---------------------------------------------------------------------------
# Size validation & soft warning
# ---------------------------------------------------------------------------
def _validate_memory_size(content: str) -> dict[str, Any] | None:
"""Return an error/warning dict if *content* is too large, else None."""
length = len(content)
if length > MEMORY_HARD_LIMIT:
return {
"status": "error",
"message": (
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
f"({length:,} chars). Consolidate by merging related items, "
"removing outdated entries, and shortening descriptions. "
"Then call update_memory again."
),
}
return None
def _soft_warning(content: str) -> str | None:
"""Return a warning string if content exceeds the soft limit."""
length = len(content)
if length > MEMORY_SOFT_LIMIT:
return (
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
"Consolidate by merging related items and removing less important "
"entries on your next update."
)
return None
# ---------------------------------------------------------------------------
# Forced rewrite when memory exceeds the hard limit
# ---------------------------------------------------------------------------
_FORCED_REWRITE_PROMPT = """\
You are a memory curator. The following memory document exceeds the character \
limit and must be shortened.
RULES:
1. Rewrite the document to be under {target} characters.
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
or rename headings to consolidate, but keep names personal and descriptive.
3. Priority for keeping content: [instr] > [pref] > [fact].
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
6. Preserve the user's first name in entries — do not replace it with "the user".
7. Output ONLY the consolidated markdown no explanations, no wrapping.
<memory_document>
{content}
</memory_document>"""
async def _forced_rewrite(content: str, llm: Any) -> str | None:
"""Use a focused LLM call to compress *content* under the hard limit.
Returns the rewritten string, or ``None`` if the call fails.
"""
try:
prompt = _FORCED_REWRITE_PROMPT.format(
target=MEMORY_HARD_LIMIT, content=content
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
text = extract_text_content(response.content).strip()
if not text:
logger.warning("Forced rewrite returned empty text; aborting rewrite")
return None
return text
except Exception:
logger.exception("Forced rewrite LLM call failed")
return None
# ---------------------------------------------------------------------------
# Shared save-and-respond logic
# ---------------------------------------------------------------------------
async def _save_memory(
*,
updated_memory: str,
old_memory: str | None,
llm: Any | None,
apply_fn,
commit_fn,
rollback_fn,
label: str,
scope: Literal["user", "team"],
) -> dict[str, Any]:
"""Validate, optionally force-rewrite if over the hard limit, save, and
return a response dict.
Parameters
----------
updated_memory : str
The new document the agent submitted.
old_memory : str | None
The previously persisted document (for diff checks).
llm : Any | None
LLM instance for forced rewrite (may be ``None``).
apply_fn : callable(str) -> None
Callback that sets the new memory on the ORM object.
commit_fn : coroutine
``session.commit``.
rollback_fn : coroutine
``session.rollback``.
label : str
Human label for log messages (e.g. "user memory", "team memory").
"""
if not isinstance(updated_memory, str):
logger.warning(
"Refusing non-string memory payload (type=%s)",
type(updated_memory).__name__,
)
return {
"status": "error",
"message": "Internal error: memory payload must be a string.",
}
content = updated_memory
# --- forced rewrite if over the hard limit ---
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
rewritten = await _forced_rewrite(content, llm)
if rewritten is not None and len(rewritten) < len(content):
content = rewritten
# --- hard-limit gate (reject if still too large after rewrite) ---
size_err = _validate_memory_size(content)
if size_err:
return size_err
scope_err = _validate_memory_scope(content, scope)
if scope_err:
return scope_err
# --- persist ---
try:
apply_fn(content)
await commit_fn()
except Exception as e:
logger.exception("Failed to update %s: %s", label, e)
await rollback_fn()
return {"status": "error", "message": f"Failed to update {label}: {e}"}
# --- build response ---
resp: dict[str, Any] = {
"status": "saved",
"message": f"{label.capitalize()} updated.",
}
if content is not updated_memory:
resp["notice"] = "Memory was automatically rewritten to fit within limits."
diff_warnings = _validate_diff(old_memory, content)
if diff_warnings:
resp["diff_warnings"] = diff_warnings
format_warnings = _validate_bullet_format(content)
if format_warnings:
resp["format_warnings"] = format_warnings
warning = _soft_warning(content)
if warning:
resp["warning"] = warning
return resp
# ---------------------------------------------------------------------------
# Tool factories
# ---------------------------------------------------------------------------
def create_update_memory_tool(
user_id: str | UUID,
db_session: AsyncSession,
llm: Any | None = None,
):
"""Factory function to create the user-memory update tool.
"""Factory for the user-memory update tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
The session's bound ``commit``/``rollback`` methods are captured at
call time, after ``async with`` has bound ``db_session`` locally.
Args:
user_id: ID of the user whose memory document is being updated.
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
llm: Optional LLM for the forced-rewrite path.
Returns:
Configured update_memory tool for the user-memory scope.
Uses a fresh short-lived session per call so compiled-agent caches never
retain a stale request-scoped session.
"""
del db_session # per-call session — see docstring
del db_session
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the user's personal memory document.
Your current memory is shown in <user_memory> in the system prompt.
When the user shares important long-term information (preferences,
facts, instructions, context), rewrite the memory document to include
the new information. Merge new facts with existing ones, update
contradictions, remove outdated entries, and keep it concise.
Args:
updated_memory: The FULL updated markdown document (not a diff).
The current memory is shown in <user_memory>. Pass the FULL updated
markdown document, not a diff.
"""
try:
async with async_session_maker() as db_session:
result = await db_session.execute(select(User).where(User.id == uid))
user = result.scalars().first()
if not user:
return {"status": "error", "message": "User not found."}
old_memory = user.memory_md
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
result = await save_memory(
scope=MemoryScope.USER,
target_id=uid,
content=updated_memory,
session=db_session,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="memory",
scope="user",
)
return result.to_dict()
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
return {
"status": "error",
"message": f"Failed to update memory: {e}",
}
return {"status": "error", "message": f"Failed to update memory: {e}"}
return update_memory
@ -373,64 +57,26 @@ def create_update_team_memory_tool(
db_session: AsyncSession,
llm: Any | None = None,
):
"""Factory function to create the team-memory update tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
The session's bound ``commit``/``rollback`` methods are captured at
call time, after ``async with`` has bound ``db_session`` locally.
Args:
search_space_id: ID of the search space whose team memory is being
updated.
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
llm: Optional LLM for the forced-rewrite path.
Returns:
Configured update_memory tool for the team-memory scope.
"""
del db_session # per-call session — see docstring
"""Factory for the team-memory update tool."""
del db_session
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space.
Your current team memory is shown in <team_memory> in the system
prompt. When the team shares important long-term information
(decisions, conventions, key facts, priorities), rewrite the memory
document to include the new information. Merge new facts with
existing ones, update contradictions, remove outdated entries, and
keep it concise.
Args:
updated_memory: The FULL updated markdown document (not a diff).
The current team memory is shown in <team_memory>. Pass the FULL updated
markdown document, not a diff.
"""
try:
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
space = result.scalars().first()
if not space:
return {"status": "error", "message": "Search space not found."}
old_memory = space.shared_memory_md
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
result = await save_memory(
scope=MemoryScope.TEAM,
target_id=search_space_id,
content=updated_memory,
session=db_session,
llm=llm,
apply_fn=lambda content: setattr(
space, "shared_memory_md", content
),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="team memory",
scope="team",
)
return result.to_dict()
except Exception as e:
logger.exception("Failed to update team memory: %s", e)
return {
@ -439,3 +85,9 @@ def create_update_team_memory_tool(
}
return update_memory
__all__ = [
"create_update_memory_tool",
"create_update_team_memory_tool",
]

View file

@ -110,6 +110,19 @@ def load_global_llm_configs():
except Exception as e:
print(f"Warning: Failed to score global LLM configs: {e}")
# Planner LLM is a singleton role. If an operator accidentally
# marks multiple configs ``is_planner: true``, only the first one
# is used at runtime — surface the others at startup so the
# mistake is caught before traffic, not silently buried.
planner_cfgs = [c for c in configs if c.get("is_planner") is True]
if len(planner_cfgs) > 1:
extra_ids = [c.get("id") for c in planner_cfgs[1:]]
print(
"Warning: Multiple global LLM configs marked is_planner=true "
f"(ids {[c.get('id') for c in planner_cfgs]}); using id "
f"{planner_cfgs[0].get('id')} and ignoring {extra_ids}"
)
return configs
except Exception as e:
print(f"Warning: Failed to load global LLM configs: {e}")

View file

@ -258,6 +258,45 @@ global_llm_configs:
use_default_system_instructions: true
citations_enabled: true
# Example: Planner LLM - small, fast model used for internal utility tasks
#
# The PLANNER role handles short, structured internal calls (KB query
# rewriting, date extraction, recency classification, etc.) that don't
# need frontier-tier capability. Pointing the planner at a cheap+fast
# model (gpt-4o-mini, Claude Haiku, Azure gpt-5.x-nano, Groq Llama, ...)
# typically saves 500ms-1.5s per turn vs. routing those same internal
# calls through the user's chat model.
#
# Activation:
# - Mark EXACTLY ONE global config with ``is_planner: true``.
# - If multiple are marked, the first one wins and a WARNING is logged.
# - If none is marked, every internal call falls back to the user's
# chat LLM (same behavior as before this flag existed).
#
# This config is operator-only — it is NOT exposed in the user-facing
# model selector, never billed against premium quota, and the
# billing_tier / anonymous_enabled fields below are ignored.
- id: -9
name: "Global Planner (GPT-4o mini)"
description: "Internal-only planner LLM for query rewriting and classification"
is_planner: true
billing_tier: "free"
anonymous_enabled: false
seo_enabled: false
quota_reserve_tokens: 1000
provider: "OPENAI"
model_name: "gpt-4o-mini"
api_key: "sk-your-openai-api-key-here"
api_base: ""
rpm: 3500
tpm: 200000
litellm_params:
temperature: 0
max_tokens: 1000
system_instructions: ""
use_default_system_instructions: true
citations_enabled: false
# =============================================================================
# OpenRouter Integration
# =============================================================================
@ -493,6 +532,20 @@ global_vision_llm_configs:
# - Lower temperature (0.3) is recommended for accurate screenshot analysis
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
#
# PLANNER LLM NOTES:
# - is_planner: true marks a config as the internal-only planner LLM (small,
# fast model used for KB query rewriting, date extraction, recency
# classification, etc.). Only one config may carry this flag — if
# multiple do, the first one wins and a startup WARNING is logged.
# - When no config is marked is_planner, every internal utility call falls
# back to the user's chat LLM (the historical behavior).
# - Planner configs are NOT shown in the user-facing model selector and
# are NOT billed against the user's premium quota. Their billing_tier,
# anonymous_enabled, seo_* fields are ignored.
# - Recommended models: gpt-4o-mini, claude-3-5-haiku, gemini-1.5-flash,
# azure gpt-5.x-nano, groq llama3-8b — anything <200ms p50 on a 1-2k
# prompt. Frontier models here defeat the purpose of the flag.
#
# TOKEN QUOTA & ANONYMOUS ACCESS NOTES:
# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota.
# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog.

View file

@ -54,6 +54,7 @@ from .search_spaces_routes import router as search_spaces_router
from .slack_add_connector_route import router as slack_add_connector_router
from .stripe_routes import router as stripe_router
from .surfsense_docs_routes import router as surfsense_docs_router
from .team_memory_routes import router as team_memory_router
from .teams_add_connector_route import router as teams_add_connector_router
from .video_presentations_routes import router as video_presentations_router
from .vision_llm_routes import router as vision_llm_router
@ -117,3 +118,4 @@ router.include_router(stripe_router) # Stripe checkout for additional page pack
router.include_router(youtube_router) # YouTube playlist resolution
router.include_router(prompts_router)
router.include_router(memory_router) # User personal memory (memory.md style)
router.include_router(team_memory_router) # Search-space team memory

View file

@ -428,7 +428,7 @@ async def mcp_oauth_callback(
await session.commit()
await session.refresh(db_connector)
_invalidate_cache(space_id)
_refresh_mcp_cache(db_connector.id, space_id)
logger.info(
"Re-authenticated %s MCP connector %s for user %s",
@ -481,7 +481,7 @@ async def mcp_oauth_callback(
detail="A connector for this service already exists.",
) from e
_invalidate_cache(space_id)
_refresh_mcp_cache(new_connector.id, space_id)
logger.info(
"Created %s MCP connector %s for user %s in space %s",
@ -658,10 +658,17 @@ async def reauth_mcp_service(
# ---------------------------------------------------------------------------
def _invalidate_cache(space_id: int) -> None:
try:
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
def _refresh_mcp_cache(connector_id: int, space_id: int) -> None:
"""Evict the in-process MCP tool LRU and schedule background prefetch.
invalidate_mcp_tools_cache(space_id)
Wraps :func:`refresh_mcp_tools_cache_for_connector` so any failure is
isolated from the OAuth response flow.
"""
try:
from app.agents.new_chat.tools.mcp_tools_cache import (
refresh_mcp_tools_cache_for_connector,
)
refresh_mcp_tools_cache_for_connector(connector_id, space_id)
except Exception:
logger.debug("MCP cache invalidation skipped", exc_info=True)
logger.debug("MCP cache refresh skipped", exc_info=True)

View file

@ -1,75 +1,40 @@
"""Routes for user memory management (personal memory.md)."""
"""Routes for user memory management."""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException
from langchain_core.messages import HumanMessage
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.llm_config import (
create_chat_litellm_from_agent_config,
load_agent_llm_config_for_search_space,
)
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
from app.db import User, get_async_session
from app.services.memory import (
MemoryRead,
MemoryScope,
memory_limits,
read_memory,
reset_memory,
save_memory,
)
from app.users import current_active_user
from app.utils.content_utils import extract_text_content
logger = logging.getLogger(__name__)
router = APIRouter()
class MemoryRead(BaseModel):
memory_md: str
class MemoryUpdate(BaseModel):
memory_md: str
class MemoryEditRequest(BaseModel):
query: str
search_space_id: int
_MEMORY_EDIT_PROMPT = """\
You are a memory editor. The user wants to modify their memory document. \
Apply the user's instruction to the existing memory document and output the \
FULL updated document.
RULES:
1. If the instruction asks to add something, add it with format: \
- (YYYY-MM-DD) [fact|pref|instr] text, under an existing or new ## heading. \
Heading names should be personal and descriptive, not generic categories.
2. If the instruction asks to remove something, remove the matching entry.
3. If the instruction asks to change something, update the matching entry.
4. Preserve existing ## headings and all other entries.
5. Every bullet must include a marker: [fact], [pref], or [instr].
6. Use the user's first name (from <user_name>) in entries instead of "the user".
7. Output ONLY the updated markdown no explanations, no wrapping.
<user_name>{user_name}</user_name>
<current_memory>
{current_memory}
</current_memory>
<user_instruction>
{instruction}
</user_instruction>"""
@router.get("/users/me/memory", response_model=MemoryRead)
async def get_user_memory(
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
await session.refresh(user, ["memory_md"])
return MemoryRead(memory_md=user.memory_md or "")
memory_md = await read_memory(
scope=MemoryScope.USER,
target_id=user.id,
session=session,
)
return MemoryRead(memory_md=memory_md, limits=memory_limits())
@router.put("/users/me/memory", response_model=MemoryRead)
@ -78,73 +43,27 @@ async def update_user_memory(
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
if len(body.memory_md) > MEMORY_HARD_LIMIT:
raise HTTPException(
status_code=400,
detail=f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit ({len(body.memory_md):,} chars).",
)
user.memory_md = body.memory_md
session.add(user)
await session.commit()
await session.refresh(user, ["memory_md"])
return MemoryRead(memory_md=user.memory_md or "")
result = await save_memory(
scope=MemoryScope.USER,
target_id=user.id,
content=body.memory_md,
session=session,
)
if result.status == "error":
raise HTTPException(status_code=400, detail=result.message)
return MemoryRead(memory_md=result.memory_md, limits=memory_limits())
@router.post("/users/me/memory/edit", response_model=MemoryRead)
async def edit_user_memory(
body: MemoryEditRequest,
@router.post("/users/me/memory/reset", response_model=MemoryRead)
async def reset_user_memory(
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
"""Apply a natural language edit to the user's personal memory via LLM."""
agent_config = await load_agent_llm_config_for_search_space(
session, body.search_space_id
result = await reset_memory(
scope=MemoryScope.USER,
target_id=user.id,
session=session,
)
if not agent_config:
raise HTTPException(status_code=500, detail="No LLM configuration available.")
llm = create_chat_litellm_from_agent_config(agent_config)
if not llm:
raise HTTPException(status_code=500, detail="Failed to create LLM instance.")
await session.refresh(user, ["memory_md", "display_name"])
current_memory = user.memory_md or ""
first_name = (
user.display_name.strip().split()[0]
if user.display_name and user.display_name.strip()
else "The user"
)
prompt = _MEMORY_EDIT_PROMPT.format(
current_memory=current_memory or "(empty)",
instruction=body.query,
user_name=first_name,
)
try:
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-edit"]},
)
updated = extract_text_content(response.content).strip()
except Exception as e:
logger.exception("Memory edit LLM call failed: %s", e)
raise HTTPException(status_code=500, detail="Memory edit failed.") from e
if not updated:
raise HTTPException(status_code=400, detail="LLM returned empty result.")
result = await _save_memory(
updated_memory=updated,
old_memory=current_memory,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=session.commit,
rollback_fn=session.rollback,
label="memory",
scope="user",
)
if result.get("status") == "error":
raise HTTPException(status_code=400, detail=result["message"])
await session.refresh(user, ["memory_md"])
return MemoryRead(memory_md=user.memory_md or "")
if result.status == "error":
raise HTTPException(status_code=400, detail=result.message)
return MemoryRead(memory_md=result.memory_md, limits=memory_limits())

View file

@ -2650,9 +2650,11 @@ async def create_mcp_connector(
f"for user {user.id} in search space {search_space_id}"
)
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
from app.agents.new_chat.tools.mcp_tools_cache import (
refresh_mcp_tools_cache_for_connector,
)
invalidate_mcp_tools_cache(search_space_id)
refresh_mcp_tools_cache_for_connector(db_connector.id, search_space_id)
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
return MCPConnectorRead.from_connector(connector_read)
@ -2828,9 +2830,11 @@ async def update_mcp_connector(
logger.info(f"Updated MCP connector {connector_id}")
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
from app.agents.new_chat.tools.mcp_tools_cache import (
refresh_mcp_tools_cache_for_connector,
)
invalidate_mcp_tools_cache(connector.search_space_id)
refresh_mcp_tools_cache_for_connector(connector.id, connector.search_space_id)
connector_read = SearchSourceConnectorRead.model_validate(connector)
return MCPConnectorRead.from_connector(connector_read)

View file

@ -1,17 +1,10 @@
import logging
from fastapi import APIRouter, Depends, HTTPException
from langchain_core.messages import HumanMessage
from pydantic import BaseModel as PydanticBaseModel
from sqlalchemy import func, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.llm_config import (
create_chat_litellm_from_agent_config,
load_agent_llm_config_for_search_space,
)
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
from app.config import config
from app.db import (
ImageGenerationConfig,
@ -35,7 +28,6 @@ from app.schemas import (
SearchSpaceWithStats,
)
from app.users import current_active_user
from app.utils.content_utils import extract_text_content
from app.utils.rbac import check_permission, check_search_space_access
logger = logging.getLogger(__name__)
@ -43,34 +35,6 @@ logger = logging.getLogger(__name__)
router = APIRouter()
class _TeamMemoryEditRequest(PydanticBaseModel):
query: str
_TEAM_MEMORY_EDIT_PROMPT = """\
You are a memory editor for a team workspace. The user wants to modify the \
team's shared memory document. Apply the user's instruction to the existing \
memory document and output the FULL updated document.
RULES:
1. If the instruction asks to add something, add it with format: \
- (YYYY-MM-DD) [fact] text, under an existing or new ## heading. \
Heading names should be descriptive, not generic categories.
2. If the instruction asks to remove something, remove the matching entry.
3. If the instruction asks to change something, update the matching entry.
4. Preserve existing ## headings and all other entries.
5. NEVER use [pref] or [instr] markers. Team memory uses [fact] only.
6. Output ONLY the updated markdown no explanations, no wrapping.
<current_memory>
{current_memory}
</current_memory>
<user_instruction>
{instruction}
</user_instruction>"""
async def create_default_roles_and_membership(
session: AsyncSession,
search_space_id: int,
@ -294,15 +258,6 @@ async def update_search_space(
update_data = search_space_update.model_dump(exclude_unset=True)
if (
"shared_memory_md" in update_data
and len(update_data["shared_memory_md"] or "") > MEMORY_HARD_LIMIT
):
raise HTTPException(
status_code=400,
detail=f"Team memory exceeds {MEMORY_HARD_LIMIT:,} character limit.",
)
for key, value in update_data.items():
setattr(db_search_space, key, value)
await session.commit()
@ -317,72 +272,6 @@ async def update_search_space(
) from e
@router.post(
"/searchspaces/{search_space_id}/memory/edit",
response_model=SearchSpaceRead,
)
async def edit_team_memory(
search_space_id: int,
body: _TeamMemoryEditRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Apply a natural language edit to the team memory via LLM."""
await check_search_space_access(session, user, search_space_id)
agent_config = await load_agent_llm_config_for_search_space(
session, search_space_id
)
if not agent_config:
raise HTTPException(status_code=500, detail="No LLM configuration available.")
llm = create_chat_litellm_from_agent_config(agent_config)
if not llm:
raise HTTPException(status_code=500, detail="Failed to create LLM instance.")
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
db_search_space = result.scalars().first()
if not db_search_space:
raise HTTPException(status_code=404, detail="Search space not found")
current_memory = db_search_space.shared_memory_md or ""
prompt = _TEAM_MEMORY_EDIT_PROMPT.format(
current_memory=current_memory or "(empty)",
instruction=body.query,
)
try:
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-edit"]},
)
updated = extract_text_content(response.content).strip()
except Exception as e:
logger.exception("Team memory edit LLM call failed: %s", e)
raise HTTPException(status_code=500, detail="Team memory edit failed.") from e
if not updated:
raise HTTPException(status_code=400, detail="LLM returned empty result.")
save_result = await _save_memory(
updated_memory=updated,
old_memory=current_memory,
llm=llm,
apply_fn=lambda content: setattr(db_search_space, "shared_memory_md", content),
commit_fn=session.commit,
rollback_fn=session.rollback,
label="team memory",
scope="team",
)
if save_result.get("status") == "error":
raise HTTPException(status_code=400, detail=save_result["message"])
await session.refresh(db_search_space)
return db_search_space
@router.post("/searchspaces/{search_space_id}/ai-sort")
async def trigger_ai_sort(
search_space_id: int,

View file

@ -0,0 +1,76 @@
"""Routes for search-space team memory."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session
from app.services.memory import (
MemoryRead,
MemoryScope,
memory_limits,
read_memory,
reset_memory,
save_memory,
)
from app.users import current_active_user
from app.utils.rbac import check_search_space_access
router = APIRouter()
class TeamMemoryUpdate(BaseModel):
memory_md: str
@router.get("/searchspaces/{search_space_id}/memory", response_model=MemoryRead)
async def get_team_memory(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
await check_search_space_access(session, user, search_space_id)
memory_md = await read_memory(
scope=MemoryScope.TEAM,
target_id=search_space_id,
session=session,
)
return MemoryRead(memory_md=memory_md, limits=memory_limits())
@router.put("/searchspaces/{search_space_id}/memory", response_model=MemoryRead)
async def update_team_memory(
search_space_id: int,
body: TeamMemoryUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
await check_search_space_access(session, user, search_space_id)
result = await save_memory(
scope=MemoryScope.TEAM,
target_id=search_space_id,
content=body.memory_md,
session=session,
)
if result.status == "error":
raise HTTPException(status_code=400, detail=result.message)
return MemoryRead(memory_md=result.memory_md, limits=memory_limits())
@router.post("/searchspaces/{search_space_id}/memory/reset", response_model=MemoryRead)
async def reset_team_memory(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
await check_search_space_access(session, user, search_space_id)
result = await reset_memory(
scope=MemoryScope.TEAM,
target_id=search_space_id,
session=session,
)
if result.status == "error":
raise HTTPException(status_code=400, detail=result.message)
return MemoryRead(memory_md=result.memory_md, limits=memory_limits())

View file

@ -21,7 +21,6 @@ class SearchSpaceUpdate(BaseModel):
description: str | None = None
citations_enabled: bool | None = None
qna_custom_instructions: str | None = None
shared_memory_md: str | None = None
ai_file_sort_enabled: bool | None = None

View file

@ -1,3 +1,4 @@
import asyncio
import logging
from datetime import datetime
@ -100,7 +101,9 @@ class GmailKBSyncService:
else:
logger.warning("No LLM configured -- using fallback summary")
summary_content = f"Gmail Message: {subject}\n\n{indexable_content}"
summary_embedding = embed_text(summary_content)
summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(indexable_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View file

@ -116,7 +116,9 @@ class GoogleCalendarKBSyncService:
summary_content = (
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
)
summary_embedding = embed_text(summary_content)
summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(indexable_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -295,7 +297,9 @@ class GoogleCalendarKBSyncService:
summary_content = (
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
)
summary_embedding = embed_text(summary_content)
summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(indexable_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View file

@ -98,7 +98,9 @@ class JiraKBSyncService:
summary_content = (
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
)
summary_embedding = embed_text(summary_content)
summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(issue_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -212,7 +214,9 @@ class JiraKBSyncService:
summary_content = (
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
)
summary_embedding = embed_text(summary_content)
summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(issue_content)

View file

@ -659,3 +659,36 @@ async def get_user_long_context_llm(
return await get_document_summary_llm(
session, search_space_id, disable_streaming=disable_streaming
)
def get_planner_llm() -> ChatLiteLLM | None:
"""Return a planner LLM instance from the first global config marked
``is_planner: true``, or ``None`` if no planner config is defined.
The planner role handles short, structured internal tasks (KB search
planning: query rewriting, date extraction, recency classification).
These tasks are well-served by small/fast models (e.g. gpt-4o-mini,
Claude Haiku, Azure gpt-5.x-nano) using the user's chat LLM for them
is unnecessarily expensive and slow.
This helper reads from ``config.GLOBAL_LLM_CONFIGS`` (loaded at import
time from ``global_llm_config.yaml``) so it has no DB cost and can be
called synchronously from middleware/factory code. It returns the same
instance shape as the global path of ``get_search_space_llm_instance``.
Callers MUST fall back to their chat LLM when this returns ``None`` so
deployments without a planner config keep working unchanged.
"""
from app.agents.new_chat.llm_config import create_chat_litellm_from_config
planner_cfg = next(
(
cfg
for cfg in config.GLOBAL_LLM_CONFIGS
if cfg.get("is_planner") is True
),
None,
)
if not planner_cfg:
return None
return create_chat_litellm_from_config(planner_cfg)

View file

@ -0,0 +1,32 @@
"""First-class memory service for user and team markdown memory."""
from .schemas import MemoryLimits, MemoryRead
from .service import (
MemoryScope,
SaveResult,
memory_limits,
read_memory,
reset_memory,
save_memory,
)
from .validation import (
MEMORY_HARD_LIMIT,
MEMORY_SOFT_LIMIT,
validate_bullet_format,
validate_memory_scope,
)
__all__ = [
"MEMORY_HARD_LIMIT",
"MEMORY_SOFT_LIMIT",
"MemoryLimits",
"MemoryRead",
"MemoryScope",
"SaveResult",
"memory_limits",
"read_memory",
"reset_memory",
"save_memory",
"validate_bullet_format",
"validate_memory_scope",
]

View file

@ -0,0 +1,200 @@
"""Memory-specific markdown document model and canonical renderer.
This intentionally parses only SurfSense memory's small markdown contract:
``##`` sections with dated bullet items. Unknown lines are preserved so user
edits are not lost, while legacy marker bullets are normalized on render.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import date
DEFAULT_LEGACY_SECTION = "Memory"
LEGACY_MARKERS = frozenset({"fact", "pref", "instr"})
@dataclass(frozen=True)
class MemoryBullet:
entry_date: date
text: str
@dataclass(frozen=True)
class MemoryRawLine:
text: str
MemoryLine = MemoryBullet | MemoryRawLine
@dataclass(frozen=True)
class MemorySection:
heading: str
lines: list[MemoryLine] = field(default_factory=list)
explicit_heading: bool = True
@dataclass(frozen=True)
class MemoryDocument:
sections: list[MemorySection] = field(default_factory=list)
@property
def has_explicit_heading(self) -> bool:
return any(section.explicit_heading for section in self.sections)
def is_section_heading(line: str) -> bool:
return line.startswith("## ") and bool(line[3:].strip())
def heading_text(line: str) -> str:
return line[3:].strip()
def normalize_heading(heading: str) -> str:
chars: list[str] = []
previous_was_space = True
for char in heading.strip().lower():
if char.isalnum():
chars.append(char)
previous_was_space = False
elif not previous_was_space:
chars.append(" ")
previous_was_space = True
return "".join(chars).strip()
def parse_bullet_line(line: str) -> MemoryBullet | None:
stripped = line.strip()
if not stripped.startswith("- "):
return None
body = stripped[2:]
parsed = _parse_canonical_bullet(body)
if parsed is not None:
return parsed
return _parse_legacy_bullet(body)
def _parse_canonical_bullet(body: str) -> MemoryBullet | None:
if len(body) < 13 or body[10:12] != ": ":
return None
try:
entry_date = date.fromisoformat(body[:10])
except ValueError:
return None
text = body[12:].strip()
if not text:
return None
return MemoryBullet(entry_date=entry_date, text=text)
def _parse_legacy_bullet(body: str) -> MemoryBullet | None:
if len(body) < 20 or not body.startswith("("):
return None
if len(body) < 14 or body[11:14] != ") [":
return None
try:
entry_date = date.fromisoformat(body[1:11])
except ValueError:
return None
marker_end = body.find("] ", 14)
if marker_end == -1:
return None
marker = body[14:marker_end]
if marker not in LEGACY_MARKERS:
return None
text = body[marker_end + 2 :].strip()
if not text:
return None
return MemoryBullet(entry_date=entry_date, text=text)
def parse_memory_document(content: str | None) -> MemoryDocument:
if not content:
return MemoryDocument()
sections: list[MemorySection] = []
current_heading: str | None = None
current_explicit = True
current_lines: list[MemoryLine] = []
def flush_current() -> None:
nonlocal current_heading, current_explicit, current_lines
if current_heading is None:
return
sections.append(
MemorySection(
heading=current_heading,
lines=current_lines,
explicit_heading=current_explicit,
)
)
current_heading = None
current_explicit = True
current_lines = []
for raw_line in content.strip().splitlines():
line = raw_line.rstrip()
if is_section_heading(line):
flush_current()
current_heading = heading_text(line)
current_explicit = True
current_lines = []
continue
bullet = parse_bullet_line(line)
if current_heading is None:
if bullet is None:
continue
current_heading = DEFAULT_LEGACY_SECTION
current_explicit = False
current_lines = [bullet]
continue
current_lines.append(bullet if bullet is not None else MemoryRawLine(text=line))
flush_current()
return MemoryDocument(sections=sections)
def render_memory_document(document: MemoryDocument) -> str:
rendered_sections: list[str] = []
for section in document.sections:
section_lines = [f"## {section.heading}"]
for line in section.lines:
if isinstance(line, MemoryBullet):
section_lines.append(f"- {line.entry_date.isoformat()}: {line.text}")
else:
section_lines.append(line.text)
rendered_sections.append("\n".join(section_lines).strip())
return "\n\n".join(section for section in rendered_sections if section).strip()
def extract_headings(memory: str | None) -> set[str]:
document = parse_memory_document(memory)
return {
normalize_heading(section.heading)
for section in document.sections
if section.explicit_heading
}
def has_explicit_heading(content: str) -> bool:
return parse_memory_document(content).has_explicit_heading
def nonstandard_bullets(content: str) -> list[str]:
warnings: list[str] = []
for line in content.splitlines():
stripped = line.strip()
if not stripped.startswith("- "):
continue
if parse_bullet_line(stripped) is not None:
continue
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
warnings.append(f"Non-standard memory bullet: {short}")
return warnings

View file

@ -0,0 +1,20 @@
"""Prompts used by the memory service."""
FORCED_REWRITE_PROMPT = """\
You are a memory curator. The following memory document exceeds the character \
limit and must be shortened.
RULES:
1. Rewrite the document to be under {target} characters.
2. Output Markdown only. Use clear `##` headings and concise bullet points.
3. New-format bullets should look like: `- YYYY-MM-DD: memory text`.
4. If the input contains legacy markers like `(YYYY-MM-DD) [fact]`, preserve the
information but remove the inline marker in the output.
5. Preserve durable instructions and preferences before generic facts when
compressing personal memory.
6. Preserve existing headings when useful; merge duplicate headings and bullets.
7. Output ONLY the consolidated markdown no explanations, no wrapping.
<memory_document>
{content}
</memory_document>"""

View file

@ -0,0 +1,35 @@
"""LLM-backed memory rewrite helpers."""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.messages import HumanMessage
from app.services.memory.prompts import FORCED_REWRITE_PROMPT
from app.services.memory.validation import MEMORY_HARD_LIMIT
from app.utils.content_utils import extract_text_content
logger = logging.getLogger(__name__)
async def forced_rewrite(content: str, llm: Any) -> str | None:
"""Use a focused LLM call to compress memory under the hard limit."""
try:
prompt = FORCED_REWRITE_PROMPT.format(
target=MEMORY_HARD_LIMIT,
content=content,
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-rewrite"]},
)
text = extract_text_content(response.content).strip()
if not text:
logger.warning("Forced memory rewrite returned empty text")
return None
return text
except Exception:
logger.exception("Forced memory rewrite LLM call failed")
return None

View file

@ -0,0 +1,19 @@
"""Schemas for memory API responses and structured extraction."""
from __future__ import annotations
from pydantic import BaseModel
class MemoryLimits(BaseModel):
"""Canonical memory size limits exposed to clients."""
soft: int
hard: int
class MemoryRead(BaseModel):
"""Memory document payload returned by user and team memory APIs."""
memory_md: str
limits: MemoryLimits

View file

@ -0,0 +1,247 @@
"""Canonical read/write/reset/extract service for markdown memory."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any, Literal
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User
from app.services.memory.document import parse_memory_document, render_memory_document
from app.services.memory.rewrite import forced_rewrite
from app.services.memory.schemas import MemoryLimits
from app.services.memory.validation import (
MEMORY_HARD_LIMIT,
MEMORY_SOFT_LIMIT,
soft_limit_warning,
strip_preamble_to_first_heading,
validate_bullet_format,
validate_diff,
validate_heading_sanity,
validate_memory_scope,
validate_memory_size,
)
logger = logging.getLogger(__name__)
_NO_UPDATE_SENTINELS = frozenset(
{
"NO_UPDATE",
"NO UPDATE",
"NO_CHANGE",
"NO CHANGE",
}
)
class MemoryScope(StrEnum):
USER = "user"
TEAM = "team"
@dataclass(frozen=True)
class SaveResult:
status: Literal["saved", "error", "no_op"]
message: str
memory_md: str = ""
warnings: list[str] = field(default_factory=list)
diff_warnings: list[str] = field(default_factory=list)
format_warnings: list[str] = field(default_factory=list)
notice: str | None = None
def to_dict(self) -> dict[str, Any]:
data: dict[str, Any] = {
"status": self.status,
"message": self.message,
"memory_md": self.memory_md,
}
if self.notice:
data["notice"] = self.notice
if self.warnings:
data["warnings"] = self.warnings
if len(self.warnings) == 1:
data["warning"] = self.warnings[0]
if self.diff_warnings:
data["diff_warnings"] = self.diff_warnings
if self.format_warnings:
data["format_warnings"] = self.format_warnings
return data
def memory_limits() -> MemoryLimits:
return MemoryLimits(soft=MEMORY_SOFT_LIMIT, hard=MEMORY_HARD_LIMIT)
def _normalize_scope(scope: MemoryScope | str) -> MemoryScope:
return scope if isinstance(scope, MemoryScope) else MemoryScope(scope)
def _normalize_user_id(target_id: str | UUID) -> UUID:
return UUID(target_id) if isinstance(target_id, str) else target_id
async def _load_target(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
session: AsyncSession,
) -> User | SearchSpace | None:
normalized = _normalize_scope(scope)
if normalized is MemoryScope.USER:
result = await session.execute(
select(User).where(User.id == _normalize_user_id(target_id)) # type: ignore[arg-type]
)
return result.scalars().first()
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == int(target_id))
)
return result.scalars().first()
def _get_memory(target: User | SearchSpace, scope: MemoryScope) -> str:
if scope is MemoryScope.USER:
return getattr(target, "memory_md", None) or ""
return getattr(target, "shared_memory_md", None) or ""
def _set_memory(target: User | SearchSpace, scope: MemoryScope, content: str) -> None:
if scope is MemoryScope.USER:
target.memory_md = content
else:
target.shared_memory_md = content
async def read_memory(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
session: AsyncSession,
) -> str:
normalized = _normalize_scope(scope)
target = await _load_target(scope=normalized, target_id=target_id, session=session)
if target is None:
return ""
return _get_memory(target, normalized)
async def save_memory(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
content: str,
session: AsyncSession,
llm: Any | None = None,
) -> SaveResult:
normalized = _normalize_scope(scope)
if not isinstance(content, str):
return SaveResult(
status="error",
message="Internal error: memory payload must be a string.",
)
target = await _load_target(scope=normalized, target_id=target_id, session=session)
if target is None:
return SaveResult(
status="error",
message="User not found."
if normalized is MemoryScope.USER
else "Search space not found.",
)
old_memory = _get_memory(target, normalized)
next_content = strip_preamble_to_first_heading(content.strip())
notice: str | None = None
warnings: list[str] = []
if next_content.upper() in _NO_UPDATE_SENTINELS:
return SaveResult(
status="no_op",
message="No memory update requested.",
memory_md=old_memory,
)
if len(next_content) > MEMORY_HARD_LIMIT and llm is not None:
rewritten = await forced_rewrite(next_content, llm)
if rewritten is not None and len(rewritten) < len(next_content):
next_content = strip_preamble_to_first_heading(rewritten)
notice = "Memory was automatically rewritten to fit within limits."
for validation in (
validate_memory_size(next_content),
validate_heading_sanity(next_content),
):
if validation:
return SaveResult(
status="error",
message=validation["message"],
memory_md=old_memory,
)
scope_error, scope_warnings = validate_memory_scope(
next_content,
normalized.value,
old_memory=old_memory,
)
warnings.extend(scope_warnings)
if scope_error:
return SaveResult(
status="error",
message=scope_error["message"],
memory_md=old_memory,
warnings=warnings,
)
next_content = render_memory_document(parse_memory_document(next_content))
try:
_set_memory(target, normalized, next_content)
session.add(target)
await session.commit()
except Exception as e:
logger.exception("Failed to update %s memory: %s", normalized.value, e)
await session.rollback()
return SaveResult(
status="error",
message=f"Failed to update {normalized.value} memory: {e}",
memory_md=old_memory,
)
diff_warnings = validate_diff(old_memory, next_content)
format_warnings = validate_bullet_format(next_content)
warning = soft_limit_warning(next_content)
if warning:
warnings.append(warning)
return SaveResult(
status="saved",
message=(
"Memory updated."
if normalized is MemoryScope.USER
else "Team memory updated."
),
memory_md=next_content,
warnings=warnings,
diff_warnings=diff_warnings,
format_warnings=format_warnings,
notice=notice,
)
async def reset_memory(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
session: AsyncSession,
) -> SaveResult:
return await save_memory(
scope=scope,
target_id=target_id,
content="",
session=session,
llm=None,
)

View file

@ -0,0 +1,140 @@
"""Validation helpers for markdown-backed memory."""
from __future__ import annotations
from typing import Literal
from app.services.memory.document import (
extract_headings,
has_explicit_heading,
nonstandard_bullets,
parse_memory_document,
)
MEMORY_SOFT_LIMIT = 18_000
MEMORY_HARD_LIMIT = 25_000
_FORBIDDEN_TEAM_HEADINGS = {
"preferences",
"instructions",
"personal notes",
"personal instructions",
}
def has_markdown_heading(content: str) -> bool:
return has_explicit_heading(content)
def strip_preamble_to_first_heading(content: str) -> str:
"""Drop model preamble before the first ``##`` heading, if one exists."""
lines = content.splitlines()
for index, line in enumerate(lines):
if line.startswith("## ") and line[3:].strip():
return "\n".join(lines[index:]).strip()
return content.strip()
def validate_memory_size(content: str) -> dict[str, str] | None:
length = len(content)
if length > MEMORY_HARD_LIMIT:
return {
"status": "error",
"message": (
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
f"({length:,} chars). Consolidate by merging related items, "
"removing outdated entries, and shortening descriptions."
),
}
return None
def validate_heading_sanity(content: str) -> dict[str, str] | None:
"""Block long prose blobs without headings unless they are legacy bullets."""
stripped = content.strip()
if not stripped:
return None
if has_markdown_heading(stripped):
return None
if len(stripped) <= 40:
return None
if parse_memory_document(stripped).sections:
return None
return {
"status": "error",
"message": "Memory must be markdown with at least one ## heading.",
}
def validate_memory_scope(
content: str,
scope: Literal["user", "team"],
*,
old_memory: str | None = None,
) -> tuple[dict[str, str] | None, list[str]]:
"""Reject new personal headings in team memory, grandfather existing ones."""
if scope != "team":
return None, []
old_forbidden = extract_headings(old_memory) & _FORBIDDEN_TEAM_HEADINGS
new_forbidden = extract_headings(content) & _FORBIDDEN_TEAM_HEADINGS
introduced = sorted(new_forbidden - old_forbidden)
grandfathered = sorted(new_forbidden & old_forbidden)
warnings: list[str] = []
if grandfathered:
warnings.append(
"Team memory contains legacy personal headings: "
+ ", ".join(grandfathered)
+ ". Please consolidate them into team-safe headings."
)
if introduced:
return (
{
"status": "error",
"message": (
"Team memory cannot introduce personal headings: "
+ ", ".join(introduced)
+ ". Use team-safe headings instead."
),
},
warnings,
)
return None, warnings
def validate_bullet_format(content: str) -> list[str]:
return nonstandard_bullets(content)
def validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
if not old_memory:
return []
warnings: list[str] = []
old_headings = extract_headings(old_memory)
new_headings = extract_headings(new_memory)
dropped = old_headings - new_headings
if dropped:
names = ", ".join(sorted(dropped))
warnings.append(
f"Sections removed: {names}. If unintentional, restore them from the memory document."
)
old_len = len(old_memory)
new_len = len(new_memory)
if old_len > 0 and new_len < old_len * 0.4:
warnings.append(
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). Possible data loss."
)
return warnings
def soft_limit_warning(content: str) -> str | None:
length = len(content)
if length > MEMORY_SOFT_LIMIT:
return (
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
"Consolidate by merging related items and removing less important entries."
)
return None

View file

@ -1,3 +1,4 @@
import asyncio
import logging
from datetime import datetime
@ -95,7 +96,9 @@ class OneDriveKBSyncService:
else:
logger.warning("No LLM configured — using fallback summary")
summary_content = f"OneDrive File: {file_name}\n\n{indexable_content}"
summary_embedding = embed_text(summary_content)
summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(indexable_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View file

@ -29,6 +29,7 @@ same trap waiting to happen).
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
@ -234,7 +235,7 @@ async def _restore_in_place_document(
if isinstance(c, dict) and isinstance(c.get("content"), str)
]
if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts)
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
session.add_all(
[
Chunk(document_id=doc.id, content=text, embedding=embedding)
@ -244,7 +245,9 @@ async def _restore_in_place_document(
]
)
if isinstance(revision.content_before, str):
doc.embedding = embed_texts([revision.content_before])[0]
doc.embedding = (
await asyncio.to_thread(embed_texts, [revision.content_before])
)[0]
doc.updated_at = datetime.now(UTC)
return RevertOutcome(status="ok", message="Document restored from snapshot.")
@ -320,7 +323,7 @@ async def _reinsert_document_from_revision(
session.add(new_doc)
await session.flush()
new_doc.embedding = embed_texts([content])[0]
new_doc.embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
chunk_texts = []
chunks_before = revision.chunks_before
if isinstance(chunks_before, list):
@ -330,7 +333,7 @@ async def _reinsert_document_from_revision(
if isinstance(c, dict) and isinstance(c.get("content"), str)
]
if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts)
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
session.add_all(
[
Chunk(document_id=new_doc.id, content=text, embedding=embedding)

View file

@ -325,6 +325,24 @@ class TokenTrackingCallback(CustomLogger):
total_tokens = getattr(usage, "total_tokens", 0) or 0
call_kind = "chat"
# Prompt-cache accounting. LiteLLM normalizes every provider's cache
# fields onto ``usage.prompt_tokens_details``:
# - ``cached_tokens`` — cache reads (OpenAI/Azure native, DeepSeek
# mapped from ``prompt_cache_hit_tokens``,
# Anthropic mapped from ``cache_read_input_tokens``).
# - ``cache_creation_tokens`` — cache writes (Anthropic only; OpenAI/Azure
# do not expose a write count).
# See ``litellm.types.utils.Usage.__init__`` for the mapping.
cached_tokens = 0
cache_creation_tokens = 0
if not is_image:
prompt_details = getattr(usage, "prompt_tokens_details", None)
if prompt_details is not None:
cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0
cache_creation_tokens = (
getattr(prompt_details, "cache_creation_tokens", 0) or 0
)
model = kwargs.get("model", "unknown")
cost_usd = _extract_cost_usd(
@ -357,9 +375,23 @@ class TokenTrackingCallback(CustomLogger):
cost_micros=cost_micros,
call_kind=call_kind,
)
# Per-LLM-call wall-clock latency (LiteLLM passes datetime objects).
call_latency_s: float | None = None
try:
if start_time is not None and end_time is not None:
delta = end_time - start_time
call_latency_s = getattr(delta, "total_seconds", lambda: float(delta))()
except Exception:
call_latency_s = None
cache_hit_ratio: float | None = None
if prompt_tokens > 0 and (cached_tokens > 0 or cache_creation_tokens > 0):
cache_hit_ratio = cached_tokens / prompt_tokens
logger.info(
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
"cost=$%.6f (%d micros) (accumulator now has %d calls)%s%s",
model,
call_kind,
prompt_tokens,
@ -368,6 +400,17 @@ class TokenTrackingCallback(CustomLogger):
cost_usd,
cost_micros,
len(acc.calls),
f" latency={call_latency_s:.3f}s" if call_latency_s is not None else "",
(
f" cache_read={cached_tokens} cache_write={cache_creation_tokens}"
f" hit_ratio={cache_hit_ratio:.1%}"
if cache_hit_ratio is not None
else (
f" cache_read={cached_tokens} cache_write={cache_creation_tokens}"
if (cached_tokens or cache_creation_tokens)
else ""
)
),
)

View file

@ -39,10 +39,6 @@ from app.agents.new_chat.llm_config import (
load_agent_config,
load_global_llm_config_by_id,
)
from app.agents.new_chat.memory_extraction import (
extract_and_save_memory,
extract_and_save_team_memory,
)
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
from app.agents.new_chat.middleware.busy_mutex import (
end_turn,
@ -64,8 +60,6 @@ from app.db import (
)
from app.prompts import TITLE_GENERATION_PROMPT
from app.services.auto_model_pin_service import (
is_recently_healthy,
mark_healthy,
mark_runtime_cooldown,
resolve_or_get_pinned_llm_config_id,
)
@ -283,7 +277,6 @@ class StreamResult:
accumulated_text: str = ""
is_interrupted: bool = False
sandbox_files: list[str] = field(default_factory=list)
agent_called_update_memory: bool = False
request_id: str | None = None
turn_id: str = ""
filesystem_mode: str = "cloud"
@ -506,54 +499,6 @@ def _is_provider_rate_limited(exc: BaseException) -> bool:
)
_PREFLIGHT_TIMEOUT_SEC: float = 2.5
_PREFLIGHT_MAX_TOKENS: int = 1
async def _preflight_llm(llm: Any) -> None:
"""Issue a minimal completion to confirm the pinned model isn't 429'ing.
Used before agent build / planner / classifier / title-gen so a known-bad
free OpenRouter deployment is detected and repinned before it cascades
into multiple wasted internal calls. The probe is intentionally cheap:
one token, low timeout, tagged ``surfsense:internal`` so token tracking
and SSE pipelines treat it as overhead rather than user output.
Raises the original exception when the provider responds with a
rate-limit-shaped error so the caller can drive the cooldown/repin
branch via :func:`_is_provider_rate_limited`. Other transient failures
are swallowed the caller continues to the normal stream path and the
in-stream recovery loop remains the safety net.
"""
from litellm import acompletion
model = getattr(llm, "model", None)
if not model or model == "auto":
# Auto-mode router doesn't have a single deployment to ping; the
# router itself handles per-deployment rate-limit accounting.
return
try:
await acompletion(
model=model,
messages=[{"role": "user", "content": "ping"}],
api_key=getattr(llm, "api_key", None),
api_base=getattr(llm, "api_base", None),
max_tokens=_PREFLIGHT_MAX_TOKENS,
timeout=_PREFLIGHT_TIMEOUT_SEC,
stream=False,
metadata={"tags": ["surfsense:internal", "auto-pin-preflight"]},
)
except Exception as exc:
if _is_provider_rate_limited(exc):
raise
logging.getLogger(__name__).debug(
"auto_pin_preflight non_rate_limit_error model=%s err=%s",
model,
exc,
)
async def _build_main_agent_for_thread(
agent_factory: Any,
*,
@ -571,9 +516,9 @@ async def _build_main_agent_for_thread(
disabled_tools: list[str] | None = None,
mentioned_document_ids: list[int] | None = None,
) -> Any:
"""Single (re)build path so the agent factory cannot drift across
initial build, preflight repin, and mid-stream 429 recovery for one
``thread_id``: a graph swap mid-turn would corrupt checkpointer state."""
"""Single (re)build path so the agent factory cannot drift across the
initial build and mid-stream 429 recovery for one ``thread_id``: a
graph swap mid-turn would corrupt checkpointer state."""
return await agent_factory(
llm=llm,
search_space_id=search_space_id,
@ -591,29 +536,6 @@ async def _build_main_agent_for_thread(
)
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
"""Wait for a discarded speculative agent build to release shared state.
Used by the parallel preflight + agent-build path. The speculative build
closes over the request-scoped ``AsyncSession`` (for the brief connector
discovery / tool-factory window before its CPU work moves into a worker
thread). If preflight reports a 429 we want to fall back to the original
repin reload rebuild path, but we MUST NOT touch ``session`` again
until any in-flight session work owned by the speculative build has
fully settled :class:`sqlalchemy.ext.asyncio.AsyncSession` is not
concurrency-safe and the same hazard cost us a hard ``InvalidRequestError``
earlier in this PR (see ``connector_service`` parallel-gather revert).
We simply ``await`` the task and swallow any exception: in this path the
build's outcome is irrelevant — success populates the agent cache (a free
side effect), failure is discarded. The wasted CPU is acceptable since
429 fallbacks are rare and the original sequential code also paid the
full build cost on the same path.
"""
with contextlib.suppress(BaseException):
await task
def _classify_stream_exception(
exc: Exception,
*,
@ -1241,39 +1163,6 @@ async def stream_new_chat(
yield streaming_service.format_done()
return
# Auto-mode preflight ping. Runs ONLY for thread-pinned auto cfgs
# (negative ids selected via ``resolve_or_get_pinned_llm_config_id``)
# whose health hasn't already been confirmed within the TTL window.
# Detecting a 429 here lets us repin BEFORE the planner/classifier/
# title-generation LLM calls fan out and each independently hit the
# same upstream rate limit.
#
# PERF: preflight is a network round-trip to the LLM provider (~1-5s)
# and is independent of the agent build (CPU-bound, ~5-7s). They used
# to run sequentially → ``preflight + build`` on cold cache = 11.5s.
# We now kick off preflight as a background task FIRST, then run the
# synchronous setup work and the agent build in parallel. In the
# success path (the common case) total wall time drops to roughly
# ``max(preflight, build)`` — the preflight finishes during the
# agent compile and we just consume its result. In the rare 429
# path the speculative build is awaited to completion (so its
# session usage is fully released) via
# :func:`_settle_speculative_agent_build`, then discarded, and
# we fall back to the original repin-and-rebuild flow.
preflight_needed = (
requested_llm_config_id == 0
and llm_config_id < 0
and not is_recently_healthy(llm_config_id)
)
preflight_task: asyncio.Task[None] | None = None
_t_preflight = 0.0
if preflight_needed:
_t_preflight = time.perf_counter()
preflight_task = asyncio.create_task(
_preflight_llm(llm),
name=f"auto_pin_preflight:{llm_config_id}",
)
# Create connector service
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
@ -1307,136 +1196,26 @@ async def stream_new_chat(
if use_multi_agent
else create_surfsense_deep_agent
)
# Speculative agent build — runs in parallel with the preflight
# task (if any). Built with the *current* ``llm`` / ``agent_config``;
# if preflight reports 429 we will discard this future and rebuild
# against the freshly pinned config below.
agent_build_task = asyncio.create_task(
_build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
),
name="agent_build:stream_new_chat",
# Build the agent inline. Provider 429s surface through the
# in-stream recovery loop below (``_is_provider_rate_limited``),
# which repins the thread to an eligible alternative config and
# rebuilds the agent before the user sees any output.
agent = await _build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
)
agent: Any = None
if preflight_task is not None:
try:
await preflight_task
mark_healthy(llm_config_id)
_perf_log.info(
"[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
llm_config_id,
time.perf_counter() - _t_preflight,
)
except Exception as preflight_exc:
# Both branches below need the session: the non-429 path
# may unwind via cleanup that uses ``session``, and the
# 429 path explicitly calls ``resolve_or_get_pinned_llm_config_id``
# against it. Wait for the speculative build to release its
# session usage before we proceed.
await _settle_speculative_agent_build(agent_build_task)
if not _is_provider_rate_limited(preflight_exc):
raise
# 429: speculative agent is discarded; run the original
# repin → reload → rebuild path against the freshly
# pinned config.
previous_config_id = llm_config_id
mark_runtime_cooldown(
previous_config_id, reason="preflight_rate_limited"
)
try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
requires_image_input=_requires_image_input,
)
).resolved_llm_config_id
except ValueError as pin_error:
yield _emit_stream_error(
message=str(pin_error),
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
llm, agent_config, llm_load_error = await _load_llm_bundle(
llm_config_id
)
if llm_load_error or not llm:
yield _emit_stream_error(
message=llm_load_error or "Failed to create LLM instance",
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
# Trust the freshly-resolved cfg for the remainder of this
# turn rather than recursing into another preflight; the
# in-stream 429 recovery loop is still in place as the
# safety net if even this fallback hits an upstream cap.
mark_healthy(llm_config_id)
_log_chat_stream_error(
flow=flow,
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model failed preflight; switched to another "
"eligible model and continuing."
),
extra={
"auto_runtime_recover": True,
"preflight": True,
"previous_config_id": previous_config_id,
"fallback_config_id": llm_config_id,
},
)
# Rebuild against the new llm/agent_config. Sequential
# here because we no longer have anything to overlap with.
agent = await agent_factory(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
)
if agent is None:
# Either no preflight was needed, or preflight succeeded —
# in both cases the speculative build is the agent we want.
agent = await agent_build_task
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
)
@ -2208,36 +1987,6 @@ async def stream_new_chat(
},
)
# Fire background memory extraction if the agent didn't handle it.
# Shared threads write to team memory; private threads write to user memory.
if not stream_result.agent_called_update_memory:
memory_seed = user_query.strip() or (
f"[{len(user_image_data_urls or [])} image(s)]"
if user_image_data_urls
else "(message)"
)
if visibility == ChatVisibility.SEARCH_SPACE:
task = asyncio.create_task(
extract_and_save_team_memory(
user_message=memory_seed,
search_space_id=search_space_id,
llm=llm,
author_display_name=current_user_display_name,
)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
elif user_id:
task = asyncio.create_task(
extract_and_save_memory(
user_message=memory_seed,
user_id=user_id,
llm=llm,
)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Finish the step and message
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
@ -2682,25 +2431,6 @@ async def stream_resume_chat(
yield streaming_service.format_done()
return
# Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``:
# one cheap probe before the agent is rebuilt so a 429'd pin gets
# repinned without burning planner/classifier/title calls first.
# See ``stream_new_chat`` for the full rationale on the speculative
# parallel build pattern below.
preflight_needed = (
requested_llm_config_id == 0
and llm_config_id < 0
and not is_recently_healthy(llm_config_id)
)
preflight_task: asyncio.Task[None] | None = None
_t_preflight = 0.0
if preflight_needed:
_t_preflight = time.perf_counter()
preflight_task = asyncio.create_task(
_preflight_llm(llm),
name=f"auto_pin_preflight_resume:{llm_config_id}",
)
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
@ -2730,115 +2460,25 @@ async def stream_resume_chat(
if _app_config.MULTI_AGENT_CHAT_ENABLED
else create_surfsense_deep_agent
)
agent_build_task = asyncio.create_task(
_build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
),
name="agent_build:stream_resume",
# Build the agent inline. Provider 429s are handled by the
# in-stream recovery loop, which repins to an eligible
# alternative config and rebuilds the agent before the user sees
# any output.
agent = await _build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
)
agent: Any = None
if preflight_task is not None:
try:
await preflight_task
mark_healthy(llm_config_id)
_perf_log.info(
"[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
llm_config_id,
time.perf_counter() - _t_preflight,
)
except Exception as preflight_exc:
# Same session-safety rationale as ``stream_new_chat``.
await _settle_speculative_agent_build(agent_build_task)
if not _is_provider_rate_limited(preflight_exc):
raise
previous_config_id = llm_config_id
mark_runtime_cooldown(
previous_config_id, reason="preflight_rate_limited"
)
try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
)
).resolved_llm_config_id
except ValueError as pin_error:
yield _emit_stream_error(
message=str(pin_error),
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
llm, agent_config, llm_load_error = await _load_llm_bundle(
llm_config_id
)
if llm_load_error or not llm:
yield _emit_stream_error(
message=llm_load_error or "Failed to create LLM instance",
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
mark_healthy(llm_config_id)
_log_chat_stream_error(
flow="resume",
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model failed preflight; switched to another "
"eligible model and continuing."
),
extra={
"auto_runtime_recover": True,
"preflight": True,
"previous_config_id": previous_config_id,
"fallback_config_id": llm_config_id,
},
)
agent = await _build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
)
if agent is None:
agent = await agent_build_task
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
)

View file

@ -48,4 +48,3 @@ async def stream_output(
yield frame
result.accumulated_text = state.accumulated_text
result.agent_called_update_memory = state.called_update_memory

View file

@ -11,7 +11,6 @@ class StreamingResult:
accumulated_text: str = ""
is_interrupted: bool = False
sandbox_files: list[str] = field(default_factory=list)
agent_called_update_memory: bool = False
request_id: str | None = None
turn_id: str = ""
filesystem_mode: str = "cloud"

View file

@ -36,9 +36,6 @@ def iter_tool_end_frames(
raw_output = event.get("data", {}).get("output", "")
staged_file_path = state.file_path_by_run.pop(run_id, None) if run_id else None
if tool_name == "update_memory":
state.called_update_memory = True
if hasattr(raw_output, "content"):
content = raw_output.content
if isinstance(content, str):

View file

@ -32,7 +32,6 @@ class AgentEventRelayState:
last_active_step_items: list[str] = field(default_factory=list)
just_finished_tool: bool = False
active_tool_depth: int = 0
called_update_memory: bool = False
current_reasoning_id: str | None = None
pending_tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
lc_tool_call_id_by_run: dict[str, str] = field(default_factory=dict)

View file

@ -670,7 +670,9 @@ async def index_discord_messages(
# Heavy processing (embeddings, chunks)
chunks = await create_document_chunks(item["combined_document_string"])
doc_embedding = embed_text(item["combined_document_string"])
doc_embedding = await asyncio.to_thread(
embed_text, item["combined_document_string"]
)
# Update document to READY with actual content
document.title = f"{item['guild_name']}#{item['channel_name']}"

View file

@ -6,6 +6,7 @@ Implements 2-phase document status updates for real-time UI feedback:
- Phase 2: Process each event: pending processing ready/failed
"""
import asyncio
import time
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
@ -465,7 +466,9 @@ async def index_luma_events(
summary_content = (
f"Luma Event: {item['event_name']}\n\n{item['event_markdown']}"
)
summary_embedding = embed_text(summary_content)
summary_embedding = await asyncio.to_thread(
embed_text, summary_content
)
chunks = await create_document_chunks(item["event_markdown"])

View file

@ -9,6 +9,7 @@ Uses 2-phase document status updates for real-time UI feedback:
- Phase 2: Process each document: pending processing ready/failed
"""
import asyncio
import time
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
@ -581,7 +582,9 @@ async def index_teams_messages(
# Heavy processing (embeddings, chunks)
chunks = await create_document_chunks(item["combined_document_string"])
doc_embedding = embed_text(item["combined_document_string"])
doc_embedding = await asyncio.to_thread(
embed_text, item["combined_document_string"]
)
# Update document to READY with actual content
document.title = f"{item['team_name']} - {item['channel_name']}"

View file

@ -2,6 +2,7 @@
Unified document save/update logic for file processors.
"""
import asyncio
import logging
from sqlalchemy.exc import SQLAlchemyError
@ -43,7 +44,7 @@ async def _generate_summary(
"""
if not enable_summary:
summary = f"File: {file_name}\n\n{markdown_content[:4000]}"
return summary, embed_text(summary)
return summary, await asyncio.to_thread(embed_text, summary)
if etl_service == "DOCLING":
from app.services.docling_service import create_docling_service
@ -65,7 +66,7 @@ async def _generate_summary(
parts.append(f"**{formatted_key}:** {value}")
enhanced = "\n".join(parts) + "\n\n# DOCUMENT SUMMARY\n\n" + summary_text
return enhanced, embed_text(enhanced)
return enhanced, await asyncio.to_thread(embed_text, enhanced)
# Standard summary (Unstructured / LlamaCloud / others)
meta = {

View file

@ -1,3 +1,4 @@
import asyncio
import hashlib
import logging
import threading
@ -221,7 +222,9 @@ async def generate_document_summary(
else:
enhanced_summary_content = summary_content
summary_embedding = embed_text(enhanced_summary_content)
summary_embedding = await asyncio.to_thread(
embed_text, enhanced_summary_content
)
return enhanced_summary_content, summary_embedding
@ -237,7 +240,7 @@ async def create_document_chunks(content: str) -> list[Chunk]:
List of Chunk objects with embeddings
"""
chunk_texts = [c.text for c in config.chunker_instance.chunk(content)]
chunk_embeddings = embed_texts(chunk_texts)
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
return [
Chunk(content=text, embedding=emb)
for text, emb in zip(chunk_texts, chunk_embeddings, strict=False)

View file

@ -1,6 +1,6 @@
[project]
name = "surf-new-backend"
version = "0.0.24"
version = "0.0.25"
description = "SurfSense Backend"
requires-python = ">=3.12"
dependencies = [

View file

@ -2,28 +2,12 @@
import pytest
from app.agents.new_chat.tools.update_memory import _save_memory
from app.services.memory import MemoryScope, save_memory
from app.utils.content_utils import extract_text_content
pytestmark = pytest.mark.unit
class _Recorder:
def __init__(self) -> None:
self.applied_content: str | None = None
self.commit_calls = 0
self.rollback_calls = 0
def apply(self, content: str) -> None:
self.applied_content = content
async def commit(self) -> None:
self.commit_calls += 1
async def rollback(self) -> None:
self.rollback_calls += 1
def test_extract_text_content_keeps_no_update_bare_string_from_content_blocks() -> None:
content = [
{"type": "thinking", "thinking": "No"},
@ -69,21 +53,12 @@ def test_extract_text_content_preserves_plain_string_responses() -> None:
@pytest.mark.asyncio
async def test_save_memory_rejects_non_string_payload_before_commit() -> None:
recorder = _Recorder()
result = await _save_memory(
updated_memory=["NO_UPDATE"], # type: ignore[arg-type]
old_memory=None,
llm=None,
apply_fn=recorder.apply,
commit_fn=recorder.commit,
rollback_fn=recorder.rollback,
label="memory",
scope="user",
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content=["NO_UPDATE"], # type: ignore[arg-type]
session=None, # type: ignore[arg-type]
)
assert result["status"] == "error"
assert "must be a string" in result["message"]
assert recorder.applied_content is None
assert recorder.commit_calls == 0
assert recorder.rollback_calls == 0
assert result.status == "error"
assert "must be a string" in result.message

View file

@ -12,13 +12,19 @@ prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
the deepagent stack accumulates multiple ``SystemMessage``\ s in
``state["messages"]`` and ``role: system`` would tag every one of
them, blowing past Anthropic's 4-block ``cache_control`` cap.
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
prompt-cache surface is available).
3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only no
OpenAI-only kwargs because the router fans out across providers.
4. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
5. Defensive: LLMs without a writable ``model_kwargs`` are silently
2. Adds ``prompt_cache_key`` for OPENAI/DEEPSEEK/XAI/AZURE/AZURE_OPENAI
configs (Microsoft's Azure transformer was added to LiteLLM in
https://github.com/BerriAI/litellm/pull/20989, Feb 2026).
3. Adds ``prompt_cache_retention="24h"`` ONLY for OPENAI/DEEPSEEK/XAI.
Azure's server-side support landed in Microsoft's docs on 2026-05-13
but LiteLLM 1.83.14 hasn't wired it through yet, so we let Azure use
its default in-memory retention rather than send a param that
``litellm.drop_params`` would silently strip.
4. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only no
destination-specific kwargs because the router fans out across
providers.
5. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
6. Defensive: LLMs without a writable ``model_kwargs`` are silently
skipped rather than raising.
"""
@ -191,9 +197,9 @@ def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None:
@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"])
def test_sets_openai_family_extras(provider: str) -> None:
"""OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate
via routing affinity) and ``prompt_cache_retention="24h"`` (extends
cache TTL beyond the default 5-10 min)."""
"""Native OpenAI-style providers gain ``prompt_cache_key`` (raises
hit rate via routing affinity) and ``prompt_cache_retention="24h"``
(extends cache TTL beyond the default 5-10 min)."""
cfg = _make_cfg(provider=provider)
llm = _FakeLLM()
@ -203,6 +209,27 @@ def test_sets_openai_family_extras(provider: str) -> None:
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
@pytest.mark.parametrize("provider", ["AZURE", "AZURE_OPENAI"])
def test_azure_gets_prompt_cache_key_only(provider: str) -> None:
"""Azure configs gain ``prompt_cache_key`` for routing affinity
(Microsoft auto-caches every GPT-4o+ deployment at 1024 tokens;
the key clusters same-prefix requests on the same backend GPU pool
so hit rate climbs). They DO NOT get ``prompt_cache_retention``
because LiteLLM 1.83.14's Azure transformer omits it from its
supported params list ``drop_params`` would silently strip it.
Azure's default in-memory retention (5-10 min, max 1 h) is already
enough to cover intra-conversation turns; revisit when LiteLLM
bumps Azure to match its OpenAI surface."""
cfg = _make_cfg(provider=provider, model_name="gpt-5.4")
llm = _FakeLLM(model="azure/gpt-5.4")
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42"
assert "prompt_cache_retention" not in llm.model_kwargs
assert "cache_control_injection_points" in llm.model_kwargs
def test_skips_prompt_cache_key_when_no_thread_id() -> None:
"""Without a thread id we can't construct a per-thread key. Retention
is still useful so we set it (it's free)."""
@ -215,12 +242,26 @@ def test_skips_prompt_cache_key_when_no_thread_id() -> None:
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
def test_azure_skips_prompt_cache_key_when_no_thread_id() -> None:
"""Azure without a thread id ends up with no extras (retention is
Azure-skipped, key needs a thread id) universal injection points
still land."""
cfg = _make_cfg(provider="AZURE", model_name="gpt-5.4")
llm = _FakeLLM(model="azure/gpt-5.4")
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None)
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
assert "cache_control_injection_points" in llm.model_kwargs
@pytest.mark.parametrize(
"provider",
["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"],
)
def test_no_openai_extras_for_other_providers(provider: str) -> None:
"""Non-OpenAI-family providers don't expose ``prompt_cache_key`` —
"""Non-OpenAI-style providers don't expose ``prompt_cache_key`` —
skip it. ``cache_control_injection_points`` is still set (universal)."""
cfg = _make_cfg(provider=provider)
llm = _FakeLLM()

View file

@ -0,0 +1,130 @@
"""Unit tests for ``mcp_tools_cache``."""
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
import pytest
from app.agents.new_chat.tools.mcp_tools_cache import (
CachedMCPToolDef,
CachedMCPTools,
read_cached_tools,
)
pytestmark = pytest.mark.unit
def _make_connector(config: dict | None) -> SimpleNamespace:
return SimpleNamespace(id=42, config=config)
def test_read_returns_none_when_config_is_none() -> None:
assert read_cached_tools(_make_connector(None)) is None
def test_read_returns_none_when_cached_tools_missing() -> None:
assert read_cached_tools(_make_connector({"server_config": {}})) is None
def test_read_returns_none_when_cached_tools_is_not_a_dict() -> None:
assert read_cached_tools(_make_connector({"cached_tools": []})) is None
assert read_cached_tools(_make_connector({"cached_tools": "stale"})) is None
def test_read_parses_minimal_valid_payload() -> None:
parsed = read_cached_tools(
_make_connector(
{
"cached_tools": {
"discovered_at": "2026-05-20T10:00:00+00:00",
"tools": [
{
"name": "list_issues",
"description": "List Linear issues",
"input_schema": {"type": "object"},
}
],
}
}
)
)
assert parsed is not None
assert parsed.server_version is None
assert parsed.server_name is None
assert parsed.transport is None
assert len(parsed.tools) == 1
assert parsed.tools[0].name == "list_issues"
def test_read_parses_full_payload_with_serverinfo() -> None:
parsed = read_cached_tools(
_make_connector(
{
"cached_tools": {
"discovered_at": "2026-05-20T10:00:00+00:00",
"server_version": "1.2.3",
"server_name": "atlassian-mcp",
"transport": "streamable-http",
"tools": [
{"name": "create_issue", "input_schema": {}},
{"name": "list_issues", "input_schema": {}},
],
}
}
)
)
assert parsed is not None
assert parsed.server_version == "1.2.3"
assert parsed.server_name == "atlassian-mcp"
assert parsed.transport == "streamable-http"
assert [t.name for t in parsed.tools] == ["create_issue", "list_issues"]
def test_read_returns_none_for_corrupt_payload(caplog) -> None:
parsed = read_cached_tools(
_make_connector(
{
"cached_tools": {
"discovered_at": "not-a-date",
"tools": "should-be-a-list",
}
}
)
)
assert parsed is None
assert any("corrupt cached_tools" in r.getMessage() for r in caplog.records)
def test_read_returns_none_when_tools_missing() -> None:
parsed = read_cached_tools(
_make_connector(
{"cached_tools": {"discovered_at": "2026-05-20T10:00:00+00:00"}}
)
)
assert parsed is None
def test_tool_def_defaults_description_and_schema() -> None:
td = CachedMCPToolDef.model_validate({"name": "ping"})
assert td.description == ""
assert td.input_schema == {}
def test_model_dump_json_mode_is_round_trippable() -> None:
original = CachedMCPTools(
discovered_at=datetime(2026, 5, 20, 10, 0, 0, tzinfo=UTC),
server_version="1.2.3",
server_name="atlassian-mcp",
transport="streamable-http",
tools=[CachedMCPToolDef(name="list_issues")],
)
payload = original.model_dump(mode="json")
assert payload["discovered_at"] == "2026-05-20T10:00:00Z"
assert payload["tools"][0]["name"] == "list_issues"
reparsed = CachedMCPTools.model_validate(payload)
assert reparsed.discovered_at == original.discovered_at
assert reparsed.tools[0].name == "list_issues"

View file

@ -1,24 +1,24 @@
"""Unit tests for memory scope validation and bullet format validation."""
"""Unit tests for heading-based memory validation."""
import pytest
from app.agents.new_chat.tools.update_memory import (
_save_memory,
_validate_bullet_format,
_validate_memory_scope,
from app.services.memory import MemoryScope, save_memory
from app.services.memory.validation import (
validate_bullet_format,
validate_memory_scope,
)
pytestmark = pytest.mark.unit
class _Recorder:
class _FakeSession:
def __init__(self) -> None:
self.applied_content: str | None = None
self.added = []
self.commit_calls = 0
self.rollback_calls = 0
def apply(self, content: str) -> None:
self.applied_content = content
def add(self, obj) -> None:
self.added.append(obj)
async def commit(self) -> None:
self.commit_calls += 1
@ -27,172 +27,148 @@ class _Recorder:
self.rollback_calls += 1
# ---------------------------------------------------------------------------
# _validate_memory_scope — marker-based
# ---------------------------------------------------------------------------
def test_validate_memory_scope_rejects_pref_marker_in_team_scope() -> None:
content = "- (2026-04-10) [pref] Prefers dark mode\n"
result = _validate_memory_scope(content, "team")
def test_validate_memory_scope_rejects_new_personal_heading_in_team() -> None:
content = "## Preferences\n- 2026-04-10: Prefers dark mode\n"
result, _warnings = validate_memory_scope(content, "team")
assert result is not None
assert result["status"] == "error"
assert "[pref]" in result["message"]
assert "preferences" in result["message"]
def test_validate_memory_scope_rejects_instr_marker_in_team_scope() -> None:
content = "- (2026-04-10) [instr] Always respond in Spanish\n"
result = _validate_memory_scope(content, "team")
assert result is not None
assert result["status"] == "error"
assert "[instr]" in result["message"]
def test_validate_memory_scope_allows_old_marker_payload_in_team_scope() -> None:
content = "- (2026-04-10) [pref] Legacy personal marker remains readable\n"
result, _warnings = validate_memory_scope(content, "team")
assert result is None
def test_validate_memory_scope_rejects_both_personal_markers_in_team() -> None:
def test_validate_memory_scope_allows_team_headings() -> None:
content = "## Engineering Conventions\n- 2026-04-10: Uses PostgreSQL\n"
result, _warnings = validate_memory_scope(content, "team")
assert result is None
def test_validate_bullet_format_accepts_new_and_legacy_bullets() -> None:
content = (
"- (2026-04-10) [pref] Prefers dark mode\n"
"- (2026-04-10) [instr] Always respond in Spanish\n"
"## Facts\n"
"- 2026-04-10: Senior Python developer\n"
"- (2026-04-10) [fact] Legacy fact is preserved\n"
)
result = _validate_memory_scope(content, "team")
assert result is not None
assert result["status"] == "error"
assert "[instr]" in result["message"]
assert "[pref]" in result["message"]
def test_validate_memory_scope_allows_fact_in_team_scope() -> None:
content = "- (2026-04-10) [fact] Office is in downtown Seattle\n"
result = _validate_memory_scope(content, "team")
assert result is None
def test_validate_memory_scope_allows_all_markers_in_user_scope() -> None:
content = (
"- (2026-04-10) [fact] Python developer\n"
"- (2026-04-10) [pref] Prefers concise answers\n"
"- (2026-04-10) [instr] Always use bullet points\n"
)
result = _validate_memory_scope(content, "user")
assert result is None
def test_validate_memory_scope_allows_any_heading_in_team() -> None:
content = "## Architecture\n- (2026-04-10) [fact] Uses PostgreSQL for persistence\n"
result = _validate_memory_scope(content, "team")
assert result is None
def test_validate_memory_scope_allows_any_heading_in_user() -> None:
content = "## My Projects\n- (2026-04-10) [fact] Working on SurfSense\n"
result = _validate_memory_scope(content, "user")
assert result is None
# ---------------------------------------------------------------------------
# _validate_bullet_format
# ---------------------------------------------------------------------------
def test_validate_bullet_format_passes_valid_bullets() -> None:
content = (
"## Work\n"
"- (2026-04-10) [fact] Senior Python developer\n"
"- (2026-04-10) [pref] Prefers dark mode\n"
"- (2026-04-10) [instr] Always respond in bullet points\n"
)
warnings = _validate_bullet_format(content)
warnings = validate_bullet_format(content)
assert warnings == []
def test_validate_bullet_format_warns_on_missing_marker() -> None:
content = "- (2026-04-10) Senior Python developer\n"
warnings = _validate_bullet_format(content)
def test_validate_bullet_format_warns_on_nonstandard_bullet() -> None:
content = "## Facts\n- Senior Python developer\n"
warnings = validate_bullet_format(content)
assert len(warnings) == 1
assert "Malformed bullet" in warnings[0]
def test_validate_bullet_format_warns_on_missing_date() -> None:
content = "- [fact] Senior Python developer\n"
warnings = _validate_bullet_format(content)
assert len(warnings) == 1
assert "Malformed bullet" in warnings[0]
def test_validate_bullet_format_warns_on_unknown_marker() -> None:
content = "- (2026-04-10) [context] Working on project X\n"
warnings = _validate_bullet_format(content)
assert len(warnings) == 1
assert "Malformed bullet" in warnings[0]
def test_validate_bullet_format_ignores_non_bullet_lines() -> None:
content = "## Some Heading\nSome paragraph text\n"
warnings = _validate_bullet_format(content)
assert warnings == []
def test_validate_bullet_format_warns_on_old_format_without_marker() -> None:
content = "## About the user\n- (2026-04-10) Likes cats\n"
warnings = _validate_bullet_format(content)
assert len(warnings) == 1
# ---------------------------------------------------------------------------
# _save_memory — end-to-end with marker scope check
# ---------------------------------------------------------------------------
assert "Non-standard memory bullet" in warnings[0]
@pytest.mark.asyncio
async def test_save_memory_blocks_pref_in_team_before_commit() -> None:
recorder = _Recorder()
result = await _save_memory(
updated_memory="- (2026-04-10) [pref] Prefers dark mode\n",
old_memory=None,
llm=None,
apply_fn=recorder.apply,
commit_fn=recorder.commit,
rollback_fn=recorder.rollback,
label="team memory",
scope="team",
async def test_save_memory_normalizes_legacy_marker_bullets(monkeypatch) -> None:
target = type("Target", (), {"memory_md": ""})()
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content="- (2026-04-10) [fact] Legacy fact is preserved\n",
session=session,
)
assert result["status"] == "error"
assert recorder.commit_calls == 0
assert recorder.applied_content is None
assert result.status == "saved"
assert target.memory_md == "## Memory\n- 2026-04-10: Legacy fact is preserved"
@pytest.mark.asyncio
async def test_save_memory_allows_fact_in_team_and_commits() -> None:
recorder = _Recorder()
content = "- (2026-04-10) [fact] Weekly standup on Mondays\n"
result = await _save_memory(
updated_memory=content,
old_memory=None,
llm=None,
apply_fn=recorder.apply,
commit_fn=recorder.commit,
rollback_fn=recorder.rollback,
label="team memory",
scope="team",
async def test_save_memory_blocks_new_personal_heading_in_team_before_commit(
monkeypatch,
) -> None:
target = type("Target", (), {"shared_memory_md": ""})()
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.TEAM,
target_id=1,
content="## Preferences\n- 2026-04-10: Prefers dark mode\n",
session=session,
)
assert result["status"] == "saved"
assert recorder.commit_calls == 1
assert recorder.applied_content == content
assert result.status == "error"
assert session.commit_calls == 0
assert target.shared_memory_md == ""
@pytest.mark.asyncio
async def test_save_memory_includes_format_warnings() -> None:
recorder = _Recorder()
content = "- (2026-04-10) Missing marker text\n"
result = await _save_memory(
updated_memory=content,
old_memory=None,
llm=None,
apply_fn=recorder.apply,
commit_fn=recorder.commit,
rollback_fn=recorder.rollback,
label="memory",
scope="user",
async def test_save_memory_allows_grandfathered_personal_heading_in_team(
monkeypatch,
) -> None:
content = "## Preferences\n- 2026-04-10: Prefers dark mode\n"
target = type("Target", (), {"shared_memory_md": content})()
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.TEAM,
target_id=1,
content=content,
session=session,
)
assert result["status"] == "saved"
assert "format_warnings" in result
assert len(result["format_warnings"]) == 1
assert result.status == "saved"
assert session.commit_calls == 1
assert target.shared_memory_md == content.strip()
assert result.warnings
@pytest.mark.asyncio
async def test_save_memory_strips_preamble_before_heading(monkeypatch) -> None:
target = type("Target", (), {"memory_md": ""})()
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content="Sure, here is the update:\n\n## Facts\n- 2026-04-10: Likes cats\n",
session=session,
)
assert result.status == "saved"
assert target.memory_md == "## Facts\n- 2026-04-10: Likes cats"
@pytest.mark.asyncio
async def test_save_memory_rejects_long_no_heading_payload(monkeypatch) -> None:
target = type("Target", (), {"memory_md": ""})()
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content="NO_UPDATE because there is nothing durable to remember.",
session=session,
)
assert result.status == "error"
assert "## heading" in result.message
assert session.commit_calls == 0

View file

@ -0,0 +1,187 @@
"""Unit tests for the first-class memory service."""
from types import SimpleNamespace
import pytest
from app.services.memory import (
MemoryScope,
reset_memory,
save_memory,
)
pytestmark = pytest.mark.unit
class _FakeSession:
def __init__(self) -> None:
self.commit_calls = 0
self.rollback_calls = 0
self.added = []
def add(self, obj) -> None:
self.added.append(obj)
async def commit(self) -> None:
self.commit_calls += 1
async def rollback(self) -> None:
self.rollback_calls += 1
@pytest.mark.asyncio
async def test_save_memory_saves_heading_based_memory(monkeypatch) -> None:
target = SimpleNamespace(memory_md="")
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content="## Facts\n- 2026-05-19: Anish works on SurfSense\n",
session=session,
)
assert result.status == "saved"
assert target.memory_md.startswith("## Facts")
assert session.commit_calls == 1
@pytest.mark.asyncio
async def test_save_memory_accepts_legacy_marker_payload(monkeypatch) -> None:
target = SimpleNamespace(memory_md="")
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content="- (2026-05-19) [fact] Legacy marker memory\n",
session=session,
)
assert result.status == "saved"
assert target.memory_md == "## Memory\n- 2026-05-19: Legacy marker memory"
@pytest.mark.asyncio
async def test_save_memory_rejects_long_no_heading_payload(monkeypatch) -> None:
target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content="reasoning text before NO_UPDATE should not become saved memory",
session=session,
)
assert result.status == "error"
assert session.commit_calls == 0
assert target.memory_md.startswith("## Facts")
@pytest.mark.asyncio
async def test_save_memory_no_update_sentinel_is_no_op(monkeypatch) -> None:
existing = "## Preferences\n- 2026-05-20: Existing preference\n"
target = SimpleNamespace(memory_md=existing)
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content="NO_UPDATE",
session=session,
)
assert result.status == "no_op"
assert result.memory_md == existing
assert target.memory_md == existing
assert session.commit_calls == 0
@pytest.mark.asyncio
async def test_save_memory_no_update_sentinel_is_case_insensitive(monkeypatch) -> None:
existing = "## Preferences\n- 2026-05-20: Existing preference\n"
target = SimpleNamespace(memory_md=existing)
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content=" no update ",
session=session,
)
assert result.status == "no_op"
assert result.memory_md == existing
assert target.memory_md == existing
assert session.commit_calls == 0
@pytest.mark.asyncio
async def test_save_memory_grandfathers_existing_team_personal_heading(
monkeypatch,
) -> None:
content = "## Preferences\n- 2026-05-19: Existing legacy heading\n"
target = SimpleNamespace(shared_memory_md=content)
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.TEAM,
target_id=1,
content=content,
session=session,
)
assert result.status == "saved"
assert result.warnings
assert session.commit_calls == 1
@pytest.mark.asyncio
async def test_reset_memory_clears_memory(monkeypatch) -> None:
target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await reset_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
session=session,
)
assert result.status == "saved"
assert target.memory_md == ""

View file

@ -89,7 +89,6 @@ async def test_stream_output_emits_text_lifecycle_and_updates_result() -> None:
"text_end:text-1",
]
assert result.accumulated_text == "Hello world"
assert result.agent_called_update_memory is False
async def test_stream_output_passes_runtime_context_to_agent() -> None:

View file

@ -209,128 +209,6 @@ def test_stream_exception_classifies_openrouter_429_payload():
assert extra is None
@pytest.mark.asyncio
async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch):
"""``_preflight_llm`` is best-effort.
- On rate-limit shaped exceptions (provider 429) it MUST re-raise so the
caller can drive the cooldown/repin branch.
- On any other transient failure it MUST swallow the error so the normal
stream path continues without surfacing preflight noise to the user.
"""
from types import SimpleNamespace
from app.tasks.chat.stream_new_chat import _preflight_llm
class _RateLimitedError(Exception):
"""Class-name carries 'RateLimit' so _is_provider_rate_limited triggers."""
rate_calls: list[dict] = []
other_calls: list[dict] = []
async def _fake_acompletion_429(**kwargs):
rate_calls.append(kwargs)
raise _RateLimitedError("simulated 429")
async def _fake_acompletion_other(**kwargs):
other_calls.append(kwargs)
raise RuntimeError("some unrelated transient failure")
fake_llm = SimpleNamespace(
model="openrouter/google/gemma-4-31b-it:free",
api_key="test",
api_base=None,
)
import litellm # type: ignore[import-not-found]
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429)
with pytest.raises(_RateLimitedError):
await _preflight_llm(fake_llm)
assert len(rate_calls) == 1
assert rate_calls[0]["max_tokens"] == 1
assert rate_calls[0]["stream"] is False
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other)
# MUST NOT raise: non-rate-limit failures are swallowed.
await _preflight_llm(fake_llm)
assert len(other_calls) == 1
@pytest.mark.asyncio
async def test_preflight_skipped_for_auto_router_model():
"""Router-mode ``model='auto'`` has no single deployment to ping; the
LiteLLM router itself owns per-deployment rate-limit accounting, so the
preflight helper must short-circuit instead of issuing a probe."""
from types import SimpleNamespace
from app.tasks.chat.stream_new_chat import _preflight_llm
fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None)
# Should return without raising or making any LiteLLM call.
await _preflight_llm(fake_llm)
@pytest.mark.asyncio
async def test_settle_speculative_agent_build_swallows_exceptions():
"""``_settle_speculative_agent_build`` MUST always return cleanly so the
caller can safely re-touch the request-scoped session afterwards.
The helper guards the parallel preflight + agent-build path: when the
speculative build is being discarded (429 or non-429 preflight failure)
we await it solely to release any in-flight ``AsyncSession`` usage
the build's outcome is irrelevant. Any exception (including
``CancelledError``) leaking out would skip the caller's recovery flow
and re-introduce the very session-concurrency hazard the helper exists
to prevent.
"""
import asyncio
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
async def _raises() -> None:
raise RuntimeError("speculative build crashed")
async def _succeeds() -> str:
return "agent"
async def _slow() -> None:
await asyncio.sleep(0.05)
for coro in (_raises(), _succeeds(), _slow()):
task = asyncio.create_task(coro)
await _settle_speculative_agent_build(task)
assert task.done()
@pytest.mark.asyncio
async def test_settle_speculative_agent_build_handles_already_done_task():
"""Done tasks (success or failure) must still be settled without raising."""
import asyncio
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
async def _ok() -> str:
return "ok"
async def _bad() -> None:
raise ValueError("nope")
ok_task = asyncio.create_task(_ok())
bad_task = asyncio.create_task(_bad())
# Drive both to completion before settling.
await asyncio.sleep(0)
await asyncio.sleep(0)
await _settle_speculative_agent_build(ok_task)
await _settle_speculative_agent_build(bad_task)
assert ok_task.result() == "ok"
# ``bad_task`` exception was consumed by the settle helper; calling
# ``.exception()`` after the fact must still return the original error
# (the helper observes it but doesn't clear it).
assert isinstance(bad_task.exception(), ValueError)
def test_stream_exception_classifies_thread_busy():
exc = BusyError(request_id="thread-123")
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(

View file

@ -7947,7 +7947,7 @@ wheels = [
[[package]]
name = "surf-new-backend"
version = "0.0.24"
version = "0.0.25"
source = { editable = "." }
dependencies = [
{ name = "alembic" },

View file

@ -1,7 +1,7 @@
{
"name": "surfsense_browser_extension",
"displayName": "Surfsense Browser Extension",
"version": "0.0.24",
"version": "0.0.25",
"description": "Extension to collect Browsing History for SurfSense.",
"author": "https://github.com/MODSetter",
"engines": {

View file

@ -1,6 +1,6 @@
{
"name": "surfsense-desktop",
"version": "0.0.24",
"version": "0.0.25",
"description": "SurfSense Desktop App",
"main": "dist/main.js",
"scripts": {

View file

@ -5,6 +5,7 @@ import { Logo } from "@/components/Logo";
import { Button } from "@/components/ui/button";
import { trackLoginAttempt } from "@/lib/posthog/events";
import { AmbientBackground } from "./AmbientBackground";
import { BACKEND_URL } from "@/lib/env-config";
function GoogleGLogo({ className }: { className?: string }) {
return (
@ -50,7 +51,7 @@ export function GoogleLoginButton() {
// cross-origin fetch requests may not be sent on subsequent redirects.
// The authorize-redirect endpoint does a server-side redirect to Google
// and sets the CSRF cookie properly for same-site context.
window.location.href = `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/auth/google/authorize-redirect`;
window.location.href = `${BACKEND_URL}/auth/google/authorize-redirect`;
};
return (
<div className="relative w-full overflow-hidden">

View file

@ -4,11 +4,9 @@ import { NextResponse } from "next/server";
import type { Context } from "@/types/zero";
import { queries } from "@/zero/queries";
import { schema } from "@/zero/schema";
import { BACKEND_URL } from "@/lib/env-config";
const backendURL =
process.env.FASTAPI_BACKEND_INTERNAL_URL ||
process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL ||
"http://localhost:8000";
const backendURL = BACKEND_URL;
async function authenticateRequest(
request: Request

View file

@ -118,7 +118,7 @@ import {
trackChatResponseReceived,
} from "@/lib/posthog/events";
import Loading from "../loading";
import { BACKEND_URL } from "@/lib/env-config";
const MobileEditorPanel = dynamic(
() =>
import("@/components/editor-panel/editor-panel").then((m) => ({
@ -777,7 +777,7 @@ export default function NewChatPage() {
if (threadId) {
const token = getBearerToken();
if (token) {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
const backendUrl = BACKEND_URL;
try {
const response = await fetch(
`${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`,
@ -978,7 +978,7 @@ export default function NewChatPage() {
let streamBatcher: FrameBatchedUpdater | null = null;
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
const backendUrl = BACKEND_URL;
const selection = await getAgentFilesystemSelection(searchSpaceId, {
localFilesystemEnabled,
});
@ -1520,7 +1520,7 @@ export default function NewChatPage() {
}
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
const backendUrl = BACKEND_URL;
const selection = await getAgentFilesystemSelection(searchSpaceId, {
localFilesystemEnabled,
});

View file

@ -3,7 +3,6 @@
import {
BookText,
Bot,
Brain,
CircleUser,
Earth,
ImageIcon,
@ -27,7 +26,6 @@ export type SearchSpaceSettingsTab =
| "vision-models"
| "team-roles"
| "prompts"
| "team-memory"
| "public-links";
const DEFAULT_TAB: SearchSpaceSettingsTab = "general";
@ -89,11 +87,6 @@ export function SearchSpaceSettingsLayoutShell({
label: t("nav_system_instructions"),
icon: <BookText className="h-4 w-4" />,
},
{
value: "team-memory" as const,
label: "Team Memory",
icon: <Brain className="h-4 w-4" />,
},
{
value: "public-links" as const,
label: t("nav_public_links"),

View file

@ -1,6 +0,0 @@
import { TeamMemoryManager } from "@/components/settings/team-memory-manager";
export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) {
const { search_space_id } = await params;
return <TeamMemoryManager searchSpaceId={Number(search_space_id)} />;
}

View file

@ -1,293 +0,0 @@
"use client";
import { useAtomValue } from "jotai";
import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pencil } from "lucide-react";
import { useCallback, useEffect, useRef, useState } from "react";
import { toast } from "sonner";
import { z } from "zod";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import { PlateEditor } from "@/components/editor/plate-editor";
import { Alert, AlertDescription } from "@/components/ui/alert";
import { Button } from "@/components/ui/button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { Spinner } from "@/components/ui/spinner";
import { baseApiService } from "@/lib/apis/base-api.service";
const MEMORY_HARD_LIMIT = 25_000;
const MemoryReadSchema = z.object({
memory_md: z.string(),
});
export function MemoryContent() {
const activeSearchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
const [memory, setMemory] = useState("");
const [loading, setLoading] = useState(true);
const [saving, setSaving] = useState(false);
const [editQuery, setEditQuery] = useState("");
const [editing, setEditing] = useState(false);
const [showInput, setShowInput] = useState(false);
const textareaRef = useRef<HTMLInputElement>(null);
const inputContainerRef = useRef<HTMLDivElement>(null);
const fetchMemory = useCallback(async () => {
try {
setLoading(true);
const data = await baseApiService.get("/api/v1/users/me/memory", MemoryReadSchema);
setMemory(data.memory_md);
} catch {
toast.error("Failed to load memory");
} finally {
setLoading(false);
}
}, []);
useEffect(() => {
fetchMemory();
}, [fetchMemory]);
useEffect(() => {
if (!showInput) return;
const handlePointerDownOutside = (event: MouseEvent | TouchEvent) => {
const target = event.target;
if (!(target instanceof Node)) return;
if (inputContainerRef.current?.contains(target)) return;
setShowInput(false);
};
document.addEventListener("mousedown", handlePointerDownOutside);
document.addEventListener("touchstart", handlePointerDownOutside, { passive: true });
return () => {
document.removeEventListener("mousedown", handlePointerDownOutside);
document.removeEventListener("touchstart", handlePointerDownOutside);
};
}, [showInput]);
const handleClear = async () => {
try {
setSaving(true);
const data = await baseApiService.put("/api/v1/users/me/memory", MemoryReadSchema, {
body: { memory_md: "" },
});
setMemory(data.memory_md);
toast.success("Memory cleared");
} catch {
toast.error("Failed to clear memory");
} finally {
setSaving(false);
}
};
const handleEdit = async () => {
const query = editQuery.trim();
if (!query) return;
try {
setEditing(true);
const data = await baseApiService.post("/api/v1/users/me/memory/edit", MemoryReadSchema, {
body: { query, search_space_id: Number(activeSearchSpaceId) },
});
setMemory(data.memory_md);
setEditQuery("");
setShowInput(false);
toast.success("Memory updated");
} catch {
toast.error("Failed to edit memory");
} finally {
setEditing(false);
}
};
const openInput = () => {
setShowInput(true);
requestAnimationFrame(() => textareaRef.current?.focus());
};
const handleDownload = () => {
if (!memory) return;
try {
const blob = new Blob([memory], { type: "text/markdown;charset=utf-8" });
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = "personal-memory.md";
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
} catch {
toast.error("Failed to download memory");
}
};
const handleCopyMarkdown = async () => {
if (!memory) return;
try {
await navigator.clipboard.writeText(memory);
toast.success("Copied to clipboard");
} catch {
toast.error("Failed to copy memory");
}
};
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
handleEdit();
}
};
const displayMemory = memory.replace(/\(\d{4}-\d{2}-\d{2}\)\s*\[(fact|pref|instr)\]\s*/g, "");
const charCount = memory.length;
const getCounterColor = () => {
if (charCount > MEMORY_HARD_LIMIT) return "text-red-500";
if (charCount > 15_000) return "text-orange-500";
if (charCount > 10_000) return "text-yellow-500";
return "text-muted-foreground";
};
if (loading) {
return (
<div className="flex items-center justify-center py-12">
<Spinner size="md" className="text-muted-foreground" />
</div>
);
}
if (!memory) {
return (
<div className="flex flex-col items-center justify-center py-16 text-center">
<h3 className="text-base font-medium text-foreground">What does SurfSense remember?</h3>
<p className="mt-2 max-w-sm text-sm text-muted-foreground">
Nothing yet. SurfSense picks up on your preferences and context as you chat.
</p>
</div>
);
}
return (
<div className="space-y-4">
<Alert>
<Info />
<AlertDescription>
<p>
SurfSense uses this personal memory to personalize your responses across all
conversations.
</p>
</AlertDescription>
</Alert>
<div className="relative h-[380px] rounded-lg border bg-background">
<div className="h-full overflow-y-auto scrollbar-thin">
<PlateEditor
markdown={displayMemory}
readOnly
preset="readonly"
variant="default"
editorVariant="none"
className="px-5 py-4 text-sm min-h-full"
/>
</div>
{showInput ? (
<div className="absolute bottom-3 inset-x-3 z-10">
<div
ref={inputContainerRef}
className="relative flex h-[54px] items-center gap-2 rounded-[9999px] border bg-muted/60 backdrop-blur-sm pl-4 pr-1 shadow-sm"
>
<input
ref={textareaRef}
type="text"
value={editQuery}
onChange={(e) => setEditQuery(e.target.value)}
onKeyDown={handleKeyDown}
placeholder="Tell SurfSense what to remember or forget"
disabled={editing}
className="flex-1 bg-transparent text-sm outline-none placeholder:text-muted-foreground/70"
/>
<Button
type="button"
size="icon"
variant="ghost"
onClick={handleEdit}
disabled={editing || !editQuery.trim()}
className={`h-11 w-11 shrink-0 rounded-full ${
editing
? ""
: "bg-muted-foreground/15 hover:bg-accent hover:text-accent-foreground"
}`}
>
{editing ? (
<Spinner size="sm" />
) : (
<ArrowUp className="!h-5 !w-5 text-foreground" strokeWidth={2.25} />
)}
</Button>
</div>
</div>
) : (
<Button
type="button"
size="icon"
variant="secondary"
onClick={openInput}
className="absolute bottom-3 right-3 z-10 h-[54px] w-[54px] rounded-full border bg-muted/60 backdrop-blur-sm shadow-sm"
>
<Pencil className="!h-5 !w-5" />
</Button>
)}
</div>
<div className="flex items-center justify-between gap-2">
<span className={`text-xs shrink-0 ${getCounterColor()}`}>
{charCount.toLocaleString()} / {MEMORY_HARD_LIMIT.toLocaleString()}
<span className="hidden sm:inline"> characters</span>
<span className="sm:hidden"> chars</span>
{charCount > 15_000 && charCount <= MEMORY_HARD_LIMIT && " - Approaching limit"}
{charCount > MEMORY_HARD_LIMIT && " - Exceeds limit"}
</span>
<div className="flex items-center gap-1.5 sm:gap-2">
<Button
type="button"
variant="destructive"
size="sm"
className="text-xs sm:text-sm"
onClick={handleClear}
disabled={saving || editing || !memory}
>
<span className="hidden sm:inline">Reset Memory</span>
<span className="sm:hidden">Reset</span>
</Button>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button type="button" variant="secondary" size="sm" disabled={!memory}>
Export
<ChevronDown className="h-3 w-3 opacity-60" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem onClick={handleCopyMarkdown}>
<ClipboardCopy className="h-4 w-4 mr-2" />
Copy as Markdown
</DropdownMenuItem>
<DropdownMenuItem onClick={handleDownload}>
<Download className="h-4 w-4 mr-2" />
Download as Markdown
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
</div>
</div>
);
}

View file

@ -1,7 +1,6 @@
"use client";
import {
Brain,
CircleUser,
Keyboard,
KeyRound,
@ -26,7 +25,6 @@ export type UserSettingsTab =
| "api-key"
| "prompts"
| "community-prompts"
| "memory"
| "agent-permissions"
| "agent-status"
| "purchases"
@ -75,11 +73,6 @@ export function UserSettingsLayoutShell({ searchSpaceId, children }: UserSetting
label: "Community Prompts",
icon: <Library className="h-4 w-4" />,
},
{
value: "memory" as const,
label: "Memory",
icon: <Brain className="h-4 w-4" />,
},
{
value: "agent-permissions" as const,
label: "Agent Permissions",

Some files were not shown because too many files have changed in this diff Show more