mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
Merge pull request #1423 from MODSetter/dev
feat: improved agent speed and fixed it citations
This commit is contained in:
commit
49dd8409d8
133 changed files with 3249 additions and 2971 deletions
39
.github/workflows/backend-tests.yml
vendored
39
.github/workflows/backend-tests.yml
vendored
|
|
@ -4,6 +4,9 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [main, dev]
|
branches: [main, dev]
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
|
paths:
|
||||||
|
- 'surfsense_backend/**'
|
||||||
|
- '.github/workflows/backend-tests.yml'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
|
@ -21,26 +24,15 @@ jobs:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Check if backend files changed
|
|
||||||
id: backend-changes
|
|
||||||
uses: dorny/paths-filter@v3
|
|
||||||
with:
|
|
||||||
filters: |
|
|
||||||
backend:
|
|
||||||
- 'surfsense_backend/**'
|
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
uses: actions/setup-python@v6
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.12'
|
||||||
|
|
||||||
- name: Install UV
|
- name: Install UV
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
uses: astral-sh/setup-uv@v8.1.0
|
||||||
uses: astral-sh/setup-uv@v7
|
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
|
|
@ -51,19 +43,16 @@ jobs:
|
||||||
python-deps-
|
python-deps-
|
||||||
|
|
||||||
- name: Cache HuggingFace models
|
- name: Cache HuggingFace models
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/huggingface
|
path: ~/.cache/huggingface
|
||||||
key: hf-models-${{ env.EMBEDDING_MODEL }}
|
key: hf-models-${{ env.EMBEDDING_MODEL }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
|
||||||
working-directory: surfsense_backend
|
working-directory: surfsense_backend
|
||||||
run: uv sync
|
run: uv sync
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
|
||||||
working-directory: surfsense_backend
|
working-directory: surfsense_backend
|
||||||
run: uv run pytest -m unit
|
run: uv run pytest -m unit
|
||||||
|
|
||||||
|
|
@ -93,26 +82,15 @@ jobs:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Check if backend files changed
|
|
||||||
id: backend-changes
|
|
||||||
uses: dorny/paths-filter@v3
|
|
||||||
with:
|
|
||||||
filters: |
|
|
||||||
backend:
|
|
||||||
- 'surfsense_backend/**'
|
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
uses: actions/setup-python@v6
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.12'
|
||||||
|
|
||||||
- name: Install UV
|
- name: Install UV
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
uses: astral-sh/setup-uv@v8.1.0
|
||||||
uses: astral-sh/setup-uv@v7
|
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
|
|
@ -123,19 +101,16 @@ jobs:
|
||||||
python-deps-
|
python-deps-
|
||||||
|
|
||||||
- name: Cache HuggingFace models
|
- name: Cache HuggingFace models
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/huggingface
|
path: ~/.cache/huggingface
|
||||||
key: hf-models-${{ env.EMBEDDING_MODEL }}
|
key: hf-models-${{ env.EMBEDDING_MODEL }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
|
||||||
working-directory: surfsense_backend
|
working-directory: surfsense_backend
|
||||||
run: uv sync
|
run: uv sync
|
||||||
|
|
||||||
- name: Run integration tests
|
- name: Run integration tests
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
|
||||||
working-directory: surfsense_backend
|
working-directory: surfsense_backend
|
||||||
env:
|
env:
|
||||||
TEST_DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test
|
TEST_DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test
|
||||||
|
|
|
||||||
47
.github/workflows/code-quality.yml
vendored
47
.github/workflows/code-quality.yml
vendored
|
|
@ -11,13 +11,13 @@ concurrency:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
file-quality:
|
file-quality:
|
||||||
name: File Quality Checks
|
name: File Quality
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: github.event.pull_request.draft == false
|
if: github.event.pull_request.draft == false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
|
|
@ -27,7 +27,7 @@ jobs:
|
||||||
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
|
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.12'
|
||||||
|
|
||||||
|
|
@ -35,7 +35,7 @@ jobs:
|
||||||
run: pip install pre-commit
|
run: pip install pre-commit
|
||||||
|
|
||||||
- name: Cache pre-commit hooks
|
- name: Cache pre-commit hooks
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pre-commit
|
path: ~/.cache/pre-commit
|
||||||
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
|
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||||
|
|
@ -74,7 +74,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
|
|
@ -83,7 +83,7 @@ jobs:
|
||||||
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
|
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.12'
|
||||||
|
|
||||||
|
|
@ -91,7 +91,7 @@ jobs:
|
||||||
run: pip install pre-commit
|
run: pip install pre-commit
|
||||||
|
|
||||||
- name: Cache pre-commit hooks
|
- name: Cache pre-commit hooks
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pre-commit
|
path: ~/.cache/pre-commit
|
||||||
key: pre-commit-security-${{ hashFiles('.pre-commit-config.yaml') }}
|
key: pre-commit-security-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||||
|
|
@ -125,35 +125,36 @@ jobs:
|
||||||
exit ${exit_code:-0}
|
exit ${exit_code:-0}
|
||||||
|
|
||||||
python-backend:
|
python-backend:
|
||||||
name: Python Backend Quality
|
name: Backend Quality
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: github.event.pull_request.draft == false
|
if: github.event.pull_request.draft == false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.12'
|
||||||
|
|
||||||
- name: Install UV
|
- name: Install UV
|
||||||
uses: astral-sh/setup-uv@v3
|
uses: astral-sh/setup-uv@v8.1.0
|
||||||
|
|
||||||
- name: Check if backend files changed
|
- name: Check if backend files changed
|
||||||
id: backend-changes
|
id: backend-changes
|
||||||
uses: dorny/paths-filter@v3
|
uses: dorny/paths-filter@v4
|
||||||
with:
|
with:
|
||||||
filters: |
|
filters: |
|
||||||
backend:
|
backend:
|
||||||
- 'surfsense_backend/**'
|
- 'surfsense_backend/**'
|
||||||
|
- '.github/workflows/code-quality.yml'
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
if: steps.backend-changes.outputs.backend == 'true'
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/.cache/uv
|
~/.cache/uv
|
||||||
|
|
@ -171,7 +172,7 @@ jobs:
|
||||||
|
|
||||||
- name: Cache pre-commit hooks
|
- name: Cache pre-commit hooks
|
||||||
if: steps.backend-changes.outputs.backend == 'true'
|
if: steps.backend-changes.outputs.backend == 'true'
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pre-commit
|
path: ~/.cache/pre-commit
|
||||||
key: pre-commit-backend-${{ hashFiles('.pre-commit-config.yaml') }}
|
key: pre-commit-backend-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||||
|
|
@ -206,13 +207,13 @@ jobs:
|
||||||
exit ${exit_code:-0}
|
exit ${exit_code:-0}
|
||||||
|
|
||||||
typescript-frontend:
|
typescript-frontend:
|
||||||
name: TypeScript/JavaScript Quality
|
name: Frontend Quality
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: github.event.pull_request.draft == false
|
if: github.event.pull_request.draft == false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
|
|
@ -221,24 +222,24 @@ jobs:
|
||||||
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
|
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
|
||||||
|
|
||||||
- name: Setup Node.js
|
- name: Setup Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: '18'
|
node-version: '20'
|
||||||
|
|
||||||
- name: Install pnpm
|
- name: Install pnpm
|
||||||
uses: pnpm/action-setup@v4
|
uses: pnpm/action-setup@v6
|
||||||
with:
|
|
||||||
version: latest
|
|
||||||
|
|
||||||
- name: Check if frontend files changed
|
- name: Check if frontend files changed
|
||||||
id: frontend-changes
|
id: frontend-changes
|
||||||
uses: dorny/paths-filter@v3
|
uses: dorny/paths-filter@v4
|
||||||
with:
|
with:
|
||||||
filters: |
|
filters: |
|
||||||
web:
|
web:
|
||||||
- 'surfsense_web/**'
|
- 'surfsense_web/**'
|
||||||
|
- '.github/workflows/code-quality.yml'
|
||||||
extension:
|
extension:
|
||||||
- 'surfsense_browser_extension/**'
|
- 'surfsense_browser_extension/**'
|
||||||
|
- '.github/workflows/code-quality.yml'
|
||||||
|
|
||||||
- name: Install dependencies for web
|
- name: Install dependencies for web
|
||||||
if: steps.frontend-changes.outputs.web == 'true'
|
if: steps.frontend-changes.outputs.web == 'true'
|
||||||
|
|
@ -254,7 +255,7 @@ jobs:
|
||||||
run: pip install pre-commit
|
run: pip install pre-commit
|
||||||
|
|
||||||
- name: Cache pre-commit hooks
|
- name: Cache pre-commit hooks
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pre-commit
|
path: ~/.cache/pre-commit
|
||||||
key: pre-commit-frontend-${{ hashFiles('.pre-commit-config.yaml') }}
|
key: pre-commit-frontend-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ repos:
|
||||||
# Biome check for surfsense_web
|
# Biome check for surfsense_web
|
||||||
- id: biome-check-web
|
- id: biome-check-web
|
||||||
name: biome-check-web
|
name: biome-check-web
|
||||||
entry: bash -c 'cd surfsense_web && npx @biomejs/biome check --diagnostic-level=error .'
|
entry: bash -c 'cd surfsense_web && npx @biomejs/biome@2.4.6 check --diagnostic-level=error .'
|
||||||
language: system
|
language: system
|
||||||
files: ^surfsense_web/
|
files: ^surfsense_web/
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|
|
||||||
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
||||||
0.0.24
|
0.0.25
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,42 @@
|
||||||
<citations>
|
<citations>
|
||||||
Apply chunk citations only when the runtime injects `<document>` /
|
Citations reach the answer through two channels. Use whichever applies — and
|
||||||
`<chunk id='…'>` blocks.
|
never invent ids you didn't see. Citation ids are resolved by exact-match
|
||||||
|
lookup; a wrong id silently breaks the link, so when in doubt, omit.
|
||||||
|
|
||||||
|
### Channel A — chunk blocks injected this turn
|
||||||
|
When `search_surfsense_docs` or `web_search` returns `<document>` /
|
||||||
|
`<chunk id='…'>` blocks in this turn:
|
||||||
|
|
||||||
1. For each factual statement taken from those chunks, add
|
1. For each factual statement taken from those chunks, add
|
||||||
`[citation:chunk_id]` using the exact id from `<chunk id='…'>`.
|
`[citation:chunk_id]` using the **exact** id from a visible
|
||||||
2. Multiple chunks → `[citation:id1], [citation:id2]` (comma-separated).
|
`<chunk id='…'>` tag. Copy digit-for-digit (or the URL verbatim);
|
||||||
3. Never invent or normalise ids; if unsure, omit.
|
do not retype from memory.
|
||||||
4. Plain brackets only — no markdown links, no footnote numbering.
|
2. `<document_id>` is the parent doc id, **not** a citation source —
|
||||||
5. If no chunk-tagged documents appear this turn, do not fabricate citations.
|
only ids inside `<chunk id='…'>` count.
|
||||||
|
3. Multiple chunks → `[citation:id1], [citation:id2]` (comma-separated,
|
||||||
|
each id copied individually).
|
||||||
|
4. Never invent, normalise, or guess at adjacent ids; if unsure, omit.
|
||||||
|
5. Plain brackets only — no markdown links, no footnote numbering.
|
||||||
|
|
||||||
|
### Channel B — citations relayed by a `task` specialist
|
||||||
|
A `task(...)` tool message may contain `[citation:<chunk_id>]` markers
|
||||||
|
the specialist already attached to its prose. The specialist saw the
|
||||||
|
underlying `<chunk id='…'>` blocks; you didn't. So:
|
||||||
|
|
||||||
|
1. **Preserve those markers verbatim** in your final answer — do not
|
||||||
|
reformat, renumber, drop, or wrap them in markdown links. When you
|
||||||
|
paraphrase a specialist sentence, copy the marker character-for-
|
||||||
|
character; do not regenerate the id from memory (LLMs reliably
|
||||||
|
corrupt nearby digits).
|
||||||
|
2. Keep each marker attached to the sentence the specialist attached
|
||||||
|
it to.
|
||||||
|
3. Do **not** add new `[citation:…]` markers of your own to a
|
||||||
|
specialist's prose; if a fact has no marker, the specialist
|
||||||
|
couldn't tie it to a chunk and neither can you.
|
||||||
|
4. When a specialist returns JSON, the citation markers live inside
|
||||||
|
the prose-bearing fields (e.g. a summary or excerpt). Pull them
|
||||||
|
along with the surrounding sentence when you quote.
|
||||||
|
|
||||||
|
If neither channel surfaces citation markers this turn, do not fabricate
|
||||||
|
them.
|
||||||
</citations>
|
</citations>
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,10 @@ standing instructions?
|
||||||
If yes, call `update_memory` **alongside** your normal response — don't
|
If yes, call `update_memory` **alongside** your normal response — don't
|
||||||
defer it to a later turn. Skip ephemeral chat noise (one-off Q/A, greetings,
|
defer it to a later turn. Skip ephemeral chat noise (one-off Q/A, greetings,
|
||||||
session logistics). Stay within the budget shown in `<user_memory>`.
|
session logistics). Stay within the budget shown in `<user_memory>`.
|
||||||
|
|
||||||
|
Memory is heading-based markdown. New entries should be under `##` headings
|
||||||
|
such as `## Facts`, `## Preferences`, or `## Instructions`, with bullets like
|
||||||
|
`- YYYY-MM-DD: text`. If existing memory contains legacy
|
||||||
|
`(YYYY-MM-DD) [fact|pref|instr]` markers, preserve the information but write
|
||||||
|
new saves in the heading-based format.
|
||||||
</memory_protocol>
|
</memory_protocol>
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,12 @@ key facts?
|
||||||
If yes, call `update_memory` **alongside** your normal response — don't
|
If yes, call `update_memory` **alongside** your normal response — don't
|
||||||
defer it to a later turn. Skip ephemeral chat noise (one-off Q/A, greetings,
|
defer it to a later turn. Skip ephemeral chat noise (one-off Q/A, greetings,
|
||||||
session logistics). Stay within the budget shown in `<team_memory>`.
|
session logistics). Stay within the budget shown in `<team_memory>`.
|
||||||
|
|
||||||
|
Team memory is heading-based markdown. New entries should be under `##`
|
||||||
|
headings such as `## Product Decisions`, `## Engineering Conventions`,
|
||||||
|
`## Project Facts`, or `## Open Questions`, with bullets like
|
||||||
|
`- YYYY-MM-DD: text`. If existing memory contains legacy `(YYYY-MM-DD) [fact]`
|
||||||
|
markers, preserve the information but write new saves in the heading-based
|
||||||
|
format. Do not create personal headings such as `## Preferences` or
|
||||||
|
`## Instructions`.
|
||||||
</memory_protocol>
|
</memory_protocol>
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,9 @@
|
||||||
- Skip ephemeral chat noise (one-off Q/A, greetings, session logistics).
|
- Skip ephemeral chat noise (one-off Q/A, greetings, session logistics).
|
||||||
- Args: `updated_memory` — FULL replacement markdown (merge and curate,
|
- Args: `updated_memory` — FULL replacement markdown (merge and curate,
|
||||||
don't only append).
|
don't only append).
|
||||||
- Formatting: bullets `- (YYYY-MM-DD) [marker] text` with markers `[fact]`,
|
- Formatting: heading-based markdown with entries under `##` headings.
|
||||||
`[pref]`, `[instr]` (priority when trimming: `instr > pref > fact`).
|
Recommended headings are `## Facts`, `## Preferences`, `## Instructions`,
|
||||||
Group bullets under short `##` headings; stay under the limit shown in
|
though clearer natural headings are allowed. New bullets should look like
|
||||||
`<user_memory>`.
|
`- YYYY-MM-DD: text`; stay under the limit shown in `<user_memory>`.
|
||||||
|
- If existing memory uses legacy `(YYYY-MM-DD) [fact|pref|instr]` markers,
|
||||||
|
preserve the information but write the updated document in the new format.
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,28 @@
|
||||||
<example>
|
<example>
|
||||||
<user_name>Alex</user_name>, <user_memory> is empty.
|
<user_name>Alex</user_name>, <user_memory> is empty.
|
||||||
user: "I'm a space enthusiast, explain astrophage to me"
|
user: "I'm a space enthusiast, explain astrophage to me"
|
||||||
→ update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n")
|
→ update_memory(updated_memory="## Facts\n- 2025-03-15: Alex is a space enthusiast\n")
|
||||||
(Casual durable fact; use first name, neutral heading.)
|
(Casual durable fact; use first name, neutral heading.)
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
user: "Remember that I prefer concise answers over detailed explanations"
|
user: "Remember that I prefer concise answers over detailed explanations"
|
||||||
→ update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n\n## Response style\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\n")
|
→ update_memory(updated_memory="## Facts\n- 2025-03-15: Alex is a space enthusiast\n\n## Preferences\n- 2025-03-15: Alex prefers concise answers over detailed explanations\n")
|
||||||
(Durable preference; merge with existing memory.)
|
(Durable preference; merge with existing memory.)
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
user: "I actually moved to Tokyo last month"
|
user: "I actually moved to Tokyo last month"
|
||||||
→ update_memory(updated_memory="...\n\n## Personal context\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\n...")
|
→ update_memory(updated_memory="...\n\n## Facts\n- 2025-03-15: Alex lives in Tokyo (previously London)\n...")
|
||||||
(Updated fact; date reflects when recorded.)
|
(Updated fact; date reflects when recorded.)
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
user: "I'm a freelance photographer working on a nature documentary"
|
user: "I'm a freelance photographer working on a nature documentary"
|
||||||
→ update_memory(updated_memory="...\n\n## Current focus\n- (2025-03-15) [fact] Alex is a freelance photographer\n- (2025-03-15) [fact] Alex is working on a nature documentary\n")
|
→ update_memory(updated_memory="...\n\n## Current Focus\n- 2025-03-15: Alex is a freelance photographer\n- 2025-03-15: Alex is working on a nature documentary\n")
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
user: "Always respond in bullet points"
|
user: "Always respond in bullet points"
|
||||||
→ update_memory(updated_memory="...\n\n## Response style\n- (2025-03-15) [instr] Always respond to Alex in bullet points\n")
|
→ update_memory(updated_memory="...\n\n## Instructions\n- 2025-03-15: Always respond to Alex in bullet points\n")
|
||||||
</example>
|
</example>
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,14 @@
|
||||||
- Skip ephemeral chat noise (one-off Q/A, greetings, session logistics).
|
- Skip ephemeral chat noise (one-off Q/A, greetings, session logistics).
|
||||||
- Args: `updated_memory` — FULL replacement markdown (merge and curate,
|
- Args: `updated_memory` — FULL replacement markdown (merge and curate,
|
||||||
don't only append).
|
don't only append).
|
||||||
- Formatting: bullets `- (YYYY-MM-DD) [fact] text`. Team memory uses ONLY
|
- Formatting: heading-based markdown with entries under `##` headings.
|
||||||
the `[fact]` marker (never `[pref]` or `[instr]`). Group bullets under
|
Recommended headings are `## Product Decisions`,
|
||||||
short `##` headings (2-3 words each); stay under the limit shown in
|
`## Engineering Conventions`, `## Project Facts`, and `## Open Questions`.
|
||||||
`<team_memory>`. When trimming, prioritise: decisions/conventions > key
|
New bullets should look like `- YYYY-MM-DD: text`; stay under the limit
|
||||||
facts > current priorities.
|
shown in `<team_memory>`.
|
||||||
|
- If existing memory uses legacy `(YYYY-MM-DD) [fact]` markers, preserve the
|
||||||
|
information but write the updated document in the new format.
|
||||||
|
- Do not create personal headings such as `## Preferences`,
|
||||||
|
`## Instructions`, `## Personal Notes`, or `## Personal Instructions`.
|
||||||
|
When trimming, prioritise: decisions/conventions > key facts > current
|
||||||
|
priorities.
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
<example>
|
<example>
|
||||||
user: "Let's remember that we decided to do weekly standup meetings on Mondays"
|
user: "Let's remember that we decided to do weekly standup meetings on Mondays"
|
||||||
→ update_memory(updated_memory="...\n\n## Team rituals\n- (2025-03-15) [fact] Weekly standup meetings on Mondays\n...")
|
→ update_memory(updated_memory="...\n\n## Product Decisions\n- 2025-03-15: Weekly standup meetings happen on Mondays\n...")
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
user: "Our office is in downtown Seattle, 5th floor"
|
user: "Our office is in downtown Seattle, 5th floor"
|
||||||
→ update_memory(updated_memory="...\n\n## Workspace\n- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\n...")
|
→ update_memory(updated_memory="...\n\n## Project Facts\n- 2025-03-15: Office location is downtown Seattle, 5th floor\n...")
|
||||||
</example>
|
</example>
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from deepagents.backends.protocol import BackendFactory, BackendProtocol
|
from deepagents.backends.protocol import BackendFactory, BackendProtocol
|
||||||
|
|
@ -15,8 +16,12 @@ from langchain.agents import create_agent
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
from .task_tool import build_task_tool_with_parent_config
|
from .task_tool import build_task_tool_with_parent_config
|
||||||
|
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
||||||
"""``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer."""
|
"""``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer."""
|
||||||
|
|
@ -54,8 +59,11 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
||||||
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
|
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
|
||||||
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
|
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
|
||||||
specs: list[dict[str, Any]] = []
|
specs: list[dict[str, Any]] = []
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
timings: list[tuple[str, float, str]] = [] # (name, elapsed, source)
|
||||||
|
|
||||||
for spec in self._subagents:
|
for spec in self._subagents:
|
||||||
|
spec_start = time.perf_counter()
|
||||||
if "runnable" in spec:
|
if "runnable" in spec:
|
||||||
compiled = cast(CompiledSubAgent, spec)
|
compiled = cast(CompiledSubAgent, spec)
|
||||||
specs.append(
|
specs.append(
|
||||||
|
|
@ -65,6 +73,9 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
||||||
"runnable": compiled["runnable"],
|
"runnable": compiled["runnable"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
timings.append(
|
||||||
|
(compiled["name"], time.perf_counter() - spec_start, "precompiled")
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "model" not in spec:
|
if "model" not in spec:
|
||||||
|
|
@ -79,20 +90,44 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
||||||
model = init_chat_model(model)
|
model = init_chat_model(model)
|
||||||
|
|
||||||
middleware: list[Any] = list(spec.get("middleware", []))
|
middleware: list[Any] = list(spec.get("middleware", []))
|
||||||
|
tools_count = len(spec.get("tools") or [])
|
||||||
|
mw_count = len(middleware)
|
||||||
|
|
||||||
|
compile_start = time.perf_counter()
|
||||||
|
runnable = create_agent(
|
||||||
|
model,
|
||||||
|
system_prompt=spec["system_prompt"],
|
||||||
|
tools=spec["tools"],
|
||||||
|
middleware=middleware,
|
||||||
|
name=spec["name"],
|
||||||
|
checkpointer=self._surf_checkpointer,
|
||||||
|
)
|
||||||
|
compile_elapsed = time.perf_counter() - compile_start
|
||||||
specs.append(
|
specs.append(
|
||||||
{
|
{
|
||||||
"name": spec["name"],
|
"name": spec["name"],
|
||||||
"description": spec["description"],
|
"description": spec["description"],
|
||||||
"runnable": create_agent(
|
"runnable": runnable,
|
||||||
model,
|
|
||||||
system_prompt=spec["system_prompt"],
|
|
||||||
tools=spec["tools"],
|
|
||||||
middleware=middleware,
|
|
||||||
name=spec["name"],
|
|
||||||
checkpointer=self._surf_checkpointer,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
timings.append(
|
||||||
|
(
|
||||||
|
spec["name"],
|
||||||
|
compile_elapsed,
|
||||||
|
f"compiled tools={tools_count} mw={mw_count}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
total_elapsed = time.perf_counter() - loop_start
|
||||||
|
per_subagent = ", ".join(
|
||||||
|
f"{name}={elapsed * 1000:.0f}ms[{source}]"
|
||||||
|
for name, elapsed, source in timings
|
||||||
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[subagent_compile] total=%.3fs count=%d details=[%s]",
|
||||||
|
total_elapsed,
|
||||||
|
len(timings),
|
||||||
|
per_subagent,
|
||||||
|
)
|
||||||
|
|
||||||
return specs
|
return specs
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ re-raises any new pending interrupt back to the parent.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Annotated, Any, NoReturn
|
from typing import Annotated, Any, NoReturn
|
||||||
|
|
||||||
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
|
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
|
||||||
|
|
@ -19,6 +20,8 @@ from langchain_core.tools import StructuredTool
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
from langgraph.types import Command, Interrupt
|
from langgraph.types import Command, Interrupt
|
||||||
|
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
consume_surfsense_resume,
|
consume_surfsense_resume,
|
||||||
drain_parent_null_resume,
|
drain_parent_null_resume,
|
||||||
|
|
@ -35,6 +38,7 @@ from .resume import (
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
def _reraise_stamped_subagent_interrupt(
|
def _reraise_stamped_subagent_interrupt(
|
||||||
|
|
@ -209,6 +213,7 @@ def build_task_tool_with_parent_config(
|
||||||
],
|
],
|
||||||
runtime: ToolRuntime,
|
runtime: ToolRuntime,
|
||||||
) -> str | Command:
|
) -> str | Command:
|
||||||
|
atask_start = time.perf_counter()
|
||||||
logger.info(
|
logger.info(
|
||||||
"[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s",
|
"[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s",
|
||||||
subagent_type,
|
subagent_type,
|
||||||
|
|
@ -230,8 +235,10 @@ def build_task_tool_with_parent_config(
|
||||||
# Resume bridge — see ``task`` above.
|
# Resume bridge — see ``task`` above.
|
||||||
pending_id: str | None = None
|
pending_id: str | None = None
|
||||||
pending_value: Any = None
|
pending_value: Any = None
|
||||||
|
aget_state_elapsed = 0.0
|
||||||
aget_state = getattr(subagent, "aget_state", None)
|
aget_state = getattr(subagent, "aget_state", None)
|
||||||
if callable(aget_state):
|
if callable(aget_state):
|
||||||
|
aget_state_start = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
snapshot = await aget_state(sub_config)
|
snapshot = await aget_state(sub_config)
|
||||||
pending_id, pending_value = get_first_pending_subagent_interrupt(
|
pending_id, pending_value = get_first_pending_subagent_interrupt(
|
||||||
|
|
@ -248,32 +255,78 @@ def build_task_tool_with_parent_config(
|
||||||
"Subagent aget_state failed; falling back to fresh ainvoke",
|
"Subagent aget_state failed; falling back to fresh ainvoke",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
aget_state_elapsed = time.perf_counter() - aget_state_start
|
||||||
|
|
||||||
if pending_value is not None:
|
invoke_path = "resume" if pending_value is not None else "fresh"
|
||||||
resume_value = consume_surfsense_resume(runtime)
|
ainvoke_start = time.perf_counter()
|
||||||
if resume_value is None:
|
ainvoke_outcome = "ok"
|
||||||
raise RuntimeError(
|
try:
|
||||||
f"Subagent {subagent_type!r} has a pending interrupt but no "
|
if pending_value is not None:
|
||||||
"surfsense_resume_value on config; resume bridge is broken."
|
resume_value = consume_surfsense_resume(runtime)
|
||||||
)
|
if resume_value is None:
|
||||||
expected = hitlrequest_action_count(pending_value)
|
raise RuntimeError(
|
||||||
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
f"Subagent {subagent_type!r} has a pending interrupt but no "
|
||||||
# Prevent the parent's resume payload from leaking into subagent
|
"surfsense_resume_value on config; resume bridge is broken."
|
||||||
# interrupts via langgraph's parent_scratchpad fallback.
|
)
|
||||||
drain_parent_null_resume(runtime)
|
expected = hitlrequest_action_count(pending_value)
|
||||||
try:
|
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
||||||
result = await subagent.ainvoke(
|
# Prevent the parent's resume payload from leaking into subagent
|
||||||
build_resume_command(resume_value, pending_id),
|
# interrupts via langgraph's parent_scratchpad fallback.
|
||||||
config=sub_config,
|
drain_parent_null_resume(runtime)
|
||||||
)
|
try:
|
||||||
except GraphInterrupt as gi:
|
result = await subagent.ainvoke(
|
||||||
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
build_resume_command(resume_value, pending_id),
|
||||||
else:
|
config=sub_config,
|
||||||
try:
|
)
|
||||||
result = await subagent.ainvoke(subagent_state, config=sub_config)
|
except GraphInterrupt as gi:
|
||||||
except GraphInterrupt as gi:
|
ainvoke_outcome = "interrupted"
|
||||||
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
_perf_log.info(
|
||||||
return _return_command_with_state_update(result, runtime.tool_call_id)
|
"[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s "
|
||||||
|
"aget_state=%.3fs ainvoke=%.3fs total=%.3fs",
|
||||||
|
subagent_type,
|
||||||
|
invoke_path,
|
||||||
|
ainvoke_outcome,
|
||||||
|
aget_state_elapsed,
|
||||||
|
time.perf_counter() - ainvoke_start,
|
||||||
|
time.perf_counter() - atask_start,
|
||||||
|
)
|
||||||
|
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
result = await subagent.ainvoke(subagent_state, config=sub_config)
|
||||||
|
except GraphInterrupt as gi:
|
||||||
|
ainvoke_outcome = "interrupted"
|
||||||
|
_perf_log.info(
|
||||||
|
"[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s "
|
||||||
|
"aget_state=%.3fs ainvoke=%.3fs total=%.3fs",
|
||||||
|
subagent_type,
|
||||||
|
invoke_path,
|
||||||
|
ainvoke_outcome,
|
||||||
|
aget_state_elapsed,
|
||||||
|
time.perf_counter() - ainvoke_start,
|
||||||
|
time.perf_counter() - atask_start,
|
||||||
|
)
|
||||||
|
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
||||||
|
ainvoke_elapsed = time.perf_counter() - ainvoke_start
|
||||||
|
except GraphInterrupt:
|
||||||
|
raise
|
||||||
|
|
||||||
|
merge_start = time.perf_counter()
|
||||||
|
cmd = _return_command_with_state_update(result, runtime.tool_call_id)
|
||||||
|
merge_elapsed = time.perf_counter() - merge_start
|
||||||
|
_perf_log.info(
|
||||||
|
"[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s "
|
||||||
|
"aget_state=%.3fs ainvoke=%.3fs merge=%.3fs total=%.3fs",
|
||||||
|
subagent_type,
|
||||||
|
invoke_path,
|
||||||
|
ainvoke_outcome,
|
||||||
|
aget_state_elapsed,
|
||||||
|
ainvoke_elapsed,
|
||||||
|
merge_elapsed,
|
||||||
|
time.perf_counter() - atask_start,
|
||||||
|
)
|
||||||
|
return cmd
|
||||||
|
|
||||||
return StructuredTool.from_function(
|
return StructuredTool.from_function(
|
||||||
name="task",
|
name="task",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
|
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
|
||||||
|
from app.services.llm_service import get_planner_llm
|
||||||
|
|
||||||
|
|
||||||
def build_knowledge_priority_mw(
|
def build_knowledge_priority_mw(
|
||||||
|
|
@ -19,6 +20,7 @@ def build_knowledge_priority_mw(
|
||||||
) -> KnowledgePriorityMiddleware:
|
) -> KnowledgePriorityMiddleware:
|
||||||
return KnowledgePriorityMiddleware(
|
return KnowledgePriorityMiddleware(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
planner_llm=get_planner_llm(),
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
filesystem_mode=filesystem_mode,
|
filesystem_mode=filesystem_mode,
|
||||||
available_connectors=available_connectors,
|
available_connectors=available_connectors,
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
|
@ -10,6 +11,9 @@ from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||||
from app.agents.new_chat.middleware.knowledge_search import _render_priority_message
|
from app.agents.new_chat.middleware.knowledge_search import _render_priority_message
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
@ -30,17 +34,34 @@ class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
runtime: Runtime[Any],
|
runtime: Runtime[Any],
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
del runtime
|
del runtime
|
||||||
|
start = time.perf_counter()
|
||||||
tree_text = state.get("workspace_tree_text")
|
tree_text = state.get("workspace_tree_text")
|
||||||
priority = state.get("kb_priority")
|
priority = state.get("kb_priority")
|
||||||
if not tree_text and not priority:
|
if not tree_text and not priority:
|
||||||
|
_perf_log.info(
|
||||||
|
"[kb_context_projection] tree=0 priority=0 elapsed=%.3fs",
|
||||||
|
time.perf_counter() - start,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
messages = list(state.get("messages") or [])
|
messages = list(state.get("messages") or [])
|
||||||
insert_at = max(len(messages) - 1, 0)
|
insert_at = max(len(messages) - 1, 0)
|
||||||
|
tree_chars = 0
|
||||||
if tree_text:
|
if tree_text:
|
||||||
|
tree_chars = len(tree_text)
|
||||||
messages.insert(insert_at, SystemMessage(content=tree_text))
|
messages.insert(insert_at, SystemMessage(content=tree_text))
|
||||||
|
priority_count = 0
|
||||||
if priority:
|
if priority:
|
||||||
|
priority_count = (
|
||||||
|
len(priority) if hasattr(priority, "__len__") else 1
|
||||||
|
)
|
||||||
messages.insert(insert_at, _render_priority_message(priority))
|
messages.insert(insert_at, _render_priority_message(priority))
|
||||||
|
_perf_log.info(
|
||||||
|
"[kb_context_projection] tree_chars=%d priority_items=%d elapsed=%.3fs",
|
||||||
|
tree_chars,
|
||||||
|
priority_count,
|
||||||
|
time.perf_counter() - start,
|
||||||
|
)
|
||||||
return {"messages": messages}
|
return {"messages": messages}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,4 +2,4 @@ Read-only specialist for the user's workspace (documents and folders). Use to fi
|
||||||
|
|
||||||
Pass your full question as one string. The specialist runs in isolation: it cannot see this thread, so include any path hints, filters, or constraints it needs.
|
Pass your full question as one string. The specialist runs in isolation: it cannot see this thread, so include any path hints, filters, or constraints it needs.
|
||||||
|
|
||||||
The specialist returns plain prose with absolute paths.
|
The specialist returns plain prose with absolute paths and `[citation:<chunk_id>]` markers when claims came from KB-indexed chunks. Preserve those markers verbatim if you forward the answer.
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,43 @@ Map outcomes to your `status`:
|
||||||
|
|
||||||
You construct the structured `evidence` fields from your own knowledge of what you called and what you observed — the tools do not return them. Never report values you did not actually see.
|
You construct the structured `evidence` fields from your own knowledge of what you called and what you observed — the tools do not return them. Never report values you did not actually see.
|
||||||
|
|
||||||
|
## Chunk citations in your prose
|
||||||
|
|
||||||
|
When `read_file` returns a KB-indexed document under `/documents/`, the response includes `<chunk id='…'>` blocks. Whenever a fact in your `action_summary` or `evidence.content_excerpt` came from a specific chunk, append `[citation:<chunk_id>]` to the sentence stating that fact, using the **exact** id from the `<chunk id='…'>` tag. The caller relays these markers to the end user verbatim, and the UI resolves each id by exact match against the database, so a wrong id silently breaks the citation.
|
||||||
|
|
||||||
|
### Where chunk ids live in `read_file` output
|
||||||
|
|
||||||
|
A KB document's XML has three numeric attributes — only **one** is a citation source:
|
||||||
|
|
||||||
|
```
|
||||||
|
<document>
|
||||||
|
<document_metadata>
|
||||||
|
<document_id>42</document_id> ← NOT a citation. Parent doc id; ignore for citations.
|
||||||
|
...
|
||||||
|
</document_metadata>
|
||||||
|
<chunk_index>
|
||||||
|
<entry chunk_id="128" lines="14-22"/> ← Index hint; the same id also appears below.
|
||||||
|
<entry chunk_id="129" lines="23-30" matched="true"/>
|
||||||
|
</chunk_index>
|
||||||
|
<document_content>
|
||||||
|
<chunk id='128'><![CDATA[…]]></chunk> ← This is the citation source.
|
||||||
|
<chunk id='129'><![CDATA[…]]></chunk>
|
||||||
|
</document_content>
|
||||||
|
</document>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Rules
|
||||||
|
|
||||||
|
- Use the **exact** id from a `<chunk id='…'>` tag whose content you actually quoted or paraphrased. Copy digit-for-digit; do **not** retype from memory.
|
||||||
|
- Before emitting `[citation:N]`, confirm the literal substring `<chunk id='N'>` (or its index twin `chunk_id="N"`) appears in the tool result you are summarising this turn. If you can't see it, omit the citation.
|
||||||
|
- Never cite `<document_id>` — that's the parent doc, not a chunk.
|
||||||
|
- Never invent, normalise, shorten, or guess at adjacent ids. If unsure between two candidates, omit rather than pick.
|
||||||
|
- Prefer **fewer accurate citations** over many speculative ones.
|
||||||
|
- Multiple chunks supporting the same point → comma-separated and copied individually: `[citation:128], [citation:129]`.
|
||||||
|
- Plain square brackets only — no markdown links, no parentheses, no footnote numbers.
|
||||||
|
- Tool results without `<chunk id='…'>` (write/edit/move confirmations, `ls` / `glob` / `grep` listings, error strings) carry no chunk id and need none.
|
||||||
|
- Populate `evidence.chunk_ids` with **only** ids you actually emitted in `[citation:…]` markers — same set, same digits.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
**Example 1 — happy path write (path discovered from existing convention):**
|
**Example 1 — happy path write (path discovered from existing convention):**
|
||||||
|
|
@ -118,5 +155,6 @@ Rules:
|
||||||
- `status=success` → `next_step=null`, `missing_fields=null`.
|
- `status=success` → `next_step=null`, `missing_fields=null`.
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
|
- `evidence.content_excerpt`: max ~500 characters. Surface a short excerpt or a one-sentence summary, not the full file body. The supervisor already sees the tool's raw output.
|
||||||
|
|
||||||
Infer before you call; map every tool outcome faithfully.
|
Infer before you call; map every tool outcome faithfully.
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,10 @@ Map outcomes to your `status`:
|
||||||
|
|
||||||
You construct the structured `evidence` fields from your own knowledge of what you called and what you observed — the tools do not return them. `chunk_ids` apply only to `<priority_documents>` hits; for local-file operations leave them `null`. Never report values you did not actually see.
|
You construct the structured `evidence` fields from your own knowledge of what you called and what you observed — the tools do not return them. `chunk_ids` apply only to `<priority_documents>` hits; for local-file operations leave them `null`. Never report values you did not actually see.
|
||||||
|
|
||||||
|
## Chunk citations in your prose
|
||||||
|
|
||||||
|
In desktop mode your filesystem tools read local files only, and local-file tool results do **not** carry `<chunk id='…'>` tags. Do not emit `[citation:…]` markers in `action_summary` or `evidence.content_excerpt`, and leave `evidence.chunk_ids` `null` — the absolute path is the only reference for local-file work.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
**Example 1 — happy path write (path discovered from existing convention):**
|
**Example 1 — happy path write (path discovered from existing convention):**
|
||||||
|
|
@ -118,5 +122,6 @@ Rules:
|
||||||
- `status=success` → `next_step=null`, `missing_fields=null`.
|
- `status=success` → `next_step=null`, `missing_fields=null`.
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
|
- `evidence.content_excerpt`: max ~500 characters. Surface a short excerpt or a one-sentence summary, not the full file body. The supervisor already sees the tool's raw output.
|
||||||
|
|
||||||
Infer before you call; map every tool outcome faithfully.
|
Infer before you call; map every tool outcome faithfully.
|
||||||
|
|
|
||||||
|
|
@ -27,3 +27,42 @@ Reply in plain prose:
|
||||||
- Cite every claim with an absolute path under `/documents/`.
|
- Cite every claim with an absolute path under `/documents/`.
|
||||||
- If the workspace does not contain the requested information, say so explicitly. Do not fabricate paths or content.
|
- If the workspace does not contain the requested information, say so explicitly. Do not fabricate paths or content.
|
||||||
- If the question is genuinely ambiguous after a thorough lookup, list the candidates with their paths and stop.
|
- If the question is genuinely ambiguous after a thorough lookup, list the candidates with their paths and stop.
|
||||||
|
|
||||||
|
## Chunk citations
|
||||||
|
|
||||||
|
When the evidence for a claim came from a `read_file` response that included `<chunk id='…'>` blocks (i.e. a KB-indexed document under `/documents/`), append `[citation:<chunk_id>]` to the sentence stating that claim. The caller passes these markers through to the end user verbatim, and the UI resolves each id by exact match against the database, so a wrong id silently breaks the citation.
|
||||||
|
|
||||||
|
### Where chunk ids live in `read_file` output
|
||||||
|
|
||||||
|
A KB document's XML has three numeric attributes — only **one** is a citation source:
|
||||||
|
|
||||||
|
```
|
||||||
|
<document>
|
||||||
|
<document_metadata>
|
||||||
|
<document_id>42</document_id> ← NOT a citation. Parent doc id; ignore for citations.
|
||||||
|
...
|
||||||
|
</document_metadata>
|
||||||
|
<chunk_index>
|
||||||
|
<entry chunk_id="128" lines="14-22"/> ← Index hint; the same id also appears below.
|
||||||
|
<entry chunk_id="129" lines="23-30" matched="true"/>
|
||||||
|
</chunk_index>
|
||||||
|
<document_content>
|
||||||
|
<chunk id='128'><![CDATA[…]]></chunk> ← This is the citation source.
|
||||||
|
<chunk id='129'><![CDATA[…]]></chunk>
|
||||||
|
</document_content>
|
||||||
|
</document>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Rules
|
||||||
|
|
||||||
|
- Use the **exact** id from a `<chunk id='…'>` tag whose content you actually quoted or paraphrased. Copy digit-for-digit; do **not** retype from memory.
|
||||||
|
- Before emitting `[citation:N]`, confirm the literal substring `<chunk id='N'>` (or its index twin `chunk_id="N"`) appears in the tool result you are summarising this turn. If you can't see it, omit the citation.
|
||||||
|
- Never cite `<document_id>` — that's the parent doc, not a chunk.
|
||||||
|
- Never invent, normalise, shorten, or guess at adjacent ids. If unsure between two candidates, omit rather than pick.
|
||||||
|
- Prefer **fewer accurate citations** over many speculative ones. One correct `[citation:128]` is more useful than a string of wrong ids.
|
||||||
|
- Multiple chunks supporting the same point → comma-separated and copied individually: `[citation:128], [citation:129]`.
|
||||||
|
- Plain square brackets only — no markdown links, no parentheses, no footnote numbers.
|
||||||
|
- If a claim came from a tool result that did **not** carry a chunk id (`ls`, `glob`, `grep` listings, error strings, or files without `<chunk id='…'>`), skip the citation.
|
||||||
|
- The absolute path under `/documents/` is always required; chunk citations are additive, they do not replace the path reference.
|
||||||
|
|
||||||
|
Example: `The Q2 roadmap lists three milestones (/documents/planning/q2-roadmap.md) [citation:128], [citation:129].`
|
||||||
|
|
|
||||||
|
|
@ -28,3 +28,7 @@ Reply in plain prose:
|
||||||
- Cite every claim with an absolute path.
|
- Cite every claim with an absolute path.
|
||||||
- If the workspace does not contain the requested information, say so explicitly. Do not fabricate paths or content.
|
- If the workspace does not contain the requested information, say so explicitly. Do not fabricate paths or content.
|
||||||
- If the question is genuinely ambiguous after a thorough lookup, list the candidates with their paths and stop.
|
- If the question is genuinely ambiguous after a thorough lookup, list the candidates with their paths and stop.
|
||||||
|
|
||||||
|
## Chunk citations
|
||||||
|
|
||||||
|
In desktop mode your filesystem tools read local files only, and local-file `read_file` responses do **not** carry `<chunk id='…'>` tags. Cite each claim with the absolute local path; do not emit `[citation:…]` markers — your caller has nothing to resolve them against.
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,10 @@ Persist durable preferences/facts/instructions with `update_memory` while avoidi
|
||||||
- Do not store transient chatter.
|
- Do not store transient chatter.
|
||||||
- Do not store secrets unless explicitly instructed.
|
- Do not store secrets unless explicitly instructed.
|
||||||
- If memory intent is unclear, return `status=blocked` with the missing intent signal.
|
- If memory intent is unclear, return `status=blocked` with the missing intent signal.
|
||||||
|
- Persisted memory is heading-based markdown. New saved bullets should look like
|
||||||
|
`- YYYY-MM-DD: text` under `##` headings. If existing memory has legacy
|
||||||
|
`(YYYY-MM-DD) [fact|pref|instr]` markers, preserve the information but write
|
||||||
|
the updated document in the heading-based format.
|
||||||
</tool_policy>
|
</tool_policy>
|
||||||
|
|
||||||
<out_of_scope>
|
<out_of_scope>
|
||||||
|
|
@ -53,4 +57,7 @@ Rules:
|
||||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||||
|
- `evidence.memory_category` is a semantic classification for supervisor logs
|
||||||
|
only. It is not the persisted storage format and must not force inline
|
||||||
|
`[fact|preference|instruction]` markers into saved memory.
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
|
||||||
|
|
@ -1,280 +1,23 @@
|
||||||
"""Overwrite one markdown memory document per user or team, with size and shrink guards."""
|
"""Memory update tools backed by the canonical memory service."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
from typing import Any
|
||||||
from typing import Any, Literal
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import SearchSpace, User
|
from app.services.memory import (
|
||||||
|
MEMORY_HARD_LIMIT,
|
||||||
|
MEMORY_SOFT_LIMIT,
|
||||||
|
MemoryScope,
|
||||||
|
save_memory,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MEMORY_SOFT_LIMIT = 18_000
|
|
||||||
MEMORY_HARD_LIMIT = 25_000
|
|
||||||
|
|
||||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
|
||||||
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
|
|
||||||
|
|
||||||
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
|
|
||||||
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
|
|
||||||
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Diff validation
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_headings(memory: str) -> set[str]:
|
|
||||||
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
|
|
||||||
return set(_SECTION_HEADING_RE.findall(memory))
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_heading(heading: str) -> str:
|
|
||||||
"""Normalize heading text for robust scope checks."""
|
|
||||||
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_memory_scope(
|
|
||||||
content: str, scope: Literal["user", "team"]
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Reject personal-only markers ([pref], [instr]) in team memory."""
|
|
||||||
if scope != "team":
|
|
||||||
return None
|
|
||||||
|
|
||||||
markers = set(_MARKER_RE.findall(content))
|
|
||||||
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
|
|
||||||
if leaked:
|
|
||||||
tags = ", ".join(f"[{m}]" for m in leaked)
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": (
|
|
||||||
f"Team memory cannot include personal markers: {tags}. "
|
|
||||||
"Use [fact] only in team memory."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_bullet_format(content: str) -> list[str]:
|
|
||||||
"""Return warnings for bullet lines that don't match the required format.
|
|
||||||
|
|
||||||
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
|
|
||||||
"""
|
|
||||||
warnings: list[str] = []
|
|
||||||
for line in content.splitlines():
|
|
||||||
stripped = line.strip()
|
|
||||||
if not stripped.startswith("- "):
|
|
||||||
continue
|
|
||||||
if not _BULLET_FORMAT_RE.match(stripped):
|
|
||||||
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
|
|
||||||
warnings.append(f"Malformed bullet: {short}")
|
|
||||||
return warnings
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
|
||||||
"""Return a list of warning strings about suspicious changes."""
|
|
||||||
if not old_memory:
|
|
||||||
return []
|
|
||||||
|
|
||||||
warnings: list[str] = []
|
|
||||||
old_headings = _extract_headings(old_memory)
|
|
||||||
new_headings = _extract_headings(new_memory)
|
|
||||||
dropped = old_headings - new_headings
|
|
||||||
if dropped:
|
|
||||||
names = ", ".join(sorted(dropped))
|
|
||||||
warnings.append(
|
|
||||||
f"Sections removed: {names}. "
|
|
||||||
"If unintentional, the user can restore from the settings page."
|
|
||||||
)
|
|
||||||
|
|
||||||
old_len = len(old_memory)
|
|
||||||
new_len = len(new_memory)
|
|
||||||
if old_len > 0 and new_len < old_len * 0.4:
|
|
||||||
warnings.append(
|
|
||||||
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
|
|
||||||
"Possible data loss."
|
|
||||||
)
|
|
||||||
return warnings
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Size validation & soft warning
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_memory_size(content: str) -> dict[str, Any] | None:
|
|
||||||
"""Return an error/warning dict if *content* is too large, else None."""
|
|
||||||
length = len(content)
|
|
||||||
if length > MEMORY_HARD_LIMIT:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": (
|
|
||||||
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
|
|
||||||
f"({length:,} chars). Consolidate by merging related items, "
|
|
||||||
"removing outdated entries, and shortening descriptions. "
|
|
||||||
"Then call update_memory again."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _soft_warning(content: str) -> str | None:
|
|
||||||
"""Return a warning string if content exceeds the soft limit."""
|
|
||||||
length = len(content)
|
|
||||||
if length > MEMORY_SOFT_LIMIT:
|
|
||||||
return (
|
|
||||||
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
|
|
||||||
"Consolidate by merging related items and removing less important "
|
|
||||||
"entries on your next update."
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Forced rewrite when memory exceeds the hard limit
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_FORCED_REWRITE_PROMPT = """\
|
|
||||||
You are a memory curator. The following memory document exceeds the character \
|
|
||||||
limit and must be shortened.
|
|
||||||
|
|
||||||
RULES:
|
|
||||||
1. Rewrite the document to be under {target} characters.
|
|
||||||
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
|
|
||||||
or rename headings to consolidate, but keep names personal and descriptive.
|
|
||||||
3. Priority for keeping content: [instr] > [pref] > [fact].
|
|
||||||
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
|
|
||||||
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
|
|
||||||
6. Preserve the user's first name in entries — do not replace it with "the user".
|
|
||||||
7. Output ONLY the consolidated markdown — no explanations, no wrapping.
|
|
||||||
|
|
||||||
<memory_document>
|
|
||||||
{content}
|
|
||||||
</memory_document>"""
|
|
||||||
|
|
||||||
|
|
||||||
async def _forced_rewrite(content: str, llm: Any) -> str | None:
|
|
||||||
"""Use a focused LLM call to compress *content* under the hard limit.
|
|
||||||
|
|
||||||
Returns the rewritten string, or ``None`` if the call fails.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
prompt = _FORCED_REWRITE_PROMPT.format(
|
|
||||||
target=MEMORY_HARD_LIMIT, content=content
|
|
||||||
)
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[HumanMessage(content=prompt)],
|
|
||||||
config={"tags": ["surfsense:internal"]},
|
|
||||||
)
|
|
||||||
text = (
|
|
||||||
response.content
|
|
||||||
if isinstance(response.content, str)
|
|
||||||
else str(response.content)
|
|
||||||
)
|
|
||||||
return text.strip()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Forced rewrite LLM call failed")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Shared save-and-respond logic
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def _save_memory(
|
|
||||||
*,
|
|
||||||
updated_memory: str,
|
|
||||||
old_memory: str | None,
|
|
||||||
llm: Any | None,
|
|
||||||
apply_fn,
|
|
||||||
commit_fn,
|
|
||||||
rollback_fn,
|
|
||||||
label: str,
|
|
||||||
scope: Literal["user", "team"],
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
|
||||||
return a response dict.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
updated_memory : str
|
|
||||||
The new document the agent submitted.
|
|
||||||
old_memory : str | None
|
|
||||||
The previously persisted document (for diff checks).
|
|
||||||
llm : Any | None
|
|
||||||
LLM instance for forced rewrite (may be ``None``).
|
|
||||||
apply_fn : callable(str) -> None
|
|
||||||
Callback that sets the new memory on the ORM object.
|
|
||||||
commit_fn : coroutine
|
|
||||||
``session.commit``.
|
|
||||||
rollback_fn : coroutine
|
|
||||||
``session.rollback``.
|
|
||||||
label : str
|
|
||||||
Human label for log messages (e.g. "user memory", "team memory").
|
|
||||||
"""
|
|
||||||
content = updated_memory
|
|
||||||
|
|
||||||
# --- forced rewrite if over the hard limit ---
|
|
||||||
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
|
|
||||||
rewritten = await _forced_rewrite(content, llm)
|
|
||||||
if rewritten is not None and len(rewritten) < len(content):
|
|
||||||
content = rewritten
|
|
||||||
|
|
||||||
# --- hard-limit gate (reject if still too large after rewrite) ---
|
|
||||||
size_err = _validate_memory_size(content)
|
|
||||||
if size_err:
|
|
||||||
return size_err
|
|
||||||
|
|
||||||
scope_err = _validate_memory_scope(content, scope)
|
|
||||||
if scope_err:
|
|
||||||
return scope_err
|
|
||||||
|
|
||||||
# --- persist ---
|
|
||||||
try:
|
|
||||||
apply_fn(content)
|
|
||||||
await commit_fn()
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Failed to update %s: %s", label, e)
|
|
||||||
await rollback_fn()
|
|
||||||
return {"status": "error", "message": f"Failed to update {label}: {e}"}
|
|
||||||
|
|
||||||
# --- build response ---
|
|
||||||
resp: dict[str, Any] = {
|
|
||||||
"status": "saved",
|
|
||||||
"message": f"{label.capitalize()} updated.",
|
|
||||||
}
|
|
||||||
|
|
||||||
if content is not updated_memory:
|
|
||||||
resp["notice"] = "Memory was automatically rewritten to fit within limits."
|
|
||||||
|
|
||||||
diff_warnings = _validate_diff(old_memory, content)
|
|
||||||
if diff_warnings:
|
|
||||||
resp["diff_warnings"] = diff_warnings
|
|
||||||
|
|
||||||
format_warnings = _validate_bullet_format(content)
|
|
||||||
if format_warnings:
|
|
||||||
resp["format_warnings"] = format_warnings
|
|
||||||
|
|
||||||
warning = _soft_warning(content)
|
|
||||||
if warning:
|
|
||||||
resp["warning"] = warning
|
|
||||||
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Tool factories
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def create_update_memory_tool(
|
def create_update_memory_tool(
|
||||||
user_id: str | UUID,
|
user_id: str | UUID,
|
||||||
|
|
@ -287,40 +30,22 @@ def create_update_memory_tool(
|
||||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||||
"""Update the user's personal memory document.
|
"""Update the user's personal memory document.
|
||||||
|
|
||||||
Your current memory is shown in <user_memory> in the system prompt.
|
The current memory is shown in <user_memory>. Pass the FULL updated
|
||||||
When the user shares important long-term information (preferences,
|
markdown document, not a diff.
|
||||||
facts, instructions, context), rewrite the memory document to include
|
|
||||||
the new information. Merge new facts with existing ones, update
|
|
||||||
contradictions, remove outdated entries, and keep it concise.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
updated_memory: The FULL updated markdown document (not a diff).
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = await db_session.execute(select(User).where(User.id == uid))
|
result = await save_memory(
|
||||||
user = result.scalars().first()
|
scope=MemoryScope.USER,
|
||||||
if not user:
|
target_id=uid,
|
||||||
return {"status": "error", "message": "User not found."}
|
content=updated_memory,
|
||||||
|
session=db_session,
|
||||||
old_memory = user.memory_md
|
|
||||||
|
|
||||||
return await _save_memory(
|
|
||||||
updated_memory=updated_memory,
|
|
||||||
old_memory=old_memory,
|
|
||||||
llm=llm,
|
llm=llm,
|
||||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
|
||||||
commit_fn=db_session.commit,
|
|
||||||
rollback_fn=db_session.rollback,
|
|
||||||
label="memory",
|
|
||||||
scope="user",
|
|
||||||
)
|
)
|
||||||
|
return result.to_dict()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to update user memory: %s", e)
|
logger.exception("Failed to update user memory: %s", e)
|
||||||
await db_session.rollback()
|
await db_session.rollback()
|
||||||
return {
|
return {"status": "error", "message": f"Failed to update memory: {e}"}
|
||||||
"status": "error",
|
|
||||||
"message": f"Failed to update memory: {e}",
|
|
||||||
}
|
|
||||||
|
|
||||||
return update_memory
|
return update_memory
|
||||||
|
|
||||||
|
|
@ -334,36 +59,18 @@ def create_update_team_memory_tool(
|
||||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||||
"""Update the team's shared memory document for this search space.
|
"""Update the team's shared memory document for this search space.
|
||||||
|
|
||||||
Your current team memory is shown in <team_memory> in the system
|
The current team memory is shown in <team_memory>. Pass the FULL updated
|
||||||
prompt. When the team shares important long-term information
|
markdown document, not a diff.
|
||||||
(decisions, conventions, key facts, priorities), rewrite the memory
|
|
||||||
document to include the new information. Merge new facts with
|
|
||||||
existing ones, update contradictions, remove outdated entries, and
|
|
||||||
keep it concise.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
updated_memory: The FULL updated markdown document (not a diff).
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = await db_session.execute(
|
result = await save_memory(
|
||||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
scope=MemoryScope.TEAM,
|
||||||
)
|
target_id=search_space_id,
|
||||||
space = result.scalars().first()
|
content=updated_memory,
|
||||||
if not space:
|
session=db_session,
|
||||||
return {"status": "error", "message": "Search space not found."}
|
|
||||||
|
|
||||||
old_memory = space.shared_memory_md
|
|
||||||
|
|
||||||
return await _save_memory(
|
|
||||||
updated_memory=updated_memory,
|
|
||||||
old_memory=old_memory,
|
|
||||||
llm=llm,
|
llm=llm,
|
||||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
|
||||||
commit_fn=db_session.commit,
|
|
||||||
rollback_fn=db_session.rollback,
|
|
||||||
label="team memory",
|
|
||||||
scope="team",
|
|
||||||
)
|
)
|
||||||
|
return result.to_dict()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to update team memory: %s", e)
|
logger.exception("Failed to update team memory: %s", e)
|
||||||
await db_session.rollback()
|
await db_session.rollback()
|
||||||
|
|
@ -373,3 +80,11 @@ def create_update_team_memory_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
return update_memory
|
return update_memory
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MEMORY_HARD_LIMIT",
|
||||||
|
"MEMORY_SOFT_LIMIT",
|
||||||
|
"create_update_memory_tool",
|
||||||
|
"create_update_team_memory_tool",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -50,4 +50,6 @@ Rules:
|
||||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||||
|
- `evidence.findings`: max 10 entries, each a single sentence stating one distinct fact. Do not paste raw paragraphs, scraped pages, or quote blocks.
|
||||||
|
- `evidence.sources`: max 10 URLs, one per finding when applicable. List each URL once.
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ Supervisor: "List open tasks in the Project Tracker base."
|
||||||
2. List tables in that base → identify the Tasks table; capture its table ID.
|
2. List tables in that base → identify the Tasks table; capture its table ID.
|
||||||
3. Get table schema → identify the status field and the choice IDs that represent "open" states.
|
3. Get table schema → identify the status field and the choice IDs that represent "open" states.
|
||||||
4. List records with a typed filter on the status field for those choice IDs.
|
4. List records with a typed filter on the status field for those choice IDs.
|
||||||
5. Return `status=success` with the matched records in `evidence.items`.
|
5. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched records listed in `action_summary` (record id, primary-field value, and 1-2 most relevant fields; one line per record; up to 10 entries, then `"...and N more"`).
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
|
|
@ -97,7 +97,7 @@ Rules:
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: base, table, field, choice, record, etc.).
|
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: base, table, field, choice, record, etc.).
|
||||||
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
|
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (record id, primary-field value, and 1-2 most relevant fields; up to 10 entries, then `"...and N more"`).
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
||||||
Discover before you mutate; never guess identifiers, choice IDs, or required fields.
|
Discover before you mutate; never guess identifiers, choice IDs, or required fields.
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ You are a Google Calendar specialist for the user's connected calendar.
|
||||||
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
||||||
| tool raises / unknown | `error` | `"Calendar tool failed unexpectedly. Ask the user to retry shortly."` |
|
| tool raises / unknown | `error` | `"Calendar tool failed unexpectedly. Ask the user to retry shortly."` |
|
||||||
|
|
||||||
Surface the tool's `event_id`, `title` / `summary`, `start_at`, `end_at`, and `html_link` inside `evidence` when the tool returned them. For `search_calendar_events`, place the raw `events` array inside `evidence.items`. Never invent a field the tool did not return.
|
Surface the tool's `event_id`, `title` / `summary`, `start_at`, `end_at`, and `html_link` inside `evidence` when the tool returned them. For `search_calendar_events`, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (title, date, start time; one line per event; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
|
@ -115,7 +115,7 @@ Rules:
|
||||||
- `status=success` → `next_step=null`, `missing_fields=null`.
|
- `status=success` → `next_step=null`, `missing_fields=null`.
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For `search_calendar_events` results, populate `evidence.items` with `{ "events": [...], "total": N }`.
|
- For `search_calendar_events` results, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (title, date, start time; up to 10 entries, then `"...and N more"`).
|
||||||
- For ambiguous matches across `update_calendar_event` / `delete_calendar_event`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`, where `label` should include the event title and start time for human readability).
|
- For ambiguous matches across `update_calendar_event` / `delete_calendar_event`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`, where `label` should include the event title and start time for human readability).
|
||||||
|
|
||||||
Infer before you call; map every tool outcome faithfully.
|
Infer before you call; map every tool outcome faithfully.
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ Failure handling:
|
||||||
<example>
|
<example>
|
||||||
Supervisor: "Find tasks about the homepage redesign."
|
Supervisor: "Find tasks about the homepage redesign."
|
||||||
1. Workspace search for "homepage redesign" → matched tasks.
|
1. Workspace search for "homepage redesign" → matched tasks.
|
||||||
2. Return `status=success` with the matched tasks in `evidence.items`.
|
2. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched tasks listed in `action_summary` (task id, title, status, assignees; one line per task; up to 10 entries, then `"...and N more"`).
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
|
|
@ -98,7 +98,7 @@ Rules:
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: task, list, member, status, custom-field choice, etc.).
|
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: task, list, member, status, custom-field choice, etc.).
|
||||||
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
|
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (task id, title, status, assignees; up to 10 entries, then `"...and N more"`).
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
||||||
Discover before you mutate; never guess identifiers, list statuses, or assignees.
|
Discover before you mutate; never guess identifiers, list statuses, or assignees.
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ You are a Discord specialist for the user's connected Discord server.
|
||||||
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
||||||
| tool raises / unknown | `error` | `"Discord tool failed unexpectedly. Ask the user to retry shortly."` |
|
| tool raises / unknown | `error` | `"Discord tool failed unexpectedly. Ask the user to retry shortly."` |
|
||||||
|
|
||||||
Surface the tool's `message`, `channel_id`, `message_id`, and the listed channels/messages payload inside `evidence` when the tool returned them. Never invent a field the tool did not return.
|
Surface the tool's `message`, `channel_id`, and `message_id` inside `evidence` when the tool returned them. For `list_discord_channels` and `read_discord_messages`, set `evidence.items` to `{ "total": N }` and list the matched entries in `action_summary` (channel name or sender + timestamp + short text snippet; one line per entry; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ You are a Gmail specialist for the user's connected Gmail mailbox.
|
||||||
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
||||||
| tool raises / unknown | `error` | `"Gmail tool failed unexpectedly. Ask the user to retry shortly."` |
|
| tool raises / unknown | `error` | `"Gmail tool failed unexpectedly. Ask the user to retry shortly."` |
|
||||||
|
|
||||||
Surface the tool's `message_id`, `thread_id`, `draft_id`, `subject`, and recipient fields inside `evidence` when the tool returned them. For `search_gmail`, place the raw `emails` array inside `evidence.items`. Never invent a field the tool did not return.
|
Surface the tool's `message_id`, `thread_id`, `draft_id`, `subject`, and recipient fields inside `evidence` when the tool returned them. For `search_gmail`, set `evidence.items` to `{ "total": N }` and list the matched emails in `action_summary` (sender, subject, date; one line per email; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
|
@ -114,7 +114,7 @@ Rules:
|
||||||
- `status=success` → `next_step=null`, `missing_fields=null`.
|
- `status=success` → `next_step=null`, `missing_fields=null`.
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For `search_gmail` results, populate `evidence.items` with `{ "emails": [...], "total": N }`.
|
- For `search_gmail` results, set `evidence.items` to `{ "total": N }` and list the matched emails in `action_summary` (sender, subject, date; up to 10 entries, then `"...and N more"`).
|
||||||
- For ambiguous matches across `update_gmail_draft` / `trash_gmail_email` / `read_gmail_email`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`).
|
- For ambiguous matches across `update_gmail_draft` / `trash_gmail_email` / `read_gmail_email`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`).
|
||||||
|
|
||||||
Infer before you call; verify before you send; map every tool outcome faithfully.
|
Infer before you call; verify before you send; map every tool outcome faithfully.
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ Failure handling:
|
||||||
<example>
|
<example>
|
||||||
Supervisor: "Find issues assigned to me with status 'In Progress'."
|
Supervisor: "Find issues assigned to me with status 'In Progress'."
|
||||||
1. JQL search with `assignee = currentUser() AND status = "In Progress"`.
|
1. JQL search with `assignee = currentUser() AND status = "In Progress"`.
|
||||||
2. Return `status=success` with the matched issues in `evidence.items`.
|
2. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched issues listed in `action_summary` (issue key, summary, status, assignee; one line per issue; up to 10 entries, then `"...and N more"`).
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
|
|
@ -116,7 +116,7 @@ Rules:
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: site, project, issue, user, transition, etc.).
|
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: site, project, issue, user, transition, etc.).
|
||||||
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
|
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (issue key, summary, status, assignee; up to 10 entries, then `"...and N more"`).
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
||||||
Discover before you mutate; never guess identifiers, transitions, or required fields.
|
Discover before you mutate; never guess identifiers, transitions, or required fields.
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ Failure handling:
|
||||||
<example>
|
<example>
|
||||||
Supervisor: "Find issues assigned to me with priority Urgent."
|
Supervisor: "Find issues assigned to me with priority Urgent."
|
||||||
1. Discovery: list issues with filters `{assignee: "me", priority: 1}`.
|
1. Discovery: list issues with filters `{assignee: "me", priority: 1}`.
|
||||||
2. Return `status=success` with the matched issues in `evidence.items`.
|
2. Return `status=success` with `evidence.items` set to `{ "total": N }` and the matched issues listed in `action_summary` (identifier, title, state, assignee; one line per issue; up to 10 entries, then `"...and N more"`).
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
|
|
@ -106,7 +106,7 @@ Rules:
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: issue, user, project, state, etc.).
|
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: issue, user, project, state, etc.).
|
||||||
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
|
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (identifier, title, state, assignee; up to 10 entries, then `"...and N more"`).
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
||||||
Discover before you mutate; never guess identifiers.
|
Discover before you mutate; never guess identifiers.
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ You are a Luma specialist for the user's connected Luma account.
|
||||||
| `error` | `error` | Relay the tool's `message` verbatim as `next_step` (this covers Luma Plus 403s and other API errors). |
|
| `error` | `error` | Relay the tool's `message` verbatim as `next_step` (this covers Luma Plus 403s and other API errors). |
|
||||||
| tool raises / unknown | `error` | `"Luma tool failed unexpectedly. Ask the user to retry shortly."` |
|
| tool raises / unknown | `error` | `"Luma tool failed unexpectedly. Ask the user to retry shortly."` |
|
||||||
|
|
||||||
Surface the tool's `message`, `event_id`, `name`, `start_at`, and `url` inside `evidence` when the tool returned them. Never invent a field the tool did not return.
|
Surface the tool's `message`, `event_id`, `name`, `start_at`, and `url` inside `evidence` when the tool returned them. For `list_luma_events`, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (event name, start date/time, location if present; one line per event; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ Failure handling:
|
||||||
Supervisor: "Summarize the latest discussion in #marketing."
|
Supervisor: "Summarize the latest discussion in #marketing."
|
||||||
1. Search channels for "marketing" → one strong match. Capture the channel ID.
|
1. Search channels for "marketing" → one strong match. Capture the channel ID.
|
||||||
2. Read that channel's recent message history.
|
2. Read that channel's recent message history.
|
||||||
3. Return `status=success` with the message list in `evidence.items`.
|
3. Return `status=success` with `evidence.items` set to `{ "total": N }` and the messages listed in `action_summary` (sender, timestamp, text snippet; one line per message; up to 10 entries, then `"...and N more"`).
|
||||||
</example>
|
</example>
|
||||||
|
|
||||||
<example>
|
<example>
|
||||||
|
|
@ -92,7 +92,7 @@ Rules:
|
||||||
- `status=partial|blocked|error` → `next_step` must be non-null.
|
- `status=partial|blocked|error` → `next_step` must be non-null.
|
||||||
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
- `status=blocked` due to missing required inputs → `missing_fields` must be non-null.
|
||||||
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: channel, user, message, thread).
|
- For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: channel, user, message, thread).
|
||||||
- For discovery-only queries (lists), populate `evidence.items` with the structured list.
|
- For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (channel/user, key identifier, timestamp, short snippet; up to 10 entries, then `"...and N more"`).
|
||||||
</output_contract>
|
</output_contract>
|
||||||
|
|
||||||
Discover before you post; never guess channel, user, or thread targets.
|
Discover before you post; never guess channel, user, or thread targets.
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ You are a Microsoft Teams specialist for the user's connected Teams account.
|
||||||
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
| `error` | `error` | Relay the tool's `message` verbatim as `next_step`. |
|
||||||
| tool raises / unknown | `error` | `"Teams tool failed unexpectedly. Ask the user to retry shortly."` |
|
| tool raises / unknown | `error` | `"Teams tool failed unexpectedly. Ask the user to retry shortly."` |
|
||||||
|
|
||||||
Surface the tool's `message`, `team_id`, `team_name`, `channel_id`, `channel_name`, and `message_id` inside `evidence` when the tool returned them. Never invent a field the tool did not return.
|
Surface the tool's `message`, `team_id`, `team_name`, `channel_id`, `channel_name`, and `message_id` inside `evidence` when the tool returned them. For `list_teams_channels` and `read_teams_messages`, set `evidence.items` to `{ "total": N }` and list the matched entries in `action_summary` (team › channel, or sender + timestamp + short text snippet; one line per entry; up to 10 entries, then `"...and N more"`). Never invent a field the tool did not return.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,7 @@ from app.agents.new_chat.tools.registry import (
|
||||||
)
|
)
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
|
from app.services.llm_service import get_planner_llm
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
|
@ -1077,6 +1078,7 @@ def _build_compiled_agent_blocking(
|
||||||
else None,
|
else None,
|
||||||
KnowledgePriorityMiddleware(
|
KnowledgePriorityMiddleware(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
planner_llm=get_planner_llm(),
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
filesystem_mode=filesystem_mode,
|
filesystem_mode=filesystem_mode,
|
||||||
available_connectors=available_connectors,
|
available_connectors=available_connectors,
|
||||||
|
|
|
||||||
|
|
@ -1,232 +0,0 @@
|
||||||
"""Background memory extraction for the SurfSense agent.
|
|
||||||
|
|
||||||
After each agent response, if the agent did not call ``update_memory`` during
|
|
||||||
the turn, this module can run a lightweight LLM call to decide whether the
|
|
||||||
latest message contains long-term information worth persisting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from app.agents.new_chat.tools.update_memory import _save_memory
|
|
||||||
from app.db import SearchSpace, User, shielded_async_session
|
|
||||||
from app.utils.content_utils import extract_text_content
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_MEMORY_EXTRACT_PROMPT = """\
|
|
||||||
You are a memory extraction assistant. Analyze the user's message and decide \
|
|
||||||
if it contains any long-term information worth persisting to memory.
|
|
||||||
|
|
||||||
Worth remembering: preferences, background/identity, goals, projects, \
|
|
||||||
instructions, tools/languages they use, decisions, expertise, workplace — \
|
|
||||||
durable facts that will matter in future conversations.
|
|
||||||
|
|
||||||
NOT worth remembering: greetings, one-off factual questions, session \
|
|
||||||
logistics, ephemeral requests, follow-up clarifications with no new personal \
|
|
||||||
info, things that only matter for the current task.
|
|
||||||
|
|
||||||
If the message contains memorizable information, output the FULL updated \
|
|
||||||
memory document with the new facts merged into the existing content. Follow \
|
|
||||||
these rules:
|
|
||||||
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
|
|
||||||
freely. Keep heading names short (2-3 words) and natural. Do NOT include the user's
|
|
||||||
name in headings.
|
|
||||||
- Keep entries as single bullet points. Be descriptive but concise — include relevant
|
|
||||||
details and context rather than just a few words.
|
|
||||||
- Every bullet MUST use format: - (YYYY-MM-DD) [fact|pref|instr] text
|
|
||||||
[fact] = durable facts, [pref] = preferences, [instr] = standing instructions.
|
|
||||||
- Use the user's first name (from <user_name>) in entry text, not "the user".
|
|
||||||
- If a new fact contradicts an existing entry, update the existing entry.
|
|
||||||
- Do not duplicate information that is already present.
|
|
||||||
|
|
||||||
If nothing is worth remembering, output exactly: NO_UPDATE
|
|
||||||
|
|
||||||
<user_name>{user_name}</user_name>
|
|
||||||
|
|
||||||
<current_memory>
|
|
||||||
{current_memory}
|
|
||||||
</current_memory>
|
|
||||||
|
|
||||||
<user_message>
|
|
||||||
{user_message}
|
|
||||||
</user_message>"""
|
|
||||||
|
|
||||||
_TEAM_MEMORY_EXTRACT_PROMPT = """\
|
|
||||||
You are a team-memory extraction assistant. Analyze the latest message and \
|
|
||||||
decide if it contains durable TEAM-level information worth persisting.
|
|
||||||
|
|
||||||
Decision policy:
|
|
||||||
- Prioritize recall for durable team context, while avoiding personal-only facts.
|
|
||||||
- Do NOT require explicit consensus language. A direct team-level statement can
|
|
||||||
be stored if it is stable and broadly useful for future team chats.
|
|
||||||
- If evidence is weak or clearly tentative, output NO_UPDATE.
|
|
||||||
|
|
||||||
Worth remembering (team-level only):
|
|
||||||
- Decisions and defaults that guide future team work
|
|
||||||
- Team conventions/standards (naming, review policy, coding norms)
|
|
||||||
- Stable org/project facts (locations, ownership, constraints)
|
|
||||||
- Long-lived architecture/process facts
|
|
||||||
- Ongoing priorities that are likely relevant beyond this turn
|
|
||||||
|
|
||||||
NOT worth remembering:
|
|
||||||
- Personal preferences or biography of one person
|
|
||||||
- Questions, brainstorming, tentative ideas, or speculation
|
|
||||||
- One-off requests, status updates, TODOs, logistics for this session
|
|
||||||
- Information scoped only to a single ephemeral task
|
|
||||||
|
|
||||||
If the message contains memorizable team information, output the FULL updated \
|
|
||||||
team memory document with new facts merged into existing content. Follow rules:
|
|
||||||
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
|
|
||||||
freely. Keep heading names short (2-3 words) and natural.
|
|
||||||
- Keep entries as single bullet points. Be descriptive but concise — include relevant
|
|
||||||
details and context rather than just a few words.
|
|
||||||
- Every bullet MUST use format: - (YYYY-MM-DD) [fact] text
|
|
||||||
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr].
|
|
||||||
- If a new fact contradicts an existing entry, update the existing entry.
|
|
||||||
- Do not duplicate existing information.
|
|
||||||
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
|
|
||||||
|
|
||||||
If nothing is worth remembering, output exactly: NO_UPDATE
|
|
||||||
|
|
||||||
<current_team_memory>
|
|
||||||
{current_memory}
|
|
||||||
</current_team_memory>
|
|
||||||
|
|
||||||
<latest_message_author>
|
|
||||||
{author}
|
|
||||||
</latest_message_author>
|
|
||||||
|
|
||||||
<latest_message>
|
|
||||||
{user_message}
|
|
||||||
</latest_message>"""
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_and_save_memory(
|
|
||||||
*,
|
|
||||||
user_message: str,
|
|
||||||
user_id: str | None,
|
|
||||||
llm: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Background task: extract memorizable info and persist it.
|
|
||||||
|
|
||||||
Designed to be fire-and-forget — catches all exceptions internally.
|
|
||||||
"""
|
|
||||||
if not user_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
|
||||||
|
|
||||||
async with shielded_async_session() as session:
|
|
||||||
result = await session.execute(select(User).where(User.id == uid))
|
|
||||||
user = result.scalars().first()
|
|
||||||
if not user:
|
|
||||||
return
|
|
||||||
|
|
||||||
old_memory = user.memory_md
|
|
||||||
first_name = (
|
|
||||||
user.display_name.strip().split()[0]
|
|
||||||
if user.display_name and user.display_name.strip()
|
|
||||||
else "The user"
|
|
||||||
)
|
|
||||||
prompt = _MEMORY_EXTRACT_PROMPT.format(
|
|
||||||
current_memory=old_memory or "(empty)",
|
|
||||||
user_message=user_message,
|
|
||||||
user_name=first_name,
|
|
||||||
)
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[HumanMessage(content=prompt)],
|
|
||||||
config={"tags": ["surfsense:internal", "memory-extraction"]},
|
|
||||||
)
|
|
||||||
text = extract_text_content(response.content).strip()
|
|
||||||
|
|
||||||
if text == "NO_UPDATE" or not text:
|
|
||||||
logger.debug("Memory extraction: no update needed (user %s)", uid)
|
|
||||||
return
|
|
||||||
|
|
||||||
save_result = await _save_memory(
|
|
||||||
updated_memory=text,
|
|
||||||
old_memory=old_memory,
|
|
||||||
llm=llm,
|
|
||||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
|
||||||
commit_fn=session.commit,
|
|
||||||
rollback_fn=session.rollback,
|
|
||||||
label="memory",
|
|
||||||
scope="user",
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"Background memory extraction for user %s: %s",
|
|
||||||
uid,
|
|
||||||
save_result.get("status"),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Background user memory extraction failed")
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_and_save_team_memory(
|
|
||||||
*,
|
|
||||||
user_message: str,
|
|
||||||
search_space_id: int | None,
|
|
||||||
llm: Any,
|
|
||||||
author_display_name: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Background task: extract team-level memory and persist it.
|
|
||||||
|
|
||||||
Runs only for shared threads. Designed to be fire-and-forget and catches
|
|
||||||
exceptions internally.
|
|
||||||
"""
|
|
||||||
if not search_space_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with shielded_async_session() as session:
|
|
||||||
result = await session.execute(
|
|
||||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
|
||||||
)
|
|
||||||
space = result.scalars().first()
|
|
||||||
if not space:
|
|
||||||
return
|
|
||||||
|
|
||||||
old_memory = space.shared_memory_md
|
|
||||||
prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
|
|
||||||
current_memory=old_memory or "(empty)",
|
|
||||||
author=author_display_name or "Unknown team member",
|
|
||||||
user_message=user_message,
|
|
||||||
)
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[HumanMessage(content=prompt)],
|
|
||||||
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
|
|
||||||
)
|
|
||||||
text = extract_text_content(response.content).strip()
|
|
||||||
|
|
||||||
if text == "NO_UPDATE" or not text:
|
|
||||||
logger.debug(
|
|
||||||
"Team memory extraction: no update needed (space %s)",
|
|
||||||
search_space_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
save_result = await _save_memory(
|
|
||||||
updated_memory=text,
|
|
||||||
old_memory=old_memory,
|
|
||||||
llm=llm,
|
|
||||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
|
||||||
commit_fn=session.commit,
|
|
||||||
rollback_fn=session.rollback,
|
|
||||||
label="team memory",
|
|
||||||
scope="team",
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"Background team memory extraction for space %s: %s",
|
|
||||||
search_space_id,
|
|
||||||
save_result.get("status"),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Background team memory extraction failed")
|
|
||||||
|
|
@ -32,6 +32,7 @@ exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect).
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -249,11 +250,11 @@ async def _create_document(
|
||||||
session.add(doc)
|
session.add(doc)
|
||||||
await session.flush()
|
await session.flush()
|
||||||
|
|
||||||
summary_embedding = embed_texts([content])[0]
|
summary_embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
|
||||||
doc.embedding = summary_embedding
|
doc.embedding = summary_embedding
|
||||||
chunks = chunk_text(content)
|
chunks = chunk_text(content)
|
||||||
if chunks:
|
if chunks:
|
||||||
chunk_embeddings = embed_texts(chunks)
|
chunk_embeddings = await asyncio.to_thread(embed_texts, chunks)
|
||||||
session.add_all(
|
session.add_all(
|
||||||
[
|
[
|
||||||
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||||
|
|
@ -295,13 +296,13 @@ async def _update_document(
|
||||||
search_space_id,
|
search_space_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
summary_embedding = embed_texts([content])[0]
|
summary_embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
|
||||||
document.embedding = summary_embedding
|
document.embedding = summary_embedding
|
||||||
|
|
||||||
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
|
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
|
||||||
chunks = chunk_text(content)
|
chunks = chunk_text(content)
|
||||||
if chunks:
|
if chunks:
|
||||||
chunk_embeddings = embed_texts(chunks)
|
chunk_embeddings = await asyncio.to_thread(embed_texts, chunks)
|
||||||
session.add_all(
|
session.add_all(
|
||||||
[
|
[
|
||||||
Chunk(document_id=document.id, content=text, embedding=embedding)
|
Chunk(document_id=document.id, content=text, embedding=embedding)
|
||||||
|
|
|
||||||
|
|
@ -457,7 +457,7 @@ async def search_knowledge_base(
|
||||||
if not query:
|
if not query:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
[embedding] = embed_texts([query])
|
[embedding] = await asyncio.to_thread(embed_texts, [query])
|
||||||
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
||||||
retriever_top_k = min(top_k * 3, 30)
|
retriever_top_k = min(top_k * 3, 30)
|
||||||
|
|
||||||
|
|
@ -579,6 +579,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
llm: BaseChatModel | None = None,
|
llm: BaseChatModel | None = None,
|
||||||
|
planner_llm: BaseChatModel | None = None,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||||
available_connectors: list[str] | None = None,
|
available_connectors: list[str] | None = None,
|
||||||
|
|
@ -588,6 +589,15 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
inject_system_message: bool = True, # For backwards compatibility
|
inject_system_message: bool = True, # For backwards compatibility
|
||||||
) -> None:
|
) -> None:
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
|
# The planner LLM handles short, structured internal tasks (query
|
||||||
|
# rewriting, date extraction, recency classification). When an
|
||||||
|
# operator marks a global config ``is_planner: true`` we route
|
||||||
|
# those calls to a cheap/fast model (e.g. gpt-4o-mini, Haiku, Azure
|
||||||
|
# gpt-5.x-nano) instead of the user's chat LLM — those classification
|
||||||
|
# tasks don't need frontier-tier capability. Falls back to the chat
|
||||||
|
# LLM when no planner config is wired up so deployments without one
|
||||||
|
# keep working unchanged.
|
||||||
|
self.planner_llm = planner_llm or llm
|
||||||
self.search_space_id = search_space_id
|
self.search_space_id = search_space_id
|
||||||
self.filesystem_mode = filesystem_mode
|
self.filesystem_mode = filesystem_mode
|
||||||
self.available_connectors = available_connectors
|
self.available_connectors = available_connectors
|
||||||
|
|
@ -598,7 +608,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
# Build the kb-planner private Runnable ONCE here so we don't pay
|
# Build the kb-planner private Runnable ONCE here so we don't pay
|
||||||
# the ``create_agent`` compile cost (50-200ms) on every turn.
|
# the ``create_agent`` compile cost (50-200ms) on every turn.
|
||||||
# Disabled by default behind ``enable_kb_planner_runnable``; when
|
# Disabled by default behind ``enable_kb_planner_runnable``; when
|
||||||
# off the planner falls back to the legacy ``self.llm.ainvoke``
|
# off the planner falls back to the legacy ``planner_llm.ainvoke``
|
||||||
# path.
|
# path.
|
||||||
self._planner: Runnable | None = None
|
self._planner: Runnable | None = None
|
||||||
self._planner_compile_failed = False
|
self._planner_compile_failed = False
|
||||||
|
|
@ -608,7 +618,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
||||||
Returns ``None`` when the feature flag is disabled, when the LLM is
|
Returns ``None`` when the feature flag is disabled, when the LLM is
|
||||||
unavailable, or when ``create_agent`` raises (we fall back to the
|
unavailable, or when ``create_agent`` raises (we fall back to the
|
||||||
legacy ``self.llm.ainvoke`` path in that case). Compilation happens
|
legacy ``planner_llm.ainvoke`` path in that case). Compilation happens
|
||||||
lazily on first call, then memoized via ``self._planner``.
|
lazily on first call, then memoized via ``self._planner``.
|
||||||
|
|
||||||
The compiled agent is constructed without tools — the planner's
|
The compiled agent is constructed without tools — the planner's
|
||||||
|
|
@ -618,7 +628,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
"""
|
"""
|
||||||
if self._planner is not None or self._planner_compile_failed:
|
if self._planner is not None or self._planner_compile_failed:
|
||||||
return self._planner
|
return self._planner
|
||||||
if self.llm is None:
|
if self.planner_llm is None:
|
||||||
return None
|
return None
|
||||||
flags = get_flags()
|
flags = get_flags()
|
||||||
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
|
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
|
||||||
|
|
@ -628,13 +638,13 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._planner = create_agent(
|
self._planner = create_agent(
|
||||||
self.llm,
|
self.planner_llm,
|
||||||
tools=[],
|
tools=[],
|
||||||
middleware=[RetryAfterMiddleware(max_retries=2)],
|
middleware=[RetryAfterMiddleware(max_retries=2)],
|
||||||
)
|
)
|
||||||
except Exception as exc: # pragma: no cover - defensive
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"kb-planner Runnable compile failed; falling back to llm.ainvoke: %s",
|
"kb-planner Runnable compile failed; falling back to planner_llm.ainvoke: %s",
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
self._planner_compile_failed = True
|
self._planner_compile_failed = True
|
||||||
|
|
@ -647,12 +657,12 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
messages: Sequence[BaseMessage],
|
messages: Sequence[BaseMessage],
|
||||||
user_text: str,
|
user_text: str,
|
||||||
) -> tuple[str, datetime | None, datetime | None, bool]:
|
) -> tuple[str, datetime | None, datetime | None, bool]:
|
||||||
if self.llm is None:
|
if self.planner_llm is None:
|
||||||
return user_text, None, None, False
|
return user_text, None, None, False
|
||||||
|
|
||||||
recent_conversation = _render_recent_conversation(
|
recent_conversation = _render_recent_conversation(
|
||||||
messages,
|
messages,
|
||||||
llm=self.llm,
|
llm=self.planner_llm,
|
||||||
user_text=user_text,
|
user_text=user_text,
|
||||||
)
|
)
|
||||||
prompt = _build_kb_planner_prompt(
|
prompt = _build_kb_planner_prompt(
|
||||||
|
|
@ -663,8 +673,8 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
t0 = loop.time()
|
t0 = loop.time()
|
||||||
|
|
||||||
# Prefer the compiled-once planner Runnable when enabled; otherwise
|
# Prefer the compiled-once planner Runnable when enabled; otherwise
|
||||||
# fall back to ``self.llm.ainvoke``. The ``surfsense:internal`` tag
|
# fall back to ``planner_llm.ainvoke``. The ``surfsense:internal``
|
||||||
# is preserved on both paths so ``_stream_agent_events`` still
|
# tag is preserved on both paths so ``_stream_agent_events`` still
|
||||||
# suppresses the planner's intermediate events from the UI.
|
# suppresses the planner's intermediate events from the UI.
|
||||||
planner = self._build_kb_planner_runnable()
|
planner = self._build_kb_planner_runnable()
|
||||||
try:
|
try:
|
||||||
|
|
@ -684,7 +694,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
else AIMessage(content="")
|
else AIMessage(content="")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await self.llm.ainvoke(
|
response = await self.planner_llm.ainvoke(
|
||||||
[HumanMessage(content=prompt)],
|
[HumanMessage(content=prompt)],
|
||||||
config={"tags": ["surfsense:internal"]},
|
config={"tags": ["surfsense:internal"]},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
|
@ -41,6 +42,9 @@ from app.agents.new_chat.path_resolver import (
|
||||||
doc_to_virtual_path,
|
doc_to_virtual_path,
|
||||||
)
|
)
|
||||||
from app.db import Document, shielded_async_session
|
from app.db import Document, shielded_async_session
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
|
|
@ -124,6 +128,7 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
update: dict[str, Any] = {}
|
update: dict[str, Any] = {}
|
||||||
if not state.get("cwd"):
|
if not state.get("cwd"):
|
||||||
update["cwd"] = DOCUMENTS_ROOT
|
update["cwd"] = DOCUMENTS_ROOT
|
||||||
|
|
@ -131,7 +136,11 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
anon_doc = state.get("kb_anon_doc")
|
anon_doc = state.get("kb_anon_doc")
|
||||||
if anon_doc:
|
if anon_doc:
|
||||||
tree_msg = self._render_anon_tree(anon_doc)
|
tree_msg = self._render_anon_tree(anon_doc)
|
||||||
|
cache_outcome = "anon"
|
||||||
else:
|
else:
|
||||||
|
version = int(state.get("tree_version") or 0)
|
||||||
|
cache_key = (self.search_space_id, version, False)
|
||||||
|
cache_outcome = "hit" if cache_key in self._cache else "miss"
|
||||||
tree_msg = await self._render_kb_tree(state)
|
tree_msg = await self._render_kb_tree(state)
|
||||||
|
|
||||||
update["workspace_tree_text"] = tree_msg
|
update["workspace_tree_text"] = tree_msg
|
||||||
|
|
@ -141,6 +150,14 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
insert_at = max(len(messages) - 1, 0)
|
insert_at = max(len(messages) - 1, 0)
|
||||||
messages.insert(insert_at, SystemMessage(content=tree_msg))
|
messages.insert(insert_at, SystemMessage(content=tree_msg))
|
||||||
update["messages"] = messages
|
update["messages"] = messages
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[knowledge_tree] cache=%s chars=%d elapsed=%.3fs space=%d",
|
||||||
|
cache_outcome,
|
||||||
|
len(tree_msg),
|
||||||
|
time.perf_counter() - start,
|
||||||
|
self.search_space_id,
|
||||||
|
)
|
||||||
return update
|
return update
|
||||||
|
|
||||||
def before_agent( # type: ignore[override]
|
def before_agent( # type: ignore[override]
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ Injects memory markdown into the system prompt on every turn:
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
@ -17,10 +18,12 @@ from langgraph.runtime import Runtime
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
|
|
||||||
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
|
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
|
||||||
|
from app.services.memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
@ -53,9 +56,13 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
if not isinstance(last_message, HumanMessage):
|
if not isinstance(last_message, HumanMessage):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
db_elapsed = 0.0
|
||||||
memory_blocks: list[str] = []
|
memory_blocks: list[str] = []
|
||||||
|
scope = "team" if self.visibility == ChatVisibility.SEARCH_SPACE else "user"
|
||||||
|
|
||||||
async with shielded_async_session() as session:
|
async with shielded_async_session() as session:
|
||||||
|
db_start = time.perf_counter()
|
||||||
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
||||||
team_memory = await self._load_team_memory(session)
|
team_memory = await self._load_team_memory(session)
|
||||||
if team_memory:
|
if team_memory:
|
||||||
|
|
@ -96,7 +103,15 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
f"</memory_warning>"
|
f"</memory_warning>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
db_elapsed = time.perf_counter() - db_start
|
||||||
|
|
||||||
if not memory_blocks:
|
if not memory_blocks:
|
||||||
|
_perf_log.info(
|
||||||
|
"[memory_injection] scope=%s injected=0 db=%.3fs total=%.3fs",
|
||||||
|
scope,
|
||||||
|
db_elapsed,
|
||||||
|
time.perf_counter() - start,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
memory_text = "\n\n".join(memory_blocks)
|
memory_text = "\n\n".join(memory_blocks)
|
||||||
|
|
@ -106,6 +121,13 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
insert_idx = 1 if len(new_messages) > 1 else 0
|
insert_idx = 1 if len(new_messages) > 1 else 0
|
||||||
new_messages.insert(insert_idx, memory_msg)
|
new_messages.insert(insert_idx, memory_msg)
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[memory_injection] scope=%s injected=1 chars=%d db=%.3fs total=%.3fs",
|
||||||
|
scope,
|
||||||
|
len(memory_text),
|
||||||
|
db_elapsed,
|
||||||
|
time.perf_counter() - start,
|
||||||
|
)
|
||||||
return {"messages": new_messages}
|
return {"messages": new_messages}
|
||||||
|
|
||||||
async def _load_user_memory(
|
async def _load_user_memory(
|
||||||
|
|
|
||||||
|
|
@ -39,9 +39,19 @@ For OpenAI-family configs we additionally pass:
|
||||||
|
|
||||||
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
||||||
raises hit rate by sending requests with a shared prefix to the same
|
raises hit rate by sending requests with a shared prefix to the same
|
||||||
backend.
|
backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and
|
||||||
|
``azure/`` (added to LiteLLM's Azure transformer in
|
||||||
|
https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified
|
||||||
|
against ``AzureOpenAIConfig.get_supported_openai_params`` in our
|
||||||
|
installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``,
|
||||||
|
``azure/gpt-5.4``, ``azure/gpt-5.4-mini``).
|
||||||
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
||||||
5-10 min in-memory cache.
|
5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's
|
||||||
|
server-side support landed in Microsoft's docs on 2026-05-13 but
|
||||||
|
LiteLLM 1.83.14's Azure transformer still omits it from its supported
|
||||||
|
params list, so it gets silently dropped by ``litellm.drop_params``.
|
||||||
|
Azure's default in-memory retention (5-10 min, max 1 h) already
|
||||||
|
bridges intra-conversation turns; revisit when LiteLLM bumps Azure.
|
||||||
|
|
||||||
Safety net: ``litellm.drop_params=True`` is set globally in
|
Safety net: ``litellm.drop_params=True`` is set globally in
|
||||||
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
||||||
|
|
@ -81,13 +91,31 @@ _DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||||
{"location": "message", "index": -1},
|
{"location": "message", "index": -1},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Providers (uppercase ``AgentConfig.provider`` values) that natively expose
|
# Providers (uppercase ``AgentConfig.provider`` values) that accept the
|
||||||
# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and
|
# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs
|
||||||
# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers
|
# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o
|
||||||
# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without
|
# or newer Azure deployment at ≥1024 tokens with no configuration needed,
|
||||||
# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU,
|
# and that ``prompt_cache_key`` is combined with the prefix hash to
|
||||||
# MINIMAX), so we can't infer family from the litellm prefix alone.
|
# improve routing affinity and therefore cache hit rate. LiteLLM's Azure
|
||||||
_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"})
|
# transformer ships ``prompt_cache_key`` in its supported params as of
|
||||||
|
# https://github.com/BerriAI/litellm/pull/20989.
|
||||||
|
#
|
||||||
|
# Strict whitelist — many other providers in ``PROVIDER_MAP`` route
|
||||||
|
# through litellm's ``openai`` prefix without implementing the OpenAI
|
||||||
|
# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer
|
||||||
|
# family from the litellm prefix alone.
|
||||||
|
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
|
||||||
|
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept
|
||||||
|
# ``prompt_cache_retention="24h"``. Azure is excluded: see module
|
||||||
|
# docstring — LiteLLM 1.83.14's Azure transformer omits the param so
|
||||||
|
# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM
|
||||||
|
# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``.
|
||||||
|
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
|
||||||
|
{"OPENAI", "DEEPSEEK", "XAI"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_router_llm(llm: BaseChatModel) -> bool:
|
def _is_router_llm(llm: BaseChatModel) -> bool:
|
||||||
|
|
@ -101,13 +129,13 @@ def _is_router_llm(llm: BaseChatModel) -> bool:
|
||||||
return type(llm).__name__ == "ChatLiteLLMRouter"
|
return type(llm).__name__ == "ChatLiteLLMRouter"
|
||||||
|
|
||||||
|
|
||||||
def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
|
def _provider_supports_prompt_cache_key(agent_config: AgentConfig | None) -> bool:
|
||||||
"""Whether the config targets an OpenAI-style prompt-cache surface.
|
"""Whether the config targets a provider that accepts ``prompt_cache_key``.
|
||||||
|
|
||||||
Strict — only returns True when the user explicitly chose OPENAI,
|
Strict — only returns True for explicitly chosen OPENAI, DEEPSEEK,
|
||||||
DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` /
|
XAI, AZURE, or AZURE_OPENAI providers. Auto-mode and custom
|
||||||
``YAMLConfig``. Auto-mode and custom providers return False because
|
providers return False because we can't statically know the
|
||||||
we can't statically know the destination.
|
destination and the router fans out across mixed providers.
|
||||||
"""
|
"""
|
||||||
if agent_config is None or not agent_config.provider:
|
if agent_config is None or not agent_config.provider:
|
||||||
return False
|
return False
|
||||||
|
|
@ -115,7 +143,25 @@ def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
|
||||||
return False
|
return False
|
||||||
if agent_config.custom_provider:
|
if agent_config.custom_provider:
|
||||||
return False
|
return False
|
||||||
return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS
|
return agent_config.provider.upper() in _PROMPT_CACHE_KEY_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_supports_prompt_cache_retention(
|
||||||
|
agent_config: AgentConfig | None,
|
||||||
|
) -> bool:
|
||||||
|
"""Whether the config targets a provider that accepts ``prompt_cache_retention``.
|
||||||
|
|
||||||
|
Tighter than :func:`_provider_supports_prompt_cache_key` — Azure
|
||||||
|
deployments are excluded until LiteLLM ships the param in its Azure
|
||||||
|
transformer (see module docstring).
|
||||||
|
"""
|
||||||
|
if agent_config is None or not agent_config.provider:
|
||||||
|
return False
|
||||||
|
if agent_config.is_auto_mode:
|
||||||
|
return False
|
||||||
|
if agent_config.custom_provider:
|
||||||
|
return False
|
||||||
|
return agent_config.provider.upper() in _PROMPT_CACHE_RETENTION_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
|
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
|
||||||
|
|
@ -173,16 +219,23 @@ def apply_litellm_prompt_caching(
|
||||||
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
||||||
]
|
]
|
||||||
|
|
||||||
# OpenAI-family extras only when we statically know the destination is
|
# OpenAI-style extras only when we statically know the destination
|
||||||
# OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers
|
# accepts them. Auto-mode router fans out across mixed providers so
|
||||||
# so we can't safely set OpenAI-only kwargs there (drop_params would
|
# we can't safely set destination-specific kwargs there (drop_params
|
||||||
# strip them but it's wasteful to set them in the first place).
|
# would strip them but it's wasteful to set them in the first
|
||||||
|
# place).
|
||||||
if _is_router_llm(llm):
|
if _is_router_llm(llm):
|
||||||
return
|
return
|
||||||
if not _is_openai_family_config(agent_config):
|
|
||||||
return
|
|
||||||
|
|
||||||
if thread_id is not None and "prompt_cache_key" not in model_kwargs:
|
if (
|
||||||
|
thread_id is not None
|
||||||
|
and "prompt_cache_key" not in model_kwargs
|
||||||
|
and _provider_supports_prompt_cache_key(agent_config)
|
||||||
|
):
|
||||||
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
|
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
|
||||||
if "prompt_cache_retention" not in model_kwargs:
|
|
||||||
|
if (
|
||||||
|
"prompt_cache_retention" not in model_kwargs
|
||||||
|
and _provider_supports_prompt_cache_retention(agent_config)
|
||||||
|
):
|
||||||
model_kwargs["prompt_cache_retention"] = "24h"
|
model_kwargs["prompt_cache_retention"] = "24h"
|
||||||
|
|
|
||||||
|
|
@ -3,4 +3,10 @@ IMPORTANT — After understanding each user message, ALWAYS check: does this mes
|
||||||
reveal durable facts about the user (role, interests, preferences, projects,
|
reveal durable facts about the user (role, interests, preferences, projects,
|
||||||
background, or standing instructions)? If yes, you MUST call update_memory
|
background, or standing instructions)? If yes, you MUST call update_memory
|
||||||
alongside your normal response — do not defer this to a later turn.
|
alongside your normal response — do not defer this to a later turn.
|
||||||
|
|
||||||
|
Memory is stored as a heading-based markdown document. New entries should be
|
||||||
|
under `##` headings such as `## Facts`, `## Preferences`, or `## Instructions`
|
||||||
|
with bullets like `- YYYY-MM-DD: text`. If existing memory contains legacy
|
||||||
|
`(YYYY-MM-DD) [fact|pref|instr]` markers, preserve the information but write
|
||||||
|
new saves in the heading-based format.
|
||||||
</memory_protocol>
|
</memory_protocol>
|
||||||
|
|
|
||||||
|
|
@ -3,4 +3,12 @@ IMPORTANT — After understanding each user message, ALWAYS check: does this mes
|
||||||
reveal durable facts about the team (decisions, conventions, architecture, processes,
|
reveal durable facts about the team (decisions, conventions, architecture, processes,
|
||||||
or key facts)? If yes, you MUST call update_memory alongside your normal response —
|
or key facts)? If yes, you MUST call update_memory alongside your normal response —
|
||||||
do not defer this to a later turn.
|
do not defer this to a later turn.
|
||||||
|
|
||||||
|
Team memory is stored as a heading-based markdown document. New entries should
|
||||||
|
be under `##` headings such as `## Product Decisions`,
|
||||||
|
`## Engineering Conventions`, `## Project Facts`, or `## Open Questions` with
|
||||||
|
bullets like `- YYYY-MM-DD: text`. If existing memory contains legacy
|
||||||
|
`(YYYY-MM-DD) [fact]` markers, preserve the information but write new saves in
|
||||||
|
the heading-based format. Do not create personal headings such as
|
||||||
|
`## Preferences` or `## Instructions`.
|
||||||
</memory_protocol>
|
</memory_protocol>
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,16 @@
|
||||||
|
|
||||||
- <user_name>Alex</user_name>, <user_memory> is empty. User: "I'm a space enthusiast, explain astrophage to me"
|
- <user_name>Alex</user_name>, <user_memory> is empty. User: "I'm a space enthusiast, explain astrophage to me"
|
||||||
- The user casually shared a durable fact. Use their first name in the entry, short neutral heading:
|
- The user casually shared a durable fact:
|
||||||
update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n")
|
update_memory(updated_memory="## Facts\n- 2025-03-15: Alex is a space enthusiast\n")
|
||||||
- User: "Remember that I prefer concise answers over detailed explanations"
|
- User: "Remember that I prefer concise answers over detailed explanations"
|
||||||
- Durable preference. Merge with existing memory, add a new heading:
|
- Durable preference. Merge with existing memory:
|
||||||
update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n\n## Response style\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\n")
|
update_memory(updated_memory="## Facts\n- 2025-03-15: Alex is a space enthusiast\n\n## Preferences\n- 2025-03-15: Alex prefers concise answers over detailed explanations\n")
|
||||||
- User: "I actually moved to Tokyo last month"
|
- User: "I actually moved to Tokyo last month"
|
||||||
- Updated fact, date prefix reflects when recorded:
|
- Updated fact, date prefix reflects when recorded:
|
||||||
update_memory(updated_memory="## Interests & background\n...\n\n## Personal context\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\n...")
|
update_memory(updated_memory="## Facts\n- 2025-03-15: Alex lives in Tokyo (previously London)\n...")
|
||||||
- User: "I'm a freelance photographer working on a nature documentary"
|
- User: "I'm a freelance photographer working on a nature documentary"
|
||||||
- Durable background info under a fitting heading:
|
- Durable background info under a fitting heading:
|
||||||
update_memory(updated_memory="...\n\n## Current focus\n- (2025-03-15) [fact] Alex is a freelance photographer\n- (2025-03-15) [fact] Alex is working on a nature documentary\n")
|
update_memory(updated_memory="...\n\n## Current Focus\n- 2025-03-15: Alex is a freelance photographer\n- 2025-03-15: Alex is working on a nature documentary\n")
|
||||||
- User: "Always respond in bullet points"
|
- User: "Always respond in bullet points"
|
||||||
- Standing instruction:
|
- Standing instruction:
|
||||||
update_memory(updated_memory="...\n\n## Response style\n- (2025-03-15) [instr] Always respond to Alex in bullet points\n")
|
update_memory(updated_memory="...\n\n## Instructions\n- 2025-03-15: Always respond to Alex in bullet points\n")
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
- User: "Let's remember that we decided to do weekly standup meetings on Mondays"
|
- User: "Let's remember that we decided to do weekly standup meetings on Mondays"
|
||||||
- Durable team decision:
|
- Durable team decision:
|
||||||
update_memory(updated_memory="- (2025-03-15) [fact] Weekly standup meetings on Mondays\n...")
|
update_memory(updated_memory="## Product Decisions\n- 2025-03-15: Weekly standup meetings happen on Mondays\n...")
|
||||||
- User: "Our office is in downtown Seattle, 5th floor"
|
- User: "Our office is in downtown Seattle, 5th floor"
|
||||||
- Durable team fact:
|
- Durable team fact:
|
||||||
update_memory(updated_memory="- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\n...")
|
update_memory(updated_memory="## Project Facts\n- 2025-03-15: Office location is downtown Seattle, 5th floor\n...")
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,26 @@
|
||||||
|
|
||||||
- update_memory: Update your personal memory document about the user.
|
- update_memory: Update your personal memory document about the user.
|
||||||
- Your current memory is already in <user_memory> in your context. The `chars` and
|
- Your current memory is already in <user_memory> in your context. The `chars`
|
||||||
`limit` attributes show your current usage and the maximum allowed size.
|
and `limit` attributes show current usage and the maximum allowed size.
|
||||||
- This is your curated long-term memory — the distilled essence of what you know about
|
- This is curated long-term memory, not raw conversation logs.
|
||||||
the user, not raw conversation logs.
|
- Call update_memory when the user explicitly asks to remember/forget
|
||||||
- Call update_memory when:
|
something or shares durable facts, preferences, or standing instructions.
|
||||||
* The user explicitly asks to remember or forget something
|
- The user's first name is provided in <user_name>. Use it in entries instead
|
||||||
* The user shares durable facts or preferences that will matter in future conversations
|
of "the user" when helpful. Do not store the name alone as a memory entry.
|
||||||
- The user's first name is provided in <user_name>. Use it in memory entries
|
- Do not store short-lived info: one-off questions, greetings, session
|
||||||
instead of "the user" (e.g. "{name} works at..." not "The user works at...").
|
logistics, or things that only matter for the current task.
|
||||||
Do not store the name itself as a separate memory entry.
|
|
||||||
- Do not store short-lived or ephemeral info: one-off questions, greetings,
|
|
||||||
session logistics, or things that only matter for the current task.
|
|
||||||
- Args:
|
- Args:
|
||||||
- updated_memory: The FULL updated markdown document (not a diff).
|
- updated_memory: The FULL updated markdown document, not a diff. Merge new
|
||||||
Merge new facts with existing ones, update contradictions, remove outdated entries.
|
facts with existing ones, update contradictions, remove outdated entries,
|
||||||
Treat every update as a curation pass — consolidate, don't just append.
|
and consolidate instead of only appending.
|
||||||
- Every bullet MUST use this format: - (YYYY-MM-DD) [marker] text
|
- Use heading-based Markdown:
|
||||||
Markers:
|
* Every entry must be under a `##` heading.
|
||||||
[fact] — durable facts (role, background, projects, tools, expertise)
|
* Recommended headings: `## Facts`, `## Preferences`, `## Instructions`.
|
||||||
[pref] — preferences (response style, languages, formats, tools)
|
Specific natural headings are allowed when clearer.
|
||||||
[instr] — standing instructions (always/never do, response rules)
|
* New bullets should use `- YYYY-MM-DD: text`.
|
||||||
- Keep it concise and well under the character limit shown in <user_memory>.
|
* Each entry should be one concise but descriptive bullet.
|
||||||
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
|
- If existing memory uses legacy `(YYYY-MM-DD) [fact|pref|instr]` markers,
|
||||||
natural. Do NOT include the user's name in headings. Organize by context — e.g.
|
preserve the information but write the updated document in the new
|
||||||
who they are, what they're focused on, how they prefer things. Create, split, or
|
heading-based format.
|
||||||
merge headings freely as the memory grows.
|
- During consolidation, prioritize durable instructions and preferences before
|
||||||
- Each entry MUST be a single bullet point. Be descriptive but concise — include relevant
|
generic facts.
|
||||||
details and context rather than just a few words.
|
|
||||||
- During consolidation, prioritize keeping: [instr] > [pref] > [fact].
|
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,28 @@
|
||||||
|
|
||||||
- update_memory: Update the team's shared memory document for this search space.
|
- update_memory: Update the team's shared memory document for this search space.
|
||||||
- Your current team memory is already in <team_memory> in your context. The `chars`
|
- Your current team memory is already in <team_memory> in your context. The
|
||||||
and `limit` attributes show current usage and the maximum allowed size.
|
`chars` and `limit` attributes show current usage and the maximum allowed size.
|
||||||
- This is the team's curated long-term memory — decisions, conventions, key facts.
|
- This is curated long-term team memory: decisions, conventions, architecture,
|
||||||
- NEVER store personal memory in team memory (e.g. personal bio, individual
|
processes, and key shared facts.
|
||||||
preferences, or user-only standing instructions).
|
- NEVER store personal memory in team memory: individual bios, personal
|
||||||
- Call update_memory when:
|
preferences, or user-only standing instructions.
|
||||||
* A team member explicitly asks to remember or forget something
|
- Call update_memory when a team member asks to remember/forget something, or
|
||||||
* The conversation surfaces durable team decisions, conventions, or facts
|
when the conversation surfaces durable team context that matters later.
|
||||||
that will matter in future conversations
|
- Do not store short-lived info: one-off questions, greetings, session
|
||||||
- Do not store short-lived or ephemeral info: one-off questions, greetings,
|
logistics, or things that only matter for the current task.
|
||||||
session logistics, or things that only matter for the current task.
|
|
||||||
- Args:
|
- Args:
|
||||||
- updated_memory: The FULL updated markdown document (not a diff).
|
- updated_memory: The FULL updated markdown document, not a diff. Merge new
|
||||||
Merge new facts with existing ones, update contradictions, remove outdated entries.
|
facts with existing ones, update contradictions, remove outdated entries,
|
||||||
Treat every update as a curation pass — consolidate, don't just append.
|
and consolidate instead of only appending.
|
||||||
- Every bullet MUST use this format: - (YYYY-MM-DD) [fact] text
|
- Use heading-based Markdown:
|
||||||
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr] in team memory.
|
* Every entry must be under a `##` heading.
|
||||||
- Keep it concise and well under the character limit shown in <team_memory>.
|
* Recommended headings: `## Product Decisions`, `## Engineering Conventions`,
|
||||||
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
|
`## Project Facts`, `## Open Questions`.
|
||||||
natural. Organize by context — e.g. what the team decided, current architecture,
|
* New bullets should use `- YYYY-MM-DD: text`.
|
||||||
active processes. Create, split, or merge headings freely as the memory grows.
|
* Each entry should be one concise but descriptive bullet.
|
||||||
- Each entry MUST be a single bullet point. Be descriptive but concise — include relevant
|
- If existing memory uses legacy `(YYYY-MM-DD) [fact]` markers, preserve the
|
||||||
details and context rather than just a few words.
|
information but write the updated document in the new heading-based format.
|
||||||
- During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities.
|
- Do not create personal headings such as `## Preferences`, `## Instructions`,
|
||||||
|
`## Personal Notes`, or `## Personal Instructions`.
|
||||||
|
- During consolidation, prioritize decisions/conventions, then key facts, then
|
||||||
|
current priorities.
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,16 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||||
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
CachedMCPTools,
|
||||||
|
read_cached_tools,
|
||||||
|
write_cached_tools,
|
||||||
|
)
|
||||||
from app.db import SearchSourceConnector
|
from app.db import SearchSourceConnector
|
||||||
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
|
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -293,15 +301,21 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
timeout: float = 60.0,
|
timeout: float = 60.0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Execute a single MCP HTTP call with the given headers."""
|
"""Execute a single MCP HTTP call with the given headers."""
|
||||||
|
call_start = time.perf_counter()
|
||||||
async with (
|
async with (
|
||||||
streamablehttp_client(url, headers=call_headers) as (read, write, _),
|
streamablehttp_client(url, headers=call_headers) as (read, write, _),
|
||||||
ClientSession(read, write) as session,
|
ClientSession(read, write) as session,
|
||||||
):
|
):
|
||||||
|
init_start = time.perf_counter()
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
init_elapsed = time.perf_counter() - init_start
|
||||||
|
|
||||||
|
tool_start = time.perf_counter()
|
||||||
response = await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
session.call_tool(original_tool_name, arguments=call_kwargs),
|
session.call_tool(original_tool_name, arguments=call_kwargs),
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
tool_elapsed = time.perf_counter() - tool_start
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for content in response.content:
|
for content in response.content:
|
||||||
|
|
@ -312,7 +326,18 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
else:
|
else:
|
||||||
result.append(str(content))
|
result.append(str(content))
|
||||||
|
|
||||||
return "\n".join(result) if result else ""
|
payload = "\n".join(result) if result else ""
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_http_call] connector=%s tool=%s init=%.3fs call=%.3fs total=%.3fs out_chars=%d",
|
||||||
|
connector_id,
|
||||||
|
original_tool_name,
|
||||||
|
init_elapsed,
|
||||||
|
tool_elapsed,
|
||||||
|
time.perf_counter() - call_start,
|
||||||
|
len(payload),
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
|
||||||
async def mcp_http_tool_call(**kwargs) -> str:
|
async def mcp_http_tool_call(**kwargs) -> str:
|
||||||
"""Execute the MCP tool call via HTTP transport."""
|
"""Execute the MCP tool call via HTTP transport."""
|
||||||
|
|
@ -496,6 +521,7 @@ async def _load_http_mcp_tools(
|
||||||
is_generic_mcp: bool = False,
|
is_generic_mcp: bool = False,
|
||||||
*,
|
*,
|
||||||
bypass_internal_hitl: bool = False,
|
bypass_internal_hitl: bool = False,
|
||||||
|
cached_tools: CachedMCPTools | None = None,
|
||||||
) -> list[StructuredTool]:
|
) -> list[StructuredTool]:
|
||||||
"""Load tools from an HTTP-based MCP server.
|
"""Load tools from an HTTP-based MCP server.
|
||||||
|
|
||||||
|
|
@ -506,6 +532,8 @@ async def _load_http_mcp_tools(
|
||||||
readonly_tools: Tool names that skip HITL approval (read-only operations).
|
readonly_tools: Tool names that skip HITL approval (read-only operations).
|
||||||
tool_name_prefix: If set, each tool name is prefixed for multi-account
|
tool_name_prefix: If set, each tool name is prefixed for multi-account
|
||||||
disambiguation (e.g. ``linear_25``).
|
disambiguation (e.g. ``linear_25``).
|
||||||
|
cached_tools: If provided, skip live discovery and rebuild wrappers
|
||||||
|
from the persisted definitions.
|
||||||
"""
|
"""
|
||||||
tools: list[StructuredTool] = []
|
tools: list[StructuredTool] = []
|
||||||
|
|
||||||
|
|
@ -529,15 +557,23 @@ async def _load_http_mcp_tools(
|
||||||
|
|
||||||
allowed_set = set(allowed_tools) if allowed_tools else None
|
allowed_set = set(allowed_tools) if allowed_tools else None
|
||||||
|
|
||||||
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]:
|
async def _discover(
|
||||||
"""Connect, initialize, and list tools from the MCP server."""
|
disc_headers: dict[str, str],
|
||||||
|
) -> tuple[dict[str, str | None], list[dict[str, Any]]]:
|
||||||
|
"""Connect, initialize, and list tools — returns (serverInfo, tools)."""
|
||||||
async with (
|
async with (
|
||||||
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
|
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
|
||||||
ClientSession(read, write) as session,
|
ClientSession(read, write) as session,
|
||||||
):
|
):
|
||||||
await session.initialize()
|
init_result = await session.initialize()
|
||||||
|
server_info: dict[str, str | None] = {"name": None, "version": None}
|
||||||
|
si = getattr(init_result, "serverInfo", None)
|
||||||
|
if si is not None:
|
||||||
|
server_info["name"] = getattr(si, "name", None)
|
||||||
|
server_info["version"] = getattr(si, "version", None)
|
||||||
|
|
||||||
response = await session.list_tools()
|
response = await session.list_tools()
|
||||||
return [
|
return server_info, [
|
||||||
{
|
{
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description or "",
|
"description": tool.description or "",
|
||||||
|
|
@ -548,47 +584,65 @@ async def _load_http_mcp_tools(
|
||||||
for tool in response.tools
|
for tool in response.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
if cached_tools is not None:
|
||||||
tool_definitions = await _discover(headers)
|
tool_definitions = [
|
||||||
except Exception as first_err:
|
{
|
||||||
if not _is_auth_error(first_err) or connector_id is None:
|
"name": td.name,
|
||||||
logger.exception(
|
"description": td.description,
|
||||||
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
"input_schema": td.input_schema,
|
||||||
url,
|
}
|
||||||
connector_id,
|
for td in cached_tools.tools
|
||||||
first_err,
|
]
|
||||||
)
|
else:
|
||||||
return tools
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
|
|
||||||
connector_id,
|
|
||||||
)
|
|
||||||
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
|
||||||
if fresh_headers is None:
|
|
||||||
await _mark_connector_auth_expired(connector_id)
|
|
||||||
logger.error(
|
|
||||||
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
|
|
||||||
connector_id,
|
|
||||||
)
|
|
||||||
return tools
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_definitions = await _discover(fresh_headers)
|
server_info, tool_definitions = await _discover(headers)
|
||||||
headers = fresh_headers
|
except Exception as first_err:
|
||||||
logger.info(
|
if not _is_auth_error(first_err) or connector_id is None:
|
||||||
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
|
logger.exception(
|
||||||
|
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
||||||
|
url,
|
||||||
|
connector_id,
|
||||||
|
first_err,
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
|
||||||
connector_id,
|
connector_id,
|
||||||
)
|
)
|
||||||
except Exception as retry_err:
|
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
||||||
logger.exception(
|
if fresh_headers is None:
|
||||||
"HTTP MCP discovery for connector %d still failing after refresh: %s",
|
|
||||||
connector_id,
|
|
||||||
retry_err,
|
|
||||||
)
|
|
||||||
if _is_auth_error(retry_err):
|
|
||||||
await _mark_connector_auth_expired(connector_id)
|
await _mark_connector_auth_expired(connector_id)
|
||||||
return tools
|
logger.error(
|
||||||
|
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
try:
|
||||||
|
server_info, tool_definitions = await _discover(fresh_headers)
|
||||||
|
headers = fresh_headers
|
||||||
|
logger.info(
|
||||||
|
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
except Exception as retry_err:
|
||||||
|
logger.exception(
|
||||||
|
"HTTP MCP discovery for connector %d still failing after refresh: %s",
|
||||||
|
connector_id,
|
||||||
|
retry_err,
|
||||||
|
)
|
||||||
|
if _is_auth_error(retry_err):
|
||||||
|
await _mark_connector_auth_expired(connector_id)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
await write_cached_tools(
|
||||||
|
connector_id,
|
||||||
|
tool_definitions,
|
||||||
|
server_name=server_info.get("name"),
|
||||||
|
server_version=server_info.get("version"),
|
||||||
|
transport=server_config.get("transport", "streamable-http"),
|
||||||
|
)
|
||||||
|
|
||||||
total_discovered = len(tool_definitions)
|
total_discovered = len(tool_definitions)
|
||||||
|
|
||||||
|
|
@ -792,14 +846,25 @@ async def _maybe_refresh_mcp_oauth_token(
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return server_config
|
return server_config
|
||||||
|
|
||||||
|
refresh_start = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
new_access = await _refresh_connector_token(session, connector)
|
new_access = await _refresh_connector_token(session, connector)
|
||||||
if not new_access:
|
if not new_access:
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=no_token",
|
||||||
|
connector.id,
|
||||||
|
time.perf_counter() - refresh_start,
|
||||||
|
)
|
||||||
return server_config
|
return server_config
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Proactively refreshed MCP OAuth token for connector %s", connector.id
|
"Proactively refreshed MCP OAuth token for connector %s", connector.id
|
||||||
)
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=refreshed",
|
||||||
|
connector.id,
|
||||||
|
time.perf_counter() - refresh_start,
|
||||||
|
)
|
||||||
|
|
||||||
refreshed_config = dict(server_config)
|
refreshed_config = dict(server_config)
|
||||||
refreshed_config["headers"] = {
|
refreshed_config["headers"] = {
|
||||||
|
|
@ -809,6 +874,11 @@ async def _maybe_refresh_mcp_oauth_token(
|
||||||
return refreshed_config
|
return refreshed_config
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=failed",
|
||||||
|
connector.id,
|
||||||
|
time.perf_counter() - refresh_start,
|
||||||
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to refresh MCP OAuth token for connector %s",
|
"Failed to refresh MCP OAuth token for connector %s",
|
||||||
connector.id,
|
connector.id,
|
||||||
|
|
@ -937,6 +1007,94 @@ def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||||
_mcp_tools_cache.clear()
|
_mcp_tools_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
async def discover_single_mcp_connector(connector_id: int) -> None:
|
||||||
|
"""Force live MCP discovery for one connector so its ``cached_tools`` row is fresh.
|
||||||
|
|
||||||
|
``_load_http_mcp_tools`` persists ``cached_tools`` as a side effect of any
|
||||||
|
live discovery; passing ``cached_tools=None`` here guarantees we go to the
|
||||||
|
network. The returned wrappers are discarded — the in-process LRU is
|
||||||
|
rebuilt lazily on the next user query. Stdio connectors are not cached and
|
||||||
|
are skipped.
|
||||||
|
"""
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
|
started = time.perf_counter()
|
||||||
|
try:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
connector = await session.get(SearchSourceConnector, connector_id)
|
||||||
|
if connector is None:
|
||||||
|
logger.info(
|
||||||
|
"discover_single_mcp_connector: connector %d not found",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
cfg = connector.config or {}
|
||||||
|
server_config = cfg.get("server_config", {})
|
||||||
|
if not server_config or not isinstance(server_config, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
transport = server_config.get("transport", "stdio")
|
||||||
|
if transport not in ("streamable-http", "http", "sse"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if cfg.get("mcp_oauth"):
|
||||||
|
server_config = await _maybe_refresh_mcp_oauth_token(
|
||||||
|
session, connector, cfg, server_config
|
||||||
|
)
|
||||||
|
cfg = connector.config or {}
|
||||||
|
server_config = _inject_oauth_headers(cfg, server_config)
|
||||||
|
if server_config is None:
|
||||||
|
logger.info(
|
||||||
|
"discover_single_mcp_connector: OAuth token unavailable for connector %d",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
ct = (
|
||||||
|
connector.connector_type.value
|
||||||
|
if hasattr(connector.connector_type, "value")
|
||||||
|
else str(connector.connector_type)
|
||||||
|
)
|
||||||
|
svc_cfg = get_service_by_connector_type(ct)
|
||||||
|
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||||
|
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||||
|
|
||||||
|
await asyncio.wait_for(
|
||||||
|
_load_http_mcp_tools(
|
||||||
|
connector.id,
|
||||||
|
connector.name,
|
||||||
|
server_config,
|
||||||
|
trusted_tools=cfg.get("trusted_tools", []),
|
||||||
|
allowed_tools=allowed_tools,
|
||||||
|
readonly_tools=readonly_tools,
|
||||||
|
tool_name_prefix=None,
|
||||||
|
is_generic_mcp=svc_cfg is None,
|
||||||
|
bypass_internal_hitl=True,
|
||||||
|
cached_tools=None,
|
||||||
|
),
|
||||||
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_prefetch] connector=%s elapsed=%.3fs",
|
||||||
|
connector_id,
|
||||||
|
time.perf_counter() - started,
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"discover_single_mcp_connector: connector %d timed out after %ds",
|
||||||
|
connector_id,
|
||||||
|
_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"discover_single_mcp_connector: failed for connector %d",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def load_mcp_tools(
|
async def load_mcp_tools(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
@ -1063,6 +1221,7 @@ async def load_mcp_tools(
|
||||||
"tool_name_prefix": tool_name_prefix,
|
"tool_name_prefix": tool_name_prefix,
|
||||||
"transport": server_config.get("transport", "stdio"),
|
"transport": server_config.get("transport", "stdio"),
|
||||||
"is_generic_mcp": svc_cfg is None,
|
"is_generic_mcp": svc_cfg is None,
|
||||||
|
"cached_tools": read_cached_tools(connector),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1074,9 +1233,12 @@ async def load_mcp_tools(
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
|
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
|
||||||
|
discover_start = time.perf_counter()
|
||||||
|
transport = task["transport"]
|
||||||
|
cached_tools = task.get("cached_tools")
|
||||||
try:
|
try:
|
||||||
if task["transport"] in ("streamable-http", "http", "sse"):
|
if transport in ("streamable-http", "http", "sse"):
|
||||||
return await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
_load_http_mcp_tools(
|
_load_http_mcp_tools(
|
||||||
task["connector_id"],
|
task["connector_id"],
|
||||||
task["connector_name"],
|
task["connector_name"],
|
||||||
|
|
@ -1087,11 +1249,12 @@ async def load_mcp_tools(
|
||||||
tool_name_prefix=task["tool_name_prefix"],
|
tool_name_prefix=task["tool_name_prefix"],
|
||||||
is_generic_mcp=task.get("is_generic_mcp", False),
|
is_generic_mcp=task.get("is_generic_mcp", False),
|
||||||
bypass_internal_hitl=bypass_internal_hitl,
|
bypass_internal_hitl=bypass_internal_hitl,
|
||||||
|
cached_tools=cached_tools,
|
||||||
),
|
),
|
||||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
_load_stdio_mcp_tools(
|
_load_stdio_mcp_tools(
|
||||||
task["connector_id"],
|
task["connector_id"],
|
||||||
task["connector_name"],
|
task["connector_name"],
|
||||||
|
|
@ -1101,7 +1264,24 @@ async def load_mcp_tools(
|
||||||
),
|
),
|
||||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs cache=%s",
|
||||||
|
task["connector_id"],
|
||||||
|
task["connector_name"],
|
||||||
|
transport,
|
||||||
|
len(result),
|
||||||
|
time.perf_counter() - discover_start,
|
||||||
|
"hit" if cached_tools is not None else "miss",
|
||||||
|
)
|
||||||
|
return result
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_discover] connector=%s name=%r transport=%s elapsed=%.3fs outcome=timeout",
|
||||||
|
task["connector_id"],
|
||||||
|
task["connector_name"],
|
||||||
|
transport,
|
||||||
|
time.perf_counter() - discover_start,
|
||||||
|
)
|
||||||
logger.error(
|
logger.error(
|
||||||
"MCP connector %d timed out after %ds during discovery",
|
"MCP connector %d timed out after %ds during discovery",
|
||||||
task["connector_id"],
|
task["connector_id"],
|
||||||
|
|
@ -1109,6 +1289,13 @@ async def load_mcp_tools(
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_discover] connector=%s name=%r transport=%s elapsed=%.3fs outcome=error",
|
||||||
|
task["connector_id"],
|
||||||
|
task["connector_name"],
|
||||||
|
transport,
|
||||||
|
time.perf_counter() - discover_start,
|
||||||
|
)
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to load tools from MCP connector %d: %s",
|
"Failed to load tools from MCP connector %d: %s",
|
||||||
task["connector_id"],
|
task["connector_id"],
|
||||||
|
|
@ -1116,7 +1303,14 @@ async def load_mcp_tools(
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
gather_start = time.perf_counter()
|
||||||
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
|
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_discover] gather_wall=%.3fs connectors=%d total_tools=%d",
|
||||||
|
time.perf_counter() - gather_start,
|
||||||
|
len(discovery_tasks),
|
||||||
|
sum(len(r) for r in results),
|
||||||
|
)
|
||||||
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
|
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
|
||||||
|
|
||||||
_mcp_tools_cache[cache_key] = (now, tools)
|
_mcp_tools_cache[cache_key] = (now, tools)
|
||||||
|
|
|
||||||
145
surfsense_backend/app/agents/new_chat/tools/mcp_tools_cache.py
Normal file
145
surfsense_backend/app/agents/new_chat/tools/mcp_tools_cache.py
Normal file
|
|
@ -0,0 +1,145 @@
|
||||||
|
"""Persist MCP ``list_tools`` results in ``SearchSourceConnector.config.cached_tools``."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, async_session_maker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_pending_prefetch_tasks: set[asyncio.Task[None]] = set()
|
||||||
|
|
||||||
|
|
||||||
|
class CachedMCPToolDef(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
input_schema: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CachedMCPTools(BaseModel):
|
||||||
|
discovered_at: datetime
|
||||||
|
server_version: str | None = None
|
||||||
|
server_name: str | None = None
|
||||||
|
transport: str | None = None
|
||||||
|
tools: list[CachedMCPToolDef]
|
||||||
|
|
||||||
|
|
||||||
|
def read_cached_tools(connector: SearchSourceConnector) -> CachedMCPTools | None:
|
||||||
|
"""Return parsed cached tools or ``None`` if missing / corrupt (caller falls back to live discovery)."""
|
||||||
|
cfg = connector.config or {}
|
||||||
|
raw = cfg.get("cached_tools")
|
||||||
|
if not raw or not isinstance(raw, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return CachedMCPTools.model_validate(raw)
|
||||||
|
except ValidationError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"MCP connector %d has corrupt cached_tools — falling back to live discovery: %s",
|
||||||
|
connector.id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def write_cached_tools(
|
||||||
|
connector_id: int,
|
||||||
|
tool_definitions: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
server_version: str | None = None,
|
||||||
|
transport: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Best-effort persist; uses its own session so a write failure cannot poison the caller's transaction."""
|
||||||
|
payload = CachedMCPTools(
|
||||||
|
discovered_at=datetime.now(UTC),
|
||||||
|
server_version=server_version,
|
||||||
|
server_name=server_name,
|
||||||
|
transport=transport,
|
||||||
|
tools=[CachedMCPToolDef.model_validate(td) for td in tool_definitions],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == connector_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if connector is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
cfg = dict(connector.config or {})
|
||||||
|
cfg["cached_tools"] = payload.model_dump(mode="json")
|
||||||
|
connector.config = cfg
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Persisted cached_tools for MCP connector %d (%d tools)",
|
||||||
|
connector_id,
|
||||||
|
len(payload.tools),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist cached_tools for MCP connector %d",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_mcp_tools_cache_for_connector(
|
||||||
|
connector_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Maintain the MCP tool cache after a single-connector lifecycle event.
|
||||||
|
|
||||||
|
Synchronously evicts the in-process LRU for the connector's search space
|
||||||
|
(LRU keys are per-space, so eviction cannot be scoped finer), then schedules
|
||||||
|
a background live discovery for this connector alone so its persisted
|
||||||
|
``cached_tools`` row is refreshed before the next user query.
|
||||||
|
|
||||||
|
Idempotent. Eviction is best-effort; prefetch is best-effort and only runs
|
||||||
|
when an event loop is available. Neither path raises.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||||
|
|
||||||
|
invalidate_mcp_tools_cache(search_space_id)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"MCP in-process cache eviction skipped for space %d",
|
||||||
|
search_space_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
return
|
||||||
|
|
||||||
|
task = loop.create_task(_run_connector_prefetch(connector_id))
|
||||||
|
_pending_prefetch_tasks.add(task)
|
||||||
|
task.add_done_callback(_pending_prefetch_tasks.discard)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_connector_prefetch(connector_id: int) -> None:
|
||||||
|
from app.agents.new_chat.tools.mcp_tool import discover_single_mcp_connector
|
||||||
|
|
||||||
|
try:
|
||||||
|
await discover_single_mcp_connector(connector_id)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"MCP background prefetch failed for connector_id=%d",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
@ -1,369 +1,53 @@
|
||||||
"""Markdown-document memory tool for the SurfSense agent.
|
"""Memory update tools backed by the canonical memory service."""
|
||||||
|
|
||||||
Replaces the old row-per-fact save_memory / recall_memory tools with a single
|
|
||||||
update_memory tool that overwrites a freeform markdown TEXT column. The LLM
|
|
||||||
always sees the current memory in <user_memory> / <team_memory> tags injected
|
|
||||||
by MemoryInjectionMiddleware, so it passes the FULL updated document each time.
|
|
||||||
|
|
||||||
Overflow handling:
|
|
||||||
- Soft limit (18K chars): a warning is returned telling the agent to
|
|
||||||
consolidate on the next update.
|
|
||||||
- Hard limit (25K chars): a forced LLM-driven rewrite compresses the document.
|
|
||||||
If it still exceeds the limit after rewriting, the save is rejected.
|
|
||||||
- Diff validation: warns when entire ``##`` sections are dropped or when the
|
|
||||||
document shrinks by more than 60%.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
from typing import Any
|
||||||
from typing import Any, Literal
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import SearchSpace, User, async_session_maker
|
from app.db import async_session_maker
|
||||||
from app.utils.content_utils import extract_text_content
|
from app.services.memory import MemoryScope, save_memory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MEMORY_SOFT_LIMIT = 18_000
|
|
||||||
MEMORY_HARD_LIMIT = 25_000
|
|
||||||
|
|
||||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
|
||||||
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
|
|
||||||
|
|
||||||
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
|
|
||||||
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
|
|
||||||
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Diff validation
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_headings(memory: str) -> set[str]:
|
|
||||||
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
|
|
||||||
return set(_SECTION_HEADING_RE.findall(memory))
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_heading(heading: str) -> str:
|
|
||||||
"""Normalize heading text for robust scope checks."""
|
|
||||||
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_memory_scope(
|
|
||||||
content: str, scope: Literal["user", "team"]
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Reject personal-only markers ([pref], [instr]) in team memory."""
|
|
||||||
if scope != "team":
|
|
||||||
return None
|
|
||||||
|
|
||||||
markers = set(_MARKER_RE.findall(content))
|
|
||||||
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
|
|
||||||
if leaked:
|
|
||||||
tags = ", ".join(f"[{m}]" for m in leaked)
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": (
|
|
||||||
f"Team memory cannot include personal markers: {tags}. "
|
|
||||||
"Use [fact] only in team memory."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_bullet_format(content: str) -> list[str]:
|
|
||||||
"""Return warnings for bullet lines that don't match the required format.
|
|
||||||
|
|
||||||
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
|
|
||||||
"""
|
|
||||||
warnings: list[str] = []
|
|
||||||
for line in content.splitlines():
|
|
||||||
stripped = line.strip()
|
|
||||||
if not stripped.startswith("- "):
|
|
||||||
continue
|
|
||||||
if not _BULLET_FORMAT_RE.match(stripped):
|
|
||||||
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
|
|
||||||
warnings.append(f"Malformed bullet: {short}")
|
|
||||||
return warnings
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
|
||||||
"""Return a list of warning strings about suspicious changes."""
|
|
||||||
if not old_memory:
|
|
||||||
return []
|
|
||||||
|
|
||||||
warnings: list[str] = []
|
|
||||||
old_headings = _extract_headings(old_memory)
|
|
||||||
new_headings = _extract_headings(new_memory)
|
|
||||||
dropped = old_headings - new_headings
|
|
||||||
if dropped:
|
|
||||||
names = ", ".join(sorted(dropped))
|
|
||||||
warnings.append(
|
|
||||||
f"Sections removed: {names}. "
|
|
||||||
"If unintentional, the user can restore from the settings page."
|
|
||||||
)
|
|
||||||
|
|
||||||
old_len = len(old_memory)
|
|
||||||
new_len = len(new_memory)
|
|
||||||
if old_len > 0 and new_len < old_len * 0.4:
|
|
||||||
warnings.append(
|
|
||||||
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
|
|
||||||
"Possible data loss."
|
|
||||||
)
|
|
||||||
return warnings
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Size validation & soft warning
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_memory_size(content: str) -> dict[str, Any] | None:
|
|
||||||
"""Return an error/warning dict if *content* is too large, else None."""
|
|
||||||
length = len(content)
|
|
||||||
if length > MEMORY_HARD_LIMIT:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": (
|
|
||||||
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
|
|
||||||
f"({length:,} chars). Consolidate by merging related items, "
|
|
||||||
"removing outdated entries, and shortening descriptions. "
|
|
||||||
"Then call update_memory again."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _soft_warning(content: str) -> str | None:
|
|
||||||
"""Return a warning string if content exceeds the soft limit."""
|
|
||||||
length = len(content)
|
|
||||||
if length > MEMORY_SOFT_LIMIT:
|
|
||||||
return (
|
|
||||||
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
|
|
||||||
"Consolidate by merging related items and removing less important "
|
|
||||||
"entries on your next update."
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Forced rewrite when memory exceeds the hard limit
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_FORCED_REWRITE_PROMPT = """\
|
|
||||||
You are a memory curator. The following memory document exceeds the character \
|
|
||||||
limit and must be shortened.
|
|
||||||
|
|
||||||
RULES:
|
|
||||||
1. Rewrite the document to be under {target} characters.
|
|
||||||
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
|
|
||||||
or rename headings to consolidate, but keep names personal and descriptive.
|
|
||||||
3. Priority for keeping content: [instr] > [pref] > [fact].
|
|
||||||
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
|
|
||||||
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
|
|
||||||
6. Preserve the user's first name in entries — do not replace it with "the user".
|
|
||||||
7. Output ONLY the consolidated markdown — no explanations, no wrapping.
|
|
||||||
|
|
||||||
<memory_document>
|
|
||||||
{content}
|
|
||||||
</memory_document>"""
|
|
||||||
|
|
||||||
|
|
||||||
async def _forced_rewrite(content: str, llm: Any) -> str | None:
|
|
||||||
"""Use a focused LLM call to compress *content* under the hard limit.
|
|
||||||
|
|
||||||
Returns the rewritten string, or ``None`` if the call fails.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
prompt = _FORCED_REWRITE_PROMPT.format(
|
|
||||||
target=MEMORY_HARD_LIMIT, content=content
|
|
||||||
)
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[HumanMessage(content=prompt)],
|
|
||||||
config={"tags": ["surfsense:internal"]},
|
|
||||||
)
|
|
||||||
text = extract_text_content(response.content).strip()
|
|
||||||
if not text:
|
|
||||||
logger.warning("Forced rewrite returned empty text; aborting rewrite")
|
|
||||||
return None
|
|
||||||
return text
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Forced rewrite LLM call failed")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Shared save-and-respond logic
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def _save_memory(
|
|
||||||
*,
|
|
||||||
updated_memory: str,
|
|
||||||
old_memory: str | None,
|
|
||||||
llm: Any | None,
|
|
||||||
apply_fn,
|
|
||||||
commit_fn,
|
|
||||||
rollback_fn,
|
|
||||||
label: str,
|
|
||||||
scope: Literal["user", "team"],
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
|
||||||
return a response dict.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
updated_memory : str
|
|
||||||
The new document the agent submitted.
|
|
||||||
old_memory : str | None
|
|
||||||
The previously persisted document (for diff checks).
|
|
||||||
llm : Any | None
|
|
||||||
LLM instance for forced rewrite (may be ``None``).
|
|
||||||
apply_fn : callable(str) -> None
|
|
||||||
Callback that sets the new memory on the ORM object.
|
|
||||||
commit_fn : coroutine
|
|
||||||
``session.commit``.
|
|
||||||
rollback_fn : coroutine
|
|
||||||
``session.rollback``.
|
|
||||||
label : str
|
|
||||||
Human label for log messages (e.g. "user memory", "team memory").
|
|
||||||
"""
|
|
||||||
if not isinstance(updated_memory, str):
|
|
||||||
logger.warning(
|
|
||||||
"Refusing non-string memory payload (type=%s)",
|
|
||||||
type(updated_memory).__name__,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Internal error: memory payload must be a string.",
|
|
||||||
}
|
|
||||||
|
|
||||||
content = updated_memory
|
|
||||||
|
|
||||||
# --- forced rewrite if over the hard limit ---
|
|
||||||
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
|
|
||||||
rewritten = await _forced_rewrite(content, llm)
|
|
||||||
if rewritten is not None and len(rewritten) < len(content):
|
|
||||||
content = rewritten
|
|
||||||
|
|
||||||
# --- hard-limit gate (reject if still too large after rewrite) ---
|
|
||||||
size_err = _validate_memory_size(content)
|
|
||||||
if size_err:
|
|
||||||
return size_err
|
|
||||||
|
|
||||||
scope_err = _validate_memory_scope(content, scope)
|
|
||||||
if scope_err:
|
|
||||||
return scope_err
|
|
||||||
|
|
||||||
# --- persist ---
|
|
||||||
try:
|
|
||||||
apply_fn(content)
|
|
||||||
await commit_fn()
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Failed to update %s: %s", label, e)
|
|
||||||
await rollback_fn()
|
|
||||||
return {"status": "error", "message": f"Failed to update {label}: {e}"}
|
|
||||||
|
|
||||||
# --- build response ---
|
|
||||||
resp: dict[str, Any] = {
|
|
||||||
"status": "saved",
|
|
||||||
"message": f"{label.capitalize()} updated.",
|
|
||||||
}
|
|
||||||
|
|
||||||
if content is not updated_memory:
|
|
||||||
resp["notice"] = "Memory was automatically rewritten to fit within limits."
|
|
||||||
|
|
||||||
diff_warnings = _validate_diff(old_memory, content)
|
|
||||||
if diff_warnings:
|
|
||||||
resp["diff_warnings"] = diff_warnings
|
|
||||||
|
|
||||||
format_warnings = _validate_bullet_format(content)
|
|
||||||
if format_warnings:
|
|
||||||
resp["format_warnings"] = format_warnings
|
|
||||||
|
|
||||||
warning = _soft_warning(content)
|
|
||||||
if warning:
|
|
||||||
resp["warning"] = warning
|
|
||||||
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Tool factories
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def create_update_memory_tool(
|
def create_update_memory_tool(
|
||||||
user_id: str | UUID,
|
user_id: str | UUID,
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
llm: Any | None = None,
|
llm: Any | None = None,
|
||||||
):
|
):
|
||||||
"""Factory function to create the user-memory update tool.
|
"""Factory for the user-memory update tool.
|
||||||
|
|
||||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
Uses a fresh short-lived session per call so compiled-agent caches never
|
||||||
:data:`async_session_maker` so the closure is safe to share across
|
retain a stale request-scoped session.
|
||||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
|
||||||
session here would surface stale/closed sessions on cache hits.
|
|
||||||
The session's bound ``commit``/``rollback`` methods are captured at
|
|
||||||
call time, after ``async with`` has bound ``db_session`` locally.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the user whose memory document is being updated.
|
|
||||||
db_session: Reserved for registry compatibility. Per-call sessions
|
|
||||||
are opened via :data:`async_session_maker` inside the tool body.
|
|
||||||
llm: Optional LLM for the forced-rewrite path.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured update_memory tool for the user-memory scope.
|
|
||||||
"""
|
"""
|
||||||
del db_session # per-call session — see docstring
|
del db_session
|
||||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||||
"""Update the user's personal memory document.
|
"""Update the user's personal memory document.
|
||||||
|
|
||||||
Your current memory is shown in <user_memory> in the system prompt.
|
The current memory is shown in <user_memory>. Pass the FULL updated
|
||||||
When the user shares important long-term information (preferences,
|
markdown document, not a diff.
|
||||||
facts, instructions, context), rewrite the memory document to include
|
|
||||||
the new information. Merge new facts with existing ones, update
|
|
||||||
contradictions, remove outdated entries, and keep it concise.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
updated_memory: The FULL updated markdown document (not a diff).
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with async_session_maker() as db_session:
|
async with async_session_maker() as db_session:
|
||||||
result = await db_session.execute(select(User).where(User.id == uid))
|
result = await save_memory(
|
||||||
user = result.scalars().first()
|
scope=MemoryScope.USER,
|
||||||
if not user:
|
target_id=uid,
|
||||||
return {"status": "error", "message": "User not found."}
|
content=updated_memory,
|
||||||
|
session=db_session,
|
||||||
old_memory = user.memory_md
|
|
||||||
|
|
||||||
return await _save_memory(
|
|
||||||
updated_memory=updated_memory,
|
|
||||||
old_memory=old_memory,
|
|
||||||
llm=llm,
|
llm=llm,
|
||||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
|
||||||
commit_fn=db_session.commit,
|
|
||||||
rollback_fn=db_session.rollback,
|
|
||||||
label="memory",
|
|
||||||
scope="user",
|
|
||||||
)
|
)
|
||||||
|
return result.to_dict()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to update user memory: %s", e)
|
logger.exception("Failed to update user memory: %s", e)
|
||||||
return {
|
return {"status": "error", "message": f"Failed to update memory: {e}"}
|
||||||
"status": "error",
|
|
||||||
"message": f"Failed to update memory: {e}",
|
|
||||||
}
|
|
||||||
|
|
||||||
return update_memory
|
return update_memory
|
||||||
|
|
||||||
|
|
@ -373,64 +57,26 @@ def create_update_team_memory_tool(
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
llm: Any | None = None,
|
llm: Any | None = None,
|
||||||
):
|
):
|
||||||
"""Factory function to create the team-memory update tool.
|
"""Factory for the team-memory update tool."""
|
||||||
|
del db_session
|
||||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
|
||||||
:data:`async_session_maker` so the closure is safe to share across
|
|
||||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
|
||||||
session here would surface stale/closed sessions on cache hits.
|
|
||||||
The session's bound ``commit``/``rollback`` methods are captured at
|
|
||||||
call time, after ``async with`` has bound ``db_session`` locally.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
search_space_id: ID of the search space whose team memory is being
|
|
||||||
updated.
|
|
||||||
db_session: Reserved for registry compatibility. Per-call sessions
|
|
||||||
are opened via :data:`async_session_maker` inside the tool body.
|
|
||||||
llm: Optional LLM for the forced-rewrite path.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured update_memory tool for the team-memory scope.
|
|
||||||
"""
|
|
||||||
del db_session # per-call session — see docstring
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||||
"""Update the team's shared memory document for this search space.
|
"""Update the team's shared memory document for this search space.
|
||||||
|
|
||||||
Your current team memory is shown in <team_memory> in the system
|
The current team memory is shown in <team_memory>. Pass the FULL updated
|
||||||
prompt. When the team shares important long-term information
|
markdown document, not a diff.
|
||||||
(decisions, conventions, key facts, priorities), rewrite the memory
|
|
||||||
document to include the new information. Merge new facts with
|
|
||||||
existing ones, update contradictions, remove outdated entries, and
|
|
||||||
keep it concise.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
updated_memory: The FULL updated markdown document (not a diff).
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with async_session_maker() as db_session:
|
async with async_session_maker() as db_session:
|
||||||
result = await db_session.execute(
|
result = await save_memory(
|
||||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
scope=MemoryScope.TEAM,
|
||||||
)
|
target_id=search_space_id,
|
||||||
space = result.scalars().first()
|
content=updated_memory,
|
||||||
if not space:
|
session=db_session,
|
||||||
return {"status": "error", "message": "Search space not found."}
|
|
||||||
|
|
||||||
old_memory = space.shared_memory_md
|
|
||||||
|
|
||||||
return await _save_memory(
|
|
||||||
updated_memory=updated_memory,
|
|
||||||
old_memory=old_memory,
|
|
||||||
llm=llm,
|
llm=llm,
|
||||||
apply_fn=lambda content: setattr(
|
|
||||||
space, "shared_memory_md", content
|
|
||||||
),
|
|
||||||
commit_fn=db_session.commit,
|
|
||||||
rollback_fn=db_session.rollback,
|
|
||||||
label="team memory",
|
|
||||||
scope="team",
|
|
||||||
)
|
)
|
||||||
|
return result.to_dict()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to update team memory: %s", e)
|
logger.exception("Failed to update team memory: %s", e)
|
||||||
return {
|
return {
|
||||||
|
|
@ -439,3 +85,9 @@ def create_update_team_memory_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
return update_memory
|
return update_memory
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_update_memory_tool",
|
||||||
|
"create_update_team_memory_tool",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -110,6 +110,19 @@ def load_global_llm_configs():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to score global LLM configs: {e}")
|
print(f"Warning: Failed to score global LLM configs: {e}")
|
||||||
|
|
||||||
|
# Planner LLM is a singleton role. If an operator accidentally
|
||||||
|
# marks multiple configs ``is_planner: true``, only the first one
|
||||||
|
# is used at runtime — surface the others at startup so the
|
||||||
|
# mistake is caught before traffic, not silently buried.
|
||||||
|
planner_cfgs = [c for c in configs if c.get("is_planner") is True]
|
||||||
|
if len(planner_cfgs) > 1:
|
||||||
|
extra_ids = [c.get("id") for c in planner_cfgs[1:]]
|
||||||
|
print(
|
||||||
|
"Warning: Multiple global LLM configs marked is_planner=true "
|
||||||
|
f"(ids {[c.get('id') for c in planner_cfgs]}); using id "
|
||||||
|
f"{planner_cfgs[0].get('id')} and ignoring {extra_ids}"
|
||||||
|
)
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to load global LLM configs: {e}")
|
print(f"Warning: Failed to load global LLM configs: {e}")
|
||||||
|
|
|
||||||
|
|
@ -258,6 +258,45 @@ global_llm_configs:
|
||||||
use_default_system_instructions: true
|
use_default_system_instructions: true
|
||||||
citations_enabled: true
|
citations_enabled: true
|
||||||
|
|
||||||
|
# Example: Planner LLM - small, fast model used for internal utility tasks
|
||||||
|
#
|
||||||
|
# The PLANNER role handles short, structured internal calls (KB query
|
||||||
|
# rewriting, date extraction, recency classification, etc.) that don't
|
||||||
|
# need frontier-tier capability. Pointing the planner at a cheap+fast
|
||||||
|
# model (gpt-4o-mini, Claude Haiku, Azure gpt-5.x-nano, Groq Llama, ...)
|
||||||
|
# typically saves 500ms-1.5s per turn vs. routing those same internal
|
||||||
|
# calls through the user's chat model.
|
||||||
|
#
|
||||||
|
# Activation:
|
||||||
|
# - Mark EXACTLY ONE global config with ``is_planner: true``.
|
||||||
|
# - If multiple are marked, the first one wins and a WARNING is logged.
|
||||||
|
# - If none is marked, every internal call falls back to the user's
|
||||||
|
# chat LLM (same behavior as before this flag existed).
|
||||||
|
#
|
||||||
|
# This config is operator-only — it is NOT exposed in the user-facing
|
||||||
|
# model selector, never billed against premium quota, and the
|
||||||
|
# billing_tier / anonymous_enabled fields below are ignored.
|
||||||
|
- id: -9
|
||||||
|
name: "Global Planner (GPT-4o mini)"
|
||||||
|
description: "Internal-only planner LLM for query rewriting and classification"
|
||||||
|
is_planner: true
|
||||||
|
billing_tier: "free"
|
||||||
|
anonymous_enabled: false
|
||||||
|
seo_enabled: false
|
||||||
|
quota_reserve_tokens: 1000
|
||||||
|
provider: "OPENAI"
|
||||||
|
model_name: "gpt-4o-mini"
|
||||||
|
api_key: "sk-your-openai-api-key-here"
|
||||||
|
api_base: ""
|
||||||
|
rpm: 3500
|
||||||
|
tpm: 200000
|
||||||
|
litellm_params:
|
||||||
|
temperature: 0
|
||||||
|
max_tokens: 1000
|
||||||
|
system_instructions: ""
|
||||||
|
use_default_system_instructions: true
|
||||||
|
citations_enabled: false
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# OpenRouter Integration
|
# OpenRouter Integration
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -493,6 +532,20 @@ global_vision_llm_configs:
|
||||||
# - Lower temperature (0.3) is recommended for accurate screenshot analysis
|
# - Lower temperature (0.3) is recommended for accurate screenshot analysis
|
||||||
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
|
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
|
||||||
#
|
#
|
||||||
|
# PLANNER LLM NOTES:
|
||||||
|
# - is_planner: true marks a config as the internal-only planner LLM (small,
|
||||||
|
# fast model used for KB query rewriting, date extraction, recency
|
||||||
|
# classification, etc.). Only one config may carry this flag — if
|
||||||
|
# multiple do, the first one wins and a startup WARNING is logged.
|
||||||
|
# - When no config is marked is_planner, every internal utility call falls
|
||||||
|
# back to the user's chat LLM (the historical behavior).
|
||||||
|
# - Planner configs are NOT shown in the user-facing model selector and
|
||||||
|
# are NOT billed against the user's premium quota. Their billing_tier,
|
||||||
|
# anonymous_enabled, seo_* fields are ignored.
|
||||||
|
# - Recommended models: gpt-4o-mini, claude-3-5-haiku, gemini-1.5-flash,
|
||||||
|
# azure gpt-5.x-nano, groq llama3-8b — anything <200ms p50 on a 1-2k
|
||||||
|
# prompt. Frontier models here defeat the purpose of the flag.
|
||||||
|
#
|
||||||
# TOKEN QUOTA & ANONYMOUS ACCESS NOTES:
|
# TOKEN QUOTA & ANONYMOUS ACCESS NOTES:
|
||||||
# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota.
|
# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota.
|
||||||
# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog.
|
# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog.
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,7 @@ from .search_spaces_routes import router as search_spaces_router
|
||||||
from .slack_add_connector_route import router as slack_add_connector_router
|
from .slack_add_connector_route import router as slack_add_connector_router
|
||||||
from .stripe_routes import router as stripe_router
|
from .stripe_routes import router as stripe_router
|
||||||
from .surfsense_docs_routes import router as surfsense_docs_router
|
from .surfsense_docs_routes import router as surfsense_docs_router
|
||||||
|
from .team_memory_routes import router as team_memory_router
|
||||||
from .teams_add_connector_route import router as teams_add_connector_router
|
from .teams_add_connector_route import router as teams_add_connector_router
|
||||||
from .video_presentations_routes import router as video_presentations_router
|
from .video_presentations_routes import router as video_presentations_router
|
||||||
from .vision_llm_routes import router as vision_llm_router
|
from .vision_llm_routes import router as vision_llm_router
|
||||||
|
|
@ -117,3 +118,4 @@ router.include_router(stripe_router) # Stripe checkout for additional page pack
|
||||||
router.include_router(youtube_router) # YouTube playlist resolution
|
router.include_router(youtube_router) # YouTube playlist resolution
|
||||||
router.include_router(prompts_router)
|
router.include_router(prompts_router)
|
||||||
router.include_router(memory_router) # User personal memory (memory.md style)
|
router.include_router(memory_router) # User personal memory (memory.md style)
|
||||||
|
router.include_router(team_memory_router) # Search-space team memory
|
||||||
|
|
|
||||||
|
|
@ -428,7 +428,7 @@ async def mcp_oauth_callback(
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(db_connector)
|
await session.refresh(db_connector)
|
||||||
|
|
||||||
_invalidate_cache(space_id)
|
_refresh_mcp_cache(db_connector.id, space_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Re-authenticated %s MCP connector %s for user %s",
|
"Re-authenticated %s MCP connector %s for user %s",
|
||||||
|
|
@ -481,7 +481,7 @@ async def mcp_oauth_callback(
|
||||||
detail="A connector for this service already exists.",
|
detail="A connector for this service already exists.",
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
_invalidate_cache(space_id)
|
_refresh_mcp_cache(new_connector.id, space_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Created %s MCP connector %s for user %s in space %s",
|
"Created %s MCP connector %s for user %s in space %s",
|
||||||
|
|
@ -658,10 +658,17 @@ async def reauth_mcp_service(
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _invalidate_cache(space_id: int) -> None:
|
def _refresh_mcp_cache(connector_id: int, space_id: int) -> None:
|
||||||
try:
|
"""Evict the in-process MCP tool LRU and schedule background prefetch.
|
||||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
|
||||||
|
|
||||||
invalidate_mcp_tools_cache(space_id)
|
Wraps :func:`refresh_mcp_tools_cache_for_connector` so any failure is
|
||||||
|
isolated from the OAuth response flow.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
refresh_mcp_tools_cache_for_connector,
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_mcp_tools_cache_for_connector(connector_id, space_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("MCP cache invalidation skipped", exc_info=True)
|
logger.debug("MCP cache refresh skipped", exc_info=True)
|
||||||
|
|
|
||||||
|
|
@ -1,75 +1,40 @@
|
||||||
"""Routes for user memory management (personal memory.md)."""
|
"""Routes for user memory management."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import (
|
|
||||||
create_chat_litellm_from_agent_config,
|
|
||||||
load_agent_llm_config_for_search_space,
|
|
||||||
)
|
|
||||||
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
|
|
||||||
from app.db import User, get_async_session
|
from app.db import User, get_async_session
|
||||||
|
from app.services.memory import (
|
||||||
|
MemoryRead,
|
||||||
|
MemoryScope,
|
||||||
|
memory_limits,
|
||||||
|
read_memory,
|
||||||
|
reset_memory,
|
||||||
|
save_memory,
|
||||||
|
)
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.content_utils import extract_text_content
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class MemoryRead(BaseModel):
|
|
||||||
memory_md: str
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryUpdate(BaseModel):
|
class MemoryUpdate(BaseModel):
|
||||||
memory_md: str
|
memory_md: str
|
||||||
|
|
||||||
|
|
||||||
class MemoryEditRequest(BaseModel):
|
|
||||||
query: str
|
|
||||||
search_space_id: int
|
|
||||||
|
|
||||||
|
|
||||||
_MEMORY_EDIT_PROMPT = """\
|
|
||||||
You are a memory editor. The user wants to modify their memory document. \
|
|
||||||
Apply the user's instruction to the existing memory document and output the \
|
|
||||||
FULL updated document.
|
|
||||||
|
|
||||||
RULES:
|
|
||||||
1. If the instruction asks to add something, add it with format: \
|
|
||||||
- (YYYY-MM-DD) [fact|pref|instr] text, under an existing or new ## heading. \
|
|
||||||
Heading names should be personal and descriptive, not generic categories.
|
|
||||||
2. If the instruction asks to remove something, remove the matching entry.
|
|
||||||
3. If the instruction asks to change something, update the matching entry.
|
|
||||||
4. Preserve existing ## headings and all other entries.
|
|
||||||
5. Every bullet must include a marker: [fact], [pref], or [instr].
|
|
||||||
6. Use the user's first name (from <user_name>) in entries instead of "the user".
|
|
||||||
7. Output ONLY the updated markdown — no explanations, no wrapping.
|
|
||||||
|
|
||||||
<user_name>{user_name}</user_name>
|
|
||||||
|
|
||||||
<current_memory>
|
|
||||||
{current_memory}
|
|
||||||
</current_memory>
|
|
||||||
|
|
||||||
<user_instruction>
|
|
||||||
{instruction}
|
|
||||||
</user_instruction>"""
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/users/me/memory", response_model=MemoryRead)
|
@router.get("/users/me/memory", response_model=MemoryRead)
|
||||||
async def get_user_memory(
|
async def get_user_memory(
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
):
|
):
|
||||||
await session.refresh(user, ["memory_md"])
|
memory_md = await read_memory(
|
||||||
return MemoryRead(memory_md=user.memory_md or "")
|
scope=MemoryScope.USER,
|
||||||
|
target_id=user.id,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
return MemoryRead(memory_md=memory_md, limits=memory_limits())
|
||||||
|
|
||||||
|
|
||||||
@router.put("/users/me/memory", response_model=MemoryRead)
|
@router.put("/users/me/memory", response_model=MemoryRead)
|
||||||
|
|
@ -78,73 +43,27 @@ async def update_user_memory(
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
):
|
):
|
||||||
if len(body.memory_md) > MEMORY_HARD_LIMIT:
|
result = await save_memory(
|
||||||
raise HTTPException(
|
scope=MemoryScope.USER,
|
||||||
status_code=400,
|
target_id=user.id,
|
||||||
detail=f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit ({len(body.memory_md):,} chars).",
|
content=body.memory_md,
|
||||||
)
|
session=session,
|
||||||
user.memory_md = body.memory_md
|
)
|
||||||
session.add(user)
|
if result.status == "error":
|
||||||
await session.commit()
|
raise HTTPException(status_code=400, detail=result.message)
|
||||||
await session.refresh(user, ["memory_md"])
|
return MemoryRead(memory_md=result.memory_md, limits=memory_limits())
|
||||||
return MemoryRead(memory_md=user.memory_md or "")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/users/me/memory/edit", response_model=MemoryRead)
|
@router.post("/users/me/memory/reset", response_model=MemoryRead)
|
||||||
async def edit_user_memory(
|
async def reset_user_memory(
|
||||||
body: MemoryEditRequest,
|
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
):
|
):
|
||||||
"""Apply a natural language edit to the user's personal memory via LLM."""
|
result = await reset_memory(
|
||||||
agent_config = await load_agent_llm_config_for_search_space(
|
scope=MemoryScope.USER,
|
||||||
session, body.search_space_id
|
target_id=user.id,
|
||||||
|
session=session,
|
||||||
)
|
)
|
||||||
if not agent_config:
|
if result.status == "error":
|
||||||
raise HTTPException(status_code=500, detail="No LLM configuration available.")
|
raise HTTPException(status_code=400, detail=result.message)
|
||||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
return MemoryRead(memory_md=result.memory_md, limits=memory_limits())
|
||||||
if not llm:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to create LLM instance.")
|
|
||||||
|
|
||||||
await session.refresh(user, ["memory_md", "display_name"])
|
|
||||||
current_memory = user.memory_md or ""
|
|
||||||
first_name = (
|
|
||||||
user.display_name.strip().split()[0]
|
|
||||||
if user.display_name and user.display_name.strip()
|
|
||||||
else "The user"
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = _MEMORY_EDIT_PROMPT.format(
|
|
||||||
current_memory=current_memory or "(empty)",
|
|
||||||
instruction=body.query,
|
|
||||||
user_name=first_name,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[HumanMessage(content=prompt)],
|
|
||||||
config={"tags": ["surfsense:internal", "memory-edit"]},
|
|
||||||
)
|
|
||||||
updated = extract_text_content(response.content).strip()
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Memory edit LLM call failed: %s", e)
|
|
||||||
raise HTTPException(status_code=500, detail="Memory edit failed.") from e
|
|
||||||
|
|
||||||
if not updated:
|
|
||||||
raise HTTPException(status_code=400, detail="LLM returned empty result.")
|
|
||||||
|
|
||||||
result = await _save_memory(
|
|
||||||
updated_memory=updated,
|
|
||||||
old_memory=current_memory,
|
|
||||||
llm=llm,
|
|
||||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
|
||||||
commit_fn=session.commit,
|
|
||||||
rollback_fn=session.rollback,
|
|
||||||
label="memory",
|
|
||||||
scope="user",
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.get("status") == "error":
|
|
||||||
raise HTTPException(status_code=400, detail=result["message"])
|
|
||||||
|
|
||||||
await session.refresh(user, ["memory_md"])
|
|
||||||
return MemoryRead(memory_md=user.memory_md or "")
|
|
||||||
|
|
|
||||||
|
|
@ -2650,9 +2650,11 @@ async def create_mcp_connector(
|
||||||
f"for user {user.id} in search space {search_space_id}"
|
f"for user {user.id} in search space {search_space_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
refresh_mcp_tools_cache_for_connector,
|
||||||
|
)
|
||||||
|
|
||||||
invalidate_mcp_tools_cache(search_space_id)
|
refresh_mcp_tools_cache_for_connector(db_connector.id, search_space_id)
|
||||||
|
|
||||||
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
|
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
|
||||||
return MCPConnectorRead.from_connector(connector_read)
|
return MCPConnectorRead.from_connector(connector_read)
|
||||||
|
|
@ -2828,9 +2830,11 @@ async def update_mcp_connector(
|
||||||
|
|
||||||
logger.info(f"Updated MCP connector {connector_id}")
|
logger.info(f"Updated MCP connector {connector_id}")
|
||||||
|
|
||||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
refresh_mcp_tools_cache_for_connector,
|
||||||
|
)
|
||||||
|
|
||||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
refresh_mcp_tools_cache_for_connector(connector.id, connector.search_space_id)
|
||||||
|
|
||||||
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
||||||
return MCPConnectorRead.from_connector(connector_read)
|
return MCPConnectorRead.from_connector(connector_read)
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from pydantic import BaseModel as PydanticBaseModel
|
|
||||||
from sqlalchemy import func, update
|
from sqlalchemy import func, update
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import (
|
|
||||||
create_chat_litellm_from_agent_config,
|
|
||||||
load_agent_llm_config_for_search_space,
|
|
||||||
)
|
|
||||||
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ImageGenerationConfig,
|
ImageGenerationConfig,
|
||||||
|
|
@ -35,7 +28,6 @@ from app.schemas import (
|
||||||
SearchSpaceWithStats,
|
SearchSpaceWithStats,
|
||||||
)
|
)
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.content_utils import extract_text_content
|
|
||||||
from app.utils.rbac import check_permission, check_search_space_access
|
from app.utils.rbac import check_permission, check_search_space_access
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -43,34 +35,6 @@ logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class _TeamMemoryEditRequest(PydanticBaseModel):
|
|
||||||
query: str
|
|
||||||
|
|
||||||
|
|
||||||
_TEAM_MEMORY_EDIT_PROMPT = """\
|
|
||||||
You are a memory editor for a team workspace. The user wants to modify the \
|
|
||||||
team's shared memory document. Apply the user's instruction to the existing \
|
|
||||||
memory document and output the FULL updated document.
|
|
||||||
|
|
||||||
RULES:
|
|
||||||
1. If the instruction asks to add something, add it with format: \
|
|
||||||
- (YYYY-MM-DD) [fact] text, under an existing or new ## heading. \
|
|
||||||
Heading names should be descriptive, not generic categories.
|
|
||||||
2. If the instruction asks to remove something, remove the matching entry.
|
|
||||||
3. If the instruction asks to change something, update the matching entry.
|
|
||||||
4. Preserve existing ## headings and all other entries.
|
|
||||||
5. NEVER use [pref] or [instr] markers. Team memory uses [fact] only.
|
|
||||||
6. Output ONLY the updated markdown — no explanations, no wrapping.
|
|
||||||
|
|
||||||
<current_memory>
|
|
||||||
{current_memory}
|
|
||||||
</current_memory>
|
|
||||||
|
|
||||||
<user_instruction>
|
|
||||||
{instruction}
|
|
||||||
</user_instruction>"""
|
|
||||||
|
|
||||||
|
|
||||||
async def create_default_roles_and_membership(
|
async def create_default_roles_and_membership(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
@ -294,15 +258,6 @@ async def update_search_space(
|
||||||
|
|
||||||
update_data = search_space_update.model_dump(exclude_unset=True)
|
update_data = search_space_update.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
if (
|
|
||||||
"shared_memory_md" in update_data
|
|
||||||
and len(update_data["shared_memory_md"] or "") > MEMORY_HARD_LIMIT
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Team memory exceeds {MEMORY_HARD_LIMIT:,} character limit.",
|
|
||||||
)
|
|
||||||
|
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
setattr(db_search_space, key, value)
|
setattr(db_search_space, key, value)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
@ -317,72 +272,6 @@ async def update_search_space(
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/searchspaces/{search_space_id}/memory/edit",
|
|
||||||
response_model=SearchSpaceRead,
|
|
||||||
)
|
|
||||||
async def edit_team_memory(
|
|
||||||
search_space_id: int,
|
|
||||||
body: _TeamMemoryEditRequest,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""Apply a natural language edit to the team memory via LLM."""
|
|
||||||
await check_search_space_access(session, user, search_space_id)
|
|
||||||
|
|
||||||
agent_config = await load_agent_llm_config_for_search_space(
|
|
||||||
session, search_space_id
|
|
||||||
)
|
|
||||||
if not agent_config:
|
|
||||||
raise HTTPException(status_code=500, detail="No LLM configuration available.")
|
|
||||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
|
||||||
if not llm:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to create LLM instance.")
|
|
||||||
|
|
||||||
result = await session.execute(
|
|
||||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
|
||||||
)
|
|
||||||
db_search_space = result.scalars().first()
|
|
||||||
if not db_search_space:
|
|
||||||
raise HTTPException(status_code=404, detail="Search space not found")
|
|
||||||
|
|
||||||
current_memory = db_search_space.shared_memory_md or ""
|
|
||||||
|
|
||||||
prompt = _TEAM_MEMORY_EDIT_PROMPT.format(
|
|
||||||
current_memory=current_memory or "(empty)",
|
|
||||||
instruction=body.query,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[HumanMessage(content=prompt)],
|
|
||||||
config={"tags": ["surfsense:internal", "memory-edit"]},
|
|
||||||
)
|
|
||||||
updated = extract_text_content(response.content).strip()
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Team memory edit LLM call failed: %s", e)
|
|
||||||
raise HTTPException(status_code=500, detail="Team memory edit failed.") from e
|
|
||||||
|
|
||||||
if not updated:
|
|
||||||
raise HTTPException(status_code=400, detail="LLM returned empty result.")
|
|
||||||
|
|
||||||
save_result = await _save_memory(
|
|
||||||
updated_memory=updated,
|
|
||||||
old_memory=current_memory,
|
|
||||||
llm=llm,
|
|
||||||
apply_fn=lambda content: setattr(db_search_space, "shared_memory_md", content),
|
|
||||||
commit_fn=session.commit,
|
|
||||||
rollback_fn=session.rollback,
|
|
||||||
label="team memory",
|
|
||||||
scope="team",
|
|
||||||
)
|
|
||||||
|
|
||||||
if save_result.get("status") == "error":
|
|
||||||
raise HTTPException(status_code=400, detail=save_result["message"])
|
|
||||||
|
|
||||||
await session.refresh(db_search_space)
|
|
||||||
return db_search_space
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/searchspaces/{search_space_id}/ai-sort")
|
@router.post("/searchspaces/{search_space_id}/ai-sort")
|
||||||
async def trigger_ai_sort(
|
async def trigger_ai_sort(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
|
||||||
76
surfsense_backend/app/routes/team_memory_routes.py
Normal file
76
surfsense_backend/app/routes/team_memory_routes.py
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
"""Routes for search-space team memory."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import User, get_async_session
|
||||||
|
from app.services.memory import (
|
||||||
|
MemoryRead,
|
||||||
|
MemoryScope,
|
||||||
|
memory_limits,
|
||||||
|
read_memory,
|
||||||
|
reset_memory,
|
||||||
|
save_memory,
|
||||||
|
)
|
||||||
|
from app.users import current_active_user
|
||||||
|
from app.utils.rbac import check_search_space_access
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class TeamMemoryUpdate(BaseModel):
|
||||||
|
memory_md: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/searchspaces/{search_space_id}/memory", response_model=MemoryRead)
|
||||||
|
async def get_team_memory(
|
||||||
|
search_space_id: int,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
await check_search_space_access(session, user, search_space_id)
|
||||||
|
memory_md = await read_memory(
|
||||||
|
scope=MemoryScope.TEAM,
|
||||||
|
target_id=search_space_id,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
return MemoryRead(memory_md=memory_md, limits=memory_limits())
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/searchspaces/{search_space_id}/memory", response_model=MemoryRead)
|
||||||
|
async def update_team_memory(
|
||||||
|
search_space_id: int,
|
||||||
|
body: TeamMemoryUpdate,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
await check_search_space_access(session, user, search_space_id)
|
||||||
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.TEAM,
|
||||||
|
target_id=search_space_id,
|
||||||
|
content=body.memory_md,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
if result.status == "error":
|
||||||
|
raise HTTPException(status_code=400, detail=result.message)
|
||||||
|
return MemoryRead(memory_md=result.memory_md, limits=memory_limits())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/searchspaces/{search_space_id}/memory/reset", response_model=MemoryRead)
|
||||||
|
async def reset_team_memory(
|
||||||
|
search_space_id: int,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
await check_search_space_access(session, user, search_space_id)
|
||||||
|
result = await reset_memory(
|
||||||
|
scope=MemoryScope.TEAM,
|
||||||
|
target_id=search_space_id,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
if result.status == "error":
|
||||||
|
raise HTTPException(status_code=400, detail=result.message)
|
||||||
|
return MemoryRead(memory_md=result.memory_md, limits=memory_limits())
|
||||||
|
|
@ -21,7 +21,6 @@ class SearchSpaceUpdate(BaseModel):
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
citations_enabled: bool | None = None
|
citations_enabled: bool | None = None
|
||||||
qna_custom_instructions: str | None = None
|
qna_custom_instructions: str | None = None
|
||||||
shared_memory_md: str | None = None
|
|
||||||
ai_file_sort_enabled: bool | None = None
|
ai_file_sort_enabled: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
@ -100,7 +101,9 @@ class GmailKBSyncService:
|
||||||
else:
|
else:
|
||||||
logger.warning("No LLM configured -- using fallback summary")
|
logger.warning("No LLM configured -- using fallback summary")
|
||||||
summary_content = f"Gmail Message: {subject}\n\n{indexable_content}"
|
summary_content = f"Gmail Message: {subject}\n\n{indexable_content}"
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(indexable_content)
|
chunks = await create_document_chunks(indexable_content)
|
||||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,9 @@ class GoogleCalendarKBSyncService:
|
||||||
summary_content = (
|
summary_content = (
|
||||||
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
||||||
)
|
)
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(indexable_content)
|
chunks = await create_document_chunks(indexable_content)
|
||||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
@ -295,7 +297,9 @@ class GoogleCalendarKBSyncService:
|
||||||
summary_content = (
|
summary_content = (
|
||||||
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
||||||
)
|
)
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(indexable_content)
|
chunks = await create_document_chunks(indexable_content)
|
||||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,9 @@ class JiraKBSyncService:
|
||||||
summary_content = (
|
summary_content = (
|
||||||
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
||||||
)
|
)
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(issue_content)
|
chunks = await create_document_chunks(issue_content)
|
||||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
@ -212,7 +214,9 @@ class JiraKBSyncService:
|
||||||
summary_content = (
|
summary_content = (
|
||||||
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
||||||
)
|
)
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(issue_content)
|
chunks = await create_document_chunks(issue_content)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -659,3 +659,36 @@ async def get_user_long_context_llm(
|
||||||
return await get_document_summary_llm(
|
return await get_document_summary_llm(
|
||||||
session, search_space_id, disable_streaming=disable_streaming
|
session, search_space_id, disable_streaming=disable_streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_planner_llm() -> ChatLiteLLM | None:
|
||||||
|
"""Return a planner LLM instance from the first global config marked
|
||||||
|
``is_planner: true``, or ``None`` if no planner config is defined.
|
||||||
|
|
||||||
|
The planner role handles short, structured internal tasks (KB search
|
||||||
|
planning: query rewriting, date extraction, recency classification).
|
||||||
|
These tasks are well-served by small/fast models (e.g. gpt-4o-mini,
|
||||||
|
Claude Haiku, Azure gpt-5.x-nano) — using the user's chat LLM for them
|
||||||
|
is unnecessarily expensive and slow.
|
||||||
|
|
||||||
|
This helper reads from ``config.GLOBAL_LLM_CONFIGS`` (loaded at import
|
||||||
|
time from ``global_llm_config.yaml``) so it has no DB cost and can be
|
||||||
|
called synchronously from middleware/factory code. It returns the same
|
||||||
|
instance shape as the global path of ``get_search_space_llm_instance``.
|
||||||
|
|
||||||
|
Callers MUST fall back to their chat LLM when this returns ``None`` so
|
||||||
|
deployments without a planner config keep working unchanged.
|
||||||
|
"""
|
||||||
|
from app.agents.new_chat.llm_config import create_chat_litellm_from_config
|
||||||
|
|
||||||
|
planner_cfg = next(
|
||||||
|
(
|
||||||
|
cfg
|
||||||
|
for cfg in config.GLOBAL_LLM_CONFIGS
|
||||||
|
if cfg.get("is_planner") is True
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if not planner_cfg:
|
||||||
|
return None
|
||||||
|
return create_chat_litellm_from_config(planner_cfg)
|
||||||
|
|
|
||||||
32
surfsense_backend/app/services/memory/__init__.py
Normal file
32
surfsense_backend/app/services/memory/__init__.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
"""First-class memory service for user and team markdown memory."""
|
||||||
|
|
||||||
|
from .schemas import MemoryLimits, MemoryRead
|
||||||
|
from .service import (
|
||||||
|
MemoryScope,
|
||||||
|
SaveResult,
|
||||||
|
memory_limits,
|
||||||
|
read_memory,
|
||||||
|
reset_memory,
|
||||||
|
save_memory,
|
||||||
|
)
|
||||||
|
from .validation import (
|
||||||
|
MEMORY_HARD_LIMIT,
|
||||||
|
MEMORY_SOFT_LIMIT,
|
||||||
|
validate_bullet_format,
|
||||||
|
validate_memory_scope,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MEMORY_HARD_LIMIT",
|
||||||
|
"MEMORY_SOFT_LIMIT",
|
||||||
|
"MemoryLimits",
|
||||||
|
"MemoryRead",
|
||||||
|
"MemoryScope",
|
||||||
|
"SaveResult",
|
||||||
|
"memory_limits",
|
||||||
|
"read_memory",
|
||||||
|
"reset_memory",
|
||||||
|
"save_memory",
|
||||||
|
"validate_bullet_format",
|
||||||
|
"validate_memory_scope",
|
||||||
|
]
|
||||||
200
surfsense_backend/app/services/memory/document.py
Normal file
200
surfsense_backend/app/services/memory/document.py
Normal file
|
|
@ -0,0 +1,200 @@
|
||||||
|
"""Memory-specific markdown document model and canonical renderer.
|
||||||
|
|
||||||
|
This intentionally parses only SurfSense memory's small markdown contract:
|
||||||
|
``##`` sections with dated bullet items. Unknown lines are preserved so user
|
||||||
|
edits are not lost, while legacy marker bullets are normalized on render.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
DEFAULT_LEGACY_SECTION = "Memory"
|
||||||
|
LEGACY_MARKERS = frozenset({"fact", "pref", "instr"})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MemoryBullet:
|
||||||
|
entry_date: date
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MemoryRawLine:
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
MemoryLine = MemoryBullet | MemoryRawLine
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MemorySection:
|
||||||
|
heading: str
|
||||||
|
lines: list[MemoryLine] = field(default_factory=list)
|
||||||
|
explicit_heading: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MemoryDocument:
|
||||||
|
sections: list[MemorySection] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_explicit_heading(self) -> bool:
|
||||||
|
return any(section.explicit_heading for section in self.sections)
|
||||||
|
|
||||||
|
|
||||||
|
def is_section_heading(line: str) -> bool:
|
||||||
|
return line.startswith("## ") and bool(line[3:].strip())
|
||||||
|
|
||||||
|
|
||||||
|
def heading_text(line: str) -> str:
|
||||||
|
return line[3:].strip()
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_heading(heading: str) -> str:
|
||||||
|
chars: list[str] = []
|
||||||
|
previous_was_space = True
|
||||||
|
for char in heading.strip().lower():
|
||||||
|
if char.isalnum():
|
||||||
|
chars.append(char)
|
||||||
|
previous_was_space = False
|
||||||
|
elif not previous_was_space:
|
||||||
|
chars.append(" ")
|
||||||
|
previous_was_space = True
|
||||||
|
return "".join(chars).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_bullet_line(line: str) -> MemoryBullet | None:
|
||||||
|
stripped = line.strip()
|
||||||
|
if not stripped.startswith("- "):
|
||||||
|
return None
|
||||||
|
|
||||||
|
body = stripped[2:]
|
||||||
|
parsed = _parse_canonical_bullet(body)
|
||||||
|
if parsed is not None:
|
||||||
|
return parsed
|
||||||
|
return _parse_legacy_bullet(body)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_canonical_bullet(body: str) -> MemoryBullet | None:
|
||||||
|
if len(body) < 13 or body[10:12] != ": ":
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
entry_date = date.fromisoformat(body[:10])
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
text = body[12:].strip()
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
return MemoryBullet(entry_date=entry_date, text=text)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_legacy_bullet(body: str) -> MemoryBullet | None:
|
||||||
|
if len(body) < 20 or not body.startswith("("):
|
||||||
|
return None
|
||||||
|
if len(body) < 14 or body[11:14] != ") [":
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
entry_date = date.fromisoformat(body[1:11])
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
marker_end = body.find("] ", 14)
|
||||||
|
if marker_end == -1:
|
||||||
|
return None
|
||||||
|
marker = body[14:marker_end]
|
||||||
|
if marker not in LEGACY_MARKERS:
|
||||||
|
return None
|
||||||
|
|
||||||
|
text = body[marker_end + 2 :].strip()
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
return MemoryBullet(entry_date=entry_date, text=text)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_memory_document(content: str | None) -> MemoryDocument:
|
||||||
|
if not content:
|
||||||
|
return MemoryDocument()
|
||||||
|
|
||||||
|
sections: list[MemorySection] = []
|
||||||
|
current_heading: str | None = None
|
||||||
|
current_explicit = True
|
||||||
|
current_lines: list[MemoryLine] = []
|
||||||
|
|
||||||
|
def flush_current() -> None:
|
||||||
|
nonlocal current_heading, current_explicit, current_lines
|
||||||
|
if current_heading is None:
|
||||||
|
return
|
||||||
|
sections.append(
|
||||||
|
MemorySection(
|
||||||
|
heading=current_heading,
|
||||||
|
lines=current_lines,
|
||||||
|
explicit_heading=current_explicit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_heading = None
|
||||||
|
current_explicit = True
|
||||||
|
current_lines = []
|
||||||
|
|
||||||
|
for raw_line in content.strip().splitlines():
|
||||||
|
line = raw_line.rstrip()
|
||||||
|
if is_section_heading(line):
|
||||||
|
flush_current()
|
||||||
|
current_heading = heading_text(line)
|
||||||
|
current_explicit = True
|
||||||
|
current_lines = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
bullet = parse_bullet_line(line)
|
||||||
|
if current_heading is None:
|
||||||
|
if bullet is None:
|
||||||
|
continue
|
||||||
|
current_heading = DEFAULT_LEGACY_SECTION
|
||||||
|
current_explicit = False
|
||||||
|
current_lines = [bullet]
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_lines.append(bullet if bullet is not None else MemoryRawLine(text=line))
|
||||||
|
|
||||||
|
flush_current()
|
||||||
|
return MemoryDocument(sections=sections)
|
||||||
|
|
||||||
|
|
||||||
|
def render_memory_document(document: MemoryDocument) -> str:
|
||||||
|
rendered_sections: list[str] = []
|
||||||
|
for section in document.sections:
|
||||||
|
section_lines = [f"## {section.heading}"]
|
||||||
|
for line in section.lines:
|
||||||
|
if isinstance(line, MemoryBullet):
|
||||||
|
section_lines.append(f"- {line.entry_date.isoformat()}: {line.text}")
|
||||||
|
else:
|
||||||
|
section_lines.append(line.text)
|
||||||
|
rendered_sections.append("\n".join(section_lines).strip())
|
||||||
|
return "\n\n".join(section for section in rendered_sections if section).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_headings(memory: str | None) -> set[str]:
|
||||||
|
document = parse_memory_document(memory)
|
||||||
|
return {
|
||||||
|
normalize_heading(section.heading)
|
||||||
|
for section in document.sections
|
||||||
|
if section.explicit_heading
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def has_explicit_heading(content: str) -> bool:
|
||||||
|
return parse_memory_document(content).has_explicit_heading
|
||||||
|
|
||||||
|
|
||||||
|
def nonstandard_bullets(content: str) -> list[str]:
|
||||||
|
warnings: list[str] = []
|
||||||
|
for line in content.splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if not stripped.startswith("- "):
|
||||||
|
continue
|
||||||
|
if parse_bullet_line(stripped) is not None:
|
||||||
|
continue
|
||||||
|
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
|
||||||
|
warnings.append(f"Non-standard memory bullet: {short}")
|
||||||
|
return warnings
|
||||||
20
surfsense_backend/app/services/memory/prompts.py
Normal file
20
surfsense_backend/app/services/memory/prompts.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
"""Prompts used by the memory service."""
|
||||||
|
|
||||||
|
FORCED_REWRITE_PROMPT = """\
|
||||||
|
You are a memory curator. The following memory document exceeds the character \
|
||||||
|
limit and must be shortened.
|
||||||
|
|
||||||
|
RULES:
|
||||||
|
1. Rewrite the document to be under {target} characters.
|
||||||
|
2. Output Markdown only. Use clear `##` headings and concise bullet points.
|
||||||
|
3. New-format bullets should look like: `- YYYY-MM-DD: memory text`.
|
||||||
|
4. If the input contains legacy markers like `(YYYY-MM-DD) [fact]`, preserve the
|
||||||
|
information but remove the inline marker in the output.
|
||||||
|
5. Preserve durable instructions and preferences before generic facts when
|
||||||
|
compressing personal memory.
|
||||||
|
6. Preserve existing headings when useful; merge duplicate headings and bullets.
|
||||||
|
7. Output ONLY the consolidated markdown — no explanations, no wrapping.
|
||||||
|
|
||||||
|
<memory_document>
|
||||||
|
{content}
|
||||||
|
</memory_document>"""
|
||||||
35
surfsense_backend/app/services/memory/rewrite.py
Normal file
35
surfsense_backend/app/services/memory/rewrite.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
"""LLM-backed memory rewrite helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from app.services.memory.prompts import FORCED_REWRITE_PROMPT
|
||||||
|
from app.services.memory.validation import MEMORY_HARD_LIMIT
|
||||||
|
from app.utils.content_utils import extract_text_content
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def forced_rewrite(content: str, llm: Any) -> str | None:
|
||||||
|
"""Use a focused LLM call to compress memory under the hard limit."""
|
||||||
|
try:
|
||||||
|
prompt = FORCED_REWRITE_PROMPT.format(
|
||||||
|
target=MEMORY_HARD_LIMIT,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[HumanMessage(content=prompt)],
|
||||||
|
config={"tags": ["surfsense:internal", "memory-rewrite"]},
|
||||||
|
)
|
||||||
|
text = extract_text_content(response.content).strip()
|
||||||
|
if not text:
|
||||||
|
logger.warning("Forced memory rewrite returned empty text")
|
||||||
|
return None
|
||||||
|
return text
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Forced memory rewrite LLM call failed")
|
||||||
|
return None
|
||||||
19
surfsense_backend/app/services/memory/schemas.py
Normal file
19
surfsense_backend/app/services/memory/schemas.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
"""Schemas for memory API responses and structured extraction."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryLimits(BaseModel):
|
||||||
|
"""Canonical memory size limits exposed to clients."""
|
||||||
|
|
||||||
|
soft: int
|
||||||
|
hard: int
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRead(BaseModel):
|
||||||
|
"""Memory document payload returned by user and team memory APIs."""
|
||||||
|
|
||||||
|
memory_md: str
|
||||||
|
limits: MemoryLimits
|
||||||
247
surfsense_backend/app/services/memory/service.py
Normal file
247
surfsense_backend/app/services/memory/service.py
Normal file
|
|
@ -0,0 +1,247 @@
|
||||||
|
"""Canonical read/write/reset/extract service for markdown memory."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any, Literal
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import SearchSpace, User
|
||||||
|
from app.services.memory.document import parse_memory_document, render_memory_document
|
||||||
|
from app.services.memory.rewrite import forced_rewrite
|
||||||
|
from app.services.memory.schemas import MemoryLimits
|
||||||
|
from app.services.memory.validation import (
|
||||||
|
MEMORY_HARD_LIMIT,
|
||||||
|
MEMORY_SOFT_LIMIT,
|
||||||
|
soft_limit_warning,
|
||||||
|
strip_preamble_to_first_heading,
|
||||||
|
validate_bullet_format,
|
||||||
|
validate_diff,
|
||||||
|
validate_heading_sanity,
|
||||||
|
validate_memory_scope,
|
||||||
|
validate_memory_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_NO_UPDATE_SENTINELS = frozenset(
|
||||||
|
{
|
||||||
|
"NO_UPDATE",
|
||||||
|
"NO UPDATE",
|
||||||
|
"NO_CHANGE",
|
||||||
|
"NO CHANGE",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryScope(StrEnum):
|
||||||
|
USER = "user"
|
||||||
|
TEAM = "team"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SaveResult:
|
||||||
|
status: Literal["saved", "error", "no_op"]
|
||||||
|
message: str
|
||||||
|
memory_md: str = ""
|
||||||
|
warnings: list[str] = field(default_factory=list)
|
||||||
|
diff_warnings: list[str] = field(default_factory=list)
|
||||||
|
format_warnings: list[str] = field(default_factory=list)
|
||||||
|
notice: str | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"status": self.status,
|
||||||
|
"message": self.message,
|
||||||
|
"memory_md": self.memory_md,
|
||||||
|
}
|
||||||
|
if self.notice:
|
||||||
|
data["notice"] = self.notice
|
||||||
|
if self.warnings:
|
||||||
|
data["warnings"] = self.warnings
|
||||||
|
if len(self.warnings) == 1:
|
||||||
|
data["warning"] = self.warnings[0]
|
||||||
|
if self.diff_warnings:
|
||||||
|
data["diff_warnings"] = self.diff_warnings
|
||||||
|
if self.format_warnings:
|
||||||
|
data["format_warnings"] = self.format_warnings
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def memory_limits() -> MemoryLimits:
|
||||||
|
return MemoryLimits(soft=MEMORY_SOFT_LIMIT, hard=MEMORY_HARD_LIMIT)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_scope(scope: MemoryScope | str) -> MemoryScope:
|
||||||
|
return scope if isinstance(scope, MemoryScope) else MemoryScope(scope)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_user_id(target_id: str | UUID) -> UUID:
|
||||||
|
return UUID(target_id) if isinstance(target_id, str) else target_id
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_target(
|
||||||
|
*,
|
||||||
|
scope: MemoryScope | str,
|
||||||
|
target_id: str | int | UUID,
|
||||||
|
session: AsyncSession,
|
||||||
|
) -> User | SearchSpace | None:
|
||||||
|
normalized = _normalize_scope(scope)
|
||||||
|
if normalized is MemoryScope.USER:
|
||||||
|
result = await session.execute(
|
||||||
|
select(User).where(User.id == _normalize_user_id(target_id)) # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSpace).where(SearchSpace.id == int(target_id))
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_memory(target: User | SearchSpace, scope: MemoryScope) -> str:
|
||||||
|
if scope is MemoryScope.USER:
|
||||||
|
return getattr(target, "memory_md", None) or ""
|
||||||
|
return getattr(target, "shared_memory_md", None) or ""
|
||||||
|
|
||||||
|
|
||||||
|
def _set_memory(target: User | SearchSpace, scope: MemoryScope, content: str) -> None:
|
||||||
|
if scope is MemoryScope.USER:
|
||||||
|
target.memory_md = content
|
||||||
|
else:
|
||||||
|
target.shared_memory_md = content
|
||||||
|
|
||||||
|
|
||||||
|
async def read_memory(
|
||||||
|
*,
|
||||||
|
scope: MemoryScope | str,
|
||||||
|
target_id: str | int | UUID,
|
||||||
|
session: AsyncSession,
|
||||||
|
) -> str:
|
||||||
|
normalized = _normalize_scope(scope)
|
||||||
|
target = await _load_target(scope=normalized, target_id=target_id, session=session)
|
||||||
|
if target is None:
|
||||||
|
return ""
|
||||||
|
return _get_memory(target, normalized)
|
||||||
|
|
||||||
|
|
||||||
|
async def save_memory(
|
||||||
|
*,
|
||||||
|
scope: MemoryScope | str,
|
||||||
|
target_id: str | int | UUID,
|
||||||
|
content: str,
|
||||||
|
session: AsyncSession,
|
||||||
|
llm: Any | None = None,
|
||||||
|
) -> SaveResult:
|
||||||
|
normalized = _normalize_scope(scope)
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return SaveResult(
|
||||||
|
status="error",
|
||||||
|
message="Internal error: memory payload must be a string.",
|
||||||
|
)
|
||||||
|
|
||||||
|
target = await _load_target(scope=normalized, target_id=target_id, session=session)
|
||||||
|
if target is None:
|
||||||
|
return SaveResult(
|
||||||
|
status="error",
|
||||||
|
message="User not found."
|
||||||
|
if normalized is MemoryScope.USER
|
||||||
|
else "Search space not found.",
|
||||||
|
)
|
||||||
|
|
||||||
|
old_memory = _get_memory(target, normalized)
|
||||||
|
next_content = strip_preamble_to_first_heading(content.strip())
|
||||||
|
notice: str | None = None
|
||||||
|
warnings: list[str] = []
|
||||||
|
|
||||||
|
if next_content.upper() in _NO_UPDATE_SENTINELS:
|
||||||
|
return SaveResult(
|
||||||
|
status="no_op",
|
||||||
|
message="No memory update requested.",
|
||||||
|
memory_md=old_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(next_content) > MEMORY_HARD_LIMIT and llm is not None:
|
||||||
|
rewritten = await forced_rewrite(next_content, llm)
|
||||||
|
if rewritten is not None and len(rewritten) < len(next_content):
|
||||||
|
next_content = strip_preamble_to_first_heading(rewritten)
|
||||||
|
notice = "Memory was automatically rewritten to fit within limits."
|
||||||
|
|
||||||
|
for validation in (
|
||||||
|
validate_memory_size(next_content),
|
||||||
|
validate_heading_sanity(next_content),
|
||||||
|
):
|
||||||
|
if validation:
|
||||||
|
return SaveResult(
|
||||||
|
status="error",
|
||||||
|
message=validation["message"],
|
||||||
|
memory_md=old_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
scope_error, scope_warnings = validate_memory_scope(
|
||||||
|
next_content,
|
||||||
|
normalized.value,
|
||||||
|
old_memory=old_memory,
|
||||||
|
)
|
||||||
|
warnings.extend(scope_warnings)
|
||||||
|
if scope_error:
|
||||||
|
return SaveResult(
|
||||||
|
status="error",
|
||||||
|
message=scope_error["message"],
|
||||||
|
memory_md=old_memory,
|
||||||
|
warnings=warnings,
|
||||||
|
)
|
||||||
|
|
||||||
|
next_content = render_memory_document(parse_memory_document(next_content))
|
||||||
|
|
||||||
|
try:
|
||||||
|
_set_memory(target, normalized, next_content)
|
||||||
|
session.add(target)
|
||||||
|
await session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update %s memory: %s", normalized.value, e)
|
||||||
|
await session.rollback()
|
||||||
|
return SaveResult(
|
||||||
|
status="error",
|
||||||
|
message=f"Failed to update {normalized.value} memory: {e}",
|
||||||
|
memory_md=old_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
diff_warnings = validate_diff(old_memory, next_content)
|
||||||
|
format_warnings = validate_bullet_format(next_content)
|
||||||
|
warning = soft_limit_warning(next_content)
|
||||||
|
if warning:
|
||||||
|
warnings.append(warning)
|
||||||
|
|
||||||
|
return SaveResult(
|
||||||
|
status="saved",
|
||||||
|
message=(
|
||||||
|
"Memory updated."
|
||||||
|
if normalized is MemoryScope.USER
|
||||||
|
else "Team memory updated."
|
||||||
|
),
|
||||||
|
memory_md=next_content,
|
||||||
|
warnings=warnings,
|
||||||
|
diff_warnings=diff_warnings,
|
||||||
|
format_warnings=format_warnings,
|
||||||
|
notice=notice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def reset_memory(
|
||||||
|
*,
|
||||||
|
scope: MemoryScope | str,
|
||||||
|
target_id: str | int | UUID,
|
||||||
|
session: AsyncSession,
|
||||||
|
) -> SaveResult:
|
||||||
|
return await save_memory(
|
||||||
|
scope=scope,
|
||||||
|
target_id=target_id,
|
||||||
|
content="",
|
||||||
|
session=session,
|
||||||
|
llm=None,
|
||||||
|
)
|
||||||
140
surfsense_backend/app/services/memory/validation.py
Normal file
140
surfsense_backend/app/services/memory/validation.py
Normal file
|
|
@ -0,0 +1,140 @@
|
||||||
|
"""Validation helpers for markdown-backed memory."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from app.services.memory.document import (
|
||||||
|
extract_headings,
|
||||||
|
has_explicit_heading,
|
||||||
|
nonstandard_bullets,
|
||||||
|
parse_memory_document,
|
||||||
|
)
|
||||||
|
|
||||||
|
MEMORY_SOFT_LIMIT = 18_000
|
||||||
|
MEMORY_HARD_LIMIT = 25_000
|
||||||
|
|
||||||
|
_FORBIDDEN_TEAM_HEADINGS = {
|
||||||
|
"preferences",
|
||||||
|
"instructions",
|
||||||
|
"personal notes",
|
||||||
|
"personal instructions",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def has_markdown_heading(content: str) -> bool:
|
||||||
|
return has_explicit_heading(content)
|
||||||
|
|
||||||
|
|
||||||
|
def strip_preamble_to_first_heading(content: str) -> str:
|
||||||
|
"""Drop model preamble before the first ``##`` heading, if one exists."""
|
||||||
|
lines = content.splitlines()
|
||||||
|
for index, line in enumerate(lines):
|
||||||
|
if line.startswith("## ") and line[3:].strip():
|
||||||
|
return "\n".join(lines[index:]).strip()
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def validate_memory_size(content: str) -> dict[str, str] | None:
|
||||||
|
length = len(content)
|
||||||
|
if length > MEMORY_HARD_LIMIT:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": (
|
||||||
|
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
|
||||||
|
f"({length:,} chars). Consolidate by merging related items, "
|
||||||
|
"removing outdated entries, and shortening descriptions."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_heading_sanity(content: str) -> dict[str, str] | None:
|
||||||
|
"""Block long prose blobs without headings unless they are legacy bullets."""
|
||||||
|
stripped = content.strip()
|
||||||
|
if not stripped:
|
||||||
|
return None
|
||||||
|
if has_markdown_heading(stripped):
|
||||||
|
return None
|
||||||
|
if len(stripped) <= 40:
|
||||||
|
return None
|
||||||
|
if parse_memory_document(stripped).sections:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Memory must be markdown with at least one ## heading.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def validate_memory_scope(
|
||||||
|
content: str,
|
||||||
|
scope: Literal["user", "team"],
|
||||||
|
*,
|
||||||
|
old_memory: str | None = None,
|
||||||
|
) -> tuple[dict[str, str] | None, list[str]]:
|
||||||
|
"""Reject new personal headings in team memory, grandfather existing ones."""
|
||||||
|
if scope != "team":
|
||||||
|
return None, []
|
||||||
|
|
||||||
|
old_forbidden = extract_headings(old_memory) & _FORBIDDEN_TEAM_HEADINGS
|
||||||
|
new_forbidden = extract_headings(content) & _FORBIDDEN_TEAM_HEADINGS
|
||||||
|
introduced = sorted(new_forbidden - old_forbidden)
|
||||||
|
grandfathered = sorted(new_forbidden & old_forbidden)
|
||||||
|
|
||||||
|
warnings: list[str] = []
|
||||||
|
if grandfathered:
|
||||||
|
warnings.append(
|
||||||
|
"Team memory contains legacy personal headings: "
|
||||||
|
+ ", ".join(grandfathered)
|
||||||
|
+ ". Please consolidate them into team-safe headings."
|
||||||
|
)
|
||||||
|
if introduced:
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
"message": (
|
||||||
|
"Team memory cannot introduce personal headings: "
|
||||||
|
+ ", ".join(introduced)
|
||||||
|
+ ". Use team-safe headings instead."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
warnings,
|
||||||
|
)
|
||||||
|
return None, warnings
|
||||||
|
|
||||||
|
|
||||||
|
def validate_bullet_format(content: str) -> list[str]:
|
||||||
|
return nonstandard_bullets(content)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||||
|
if not old_memory:
|
||||||
|
return []
|
||||||
|
|
||||||
|
warnings: list[str] = []
|
||||||
|
old_headings = extract_headings(old_memory)
|
||||||
|
new_headings = extract_headings(new_memory)
|
||||||
|
dropped = old_headings - new_headings
|
||||||
|
if dropped:
|
||||||
|
names = ", ".join(sorted(dropped))
|
||||||
|
warnings.append(
|
||||||
|
f"Sections removed: {names}. If unintentional, restore them from the memory document."
|
||||||
|
)
|
||||||
|
|
||||||
|
old_len = len(old_memory)
|
||||||
|
new_len = len(new_memory)
|
||||||
|
if old_len > 0 and new_len < old_len * 0.4:
|
||||||
|
warnings.append(
|
||||||
|
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). Possible data loss."
|
||||||
|
)
|
||||||
|
return warnings
|
||||||
|
|
||||||
|
|
||||||
|
def soft_limit_warning(content: str) -> str | None:
|
||||||
|
length = len(content)
|
||||||
|
if length > MEMORY_SOFT_LIMIT:
|
||||||
|
return (
|
||||||
|
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
|
||||||
|
"Consolidate by merging related items and removing less important entries."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
@ -95,7 +96,9 @@ class OneDriveKBSyncService:
|
||||||
else:
|
else:
|
||||||
logger.warning("No LLM configured — using fallback summary")
|
logger.warning("No LLM configured — using fallback summary")
|
||||||
summary_content = f"OneDrive File: {file_name}\n\n{indexable_content}"
|
summary_content = f"OneDrive File: {file_name}\n\n{indexable_content}"
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(indexable_content)
|
chunks = await create_document_chunks(indexable_content)
|
||||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ same trap waiting to happen).
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
@ -234,7 +235,7 @@ async def _restore_in_place_document(
|
||||||
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
||||||
]
|
]
|
||||||
if chunk_texts:
|
if chunk_texts:
|
||||||
chunk_embeddings = embed_texts(chunk_texts)
|
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
|
||||||
session.add_all(
|
session.add_all(
|
||||||
[
|
[
|
||||||
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||||
|
|
@ -244,7 +245,9 @@ async def _restore_in_place_document(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if isinstance(revision.content_before, str):
|
if isinstance(revision.content_before, str):
|
||||||
doc.embedding = embed_texts([revision.content_before])[0]
|
doc.embedding = (
|
||||||
|
await asyncio.to_thread(embed_texts, [revision.content_before])
|
||||||
|
)[0]
|
||||||
|
|
||||||
doc.updated_at = datetime.now(UTC)
|
doc.updated_at = datetime.now(UTC)
|
||||||
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
||||||
|
|
@ -320,7 +323,7 @@ async def _reinsert_document_from_revision(
|
||||||
session.add(new_doc)
|
session.add(new_doc)
|
||||||
await session.flush()
|
await session.flush()
|
||||||
|
|
||||||
new_doc.embedding = embed_texts([content])[0]
|
new_doc.embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
|
||||||
chunk_texts = []
|
chunk_texts = []
|
||||||
chunks_before = revision.chunks_before
|
chunks_before = revision.chunks_before
|
||||||
if isinstance(chunks_before, list):
|
if isinstance(chunks_before, list):
|
||||||
|
|
@ -330,7 +333,7 @@ async def _reinsert_document_from_revision(
|
||||||
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
||||||
]
|
]
|
||||||
if chunk_texts:
|
if chunk_texts:
|
||||||
chunk_embeddings = embed_texts(chunk_texts)
|
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
|
||||||
session.add_all(
|
session.add_all(
|
||||||
[
|
[
|
||||||
Chunk(document_id=new_doc.id, content=text, embedding=embedding)
|
Chunk(document_id=new_doc.id, content=text, embedding=embedding)
|
||||||
|
|
|
||||||
|
|
@ -325,6 +325,24 @@ class TokenTrackingCallback(CustomLogger):
|
||||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||||
call_kind = "chat"
|
call_kind = "chat"
|
||||||
|
|
||||||
|
# Prompt-cache accounting. LiteLLM normalizes every provider's cache
|
||||||
|
# fields onto ``usage.prompt_tokens_details``:
|
||||||
|
# - ``cached_tokens`` — cache reads (OpenAI/Azure native, DeepSeek
|
||||||
|
# mapped from ``prompt_cache_hit_tokens``,
|
||||||
|
# Anthropic mapped from ``cache_read_input_tokens``).
|
||||||
|
# - ``cache_creation_tokens`` — cache writes (Anthropic only; OpenAI/Azure
|
||||||
|
# do not expose a write count).
|
||||||
|
# See ``litellm.types.utils.Usage.__init__`` for the mapping.
|
||||||
|
cached_tokens = 0
|
||||||
|
cache_creation_tokens = 0
|
||||||
|
if not is_image:
|
||||||
|
prompt_details = getattr(usage, "prompt_tokens_details", None)
|
||||||
|
if prompt_details is not None:
|
||||||
|
cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0
|
||||||
|
cache_creation_tokens = (
|
||||||
|
getattr(prompt_details, "cache_creation_tokens", 0) or 0
|
||||||
|
)
|
||||||
|
|
||||||
model = kwargs.get("model", "unknown")
|
model = kwargs.get("model", "unknown")
|
||||||
|
|
||||||
cost_usd = _extract_cost_usd(
|
cost_usd = _extract_cost_usd(
|
||||||
|
|
@ -357,9 +375,23 @@ class TokenTrackingCallback(CustomLogger):
|
||||||
cost_micros=cost_micros,
|
cost_micros=cost_micros,
|
||||||
call_kind=call_kind,
|
call_kind=call_kind,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Per-LLM-call wall-clock latency (LiteLLM passes datetime objects).
|
||||||
|
call_latency_s: float | None = None
|
||||||
|
try:
|
||||||
|
if start_time is not None and end_time is not None:
|
||||||
|
delta = end_time - start_time
|
||||||
|
call_latency_s = getattr(delta, "total_seconds", lambda: float(delta))()
|
||||||
|
except Exception:
|
||||||
|
call_latency_s = None
|
||||||
|
|
||||||
|
cache_hit_ratio: float | None = None
|
||||||
|
if prompt_tokens > 0 and (cached_tokens > 0 or cache_creation_tokens > 0):
|
||||||
|
cache_hit_ratio = cached_tokens / prompt_tokens
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
|
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
|
||||||
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
|
"cost=$%.6f (%d micros) (accumulator now has %d calls)%s%s",
|
||||||
model,
|
model,
|
||||||
call_kind,
|
call_kind,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
|
|
@ -368,6 +400,17 @@ class TokenTrackingCallback(CustomLogger):
|
||||||
cost_usd,
|
cost_usd,
|
||||||
cost_micros,
|
cost_micros,
|
||||||
len(acc.calls),
|
len(acc.calls),
|
||||||
|
f" latency={call_latency_s:.3f}s" if call_latency_s is not None else "",
|
||||||
|
(
|
||||||
|
f" cache_read={cached_tokens} cache_write={cache_creation_tokens}"
|
||||||
|
f" hit_ratio={cache_hit_ratio:.1%}"
|
||||||
|
if cache_hit_ratio is not None
|
||||||
|
else (
|
||||||
|
f" cache_read={cached_tokens} cache_write={cache_creation_tokens}"
|
||||||
|
if (cached_tokens or cache_creation_tokens)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,10 +39,6 @@ from app.agents.new_chat.llm_config import (
|
||||||
load_agent_config,
|
load_agent_config,
|
||||||
load_global_llm_config_by_id,
|
load_global_llm_config_by_id,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.memory_extraction import (
|
|
||||||
extract_and_save_memory,
|
|
||||||
extract_and_save_team_memory,
|
|
||||||
)
|
|
||||||
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
|
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
|
||||||
from app.agents.new_chat.middleware.busy_mutex import (
|
from app.agents.new_chat.middleware.busy_mutex import (
|
||||||
end_turn,
|
end_turn,
|
||||||
|
|
@ -64,8 +60,6 @@ from app.db import (
|
||||||
)
|
)
|
||||||
from app.prompts import TITLE_GENERATION_PROMPT
|
from app.prompts import TITLE_GENERATION_PROMPT
|
||||||
from app.services.auto_model_pin_service import (
|
from app.services.auto_model_pin_service import (
|
||||||
is_recently_healthy,
|
|
||||||
mark_healthy,
|
|
||||||
mark_runtime_cooldown,
|
mark_runtime_cooldown,
|
||||||
resolve_or_get_pinned_llm_config_id,
|
resolve_or_get_pinned_llm_config_id,
|
||||||
)
|
)
|
||||||
|
|
@ -283,7 +277,6 @@ class StreamResult:
|
||||||
accumulated_text: str = ""
|
accumulated_text: str = ""
|
||||||
is_interrupted: bool = False
|
is_interrupted: bool = False
|
||||||
sandbox_files: list[str] = field(default_factory=list)
|
sandbox_files: list[str] = field(default_factory=list)
|
||||||
agent_called_update_memory: bool = False
|
|
||||||
request_id: str | None = None
|
request_id: str | None = None
|
||||||
turn_id: str = ""
|
turn_id: str = ""
|
||||||
filesystem_mode: str = "cloud"
|
filesystem_mode: str = "cloud"
|
||||||
|
|
@ -506,54 +499,6 @@ def _is_provider_rate_limited(exc: BaseException) -> bool:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_PREFLIGHT_TIMEOUT_SEC: float = 2.5
|
|
||||||
_PREFLIGHT_MAX_TOKENS: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
async def _preflight_llm(llm: Any) -> None:
|
|
||||||
"""Issue a minimal completion to confirm the pinned model isn't 429'ing.
|
|
||||||
|
|
||||||
Used before agent build / planner / classifier / title-gen so a known-bad
|
|
||||||
free OpenRouter deployment is detected and repinned before it cascades
|
|
||||||
into multiple wasted internal calls. The probe is intentionally cheap:
|
|
||||||
one token, low timeout, tagged ``surfsense:internal`` so token tracking
|
|
||||||
and SSE pipelines treat it as overhead rather than user output.
|
|
||||||
|
|
||||||
Raises the original exception when the provider responds with a
|
|
||||||
rate-limit-shaped error so the caller can drive the cooldown/repin
|
|
||||||
branch via :func:`_is_provider_rate_limited`. Other transient failures
|
|
||||||
are swallowed — the caller continues to the normal stream path and the
|
|
||||||
in-stream recovery loop remains the safety net.
|
|
||||||
"""
|
|
||||||
from litellm import acompletion
|
|
||||||
|
|
||||||
model = getattr(llm, "model", None)
|
|
||||||
if not model or model == "auto":
|
|
||||||
# Auto-mode router doesn't have a single deployment to ping; the
|
|
||||||
# router itself handles per-deployment rate-limit accounting.
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
await acompletion(
|
|
||||||
model=model,
|
|
||||||
messages=[{"role": "user", "content": "ping"}],
|
|
||||||
api_key=getattr(llm, "api_key", None),
|
|
||||||
api_base=getattr(llm, "api_base", None),
|
|
||||||
max_tokens=_PREFLIGHT_MAX_TOKENS,
|
|
||||||
timeout=_PREFLIGHT_TIMEOUT_SEC,
|
|
||||||
stream=False,
|
|
||||||
metadata={"tags": ["surfsense:internal", "auto-pin-preflight"]},
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
if _is_provider_rate_limited(exc):
|
|
||||||
raise
|
|
||||||
logging.getLogger(__name__).debug(
|
|
||||||
"auto_pin_preflight non_rate_limit_error model=%s err=%s",
|
|
||||||
model,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _build_main_agent_for_thread(
|
async def _build_main_agent_for_thread(
|
||||||
agent_factory: Any,
|
agent_factory: Any,
|
||||||
*,
|
*,
|
||||||
|
|
@ -571,9 +516,9 @@ async def _build_main_agent_for_thread(
|
||||||
disabled_tools: list[str] | None = None,
|
disabled_tools: list[str] | None = None,
|
||||||
mentioned_document_ids: list[int] | None = None,
|
mentioned_document_ids: list[int] | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Single (re)build path so the agent factory cannot drift across
|
"""Single (re)build path so the agent factory cannot drift across the
|
||||||
initial build, preflight repin, and mid-stream 429 recovery for one
|
initial build and mid-stream 429 recovery for one ``thread_id``: a
|
||||||
``thread_id``: a graph swap mid-turn would corrupt checkpointer state."""
|
graph swap mid-turn would corrupt checkpointer state."""
|
||||||
return await agent_factory(
|
return await agent_factory(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
|
@ -591,29 +536,6 @@ async def _build_main_agent_for_thread(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
|
|
||||||
"""Wait for a discarded speculative agent build to release shared state.
|
|
||||||
|
|
||||||
Used by the parallel preflight + agent-build path. The speculative build
|
|
||||||
closes over the request-scoped ``AsyncSession`` (for the brief connector
|
|
||||||
discovery / tool-factory window before its CPU work moves into a worker
|
|
||||||
thread). If preflight reports a 429 we want to fall back to the original
|
|
||||||
repin → reload → rebuild path, but we MUST NOT touch ``session`` again
|
|
||||||
until any in-flight session work owned by the speculative build has
|
|
||||||
fully settled — :class:`sqlalchemy.ext.asyncio.AsyncSession` is not
|
|
||||||
concurrency-safe and the same hazard cost us a hard ``InvalidRequestError``
|
|
||||||
earlier in this PR (see ``connector_service`` parallel-gather revert).
|
|
||||||
|
|
||||||
We simply ``await`` the task and swallow any exception: in this path the
|
|
||||||
build's outcome is irrelevant — success populates the agent cache (a free
|
|
||||||
side effect), failure is discarded. The wasted CPU is acceptable since
|
|
||||||
429 fallbacks are rare and the original sequential code also paid the
|
|
||||||
full build cost on the same path.
|
|
||||||
"""
|
|
||||||
with contextlib.suppress(BaseException):
|
|
||||||
await task
|
|
||||||
|
|
||||||
|
|
||||||
def _classify_stream_exception(
|
def _classify_stream_exception(
|
||||||
exc: Exception,
|
exc: Exception,
|
||||||
*,
|
*,
|
||||||
|
|
@ -1241,39 +1163,6 @@ async def stream_new_chat(
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Auto-mode preflight ping. Runs ONLY for thread-pinned auto cfgs
|
|
||||||
# (negative ids selected via ``resolve_or_get_pinned_llm_config_id``)
|
|
||||||
# whose health hasn't already been confirmed within the TTL window.
|
|
||||||
# Detecting a 429 here lets us repin BEFORE the planner/classifier/
|
|
||||||
# title-generation LLM calls fan out and each independently hit the
|
|
||||||
# same upstream rate limit.
|
|
||||||
#
|
|
||||||
# PERF: preflight is a network round-trip to the LLM provider (~1-5s)
|
|
||||||
# and is independent of the agent build (CPU-bound, ~5-7s). They used
|
|
||||||
# to run sequentially → ``preflight + build`` on cold cache = 11.5s.
|
|
||||||
# We now kick off preflight as a background task FIRST, then run the
|
|
||||||
# synchronous setup work and the agent build in parallel. In the
|
|
||||||
# success path (the common case) total wall time drops to roughly
|
|
||||||
# ``max(preflight, build)`` — the preflight finishes during the
|
|
||||||
# agent compile and we just consume its result. In the rare 429
|
|
||||||
# path the speculative build is awaited to completion (so its
|
|
||||||
# session usage is fully released) via
|
|
||||||
# :func:`_settle_speculative_agent_build`, then discarded, and
|
|
||||||
# we fall back to the original repin-and-rebuild flow.
|
|
||||||
preflight_needed = (
|
|
||||||
requested_llm_config_id == 0
|
|
||||||
and llm_config_id < 0
|
|
||||||
and not is_recently_healthy(llm_config_id)
|
|
||||||
)
|
|
||||||
preflight_task: asyncio.Task[None] | None = None
|
|
||||||
_t_preflight = 0.0
|
|
||||||
if preflight_needed:
|
|
||||||
_t_preflight = time.perf_counter()
|
|
||||||
preflight_task = asyncio.create_task(
|
|
||||||
_preflight_llm(llm),
|
|
||||||
name=f"auto_pin_preflight:{llm_config_id}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create connector service
|
# Create connector service
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||||
|
|
@ -1307,136 +1196,26 @@ async def stream_new_chat(
|
||||||
if use_multi_agent
|
if use_multi_agent
|
||||||
else create_surfsense_deep_agent
|
else create_surfsense_deep_agent
|
||||||
)
|
)
|
||||||
# Speculative agent build — runs in parallel with the preflight
|
# Build the agent inline. Provider 429s surface through the
|
||||||
# task (if any). Built with the *current* ``llm`` / ``agent_config``;
|
# in-stream recovery loop below (``_is_provider_rate_limited``),
|
||||||
# if preflight reports 429 we will discard this future and rebuild
|
# which repins the thread to an eligible alternative config and
|
||||||
# against the freshly pinned config below.
|
# rebuilds the agent before the user sees any output.
|
||||||
agent_build_task = asyncio.create_task(
|
agent = await _build_main_agent_for_thread(
|
||||||
_build_main_agent_for_thread(
|
agent_factory,
|
||||||
agent_factory,
|
llm=llm,
|
||||||
llm=llm,
|
search_space_id=search_space_id,
|
||||||
search_space_id=search_space_id,
|
db_session=session,
|
||||||
db_session=session,
|
connector_service=connector_service,
|
||||||
connector_service=connector_service,
|
checkpointer=checkpointer,
|
||||||
checkpointer=checkpointer,
|
user_id=user_id,
|
||||||
user_id=user_id,
|
thread_id=chat_id,
|
||||||
thread_id=chat_id,
|
agent_config=agent_config,
|
||||||
agent_config=agent_config,
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
thread_visibility=visibility,
|
||||||
thread_visibility=visibility,
|
filesystem_selection=filesystem_selection,
|
||||||
filesystem_selection=filesystem_selection,
|
disabled_tools=disabled_tools,
|
||||||
disabled_tools=disabled_tools,
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
|
||||||
),
|
|
||||||
name="agent_build:stream_new_chat",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
agent: Any = None
|
|
||||||
if preflight_task is not None:
|
|
||||||
try:
|
|
||||||
await preflight_task
|
|
||||||
mark_healthy(llm_config_id)
|
|
||||||
_perf_log.info(
|
|
||||||
"[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
|
|
||||||
llm_config_id,
|
|
||||||
time.perf_counter() - _t_preflight,
|
|
||||||
)
|
|
||||||
except Exception as preflight_exc:
|
|
||||||
# Both branches below need the session: the non-429 path
|
|
||||||
# may unwind via cleanup that uses ``session``, and the
|
|
||||||
# 429 path explicitly calls ``resolve_or_get_pinned_llm_config_id``
|
|
||||||
# against it. Wait for the speculative build to release its
|
|
||||||
# session usage before we proceed.
|
|
||||||
await _settle_speculative_agent_build(agent_build_task)
|
|
||||||
if not _is_provider_rate_limited(preflight_exc):
|
|
||||||
raise
|
|
||||||
# 429: speculative agent is discarded; run the original
|
|
||||||
# repin → reload → rebuild path against the freshly
|
|
||||||
# pinned config.
|
|
||||||
previous_config_id = llm_config_id
|
|
||||||
mark_runtime_cooldown(
|
|
||||||
previous_config_id, reason="preflight_rate_limited"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
llm_config_id = (
|
|
||||||
await resolve_or_get_pinned_llm_config_id(
|
|
||||||
session,
|
|
||||||
thread_id=chat_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
selected_llm_config_id=0,
|
|
||||||
exclude_config_ids={previous_config_id},
|
|
||||||
requires_image_input=_requires_image_input,
|
|
||||||
)
|
|
||||||
).resolved_llm_config_id
|
|
||||||
except ValueError as pin_error:
|
|
||||||
yield _emit_stream_error(
|
|
||||||
message=str(pin_error),
|
|
||||||
error_kind="server_error",
|
|
||||||
error_code="SERVER_ERROR",
|
|
||||||
)
|
|
||||||
yield streaming_service.format_done()
|
|
||||||
return
|
|
||||||
|
|
||||||
llm, agent_config, llm_load_error = await _load_llm_bundle(
|
|
||||||
llm_config_id
|
|
||||||
)
|
|
||||||
if llm_load_error or not llm:
|
|
||||||
yield _emit_stream_error(
|
|
||||||
message=llm_load_error or "Failed to create LLM instance",
|
|
||||||
error_kind="server_error",
|
|
||||||
error_code="SERVER_ERROR",
|
|
||||||
)
|
|
||||||
yield streaming_service.format_done()
|
|
||||||
return
|
|
||||||
# Trust the freshly-resolved cfg for the remainder of this
|
|
||||||
# turn rather than recursing into another preflight; the
|
|
||||||
# in-stream 429 recovery loop is still in place as the
|
|
||||||
# safety net if even this fallback hits an upstream cap.
|
|
||||||
mark_healthy(llm_config_id)
|
|
||||||
_log_chat_stream_error(
|
|
||||||
flow=flow,
|
|
||||||
error_kind="rate_limited",
|
|
||||||
error_code="RATE_LIMITED",
|
|
||||||
severity="info",
|
|
||||||
is_expected=True,
|
|
||||||
request_id=request_id,
|
|
||||||
thread_id=chat_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
message=(
|
|
||||||
"Auto-pinned model failed preflight; switched to another "
|
|
||||||
"eligible model and continuing."
|
|
||||||
),
|
|
||||||
extra={
|
|
||||||
"auto_runtime_recover": True,
|
|
||||||
"preflight": True,
|
|
||||||
"previous_config_id": previous_config_id,
|
|
||||||
"fallback_config_id": llm_config_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Rebuild against the new llm/agent_config. Sequential
|
|
||||||
# here because we no longer have anything to overlap with.
|
|
||||||
agent = await agent_factory(
|
|
||||||
llm=llm,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
db_session=session,
|
|
||||||
connector_service=connector_service,
|
|
||||||
checkpointer=checkpointer,
|
|
||||||
user_id=user_id,
|
|
||||||
thread_id=chat_id,
|
|
||||||
agent_config=agent_config,
|
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
|
||||||
thread_visibility=visibility,
|
|
||||||
disabled_tools=disabled_tools,
|
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
|
||||||
filesystem_selection=filesystem_selection,
|
|
||||||
)
|
|
||||||
|
|
||||||
if agent is None:
|
|
||||||
# Either no preflight was needed, or preflight succeeded —
|
|
||||||
# in both cases the speculative build is the agent we want.
|
|
||||||
agent = await agent_build_task
|
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||||
)
|
)
|
||||||
|
|
@ -2208,36 +1987,6 @@ async def stream_new_chat(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fire background memory extraction if the agent didn't handle it.
|
|
||||||
# Shared threads write to team memory; private threads write to user memory.
|
|
||||||
if not stream_result.agent_called_update_memory:
|
|
||||||
memory_seed = user_query.strip() or (
|
|
||||||
f"[{len(user_image_data_urls or [])} image(s)]"
|
|
||||||
if user_image_data_urls
|
|
||||||
else "(message)"
|
|
||||||
)
|
|
||||||
if visibility == ChatVisibility.SEARCH_SPACE:
|
|
||||||
task = asyncio.create_task(
|
|
||||||
extract_and_save_team_memory(
|
|
||||||
user_message=memory_seed,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
llm=llm,
|
|
||||||
author_display_name=current_user_display_name,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_background_tasks.add(task)
|
|
||||||
task.add_done_callback(_background_tasks.discard)
|
|
||||||
elif user_id:
|
|
||||||
task = asyncio.create_task(
|
|
||||||
extract_and_save_memory(
|
|
||||||
user_message=memory_seed,
|
|
||||||
user_id=user_id,
|
|
||||||
llm=llm,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_background_tasks.add(task)
|
|
||||||
task.add_done_callback(_background_tasks.discard)
|
|
||||||
|
|
||||||
# Finish the step and message
|
# Finish the step and message
|
||||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
|
|
@ -2682,25 +2431,6 @@ async def stream_resume_chat(
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``:
|
|
||||||
# one cheap probe before the agent is rebuilt so a 429'd pin gets
|
|
||||||
# repinned without burning planner/classifier/title calls first.
|
|
||||||
# See ``stream_new_chat`` for the full rationale on the speculative
|
|
||||||
# parallel build pattern below.
|
|
||||||
preflight_needed = (
|
|
||||||
requested_llm_config_id == 0
|
|
||||||
and llm_config_id < 0
|
|
||||||
and not is_recently_healthy(llm_config_id)
|
|
||||||
)
|
|
||||||
preflight_task: asyncio.Task[None] | None = None
|
|
||||||
_t_preflight = 0.0
|
|
||||||
if preflight_needed:
|
|
||||||
_t_preflight = time.perf_counter()
|
|
||||||
preflight_task = asyncio.create_task(
|
|
||||||
_preflight_llm(llm),
|
|
||||||
name=f"auto_pin_preflight_resume:{llm_config_id}",
|
|
||||||
)
|
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||||
|
|
||||||
|
|
@ -2730,115 +2460,25 @@ async def stream_resume_chat(
|
||||||
if _app_config.MULTI_AGENT_CHAT_ENABLED
|
if _app_config.MULTI_AGENT_CHAT_ENABLED
|
||||||
else create_surfsense_deep_agent
|
else create_surfsense_deep_agent
|
||||||
)
|
)
|
||||||
agent_build_task = asyncio.create_task(
|
# Build the agent inline. Provider 429s are handled by the
|
||||||
_build_main_agent_for_thread(
|
# in-stream recovery loop, which repins to an eligible
|
||||||
agent_factory,
|
# alternative config and rebuilds the agent before the user sees
|
||||||
llm=llm,
|
# any output.
|
||||||
search_space_id=search_space_id,
|
agent = await _build_main_agent_for_thread(
|
||||||
db_session=session,
|
agent_factory,
|
||||||
connector_service=connector_service,
|
llm=llm,
|
||||||
checkpointer=checkpointer,
|
search_space_id=search_space_id,
|
||||||
user_id=user_id,
|
db_session=session,
|
||||||
thread_id=chat_id,
|
connector_service=connector_service,
|
||||||
agent_config=agent_config,
|
checkpointer=checkpointer,
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
user_id=user_id,
|
||||||
thread_visibility=visibility,
|
thread_id=chat_id,
|
||||||
filesystem_selection=filesystem_selection,
|
agent_config=agent_config,
|
||||||
disabled_tools=disabled_tools,
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
),
|
thread_visibility=visibility,
|
||||||
name="agent_build:stream_resume",
|
filesystem_selection=filesystem_selection,
|
||||||
|
disabled_tools=disabled_tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent: Any = None
|
|
||||||
if preflight_task is not None:
|
|
||||||
try:
|
|
||||||
await preflight_task
|
|
||||||
mark_healthy(llm_config_id)
|
|
||||||
_perf_log.info(
|
|
||||||
"[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
|
|
||||||
llm_config_id,
|
|
||||||
time.perf_counter() - _t_preflight,
|
|
||||||
)
|
|
||||||
except Exception as preflight_exc:
|
|
||||||
# Same session-safety rationale as ``stream_new_chat``.
|
|
||||||
await _settle_speculative_agent_build(agent_build_task)
|
|
||||||
if not _is_provider_rate_limited(preflight_exc):
|
|
||||||
raise
|
|
||||||
previous_config_id = llm_config_id
|
|
||||||
mark_runtime_cooldown(
|
|
||||||
previous_config_id, reason="preflight_rate_limited"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
llm_config_id = (
|
|
||||||
await resolve_or_get_pinned_llm_config_id(
|
|
||||||
session,
|
|
||||||
thread_id=chat_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
selected_llm_config_id=0,
|
|
||||||
exclude_config_ids={previous_config_id},
|
|
||||||
)
|
|
||||||
).resolved_llm_config_id
|
|
||||||
except ValueError as pin_error:
|
|
||||||
yield _emit_stream_error(
|
|
||||||
message=str(pin_error),
|
|
||||||
error_kind="server_error",
|
|
||||||
error_code="SERVER_ERROR",
|
|
||||||
)
|
|
||||||
yield streaming_service.format_done()
|
|
||||||
return
|
|
||||||
|
|
||||||
llm, agent_config, llm_load_error = await _load_llm_bundle(
|
|
||||||
llm_config_id
|
|
||||||
)
|
|
||||||
if llm_load_error or not llm:
|
|
||||||
yield _emit_stream_error(
|
|
||||||
message=llm_load_error or "Failed to create LLM instance",
|
|
||||||
error_kind="server_error",
|
|
||||||
error_code="SERVER_ERROR",
|
|
||||||
)
|
|
||||||
yield streaming_service.format_done()
|
|
||||||
return
|
|
||||||
mark_healthy(llm_config_id)
|
|
||||||
_log_chat_stream_error(
|
|
||||||
flow="resume",
|
|
||||||
error_kind="rate_limited",
|
|
||||||
error_code="RATE_LIMITED",
|
|
||||||
severity="info",
|
|
||||||
is_expected=True,
|
|
||||||
request_id=request_id,
|
|
||||||
thread_id=chat_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
message=(
|
|
||||||
"Auto-pinned model failed preflight; switched to another "
|
|
||||||
"eligible model and continuing."
|
|
||||||
),
|
|
||||||
extra={
|
|
||||||
"auto_runtime_recover": True,
|
|
||||||
"preflight": True,
|
|
||||||
"previous_config_id": previous_config_id,
|
|
||||||
"fallback_config_id": llm_config_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
agent = await _build_main_agent_for_thread(
|
|
||||||
agent_factory,
|
|
||||||
llm=llm,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
db_session=session,
|
|
||||||
connector_service=connector_service,
|
|
||||||
checkpointer=checkpointer,
|
|
||||||
user_id=user_id,
|
|
||||||
thread_id=chat_id,
|
|
||||||
agent_config=agent_config,
|
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
|
||||||
thread_visibility=visibility,
|
|
||||||
filesystem_selection=filesystem_selection,
|
|
||||||
disabled_tools=disabled_tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
if agent is None:
|
|
||||||
agent = await agent_build_task
|
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -48,4 +48,3 @@ async def stream_output(
|
||||||
yield frame
|
yield frame
|
||||||
|
|
||||||
result.accumulated_text = state.accumulated_text
|
result.accumulated_text = state.accumulated_text
|
||||||
result.agent_called_update_memory = state.called_update_memory
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ class StreamingResult:
|
||||||
accumulated_text: str = ""
|
accumulated_text: str = ""
|
||||||
is_interrupted: bool = False
|
is_interrupted: bool = False
|
||||||
sandbox_files: list[str] = field(default_factory=list)
|
sandbox_files: list[str] = field(default_factory=list)
|
||||||
agent_called_update_memory: bool = False
|
|
||||||
request_id: str | None = None
|
request_id: str | None = None
|
||||||
turn_id: str = ""
|
turn_id: str = ""
|
||||||
filesystem_mode: str = "cloud"
|
filesystem_mode: str = "cloud"
|
||||||
|
|
|
||||||
|
|
@ -36,9 +36,6 @@ def iter_tool_end_frames(
|
||||||
raw_output = event.get("data", {}).get("output", "")
|
raw_output = event.get("data", {}).get("output", "")
|
||||||
staged_file_path = state.file_path_by_run.pop(run_id, None) if run_id else None
|
staged_file_path = state.file_path_by_run.pop(run_id, None) if run_id else None
|
||||||
|
|
||||||
if tool_name == "update_memory":
|
|
||||||
state.called_update_memory = True
|
|
||||||
|
|
||||||
if hasattr(raw_output, "content"):
|
if hasattr(raw_output, "content"):
|
||||||
content = raw_output.content
|
content = raw_output.content
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,6 @@ class AgentEventRelayState:
|
||||||
last_active_step_items: list[str] = field(default_factory=list)
|
last_active_step_items: list[str] = field(default_factory=list)
|
||||||
just_finished_tool: bool = False
|
just_finished_tool: bool = False
|
||||||
active_tool_depth: int = 0
|
active_tool_depth: int = 0
|
||||||
called_update_memory: bool = False
|
|
||||||
current_reasoning_id: str | None = None
|
current_reasoning_id: str | None = None
|
||||||
pending_tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
pending_tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||||
lc_tool_call_id_by_run: dict[str, str] = field(default_factory=dict)
|
lc_tool_call_id_by_run: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
|
||||||
|
|
@ -670,7 +670,9 @@ async def index_discord_messages(
|
||||||
|
|
||||||
# Heavy processing (embeddings, chunks)
|
# Heavy processing (embeddings, chunks)
|
||||||
chunks = await create_document_chunks(item["combined_document_string"])
|
chunks = await create_document_chunks(item["combined_document_string"])
|
||||||
doc_embedding = embed_text(item["combined_document_string"])
|
doc_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, item["combined_document_string"]
|
||||||
|
)
|
||||||
|
|
||||||
# Update document to READY with actual content
|
# Update document to READY with actual content
|
||||||
document.title = f"{item['guild_name']}#{item['channel_name']}"
|
document.title = f"{item['guild_name']}#{item['channel_name']}"
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ Implements 2-phase document status updates for real-time UI feedback:
|
||||||
- Phase 2: Process each event: pending → processing → ready/failed
|
- Phase 2: Process each event: pending → processing → ready/failed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
@ -465,7 +466,9 @@ async def index_luma_events(
|
||||||
summary_content = (
|
summary_content = (
|
||||||
f"Luma Event: {item['event_name']}\n\n{item['event_markdown']}"
|
f"Luma Event: {item['event_name']}\n\n{item['event_markdown']}"
|
||||||
)
|
)
|
||||||
summary_embedding = embed_text(summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, summary_content
|
||||||
|
)
|
||||||
|
|
||||||
chunks = await create_document_chunks(item["event_markdown"])
|
chunks = await create_document_chunks(item["event_markdown"])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ Uses 2-phase document status updates for real-time UI feedback:
|
||||||
- Phase 2: Process each document: pending → processing → ready/failed
|
- Phase 2: Process each document: pending → processing → ready/failed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
@ -581,7 +582,9 @@ async def index_teams_messages(
|
||||||
|
|
||||||
# Heavy processing (embeddings, chunks)
|
# Heavy processing (embeddings, chunks)
|
||||||
chunks = await create_document_chunks(item["combined_document_string"])
|
chunks = await create_document_chunks(item["combined_document_string"])
|
||||||
doc_embedding = embed_text(item["combined_document_string"])
|
doc_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, item["combined_document_string"]
|
||||||
|
)
|
||||||
|
|
||||||
# Update document to READY with actual content
|
# Update document to READY with actual content
|
||||||
document.title = f"{item['team_name']} - {item['channel_name']}"
|
document.title = f"{item['team_name']} - {item['channel_name']}"
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
Unified document save/update logic for file processors.
|
Unified document save/update logic for file processors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
@ -43,7 +44,7 @@ async def _generate_summary(
|
||||||
"""
|
"""
|
||||||
if not enable_summary:
|
if not enable_summary:
|
||||||
summary = f"File: {file_name}\n\n{markdown_content[:4000]}"
|
summary = f"File: {file_name}\n\n{markdown_content[:4000]}"
|
||||||
return summary, embed_text(summary)
|
return summary, await asyncio.to_thread(embed_text, summary)
|
||||||
|
|
||||||
if etl_service == "DOCLING":
|
if etl_service == "DOCLING":
|
||||||
from app.services.docling_service import create_docling_service
|
from app.services.docling_service import create_docling_service
|
||||||
|
|
@ -65,7 +66,7 @@ async def _generate_summary(
|
||||||
parts.append(f"**{formatted_key}:** {value}")
|
parts.append(f"**{formatted_key}:** {value}")
|
||||||
|
|
||||||
enhanced = "\n".join(parts) + "\n\n# DOCUMENT SUMMARY\n\n" + summary_text
|
enhanced = "\n".join(parts) + "\n\n# DOCUMENT SUMMARY\n\n" + summary_text
|
||||||
return enhanced, embed_text(enhanced)
|
return enhanced, await asyncio.to_thread(embed_text, enhanced)
|
||||||
|
|
||||||
# Standard summary (Unstructured / LlamaCloud / others)
|
# Standard summary (Unstructured / LlamaCloud / others)
|
||||||
meta = {
|
meta = {
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
|
@ -221,7 +222,9 @@ async def generate_document_summary(
|
||||||
else:
|
else:
|
||||||
enhanced_summary_content = summary_content
|
enhanced_summary_content = summary_content
|
||||||
|
|
||||||
summary_embedding = embed_text(enhanced_summary_content)
|
summary_embedding = await asyncio.to_thread(
|
||||||
|
embed_text, enhanced_summary_content
|
||||||
|
)
|
||||||
|
|
||||||
return enhanced_summary_content, summary_embedding
|
return enhanced_summary_content, summary_embedding
|
||||||
|
|
||||||
|
|
@ -237,7 +240,7 @@ async def create_document_chunks(content: str) -> list[Chunk]:
|
||||||
List of Chunk objects with embeddings
|
List of Chunk objects with embeddings
|
||||||
"""
|
"""
|
||||||
chunk_texts = [c.text for c in config.chunker_instance.chunk(content)]
|
chunk_texts = [c.text for c in config.chunker_instance.chunk(content)]
|
||||||
chunk_embeddings = embed_texts(chunk_texts)
|
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
|
||||||
return [
|
return [
|
||||||
Chunk(content=text, embedding=emb)
|
Chunk(content=text, embedding=emb)
|
||||||
for text, emb in zip(chunk_texts, chunk_embeddings, strict=False)
|
for text, emb in zip(chunk_texts, chunk_embeddings, strict=False)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "surf-new-backend"
|
name = "surf-new-backend"
|
||||||
version = "0.0.24"
|
version = "0.0.25"
|
||||||
description = "SurfSense Backend"
|
description = "SurfSense Backend"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
|
|
||||||
|
|
@ -2,28 +2,12 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.agents.new_chat.tools.update_memory import _save_memory
|
from app.services.memory import MemoryScope, save_memory
|
||||||
from app.utils.content_utils import extract_text_content
|
from app.utils.content_utils import extract_text_content
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
class _Recorder:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.applied_content: str | None = None
|
|
||||||
self.commit_calls = 0
|
|
||||||
self.rollback_calls = 0
|
|
||||||
|
|
||||||
def apply(self, content: str) -> None:
|
|
||||||
self.applied_content = content
|
|
||||||
|
|
||||||
async def commit(self) -> None:
|
|
||||||
self.commit_calls += 1
|
|
||||||
|
|
||||||
async def rollback(self) -> None:
|
|
||||||
self.rollback_calls += 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_text_content_keeps_no_update_bare_string_from_content_blocks() -> None:
|
def test_extract_text_content_keeps_no_update_bare_string_from_content_blocks() -> None:
|
||||||
content = [
|
content = [
|
||||||
{"type": "thinking", "thinking": "No"},
|
{"type": "thinking", "thinking": "No"},
|
||||||
|
|
@ -69,21 +53,12 @@ def test_extract_text_content_preserves_plain_string_responses() -> None:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_save_memory_rejects_non_string_payload_before_commit() -> None:
|
async def test_save_memory_rejects_non_string_payload_before_commit() -> None:
|
||||||
recorder = _Recorder()
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.USER,
|
||||||
result = await _save_memory(
|
target_id="00000000-0000-0000-0000-000000000000",
|
||||||
updated_memory=["NO_UPDATE"], # type: ignore[arg-type]
|
content=["NO_UPDATE"], # type: ignore[arg-type]
|
||||||
old_memory=None,
|
session=None, # type: ignore[arg-type]
|
||||||
llm=None,
|
|
||||||
apply_fn=recorder.apply,
|
|
||||||
commit_fn=recorder.commit,
|
|
||||||
rollback_fn=recorder.rollback,
|
|
||||||
label="memory",
|
|
||||||
scope="user",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["status"] == "error"
|
assert result.status == "error"
|
||||||
assert "must be a string" in result["message"]
|
assert "must be a string" in result.message
|
||||||
assert recorder.applied_content is None
|
|
||||||
assert recorder.commit_calls == 0
|
|
||||||
assert recorder.rollback_calls == 0
|
|
||||||
|
|
|
||||||
|
|
@ -12,13 +12,19 @@ prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
|
||||||
the deepagent stack accumulates multiple ``SystemMessage``\ s in
|
the deepagent stack accumulates multiple ``SystemMessage``\ s in
|
||||||
``state["messages"]`` and ``role: system`` would tag every one of
|
``state["messages"]`` and ``role: system`` would tag every one of
|
||||||
them, blowing past Anthropic's 4-block ``cache_control`` cap.
|
them, blowing past Anthropic's 4-block ``cache_control`` cap.
|
||||||
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
|
2. Adds ``prompt_cache_key`` for OPENAI/DEEPSEEK/XAI/AZURE/AZURE_OPENAI
|
||||||
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
|
configs (Microsoft's Azure transformer was added to LiteLLM in
|
||||||
prompt-cache surface is available).
|
https://github.com/BerriAI/litellm/pull/20989, Feb 2026).
|
||||||
3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only — no
|
3. Adds ``prompt_cache_retention="24h"`` ONLY for OPENAI/DEEPSEEK/XAI.
|
||||||
OpenAI-only kwargs because the router fans out across providers.
|
Azure's server-side support landed in Microsoft's docs on 2026-05-13
|
||||||
4. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
|
but LiteLLM 1.83.14 hasn't wired it through yet, so we let Azure use
|
||||||
5. Defensive: LLMs without a writable ``model_kwargs`` are silently
|
its default in-memory retention rather than send a param that
|
||||||
|
``litellm.drop_params`` would silently strip.
|
||||||
|
4. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only — no
|
||||||
|
destination-specific kwargs because the router fans out across
|
||||||
|
providers.
|
||||||
|
5. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
|
||||||
|
6. Defensive: LLMs without a writable ``model_kwargs`` are silently
|
||||||
skipped rather than raising.
|
skipped rather than raising.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -191,9 +197,9 @@ def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None:
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"])
|
@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"])
|
||||||
def test_sets_openai_family_extras(provider: str) -> None:
|
def test_sets_openai_family_extras(provider: str) -> None:
|
||||||
"""OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate
|
"""Native OpenAI-style providers gain ``prompt_cache_key`` (raises
|
||||||
via routing affinity) and ``prompt_cache_retention="24h"`` (extends
|
hit rate via routing affinity) and ``prompt_cache_retention="24h"``
|
||||||
cache TTL beyond the default 5-10 min)."""
|
(extends cache TTL beyond the default 5-10 min)."""
|
||||||
cfg = _make_cfg(provider=provider)
|
cfg = _make_cfg(provider=provider)
|
||||||
llm = _FakeLLM()
|
llm = _FakeLLM()
|
||||||
|
|
||||||
|
|
@ -203,6 +209,27 @@ def test_sets_openai_family_extras(provider: str) -> None:
|
||||||
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
|
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider", ["AZURE", "AZURE_OPENAI"])
|
||||||
|
def test_azure_gets_prompt_cache_key_only(provider: str) -> None:
|
||||||
|
"""Azure configs gain ``prompt_cache_key`` for routing affinity
|
||||||
|
(Microsoft auto-caches every GPT-4o+ deployment at ≥1024 tokens;
|
||||||
|
the key clusters same-prefix requests on the same backend GPU pool
|
||||||
|
so hit rate climbs). They DO NOT get ``prompt_cache_retention``
|
||||||
|
because LiteLLM 1.83.14's Azure transformer omits it from its
|
||||||
|
supported params list — ``drop_params`` would silently strip it.
|
||||||
|
Azure's default in-memory retention (5-10 min, max 1 h) is already
|
||||||
|
enough to cover intra-conversation turns; revisit when LiteLLM
|
||||||
|
bumps Azure to match its OpenAI surface."""
|
||||||
|
cfg = _make_cfg(provider=provider, model_name="gpt-5.4")
|
||||||
|
llm = _FakeLLM(model="azure/gpt-5.4")
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
|
||||||
|
|
||||||
|
assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42"
|
||||||
|
assert "prompt_cache_retention" not in llm.model_kwargs
|
||||||
|
assert "cache_control_injection_points" in llm.model_kwargs
|
||||||
|
|
||||||
|
|
||||||
def test_skips_prompt_cache_key_when_no_thread_id() -> None:
|
def test_skips_prompt_cache_key_when_no_thread_id() -> None:
|
||||||
"""Without a thread id we can't construct a per-thread key. Retention
|
"""Without a thread id we can't construct a per-thread key. Retention
|
||||||
is still useful so we set it (it's free)."""
|
is still useful so we set it (it's free)."""
|
||||||
|
|
@ -215,12 +242,26 @@ def test_skips_prompt_cache_key_when_no_thread_id() -> None:
|
||||||
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
|
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_skips_prompt_cache_key_when_no_thread_id() -> None:
|
||||||
|
"""Azure without a thread id ends up with no extras (retention is
|
||||||
|
Azure-skipped, key needs a thread id) — universal injection points
|
||||||
|
still land."""
|
||||||
|
cfg = _make_cfg(provider="AZURE", model_name="gpt-5.4")
|
||||||
|
llm = _FakeLLM(model="azure/gpt-5.4")
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None)
|
||||||
|
|
||||||
|
assert "prompt_cache_key" not in llm.model_kwargs
|
||||||
|
assert "prompt_cache_retention" not in llm.model_kwargs
|
||||||
|
assert "cache_control_injection_points" in llm.model_kwargs
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"provider",
|
"provider",
|
||||||
["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"],
|
["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"],
|
||||||
)
|
)
|
||||||
def test_no_openai_extras_for_other_providers(provider: str) -> None:
|
def test_no_openai_extras_for_other_providers(provider: str) -> None:
|
||||||
"""Non-OpenAI-family providers don't expose ``prompt_cache_key`` —
|
"""Non-OpenAI-style providers don't expose ``prompt_cache_key`` —
|
||||||
skip it. ``cache_control_injection_points`` is still set (universal)."""
|
skip it. ``cache_control_injection_points`` is still set (universal)."""
|
||||||
cfg = _make_cfg(provider=provider)
|
cfg = _make_cfg(provider=provider)
|
||||||
llm = _FakeLLM()
|
llm = _FakeLLM()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,130 @@
|
||||||
|
"""Unit tests for ``mcp_tools_cache``."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
CachedMCPToolDef,
|
||||||
|
CachedMCPTools,
|
||||||
|
read_cached_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _make_connector(config: dict | None) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(id=42, config=config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_config_is_none() -> None:
|
||||||
|
assert read_cached_tools(_make_connector(None)) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_cached_tools_missing() -> None:
|
||||||
|
assert read_cached_tools(_make_connector({"server_config": {}})) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_cached_tools_is_not_a_dict() -> None:
|
||||||
|
assert read_cached_tools(_make_connector({"cached_tools": []})) is None
|
||||||
|
assert read_cached_tools(_make_connector({"cached_tools": "stale"})) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_parses_minimal_valid_payload() -> None:
|
||||||
|
parsed = read_cached_tools(
|
||||||
|
_make_connector(
|
||||||
|
{
|
||||||
|
"cached_tools": {
|
||||||
|
"discovered_at": "2026-05-20T10:00:00+00:00",
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": "list_issues",
|
||||||
|
"description": "List Linear issues",
|
||||||
|
"input_schema": {"type": "object"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert parsed is not None
|
||||||
|
assert parsed.server_version is None
|
||||||
|
assert parsed.server_name is None
|
||||||
|
assert parsed.transport is None
|
||||||
|
assert len(parsed.tools) == 1
|
||||||
|
assert parsed.tools[0].name == "list_issues"
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_parses_full_payload_with_serverinfo() -> None:
|
||||||
|
parsed = read_cached_tools(
|
||||||
|
_make_connector(
|
||||||
|
{
|
||||||
|
"cached_tools": {
|
||||||
|
"discovered_at": "2026-05-20T10:00:00+00:00",
|
||||||
|
"server_version": "1.2.3",
|
||||||
|
"server_name": "atlassian-mcp",
|
||||||
|
"transport": "streamable-http",
|
||||||
|
"tools": [
|
||||||
|
{"name": "create_issue", "input_schema": {}},
|
||||||
|
{"name": "list_issues", "input_schema": {}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert parsed is not None
|
||||||
|
assert parsed.server_version == "1.2.3"
|
||||||
|
assert parsed.server_name == "atlassian-mcp"
|
||||||
|
assert parsed.transport == "streamable-http"
|
||||||
|
assert [t.name for t in parsed.tools] == ["create_issue", "list_issues"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_for_corrupt_payload(caplog) -> None:
|
||||||
|
parsed = read_cached_tools(
|
||||||
|
_make_connector(
|
||||||
|
{
|
||||||
|
"cached_tools": {
|
||||||
|
"discovered_at": "not-a-date",
|
||||||
|
"tools": "should-be-a-list",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert parsed is None
|
||||||
|
assert any("corrupt cached_tools" in r.getMessage() for r in caplog.records)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_tools_missing() -> None:
|
||||||
|
parsed = read_cached_tools(
|
||||||
|
_make_connector(
|
||||||
|
{"cached_tools": {"discovered_at": "2026-05-20T10:00:00+00:00"}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert parsed is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_def_defaults_description_and_schema() -> None:
|
||||||
|
td = CachedMCPToolDef.model_validate({"name": "ping"})
|
||||||
|
assert td.description == ""
|
||||||
|
assert td.input_schema == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_dump_json_mode_is_round_trippable() -> None:
|
||||||
|
original = CachedMCPTools(
|
||||||
|
discovered_at=datetime(2026, 5, 20, 10, 0, 0, tzinfo=UTC),
|
||||||
|
server_version="1.2.3",
|
||||||
|
server_name="atlassian-mcp",
|
||||||
|
transport="streamable-http",
|
||||||
|
tools=[CachedMCPToolDef(name="list_issues")],
|
||||||
|
)
|
||||||
|
payload = original.model_dump(mode="json")
|
||||||
|
|
||||||
|
assert payload["discovered_at"] == "2026-05-20T10:00:00Z"
|
||||||
|
assert payload["tools"][0]["name"] == "list_issues"
|
||||||
|
|
||||||
|
reparsed = CachedMCPTools.model_validate(payload)
|
||||||
|
assert reparsed.discovered_at == original.discovered_at
|
||||||
|
assert reparsed.tools[0].name == "list_issues"
|
||||||
|
|
@ -1,24 +1,24 @@
|
||||||
"""Unit tests for memory scope validation and bullet format validation."""
|
"""Unit tests for heading-based memory validation."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.agents.new_chat.tools.update_memory import (
|
from app.services.memory import MemoryScope, save_memory
|
||||||
_save_memory,
|
from app.services.memory.validation import (
|
||||||
_validate_bullet_format,
|
validate_bullet_format,
|
||||||
_validate_memory_scope,
|
validate_memory_scope,
|
||||||
)
|
)
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
class _Recorder:
|
class _FakeSession:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.applied_content: str | None = None
|
self.added = []
|
||||||
self.commit_calls = 0
|
self.commit_calls = 0
|
||||||
self.rollback_calls = 0
|
self.rollback_calls = 0
|
||||||
|
|
||||||
def apply(self, content: str) -> None:
|
def add(self, obj) -> None:
|
||||||
self.applied_content = content
|
self.added.append(obj)
|
||||||
|
|
||||||
async def commit(self) -> None:
|
async def commit(self) -> None:
|
||||||
self.commit_calls += 1
|
self.commit_calls += 1
|
||||||
|
|
@ -27,172 +27,148 @@ class _Recorder:
|
||||||
self.rollback_calls += 1
|
self.rollback_calls += 1
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
def test_validate_memory_scope_rejects_new_personal_heading_in_team() -> None:
|
||||||
# _validate_memory_scope — marker-based
|
content = "## Preferences\n- 2026-04-10: Prefers dark mode\n"
|
||||||
# ---------------------------------------------------------------------------
|
result, _warnings = validate_memory_scope(content, "team")
|
||||||
|
|
||||||
|
|
||||||
def test_validate_memory_scope_rejects_pref_marker_in_team_scope() -> None:
|
|
||||||
content = "- (2026-04-10) [pref] Prefers dark mode\n"
|
|
||||||
result = _validate_memory_scope(content, "team")
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["status"] == "error"
|
assert result["status"] == "error"
|
||||||
assert "[pref]" in result["message"]
|
assert "preferences" in result["message"]
|
||||||
|
|
||||||
|
|
||||||
def test_validate_memory_scope_rejects_instr_marker_in_team_scope() -> None:
|
def test_validate_memory_scope_allows_old_marker_payload_in_team_scope() -> None:
|
||||||
content = "- (2026-04-10) [instr] Always respond in Spanish\n"
|
content = "- (2026-04-10) [pref] Legacy personal marker remains readable\n"
|
||||||
result = _validate_memory_scope(content, "team")
|
result, _warnings = validate_memory_scope(content, "team")
|
||||||
assert result is not None
|
assert result is None
|
||||||
assert result["status"] == "error"
|
|
||||||
assert "[instr]" in result["message"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_memory_scope_rejects_both_personal_markers_in_team() -> None:
|
def test_validate_memory_scope_allows_team_headings() -> None:
|
||||||
|
content = "## Engineering Conventions\n- 2026-04-10: Uses PostgreSQL\n"
|
||||||
|
result, _warnings = validate_memory_scope(content, "team")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_bullet_format_accepts_new_and_legacy_bullets() -> None:
|
||||||
content = (
|
content = (
|
||||||
"- (2026-04-10) [pref] Prefers dark mode\n"
|
"## Facts\n"
|
||||||
"- (2026-04-10) [instr] Always respond in Spanish\n"
|
"- 2026-04-10: Senior Python developer\n"
|
||||||
|
"- (2026-04-10) [fact] Legacy fact is preserved\n"
|
||||||
)
|
)
|
||||||
result = _validate_memory_scope(content, "team")
|
warnings = validate_bullet_format(content)
|
||||||
assert result is not None
|
|
||||||
assert result["status"] == "error"
|
|
||||||
assert "[instr]" in result["message"]
|
|
||||||
assert "[pref]" in result["message"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_memory_scope_allows_fact_in_team_scope() -> None:
|
|
||||||
content = "- (2026-04-10) [fact] Office is in downtown Seattle\n"
|
|
||||||
result = _validate_memory_scope(content, "team")
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_memory_scope_allows_all_markers_in_user_scope() -> None:
|
|
||||||
content = (
|
|
||||||
"- (2026-04-10) [fact] Python developer\n"
|
|
||||||
"- (2026-04-10) [pref] Prefers concise answers\n"
|
|
||||||
"- (2026-04-10) [instr] Always use bullet points\n"
|
|
||||||
)
|
|
||||||
result = _validate_memory_scope(content, "user")
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_memory_scope_allows_any_heading_in_team() -> None:
|
|
||||||
content = "## Architecture\n- (2026-04-10) [fact] Uses PostgreSQL for persistence\n"
|
|
||||||
result = _validate_memory_scope(content, "team")
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_memory_scope_allows_any_heading_in_user() -> None:
|
|
||||||
content = "## My Projects\n- (2026-04-10) [fact] Working on SurfSense\n"
|
|
||||||
result = _validate_memory_scope(content, "user")
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _validate_bullet_format
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_bullet_format_passes_valid_bullets() -> None:
|
|
||||||
content = (
|
|
||||||
"## Work\n"
|
|
||||||
"- (2026-04-10) [fact] Senior Python developer\n"
|
|
||||||
"- (2026-04-10) [pref] Prefers dark mode\n"
|
|
||||||
"- (2026-04-10) [instr] Always respond in bullet points\n"
|
|
||||||
)
|
|
||||||
warnings = _validate_bullet_format(content)
|
|
||||||
assert warnings == []
|
assert warnings == []
|
||||||
|
|
||||||
|
|
||||||
def test_validate_bullet_format_warns_on_missing_marker() -> None:
|
def test_validate_bullet_format_warns_on_nonstandard_bullet() -> None:
|
||||||
content = "- (2026-04-10) Senior Python developer\n"
|
content = "## Facts\n- Senior Python developer\n"
|
||||||
warnings = _validate_bullet_format(content)
|
warnings = validate_bullet_format(content)
|
||||||
assert len(warnings) == 1
|
assert len(warnings) == 1
|
||||||
assert "Malformed bullet" in warnings[0]
|
assert "Non-standard memory bullet" in warnings[0]
|
||||||
|
|
||||||
|
|
||||||
def test_validate_bullet_format_warns_on_missing_date() -> None:
|
|
||||||
content = "- [fact] Senior Python developer\n"
|
|
||||||
warnings = _validate_bullet_format(content)
|
|
||||||
assert len(warnings) == 1
|
|
||||||
assert "Malformed bullet" in warnings[0]
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_bullet_format_warns_on_unknown_marker() -> None:
|
|
||||||
content = "- (2026-04-10) [context] Working on project X\n"
|
|
||||||
warnings = _validate_bullet_format(content)
|
|
||||||
assert len(warnings) == 1
|
|
||||||
assert "Malformed bullet" in warnings[0]
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_bullet_format_ignores_non_bullet_lines() -> None:
|
|
||||||
content = "## Some Heading\nSome paragraph text\n"
|
|
||||||
warnings = _validate_bullet_format(content)
|
|
||||||
assert warnings == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_bullet_format_warns_on_old_format_without_marker() -> None:
|
|
||||||
content = "## About the user\n- (2026-04-10) Likes cats\n"
|
|
||||||
warnings = _validate_bullet_format(content)
|
|
||||||
assert len(warnings) == 1
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _save_memory — end-to-end with marker scope check
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_save_memory_blocks_pref_in_team_before_commit() -> None:
|
async def test_save_memory_normalizes_legacy_marker_bullets(monkeypatch) -> None:
|
||||||
recorder = _Recorder()
|
target = type("Target", (), {"memory_md": ""})()
|
||||||
result = await _save_memory(
|
session = _FakeSession()
|
||||||
updated_memory="- (2026-04-10) [pref] Prefers dark mode\n",
|
|
||||||
old_memory=None,
|
async def fake_load_target(**_kwargs):
|
||||||
llm=None,
|
return target
|
||||||
apply_fn=recorder.apply,
|
|
||||||
commit_fn=recorder.commit,
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
rollback_fn=recorder.rollback,
|
|
||||||
label="team memory",
|
result = await save_memory(
|
||||||
scope="team",
|
scope=MemoryScope.USER,
|
||||||
|
target_id="00000000-0000-0000-0000-000000000000",
|
||||||
|
content="- (2026-04-10) [fact] Legacy fact is preserved\n",
|
||||||
|
session=session,
|
||||||
)
|
)
|
||||||
assert result["status"] == "error"
|
|
||||||
assert recorder.commit_calls == 0
|
assert result.status == "saved"
|
||||||
assert recorder.applied_content is None
|
assert target.memory_md == "## Memory\n- 2026-04-10: Legacy fact is preserved"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_save_memory_allows_fact_in_team_and_commits() -> None:
|
async def test_save_memory_blocks_new_personal_heading_in_team_before_commit(
|
||||||
recorder = _Recorder()
|
monkeypatch,
|
||||||
content = "- (2026-04-10) [fact] Weekly standup on Mondays\n"
|
) -> None:
|
||||||
result = await _save_memory(
|
target = type("Target", (), {"shared_memory_md": ""})()
|
||||||
updated_memory=content,
|
session = _FakeSession()
|
||||||
old_memory=None,
|
|
||||||
llm=None,
|
async def fake_load_target(**_kwargs):
|
||||||
apply_fn=recorder.apply,
|
return target
|
||||||
commit_fn=recorder.commit,
|
|
||||||
rollback_fn=recorder.rollback,
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
label="team memory",
|
|
||||||
scope="team",
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.TEAM,
|
||||||
|
target_id=1,
|
||||||
|
content="## Preferences\n- 2026-04-10: Prefers dark mode\n",
|
||||||
|
session=session,
|
||||||
)
|
)
|
||||||
assert result["status"] == "saved"
|
assert result.status == "error"
|
||||||
assert recorder.commit_calls == 1
|
assert session.commit_calls == 0
|
||||||
assert recorder.applied_content == content
|
assert target.shared_memory_md == ""
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_save_memory_includes_format_warnings() -> None:
|
async def test_save_memory_allows_grandfathered_personal_heading_in_team(
|
||||||
recorder = _Recorder()
|
monkeypatch,
|
||||||
content = "- (2026-04-10) Missing marker text\n"
|
) -> None:
|
||||||
result = await _save_memory(
|
content = "## Preferences\n- 2026-04-10: Prefers dark mode\n"
|
||||||
updated_memory=content,
|
target = type("Target", (), {"shared_memory_md": content})()
|
||||||
old_memory=None,
|
session = _FakeSession()
|
||||||
llm=None,
|
|
||||||
apply_fn=recorder.apply,
|
async def fake_load_target(**_kwargs):
|
||||||
commit_fn=recorder.commit,
|
return target
|
||||||
rollback_fn=recorder.rollback,
|
|
||||||
label="memory",
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
scope="user",
|
|
||||||
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.TEAM,
|
||||||
|
target_id=1,
|
||||||
|
content=content,
|
||||||
|
session=session,
|
||||||
)
|
)
|
||||||
assert result["status"] == "saved"
|
assert result.status == "saved"
|
||||||
assert "format_warnings" in result
|
assert session.commit_calls == 1
|
||||||
assert len(result["format_warnings"]) == 1
|
assert target.shared_memory_md == content.strip()
|
||||||
|
assert result.warnings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_strips_preamble_before_heading(monkeypatch) -> None:
|
||||||
|
target = type("Target", (), {"memory_md": ""})()
|
||||||
|
session = _FakeSession()
|
||||||
|
|
||||||
|
async def fake_load_target(**_kwargs):
|
||||||
|
return target
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
|
|
||||||
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.USER,
|
||||||
|
target_id="00000000-0000-0000-0000-000000000000",
|
||||||
|
content="Sure, here is the update:\n\n## Facts\n- 2026-04-10: Likes cats\n",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
assert result.status == "saved"
|
||||||
|
assert target.memory_md == "## Facts\n- 2026-04-10: Likes cats"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_rejects_long_no_heading_payload(monkeypatch) -> None:
|
||||||
|
target = type("Target", (), {"memory_md": ""})()
|
||||||
|
session = _FakeSession()
|
||||||
|
|
||||||
|
async def fake_load_target(**_kwargs):
|
||||||
|
return target
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
|
|
||||||
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.USER,
|
||||||
|
target_id="00000000-0000-0000-0000-000000000000",
|
||||||
|
content="NO_UPDATE because there is nothing durable to remember.",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
assert result.status == "error"
|
||||||
|
assert "## heading" in result.message
|
||||||
|
assert session.commit_calls == 0
|
||||||
|
|
|
||||||
187
surfsense_backend/tests/unit/services/test_memory_service.py
Normal file
187
surfsense_backend/tests/unit/services/test_memory_service.py
Normal file
|
|
@ -0,0 +1,187 @@
|
||||||
|
"""Unit tests for the first-class memory service."""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.memory import (
|
||||||
|
MemoryScope,
|
||||||
|
reset_memory,
|
||||||
|
save_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.commit_calls = 0
|
||||||
|
self.rollback_calls = 0
|
||||||
|
self.added = []
|
||||||
|
|
||||||
|
def add(self, obj) -> None:
|
||||||
|
self.added.append(obj)
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
self.commit_calls += 1
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
self.rollback_calls += 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_saves_heading_based_memory(monkeypatch) -> None:
|
||||||
|
target = SimpleNamespace(memory_md="")
|
||||||
|
session = _FakeSession()
|
||||||
|
|
||||||
|
async def fake_load_target(**_kwargs):
|
||||||
|
return target
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
|
|
||||||
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.USER,
|
||||||
|
target_id="00000000-0000-0000-0000-000000000000",
|
||||||
|
content="## Facts\n- 2026-05-19: Anish works on SurfSense\n",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == "saved"
|
||||||
|
assert target.memory_md.startswith("## Facts")
|
||||||
|
assert session.commit_calls == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_accepts_legacy_marker_payload(monkeypatch) -> None:
|
||||||
|
target = SimpleNamespace(memory_md="")
|
||||||
|
session = _FakeSession()
|
||||||
|
|
||||||
|
async def fake_load_target(**_kwargs):
|
||||||
|
return target
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
|
|
||||||
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.USER,
|
||||||
|
target_id="00000000-0000-0000-0000-000000000000",
|
||||||
|
content="- (2026-05-19) [fact] Legacy marker memory\n",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == "saved"
|
||||||
|
assert target.memory_md == "## Memory\n- 2026-05-19: Legacy marker memory"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_rejects_long_no_heading_payload(monkeypatch) -> None:
|
||||||
|
target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
|
||||||
|
session = _FakeSession()
|
||||||
|
|
||||||
|
async def fake_load_target(**_kwargs):
|
||||||
|
return target
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
|
|
||||||
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.USER,
|
||||||
|
target_id="00000000-0000-0000-0000-000000000000",
|
||||||
|
content="reasoning text before NO_UPDATE should not become saved memory",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == "error"
|
||||||
|
assert session.commit_calls == 0
|
||||||
|
assert target.memory_md.startswith("## Facts")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_no_update_sentinel_is_no_op(monkeypatch) -> None:
|
||||||
|
existing = "## Preferences\n- 2026-05-20: Existing preference\n"
|
||||||
|
target = SimpleNamespace(memory_md=existing)
|
||||||
|
session = _FakeSession()
|
||||||
|
|
||||||
|
async def fake_load_target(**_kwargs):
|
||||||
|
return target
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
|
|
||||||
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.USER,
|
||||||
|
target_id="00000000-0000-0000-0000-000000000000",
|
||||||
|
content="NO_UPDATE",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == "no_op"
|
||||||
|
assert result.memory_md == existing
|
||||||
|
assert target.memory_md == existing
|
||||||
|
assert session.commit_calls == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_no_update_sentinel_is_case_insensitive(monkeypatch) -> None:
|
||||||
|
existing = "## Preferences\n- 2026-05-20: Existing preference\n"
|
||||||
|
target = SimpleNamespace(memory_md=existing)
|
||||||
|
session = _FakeSession()
|
||||||
|
|
||||||
|
async def fake_load_target(**_kwargs):
|
||||||
|
return target
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
|
|
||||||
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.USER,
|
||||||
|
target_id="00000000-0000-0000-0000-000000000000",
|
||||||
|
content=" no update ",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == "no_op"
|
||||||
|
assert result.memory_md == existing
|
||||||
|
assert target.memory_md == existing
|
||||||
|
assert session.commit_calls == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_grandfathers_existing_team_personal_heading(
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
content = "## Preferences\n- 2026-05-19: Existing legacy heading\n"
|
||||||
|
target = SimpleNamespace(shared_memory_md=content)
|
||||||
|
session = _FakeSession()
|
||||||
|
|
||||||
|
async def fake_load_target(**_kwargs):
|
||||||
|
return target
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
|
|
||||||
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.TEAM,
|
||||||
|
target_id=1,
|
||||||
|
content=content,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == "saved"
|
||||||
|
assert result.warnings
|
||||||
|
assert session.commit_calls == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_memory_clears_memory(monkeypatch) -> None:
|
||||||
|
target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
|
||||||
|
session = _FakeSession()
|
||||||
|
|
||||||
|
async def fake_load_target(**_kwargs):
|
||||||
|
return target
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
|
|
||||||
|
result = await reset_memory(
|
||||||
|
scope=MemoryScope.USER,
|
||||||
|
target_id="00000000-0000-0000-0000-000000000000",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == "saved"
|
||||||
|
assert target.memory_md == ""
|
||||||
|
|
@ -89,7 +89,6 @@ async def test_stream_output_emits_text_lifecycle_and_updates_result() -> None:
|
||||||
"text_end:text-1",
|
"text_end:text-1",
|
||||||
]
|
]
|
||||||
assert result.accumulated_text == "Hello world"
|
assert result.accumulated_text == "Hello world"
|
||||||
assert result.agent_called_update_memory is False
|
|
||||||
|
|
||||||
|
|
||||||
async def test_stream_output_passes_runtime_context_to_agent() -> None:
|
async def test_stream_output_passes_runtime_context_to_agent() -> None:
|
||||||
|
|
|
||||||
|
|
@ -209,128 +209,6 @@ def test_stream_exception_classifies_openrouter_429_payload():
|
||||||
assert extra is None
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch):
|
|
||||||
"""``_preflight_llm`` is best-effort.
|
|
||||||
|
|
||||||
- On rate-limit shaped exceptions (provider 429) it MUST re-raise so the
|
|
||||||
caller can drive the cooldown/repin branch.
|
|
||||||
- On any other transient failure it MUST swallow the error so the normal
|
|
||||||
stream path continues without surfacing preflight noise to the user.
|
|
||||||
"""
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from app.tasks.chat.stream_new_chat import _preflight_llm
|
|
||||||
|
|
||||||
class _RateLimitedError(Exception):
|
|
||||||
"""Class-name carries 'RateLimit' so _is_provider_rate_limited triggers."""
|
|
||||||
|
|
||||||
rate_calls: list[dict] = []
|
|
||||||
other_calls: list[dict] = []
|
|
||||||
|
|
||||||
async def _fake_acompletion_429(**kwargs):
|
|
||||||
rate_calls.append(kwargs)
|
|
||||||
raise _RateLimitedError("simulated 429")
|
|
||||||
|
|
||||||
async def _fake_acompletion_other(**kwargs):
|
|
||||||
other_calls.append(kwargs)
|
|
||||||
raise RuntimeError("some unrelated transient failure")
|
|
||||||
|
|
||||||
fake_llm = SimpleNamespace(
|
|
||||||
model="openrouter/google/gemma-4-31b-it:free",
|
|
||||||
api_key="test",
|
|
||||||
api_base=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
import litellm # type: ignore[import-not-found]
|
|
||||||
|
|
||||||
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429)
|
|
||||||
with pytest.raises(_RateLimitedError):
|
|
||||||
await _preflight_llm(fake_llm)
|
|
||||||
assert len(rate_calls) == 1
|
|
||||||
assert rate_calls[0]["max_tokens"] == 1
|
|
||||||
assert rate_calls[0]["stream"] is False
|
|
||||||
|
|
||||||
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other)
|
|
||||||
# MUST NOT raise: non-rate-limit failures are swallowed.
|
|
||||||
await _preflight_llm(fake_llm)
|
|
||||||
assert len(other_calls) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_preflight_skipped_for_auto_router_model():
|
|
||||||
"""Router-mode ``model='auto'`` has no single deployment to ping; the
|
|
||||||
LiteLLM router itself owns per-deployment rate-limit accounting, so the
|
|
||||||
preflight helper must short-circuit instead of issuing a probe."""
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from app.tasks.chat.stream_new_chat import _preflight_llm
|
|
||||||
|
|
||||||
fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None)
|
|
||||||
# Should return without raising or making any LiteLLM call.
|
|
||||||
await _preflight_llm(fake_llm)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_settle_speculative_agent_build_swallows_exceptions():
|
|
||||||
"""``_settle_speculative_agent_build`` MUST always return cleanly so the
|
|
||||||
caller can safely re-touch the request-scoped session afterwards.
|
|
||||||
|
|
||||||
The helper guards the parallel preflight + agent-build path: when the
|
|
||||||
speculative build is being discarded (429 or non-429 preflight failure)
|
|
||||||
we await it solely to release any in-flight ``AsyncSession`` usage —
|
|
||||||
the build's outcome is irrelevant. Any exception (including
|
|
||||||
``CancelledError``) leaking out would skip the caller's recovery flow
|
|
||||||
and re-introduce the very session-concurrency hazard the helper exists
|
|
||||||
to prevent.
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
|
|
||||||
|
|
||||||
async def _raises() -> None:
|
|
||||||
raise RuntimeError("speculative build crashed")
|
|
||||||
|
|
||||||
async def _succeeds() -> str:
|
|
||||||
return "agent"
|
|
||||||
|
|
||||||
async def _slow() -> None:
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
|
|
||||||
for coro in (_raises(), _succeeds(), _slow()):
|
|
||||||
task = asyncio.create_task(coro)
|
|
||||||
await _settle_speculative_agent_build(task)
|
|
||||||
assert task.done()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_settle_speculative_agent_build_handles_already_done_task():
|
|
||||||
"""Done tasks (success or failure) must still be settled without raising."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
|
|
||||||
|
|
||||||
async def _ok() -> str:
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
async def _bad() -> None:
|
|
||||||
raise ValueError("nope")
|
|
||||||
|
|
||||||
ok_task = asyncio.create_task(_ok())
|
|
||||||
bad_task = asyncio.create_task(_bad())
|
|
||||||
# Drive both to completion before settling.
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
await _settle_speculative_agent_build(ok_task)
|
|
||||||
await _settle_speculative_agent_build(bad_task)
|
|
||||||
assert ok_task.result() == "ok"
|
|
||||||
# ``bad_task`` exception was consumed by the settle helper; calling
|
|
||||||
# ``.exception()`` after the fact must still return the original error
|
|
||||||
# (the helper observes it but doesn't clear it).
|
|
||||||
assert isinstance(bad_task.exception(), ValueError)
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_exception_classifies_thread_busy():
|
def test_stream_exception_classifies_thread_busy():
|
||||||
exc = BusyError(request_id="thread-123")
|
exc = BusyError(request_id="thread-123")
|
||||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
|
|
|
||||||
2
surfsense_backend/uv.lock
generated
2
surfsense_backend/uv.lock
generated
|
|
@ -7947,7 +7947,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "surf-new-backend"
|
name = "surf-new-backend"
|
||||||
version = "0.0.24"
|
version = "0.0.25"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "alembic" },
|
{ name = "alembic" },
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"name": "surfsense_browser_extension",
|
"name": "surfsense_browser_extension",
|
||||||
"displayName": "Surfsense Browser Extension",
|
"displayName": "Surfsense Browser Extension",
|
||||||
"version": "0.0.24",
|
"version": "0.0.25",
|
||||||
"description": "Extension to collect Browsing History for SurfSense.",
|
"description": "Extension to collect Browsing History for SurfSense.",
|
||||||
"author": "https://github.com/MODSetter",
|
"author": "https://github.com/MODSetter",
|
||||||
"engines": {
|
"engines": {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "surfsense-desktop",
|
"name": "surfsense-desktop",
|
||||||
"version": "0.0.24",
|
"version": "0.0.25",
|
||||||
"description": "SurfSense Desktop App",
|
"description": "SurfSense Desktop App",
|
||||||
"main": "dist/main.js",
|
"main": "dist/main.js",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import { Logo } from "@/components/Logo";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { trackLoginAttempt } from "@/lib/posthog/events";
|
import { trackLoginAttempt } from "@/lib/posthog/events";
|
||||||
import { AmbientBackground } from "./AmbientBackground";
|
import { AmbientBackground } from "./AmbientBackground";
|
||||||
|
import { BACKEND_URL } from "@/lib/env-config";
|
||||||
|
|
||||||
function GoogleGLogo({ className }: { className?: string }) {
|
function GoogleGLogo({ className }: { className?: string }) {
|
||||||
return (
|
return (
|
||||||
|
|
@ -50,7 +51,7 @@ export function GoogleLoginButton() {
|
||||||
// cross-origin fetch requests may not be sent on subsequent redirects.
|
// cross-origin fetch requests may not be sent on subsequent redirects.
|
||||||
// The authorize-redirect endpoint does a server-side redirect to Google
|
// The authorize-redirect endpoint does a server-side redirect to Google
|
||||||
// and sets the CSRF cookie properly for same-site context.
|
// and sets the CSRF cookie properly for same-site context.
|
||||||
window.location.href = `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/auth/google/authorize-redirect`;
|
window.location.href = `${BACKEND_URL}/auth/google/authorize-redirect`;
|
||||||
};
|
};
|
||||||
return (
|
return (
|
||||||
<div className="relative w-full overflow-hidden">
|
<div className="relative w-full overflow-hidden">
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,9 @@ import { NextResponse } from "next/server";
|
||||||
import type { Context } from "@/types/zero";
|
import type { Context } from "@/types/zero";
|
||||||
import { queries } from "@/zero/queries";
|
import { queries } from "@/zero/queries";
|
||||||
import { schema } from "@/zero/schema";
|
import { schema } from "@/zero/schema";
|
||||||
|
import { BACKEND_URL } from "@/lib/env-config";
|
||||||
|
|
||||||
const backendURL =
|
const backendURL = BACKEND_URL;
|
||||||
process.env.FASTAPI_BACKEND_INTERNAL_URL ||
|
|
||||||
process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL ||
|
|
||||||
"http://localhost:8000";
|
|
||||||
|
|
||||||
async function authenticateRequest(
|
async function authenticateRequest(
|
||||||
request: Request
|
request: Request
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,7 @@ import {
|
||||||
trackChatResponseReceived,
|
trackChatResponseReceived,
|
||||||
} from "@/lib/posthog/events";
|
} from "@/lib/posthog/events";
|
||||||
import Loading from "../loading";
|
import Loading from "../loading";
|
||||||
|
import { BACKEND_URL } from "@/lib/env-config";
|
||||||
const MobileEditorPanel = dynamic(
|
const MobileEditorPanel = dynamic(
|
||||||
() =>
|
() =>
|
||||||
import("@/components/editor-panel/editor-panel").then((m) => ({
|
import("@/components/editor-panel/editor-panel").then((m) => ({
|
||||||
|
|
@ -777,7 +777,7 @@ export default function NewChatPage() {
|
||||||
if (threadId) {
|
if (threadId) {
|
||||||
const token = getBearerToken();
|
const token = getBearerToken();
|
||||||
if (token) {
|
if (token) {
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
const backendUrl = BACKEND_URL;
|
||||||
try {
|
try {
|
||||||
const response = await fetch(
|
const response = await fetch(
|
||||||
`${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`,
|
`${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`,
|
||||||
|
|
@ -978,7 +978,7 @@ export default function NewChatPage() {
|
||||||
let streamBatcher: FrameBatchedUpdater | null = null;
|
let streamBatcher: FrameBatchedUpdater | null = null;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
const backendUrl = BACKEND_URL;
|
||||||
const selection = await getAgentFilesystemSelection(searchSpaceId, {
|
const selection = await getAgentFilesystemSelection(searchSpaceId, {
|
||||||
localFilesystemEnabled,
|
localFilesystemEnabled,
|
||||||
});
|
});
|
||||||
|
|
@ -1520,7 +1520,7 @@ export default function NewChatPage() {
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
const backendUrl = BACKEND_URL;
|
||||||
const selection = await getAgentFilesystemSelection(searchSpaceId, {
|
const selection = await getAgentFilesystemSelection(searchSpaceId, {
|
||||||
localFilesystemEnabled,
|
localFilesystemEnabled,
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
import {
|
import {
|
||||||
BookText,
|
BookText,
|
||||||
Bot,
|
Bot,
|
||||||
Brain,
|
|
||||||
CircleUser,
|
CircleUser,
|
||||||
Earth,
|
Earth,
|
||||||
ImageIcon,
|
ImageIcon,
|
||||||
|
|
@ -27,7 +26,6 @@ export type SearchSpaceSettingsTab =
|
||||||
| "vision-models"
|
| "vision-models"
|
||||||
| "team-roles"
|
| "team-roles"
|
||||||
| "prompts"
|
| "prompts"
|
||||||
| "team-memory"
|
|
||||||
| "public-links";
|
| "public-links";
|
||||||
|
|
||||||
const DEFAULT_TAB: SearchSpaceSettingsTab = "general";
|
const DEFAULT_TAB: SearchSpaceSettingsTab = "general";
|
||||||
|
|
@ -89,11 +87,6 @@ export function SearchSpaceSettingsLayoutShell({
|
||||||
label: t("nav_system_instructions"),
|
label: t("nav_system_instructions"),
|
||||||
icon: <BookText className="h-4 w-4" />,
|
icon: <BookText className="h-4 w-4" />,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
value: "team-memory" as const,
|
|
||||||
label: "Team Memory",
|
|
||||||
icon: <Brain className="h-4 w-4" />,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
value: "public-links" as const,
|
value: "public-links" as const,
|
||||||
label: t("nav_public_links"),
|
label: t("nav_public_links"),
|
||||||
|
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
import { TeamMemoryManager } from "@/components/settings/team-memory-manager";
|
|
||||||
|
|
||||||
export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) {
|
|
||||||
const { search_space_id } = await params;
|
|
||||||
return <TeamMemoryManager searchSpaceId={Number(search_space_id)} />;
|
|
||||||
}
|
|
||||||
|
|
@ -1,293 +0,0 @@
|
||||||
"use client";
|
|
||||||
|
|
||||||
import { useAtomValue } from "jotai";
|
|
||||||
import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pencil } from "lucide-react";
|
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
|
||||||
import { toast } from "sonner";
|
|
||||||
import { z } from "zod";
|
|
||||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
|
||||||
import { PlateEditor } from "@/components/editor/plate-editor";
|
|
||||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
|
||||||
import { Button } from "@/components/ui/button";
|
|
||||||
import {
|
|
||||||
DropdownMenu,
|
|
||||||
DropdownMenuContent,
|
|
||||||
DropdownMenuItem,
|
|
||||||
DropdownMenuTrigger,
|
|
||||||
} from "@/components/ui/dropdown-menu";
|
|
||||||
import { Spinner } from "@/components/ui/spinner";
|
|
||||||
|
|
||||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
|
||||||
|
|
||||||
const MEMORY_HARD_LIMIT = 25_000;
|
|
||||||
|
|
||||||
const MemoryReadSchema = z.object({
|
|
||||||
memory_md: z.string(),
|
|
||||||
});
|
|
||||||
|
|
||||||
export function MemoryContent() {
|
|
||||||
const activeSearchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
|
|
||||||
const [memory, setMemory] = useState("");
|
|
||||||
const [loading, setLoading] = useState(true);
|
|
||||||
const [saving, setSaving] = useState(false);
|
|
||||||
const [editQuery, setEditQuery] = useState("");
|
|
||||||
const [editing, setEditing] = useState(false);
|
|
||||||
const [showInput, setShowInput] = useState(false);
|
|
||||||
const textareaRef = useRef<HTMLInputElement>(null);
|
|
||||||
const inputContainerRef = useRef<HTMLDivElement>(null);
|
|
||||||
|
|
||||||
const fetchMemory = useCallback(async () => {
|
|
||||||
try {
|
|
||||||
setLoading(true);
|
|
||||||
const data = await baseApiService.get("/api/v1/users/me/memory", MemoryReadSchema);
|
|
||||||
setMemory(data.memory_md);
|
|
||||||
} catch {
|
|
||||||
toast.error("Failed to load memory");
|
|
||||||
} finally {
|
|
||||||
setLoading(false);
|
|
||||||
}
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
fetchMemory();
|
|
||||||
}, [fetchMemory]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (!showInput) return;
|
|
||||||
|
|
||||||
const handlePointerDownOutside = (event: MouseEvent | TouchEvent) => {
|
|
||||||
const target = event.target;
|
|
||||||
if (!(target instanceof Node)) return;
|
|
||||||
if (inputContainerRef.current?.contains(target)) return;
|
|
||||||
|
|
||||||
setShowInput(false);
|
|
||||||
};
|
|
||||||
|
|
||||||
document.addEventListener("mousedown", handlePointerDownOutside);
|
|
||||||
document.addEventListener("touchstart", handlePointerDownOutside, { passive: true });
|
|
||||||
|
|
||||||
return () => {
|
|
||||||
document.removeEventListener("mousedown", handlePointerDownOutside);
|
|
||||||
document.removeEventListener("touchstart", handlePointerDownOutside);
|
|
||||||
};
|
|
||||||
}, [showInput]);
|
|
||||||
|
|
||||||
const handleClear = async () => {
|
|
||||||
try {
|
|
||||||
setSaving(true);
|
|
||||||
const data = await baseApiService.put("/api/v1/users/me/memory", MemoryReadSchema, {
|
|
||||||
body: { memory_md: "" },
|
|
||||||
});
|
|
||||||
setMemory(data.memory_md);
|
|
||||||
toast.success("Memory cleared");
|
|
||||||
} catch {
|
|
||||||
toast.error("Failed to clear memory");
|
|
||||||
} finally {
|
|
||||||
setSaving(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleEdit = async () => {
|
|
||||||
const query = editQuery.trim();
|
|
||||||
if (!query) return;
|
|
||||||
|
|
||||||
try {
|
|
||||||
setEditing(true);
|
|
||||||
const data = await baseApiService.post("/api/v1/users/me/memory/edit", MemoryReadSchema, {
|
|
||||||
body: { query, search_space_id: Number(activeSearchSpaceId) },
|
|
||||||
});
|
|
||||||
setMemory(data.memory_md);
|
|
||||||
setEditQuery("");
|
|
||||||
setShowInput(false);
|
|
||||||
toast.success("Memory updated");
|
|
||||||
} catch {
|
|
||||||
toast.error("Failed to edit memory");
|
|
||||||
} finally {
|
|
||||||
setEditing(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const openInput = () => {
|
|
||||||
setShowInput(true);
|
|
||||||
requestAnimationFrame(() => textareaRef.current?.focus());
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleDownload = () => {
|
|
||||||
if (!memory) return;
|
|
||||||
try {
|
|
||||||
const blob = new Blob([memory], { type: "text/markdown;charset=utf-8" });
|
|
||||||
const url = URL.createObjectURL(blob);
|
|
||||||
const a = document.createElement("a");
|
|
||||||
a.href = url;
|
|
||||||
a.download = "personal-memory.md";
|
|
||||||
document.body.appendChild(a);
|
|
||||||
a.click();
|
|
||||||
document.body.removeChild(a);
|
|
||||||
URL.revokeObjectURL(url);
|
|
||||||
} catch {
|
|
||||||
toast.error("Failed to download memory");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleCopyMarkdown = async () => {
|
|
||||||
if (!memory) return;
|
|
||||||
try {
|
|
||||||
await navigator.clipboard.writeText(memory);
|
|
||||||
toast.success("Copied to clipboard");
|
|
||||||
} catch {
|
|
||||||
toast.error("Failed to copy memory");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
|
||||||
if (e.key === "Enter" && !e.shiftKey) {
|
|
||||||
e.preventDefault();
|
|
||||||
handleEdit();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const displayMemory = memory.replace(/\(\d{4}-\d{2}-\d{2}\)\s*\[(fact|pref|instr)\]\s*/g, "");
|
|
||||||
const charCount = memory.length;
|
|
||||||
|
|
||||||
const getCounterColor = () => {
|
|
||||||
if (charCount > MEMORY_HARD_LIMIT) return "text-red-500";
|
|
||||||
if (charCount > 15_000) return "text-orange-500";
|
|
||||||
if (charCount > 10_000) return "text-yellow-500";
|
|
||||||
return "text-muted-foreground";
|
|
||||||
};
|
|
||||||
|
|
||||||
if (loading) {
|
|
||||||
return (
|
|
||||||
<div className="flex items-center justify-center py-12">
|
|
||||||
<Spinner size="md" className="text-muted-foreground" />
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!memory) {
|
|
||||||
return (
|
|
||||||
<div className="flex flex-col items-center justify-center py-16 text-center">
|
|
||||||
<h3 className="text-base font-medium text-foreground">What does SurfSense remember?</h3>
|
|
||||||
<p className="mt-2 max-w-sm text-sm text-muted-foreground">
|
|
||||||
Nothing yet. SurfSense picks up on your preferences and context as you chat.
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="space-y-4">
|
|
||||||
<Alert>
|
|
||||||
<Info />
|
|
||||||
<AlertDescription>
|
|
||||||
<p>
|
|
||||||
SurfSense uses this personal memory to personalize your responses across all
|
|
||||||
conversations.
|
|
||||||
</p>
|
|
||||||
</AlertDescription>
|
|
||||||
</Alert>
|
|
||||||
|
|
||||||
<div className="relative h-[380px] rounded-lg border bg-background">
|
|
||||||
<div className="h-full overflow-y-auto scrollbar-thin">
|
|
||||||
<PlateEditor
|
|
||||||
markdown={displayMemory}
|
|
||||||
readOnly
|
|
||||||
preset="readonly"
|
|
||||||
variant="default"
|
|
||||||
editorVariant="none"
|
|
||||||
className="px-5 py-4 text-sm min-h-full"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{showInput ? (
|
|
||||||
<div className="absolute bottom-3 inset-x-3 z-10">
|
|
||||||
<div
|
|
||||||
ref={inputContainerRef}
|
|
||||||
className="relative flex h-[54px] items-center gap-2 rounded-[9999px] border bg-muted/60 backdrop-blur-sm pl-4 pr-1 shadow-sm"
|
|
||||||
>
|
|
||||||
<input
|
|
||||||
ref={textareaRef}
|
|
||||||
type="text"
|
|
||||||
value={editQuery}
|
|
||||||
onChange={(e) => setEditQuery(e.target.value)}
|
|
||||||
onKeyDown={handleKeyDown}
|
|
||||||
placeholder="Tell SurfSense what to remember or forget"
|
|
||||||
disabled={editing}
|
|
||||||
className="flex-1 bg-transparent text-sm outline-none placeholder:text-muted-foreground/70"
|
|
||||||
/>
|
|
||||||
<Button
|
|
||||||
type="button"
|
|
||||||
size="icon"
|
|
||||||
variant="ghost"
|
|
||||||
onClick={handleEdit}
|
|
||||||
disabled={editing || !editQuery.trim()}
|
|
||||||
className={`h-11 w-11 shrink-0 rounded-full ${
|
|
||||||
editing
|
|
||||||
? ""
|
|
||||||
: "bg-muted-foreground/15 hover:bg-accent hover:text-accent-foreground"
|
|
||||||
}`}
|
|
||||||
>
|
|
||||||
{editing ? (
|
|
||||||
<Spinner size="sm" />
|
|
||||||
) : (
|
|
||||||
<ArrowUp className="!h-5 !w-5 text-foreground" strokeWidth={2.25} />
|
|
||||||
)}
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<Button
|
|
||||||
type="button"
|
|
||||||
size="icon"
|
|
||||||
variant="secondary"
|
|
||||||
onClick={openInput}
|
|
||||||
className="absolute bottom-3 right-3 z-10 h-[54px] w-[54px] rounded-full border bg-muted/60 backdrop-blur-sm shadow-sm"
|
|
||||||
>
|
|
||||||
<Pencil className="!h-5 !w-5" />
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="flex items-center justify-between gap-2">
|
|
||||||
<span className={`text-xs shrink-0 ${getCounterColor()}`}>
|
|
||||||
{charCount.toLocaleString()} / {MEMORY_HARD_LIMIT.toLocaleString()}
|
|
||||||
<span className="hidden sm:inline"> characters</span>
|
|
||||||
<span className="sm:hidden"> chars</span>
|
|
||||||
{charCount > 15_000 && charCount <= MEMORY_HARD_LIMIT && " - Approaching limit"}
|
|
||||||
{charCount > MEMORY_HARD_LIMIT && " - Exceeds limit"}
|
|
||||||
</span>
|
|
||||||
<div className="flex items-center gap-1.5 sm:gap-2">
|
|
||||||
<Button
|
|
||||||
type="button"
|
|
||||||
variant="destructive"
|
|
||||||
size="sm"
|
|
||||||
className="text-xs sm:text-sm"
|
|
||||||
onClick={handleClear}
|
|
||||||
disabled={saving || editing || !memory}
|
|
||||||
>
|
|
||||||
<span className="hidden sm:inline">Reset Memory</span>
|
|
||||||
<span className="sm:hidden">Reset</span>
|
|
||||||
</Button>
|
|
||||||
<DropdownMenu>
|
|
||||||
<DropdownMenuTrigger asChild>
|
|
||||||
<Button type="button" variant="secondary" size="sm" disabled={!memory}>
|
|
||||||
Export
|
|
||||||
<ChevronDown className="h-3 w-3 opacity-60" />
|
|
||||||
</Button>
|
|
||||||
</DropdownMenuTrigger>
|
|
||||||
<DropdownMenuContent align="end">
|
|
||||||
<DropdownMenuItem onClick={handleCopyMarkdown}>
|
|
||||||
<ClipboardCopy className="h-4 w-4 mr-2" />
|
|
||||||
Copy as Markdown
|
|
||||||
</DropdownMenuItem>
|
|
||||||
<DropdownMenuItem onClick={handleDownload}>
|
|
||||||
<Download className="h-4 w-4 mr-2" />
|
|
||||||
Download as Markdown
|
|
||||||
</DropdownMenuItem>
|
|
||||||
</DropdownMenuContent>
|
|
||||||
</DropdownMenu>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
Brain,
|
|
||||||
CircleUser,
|
CircleUser,
|
||||||
Keyboard,
|
Keyboard,
|
||||||
KeyRound,
|
KeyRound,
|
||||||
|
|
@ -26,7 +25,6 @@ export type UserSettingsTab =
|
||||||
| "api-key"
|
| "api-key"
|
||||||
| "prompts"
|
| "prompts"
|
||||||
| "community-prompts"
|
| "community-prompts"
|
||||||
| "memory"
|
|
||||||
| "agent-permissions"
|
| "agent-permissions"
|
||||||
| "agent-status"
|
| "agent-status"
|
||||||
| "purchases"
|
| "purchases"
|
||||||
|
|
@ -75,11 +73,6 @@ export function UserSettingsLayoutShell({ searchSpaceId, children }: UserSetting
|
||||||
label: "Community Prompts",
|
label: "Community Prompts",
|
||||||
icon: <Library className="h-4 w-4" />,
|
icon: <Library className="h-4 w-4" />,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
value: "memory" as const,
|
|
||||||
label: "Memory",
|
|
||||||
icon: <Brain className="h-4 w-4" />,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
value: "agent-permissions" as const,
|
value: "agent-permissions" as const,
|
||||||
label: "Agent Permissions",
|
label: "Agent Permissions",
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue