mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 22:32:39 +02:00
Merge pull request #1335 from AnishSarkar22/fix/memory-extraction
refactor(memory): streamline memory extraction
This commit is contained in:
commit
ce6d9233bc
5 changed files with 110 additions and 26 deletions
|
|
@ -16,6 +16,7 @@ from sqlalchemy import select
|
||||||
|
|
||||||
from app.agents.new_chat.tools.update_memory import _save_memory
|
from app.agents.new_chat.tools.update_memory import _save_memory
|
||||||
from app.db import SearchSpace, User, shielded_async_session
|
from app.db import SearchSpace, User, shielded_async_session
|
||||||
|
from app.utils.content_utils import extract_text_content
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -144,11 +145,7 @@ async def extract_and_save_memory(
|
||||||
[HumanMessage(content=prompt)],
|
[HumanMessage(content=prompt)],
|
||||||
config={"tags": ["surfsense:internal", "memory-extraction"]},
|
config={"tags": ["surfsense:internal", "memory-extraction"]},
|
||||||
)
|
)
|
||||||
text = (
|
text = extract_text_content(response.content).strip()
|
||||||
response.content
|
|
||||||
if isinstance(response.content, str)
|
|
||||||
else str(response.content)
|
|
||||||
).strip()
|
|
||||||
|
|
||||||
if text == "NO_UPDATE" or not text:
|
if text == "NO_UPDATE" or not text:
|
||||||
logger.debug("Memory extraction: no update needed (user %s)", uid)
|
logger.debug("Memory extraction: no update needed (user %s)", uid)
|
||||||
|
|
@ -207,11 +204,7 @@ async def extract_and_save_team_memory(
|
||||||
[HumanMessage(content=prompt)],
|
[HumanMessage(content=prompt)],
|
||||||
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
|
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
|
||||||
)
|
)
|
||||||
text = (
|
text = extract_text_content(response.content).strip()
|
||||||
response.content
|
|
||||||
if isinstance(response.content, str)
|
|
||||||
else str(response.content)
|
|
||||||
).strip()
|
|
||||||
|
|
||||||
if text == "NO_UPDATE" or not text:
|
if text == "NO_UPDATE" or not text:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import SearchSpace, User, async_session_maker
|
from app.db import SearchSpace, User, async_session_maker
|
||||||
|
from app.utils.content_utils import extract_text_content
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -188,12 +189,11 @@ async def _forced_rewrite(content: str, llm: Any) -> str | None:
|
||||||
[HumanMessage(content=prompt)],
|
[HumanMessage(content=prompt)],
|
||||||
config={"tags": ["surfsense:internal"]},
|
config={"tags": ["surfsense:internal"]},
|
||||||
)
|
)
|
||||||
text = (
|
text = extract_text_content(response.content).strip()
|
||||||
response.content
|
if not text:
|
||||||
if isinstance(response.content, str)
|
logger.warning("Forced rewrite returned empty text; aborting rewrite")
|
||||||
else str(response.content)
|
return None
|
||||||
)
|
return text
|
||||||
return text.strip()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Forced rewrite LLM call failed")
|
logger.exception("Forced rewrite LLM call failed")
|
||||||
return None
|
return None
|
||||||
|
|
@ -235,6 +235,16 @@ async def _save_memory(
|
||||||
label : str
|
label : str
|
||||||
Human label for log messages (e.g. "user memory", "team memory").
|
Human label for log messages (e.g. "user memory", "team memory").
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(updated_memory, str):
|
||||||
|
logger.warning(
|
||||||
|
"Refusing non-string memory payload (type=%s)",
|
||||||
|
type(updated_memory).__name__,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Internal error: memory payload must be a string.",
|
||||||
|
}
|
||||||
|
|
||||||
content = updated_memory
|
content = updated_memory
|
||||||
|
|
||||||
# --- forced rewrite if over the hard limit ---
|
# --- forced rewrite if over the hard limit ---
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from app.agents.new_chat.llm_config import (
|
||||||
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
|
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
|
||||||
from app.db import User, get_async_session
|
from app.db import User, get_async_session
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
|
from app.utils.content_utils import extract_text_content
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -123,11 +124,7 @@ async def edit_user_memory(
|
||||||
[HumanMessage(content=prompt)],
|
[HumanMessage(content=prompt)],
|
||||||
config={"tags": ["surfsense:internal", "memory-edit"]},
|
config={"tags": ["surfsense:internal", "memory-edit"]},
|
||||||
)
|
)
|
||||||
updated = (
|
updated = extract_text_content(response.content).strip()
|
||||||
response.content
|
|
||||||
if isinstance(response.content, str)
|
|
||||||
else str(response.content)
|
|
||||||
).strip()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Memory edit LLM call failed: %s", e)
|
logger.exception("Memory edit LLM call failed: %s", e)
|
||||||
raise HTTPException(status_code=500, detail="Memory edit failed.") from e
|
raise HTTPException(status_code=500, detail="Memory edit failed.") from e
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ from app.schemas import (
|
||||||
SearchSpaceWithStats,
|
SearchSpaceWithStats,
|
||||||
)
|
)
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
|
from app.utils.content_utils import extract_text_content
|
||||||
from app.utils.rbac import check_permission, check_search_space_access
|
from app.utils.rbac import check_permission, check_search_space_access
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -356,11 +357,7 @@ async def edit_team_memory(
|
||||||
[HumanMessage(content=prompt)],
|
[HumanMessage(content=prompt)],
|
||||||
config={"tags": ["surfsense:internal", "memory-edit"]},
|
config={"tags": ["surfsense:internal", "memory-edit"]},
|
||||||
)
|
)
|
||||||
updated = (
|
updated = extract_text_content(response.content).strip()
|
||||||
response.content
|
|
||||||
if isinstance(response.content, str)
|
|
||||||
else str(response.content)
|
|
||||||
).strip()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Team memory edit LLM call failed: %s", e)
|
logger.exception("Team memory edit LLM call failed: %s", e)
|
||||||
raise HTTPException(status_code=500, detail="Team memory edit failed.") from e
|
raise HTTPException(status_code=500, detail="Team memory edit failed.") from e
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
"""Unit tests for extracting text from LLM memory responses."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.update_memory import _save_memory
|
||||||
|
from app.utils.content_utils import extract_text_content
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _Recorder:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.applied_content: str | None = None
|
||||||
|
self.commit_calls = 0
|
||||||
|
self.rollback_calls = 0
|
||||||
|
|
||||||
|
def apply(self, content: str) -> None:
|
||||||
|
self.applied_content = content
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
self.commit_calls += 1
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
self.rollback_calls += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_text_content_keeps_no_update_bare_string_from_content_blocks() -> None:
|
||||||
|
content = [
|
||||||
|
{"type": "thinking", "thinking": "No"},
|
||||||
|
{"type": "thinking", "thinking": " memorizable info."},
|
||||||
|
"NO_UPDATE",
|
||||||
|
]
|
||||||
|
|
||||||
|
assert extract_text_content(content).strip() == "NO_UPDATE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_text_content_ignores_thinking_blocks_and_keeps_markdown_text() -> None:
|
||||||
|
markdown = (
|
||||||
|
"## Work Context\n"
|
||||||
|
"- (2026-05-02) [fact] Anish is hardening SurfSense memory extraction.\n"
|
||||||
|
)
|
||||||
|
content = [
|
||||||
|
{"type": "thinking", "thinking": "This is durable context."},
|
||||||
|
{"type": "text", "text": markdown},
|
||||||
|
]
|
||||||
|
|
||||||
|
assert extract_text_content(content).strip() == markdown.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_text_content_returns_empty_when_only_thinking_blocks_are_present() -> None:
|
||||||
|
content = [
|
||||||
|
{"type": "thinking", "thinking": "No durable fact."},
|
||||||
|
{"type": "thinking", "thinking": "Return no update."},
|
||||||
|
]
|
||||||
|
|
||||||
|
assert extract_text_content(content) == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_text_content_preserves_plain_string_responses() -> None:
|
||||||
|
markdown = (
|
||||||
|
"## Preferences\n"
|
||||||
|
"- (2026-05-02) [pref] Anish prefers no regex for memory validation.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert extract_text_content(markdown) == markdown
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_rejects_non_string_payload_before_commit() -> None:
|
||||||
|
recorder = _Recorder()
|
||||||
|
|
||||||
|
result = await _save_memory(
|
||||||
|
updated_memory=["NO_UPDATE"], # type: ignore[arg-type]
|
||||||
|
old_memory=None,
|
||||||
|
llm=None,
|
||||||
|
apply_fn=recorder.apply,
|
||||||
|
commit_fn=recorder.commit,
|
||||||
|
rollback_fn=recorder.rollback,
|
||||||
|
label="memory",
|
||||||
|
scope="user",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "must be a string" in result["message"]
|
||||||
|
assert recorder.applied_content is None
|
||||||
|
assert recorder.commit_calls == 0
|
||||||
|
assert recorder.rollback_calls == 0
|
||||||
Loading…
Add table
Add a link
Reference in a new issue