refactor(memory): streamline memory extraction by utilizing extract_text_content utility

This commit is contained in:
Anish Sarkar 2026-05-02 16:10:30 +05:30
parent 451a98936e
commit 9975e085aa
5 changed files with 106 additions and 25 deletions

View file

@ -16,6 +16,7 @@ from sqlalchemy import select
from app.agents.new_chat.tools.update_memory import _save_memory
from app.db import SearchSpace, User, shielded_async_session
from app.utils.content_utils import extract_text_content
logger = logging.getLogger(__name__)
@ -144,11 +145,7 @@ async def extract_and_save_memory(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-extraction"]},
)
text = (
response.content
if isinstance(response.content, str)
else str(response.content)
).strip()
text = extract_text_content(response.content).strip()
if text == "NO_UPDATE" or not text:
logger.debug("Memory extraction: no update needed (user %s)", uid)
@ -207,11 +204,7 @@ async def extract_and_save_team_memory(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
)
text = (
response.content
if isinstance(response.content, str)
else str(response.content)
).strip()
text = extract_text_content(response.content).strip()
if text == "NO_UPDATE" or not text:
logger.debug(

View file

@ -27,6 +27,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User
from app.utils.content_utils import extract_text_content
logger = logging.getLogger(__name__)
@ -188,11 +189,7 @@ async def _forced_rewrite(content: str, llm: Any) -> str | None:
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
text = (
response.content
if isinstance(response.content, str)
else str(response.content)
)
text = extract_text_content(response.content)
return text.strip()
except Exception:
logger.exception("Forced rewrite LLM call failed")
@ -235,6 +232,16 @@ async def _save_memory(
label : str
Human label for log messages (e.g. "user memory", "team memory").
"""
if not isinstance(updated_memory, str):
logger.warning(
"Refusing non-string memory payload (type=%s)",
type(updated_memory).__name__,
)
return {
"status": "error",
"message": "Internal error: memory payload must be a string.",
}
content = updated_memory
# --- forced rewrite if over the hard limit ---

View file

@ -16,6 +16,7 @@ from app.agents.new_chat.llm_config import (
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
from app.db import User, get_async_session
from app.users import current_active_user
from app.utils.content_utils import extract_text_content
logger = logging.getLogger(__name__)
@ -123,11 +124,7 @@ async def edit_user_memory(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-edit"]},
)
updated = (
response.content
if isinstance(response.content, str)
else str(response.content)
).strip()
updated = extract_text_content(response.content).strip()
except Exception as e:
logger.exception("Memory edit LLM call failed: %s", e)
raise HTTPException(status_code=500, detail="Memory edit failed.") from e

View file

@ -35,6 +35,7 @@ from app.schemas import (
SearchSpaceWithStats,
)
from app.users import current_active_user
from app.utils.content_utils import extract_text_content
from app.utils.rbac import check_permission, check_search_space_access
logger = logging.getLogger(__name__)
@ -356,11 +357,7 @@ async def edit_team_memory(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-edit"]},
)
updated = (
response.content
if isinstance(response.content, str)
else str(response.content)
).strip()
updated = extract_text_content(response.content).strip()
except Exception as e:
logger.exception("Team memory edit LLM call failed: %s", e)
raise HTTPException(status_code=500, detail="Team memory edit failed.") from e

View file

@ -0,0 +1,87 @@
"""Unit tests for extracting text from LLM memory responses."""
import pytest
from app.agents.new_chat.tools.update_memory import _save_memory
from app.utils.content_utils import extract_text_content
pytestmark = pytest.mark.unit
class _Recorder:
def __init__(self) -> None:
self.applied_content: str | None = None
self.commit_calls = 0
self.rollback_calls = 0
def apply(self, content: str) -> None:
self.applied_content = content
async def commit(self) -> None:
self.commit_calls += 1
async def rollback(self) -> None:
self.rollback_calls += 1
def test_extract_text_content_keeps_no_update_bare_string_from_content_blocks() -> None:
content = [
{"type": "thinking", "thinking": "No"},
{"type": "thinking", "thinking": " memorizable info."},
"NO_UPDATE",
]
assert extract_text_content(content).strip() == "NO_UPDATE"
def test_extract_text_content_ignores_thinking_blocks_and_keeps_markdown_text() -> None:
markdown = (
"## Work Context\n"
"- (2026-05-02) [fact] Anish is hardening SurfSense memory extraction.\n"
)
content = [
{"type": "thinking", "thinking": "This is durable context."},
{"type": "text", "text": markdown},
]
assert extract_text_content(content).strip() == markdown.strip()
def test_extract_text_content_returns_empty_when_only_thinking_blocks_are_present() -> None:
content = [
{"type": "thinking", "thinking": "No durable fact."},
{"type": "thinking", "thinking": "Return no update."},
]
assert extract_text_content(content) == ""
def test_extract_text_content_preserves_plain_string_responses() -> None:
markdown = (
"## Preferences\n"
"- (2026-05-02) [pref] Anish prefers no regex for memory validation.\n"
)
assert extract_text_content(markdown) == markdown
@pytest.mark.asyncio
async def test_save_memory_rejects_non_string_payload_before_commit() -> None:
recorder = _Recorder()
result = await _save_memory(
updated_memory=["NO_UPDATE"], # type: ignore[arg-type]
old_memory=None,
llm=None,
apply_fn=recorder.apply,
commit_fn=recorder.commit,
rollback_fn=recorder.rollback,
label="memory",
scope="user",
)
assert result["status"] == "error"
assert "must be a string" in result["message"]
assert recorder.applied_content is None
assert recorder.commit_calls == 0
assert recorder.rollback_calls == 0