chore: ran linting

This commit is contained in:
Anish Sarkar 2026-04-09 18:10:34 +05:30
parent b8091114b5
commit f38ea77940
14 changed files with 137 additions and 111 deletions

View file

@ -106,9 +106,7 @@ async def _call_extraction_llm(
config={"tags": ["surfsense:internal", "memory-extraction"]},
)
text = (
response.content
if isinstance(response.content, str)
else str(response.content)
response.content if isinstance(response.content, str) else str(response.content)
).strip()
if text == "NO_UPDATE" or not text:
@ -155,9 +153,7 @@ async def _extract_user_memory(
uid = UUID(user_id) if isinstance(user_id, str) else user_id
async with shielded_async_session() as session:
result = await session.execute(
select(User).where(User.id == uid)
)
result = await session.execute(select(User).where(User.id == uid))
user = result.scalars().first()
if not user:
return

View file

@ -91,9 +91,7 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
return {"messages": new_messages}
async def _load_user_memory(
self, session: AsyncSession
) -> tuple[str | None, bool]:
async def _load_user_memory(self, session: AsyncSession) -> tuple[str | None, bool]:
"""Return (memory_content, is_persisted).
When the user has no saved memory but has a display name, a seed
@ -102,9 +100,7 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""
try:
result = await session.execute(
select(User.memory_md, User.display_name).where(
User.id == self.user_id
)
select(User.memory_md, User.display_name).where(User.id == self.user_id)
)
row = result.one_or_none()
if row is None:

View file

@ -228,7 +228,13 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
llm=deps.get("llm"),
)
),
requires=["user_id", "search_space_id", "db_session", "thread_visibility", "llm"],
requires=[
"user_id",
"search_space_id",
"db_session",
"thread_visibility",
"llm",
],
),
# =========================================================================
# LINEAR TOOLS - create, update, delete issues

View file

@ -42,6 +42,7 @@ _SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
# Pinned-section helpers
# ---------------------------------------------------------------------------
def _extract_pinned_headings(memory: str) -> set[str]:
"""Return the set of ``## …`` headings that contain ``(pinned)``."""
return set(_PINNED_RE.findall(memory))
@ -59,9 +60,7 @@ def _extract_section_map(memory: str) -> dict[str, str]:
return sections
def _validate_pinned_preserved(
old_memory: str | None, new_memory: str
) -> str | None:
def _validate_pinned_preserved(old_memory: str | None, new_memory: str) -> str | None:
"""Return an error message if pinned headings from *old_memory* are missing
in *new_memory*, else ``None``."""
if not old_memory:
@ -81,9 +80,7 @@ def _validate_pinned_preserved(
return None
def _restore_missing_pinned(
old_memory: str, consolidated: str
) -> str:
def _restore_missing_pinned(old_memory: str, consolidated: str) -> str:
"""Prepend any pinned sections from *old_memory* that are absent in
*consolidated*."""
old_pinned = _extract_pinned_headings(old_memory)
@ -109,14 +106,13 @@ def _restore_missing_pinned(
# Diff validation
# ---------------------------------------------------------------------------
def _extract_headings(memory: str) -> set[str]:
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
return set(_SECTION_HEADING_RE.findall(memory))
def _validate_diff(
old_memory: str | None, new_memory: str
) -> list[str]:
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
"""Return a list of warning strings about suspicious changes."""
if not old_memory:
return []
@ -146,6 +142,7 @@ def _validate_diff(
# Size validation & soft warning
# ---------------------------------------------------------------------------
def _validate_memory_size(content: str) -> dict[str, Any] | None:
"""Return an error/warning dict if *content* is too large, else None."""
length = len(content)
@ -199,17 +196,13 @@ RULES:
</memory_document>"""
async def _auto_consolidate(
content: str, llm: Any
) -> str | None:
async def _auto_consolidate(content: str, llm: Any) -> str | None:
"""Use a focused LLM call to consolidate *content* under the soft limit.
Returns the consolidated string, or ``None`` if consolidation fails.
"""
try:
prompt = _CONSOLIDATION_PROMPT.format(
target=MEMORY_SOFT_LIMIT, content=content
)
prompt = _CONSOLIDATION_PROMPT.format(target=MEMORY_SOFT_LIMIT, content=content)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
@ -229,6 +222,7 @@ async def _auto_consolidate(
# Shared save-and-respond logic
# ---------------------------------------------------------------------------
async def _save_memory(
*,
updated_memory: str,
@ -295,12 +289,13 @@ async def _save_memory(
return {"status": "error", "message": f"Failed to update {label}: {e}"}
# --- build response ---
resp: dict[str, Any] = {"status": "saved", "message": f"{label.capitalize()} updated."}
resp: dict[str, Any] = {
"status": "saved",
"message": f"{label.capitalize()} updated.",
}
if content is not updated_memory:
resp["notice"] = (
"Memory was automatically consolidated to fit within limits."
)
resp["notice"] = "Memory was automatically consolidated to fit within limits."
diff_warnings = _validate_diff(old_memory, content)
if diff_warnings:
@ -317,6 +312,7 @@ async def _save_memory(
# Tool factories
# ---------------------------------------------------------------------------
def create_update_memory_tool(
user_id: str | UUID,
db_session: AsyncSession,
@ -338,9 +334,7 @@ def create_update_memory_tool(
updated_memory: The FULL updated markdown document (not a diff).
"""
try:
result = await db_session.execute(
select(User).where(User.id == uid)
)
result = await db_session.execute(select(User).where(User.id == uid))
user = result.scalars().first()
if not user:
return {"status": "error", "message": "User not found."}