mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 17:26:23 +02:00
chore: ran linting
This commit is contained in:
parent
b8091114b5
commit
f38ea77940
14 changed files with 137 additions and 111 deletions
|
|
@ -228,7 +228,13 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
llm=deps.get("llm"),
|
||||
)
|
||||
),
|
||||
requires=["user_id", "search_space_id", "db_session", "thread_visibility", "llm"],
|
||||
requires=[
|
||||
"user_id",
|
||||
"search_space_id",
|
||||
"db_session",
|
||||
"thread_visibility",
|
||||
"llm",
|
||||
],
|
||||
),
|
||||
# =========================================================================
|
||||
# LINEAR TOOLS - create, update, delete issues
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ _SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
|||
# Pinned-section helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_pinned_headings(memory: str) -> set[str]:
|
||||
"""Return the set of ``## …`` headings that contain ``(pinned)``."""
|
||||
return set(_PINNED_RE.findall(memory))
|
||||
|
|
@ -59,9 +60,7 @@ def _extract_section_map(memory: str) -> dict[str, str]:
|
|||
return sections
|
||||
|
||||
|
||||
def _validate_pinned_preserved(
|
||||
old_memory: str | None, new_memory: str
|
||||
) -> str | None:
|
||||
def _validate_pinned_preserved(old_memory: str | None, new_memory: str) -> str | None:
|
||||
"""Return an error message if pinned headings from *old_memory* are missing
|
||||
in *new_memory*, else ``None``."""
|
||||
if not old_memory:
|
||||
|
|
@ -81,9 +80,7 @@ def _validate_pinned_preserved(
|
|||
return None
|
||||
|
||||
|
||||
def _restore_missing_pinned(
|
||||
old_memory: str, consolidated: str
|
||||
) -> str:
|
||||
def _restore_missing_pinned(old_memory: str, consolidated: str) -> str:
|
||||
"""Prepend any pinned sections from *old_memory* that are absent in
|
||||
*consolidated*."""
|
||||
old_pinned = _extract_pinned_headings(old_memory)
|
||||
|
|
@ -109,14 +106,13 @@ def _restore_missing_pinned(
|
|||
# Diff validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_headings(memory: str) -> set[str]:
|
||||
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
|
||||
return set(_SECTION_HEADING_RE.findall(memory))
|
||||
|
||||
|
||||
def _validate_diff(
|
||||
old_memory: str | None, new_memory: str
|
||||
) -> list[str]:
|
||||
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||
"""Return a list of warning strings about suspicious changes."""
|
||||
if not old_memory:
|
||||
return []
|
||||
|
|
@ -146,6 +142,7 @@ def _validate_diff(
|
|||
# Size validation & soft warning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _validate_memory_size(content: str) -> dict[str, Any] | None:
|
||||
"""Return an error/warning dict if *content* is too large, else None."""
|
||||
length = len(content)
|
||||
|
|
@ -199,17 +196,13 @@ RULES:
|
|||
</memory_document>"""
|
||||
|
||||
|
||||
async def _auto_consolidate(
|
||||
content: str, llm: Any
|
||||
) -> str | None:
|
||||
async def _auto_consolidate(content: str, llm: Any) -> str | None:
|
||||
"""Use a focused LLM call to consolidate *content* under the soft limit.
|
||||
|
||||
Returns the consolidated string, or ``None`` if consolidation fails.
|
||||
"""
|
||||
try:
|
||||
prompt = _CONSOLIDATION_PROMPT.format(
|
||||
target=MEMORY_SOFT_LIMIT, content=content
|
||||
)
|
||||
prompt = _CONSOLIDATION_PROMPT.format(target=MEMORY_SOFT_LIMIT, content=content)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
|
|
@ -229,6 +222,7 @@ async def _auto_consolidate(
|
|||
# Shared save-and-respond logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _save_memory(
|
||||
*,
|
||||
updated_memory: str,
|
||||
|
|
@ -295,12 +289,13 @@ async def _save_memory(
|
|||
return {"status": "error", "message": f"Failed to update {label}: {e}"}
|
||||
|
||||
# --- build response ---
|
||||
resp: dict[str, Any] = {"status": "saved", "message": f"{label.capitalize()} updated."}
|
||||
resp: dict[str, Any] = {
|
||||
"status": "saved",
|
||||
"message": f"{label.capitalize()} updated.",
|
||||
}
|
||||
|
||||
if content is not updated_memory:
|
||||
resp["notice"] = (
|
||||
"Memory was automatically consolidated to fit within limits."
|
||||
)
|
||||
resp["notice"] = "Memory was automatically consolidated to fit within limits."
|
||||
|
||||
diff_warnings = _validate_diff(old_memory, content)
|
||||
if diff_warnings:
|
||||
|
|
@ -317,6 +312,7 @@ async def _save_memory(
|
|||
# Tool factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_update_memory_tool(
|
||||
user_id: str | UUID,
|
||||
db_session: AsyncSession,
|
||||
|
|
@ -338,9 +334,7 @@ def create_update_memory_tool(
|
|||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
"""
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(User).where(User.id == uid)
|
||||
)
|
||||
result = await db_session.execute(select(User).where(User.id == uid))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return {"status": "error", "message": "User not found."}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue