From 939bfb2c1821e86d21260b6680e0317967b8aa03 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 24 Jun 2026 22:07:54 +0200 Subject: [PATCH] references: discriminated per-kind reference types --- .../chat/runtime/references/__init__.py | 19 ++++-- .../agents/chat/runtime/references/chats.py | 11 +--- .../chat/runtime/references/connectors.py | 9 ++- .../chat/runtime/references/documents.py | 9 ++- .../agents/chat/runtime/references/folders.py | 9 ++- .../agents/chat/runtime/references/models.py | 58 ++++++++++++++++--- .../runtime/references/reference_pointers.py | 30 ++++++---- .../references/test_reference_pointers.py | 49 ++++------------ 8 files changed, 112 insertions(+), 82 deletions(-) diff --git a/surfsense_backend/app/agents/chat/runtime/references/__init__.py b/surfsense_backend/app/agents/chat/runtime/references/__init__.py index ad8b7cc0e..dfbb22da3 100644 --- a/surfsense_backend/app/agents/chat/runtime/references/__init__.py +++ b/surfsense_backend/app/agents/chat/runtime/references/__init__.py @@ -15,7 +15,14 @@ from .chats import resolve_chat_references from .connectors import resolve_connector_references from .documents import resolve_document_references from .folders import resolve_folder_references -from .models import ReferenceKind, ResolvedReference +from .models import ( + ChatReference, + ConnectorReference, + DocumentReference, + FolderReference, + Reference, + ReferenceKind, +) from .reference_pointers import render_reference_pointers @@ -30,13 +37,13 @@ async def resolve_references( connector_ids: list[int] | None = None, connector_chips: list[MentionedDocumentInfo] | None = None, thread_ids: list[int] | None = None, -) -> list[ResolvedReference]: +) -> list[Reference]: """Resolve a turn's ``@``-references into one ordered pointer list. Order is documents, folders, connectors, chats. The path index is built once and shared by the document and folder resolvers. """ - references: list[ResolvedReference] = [] + references: list[Reference] = [] if document_ids or folder_ids: index = await build_path_index(session, search_space_id) @@ -76,8 +83,12 @@ async def resolve_references( __all__ = [ + "ChatReference", + "ConnectorReference", + "DocumentReference", + "FolderReference", + "Reference", "ReferenceKind", - "ResolvedReference", "render_reference_pointers", "resolve_references", ] diff --git a/surfsense_backend/app/agents/chat/runtime/references/chats.py b/surfsense_backend/app/agents/chat/runtime/references/chats.py index d19e7d4a1..be9d1025c 100644 --- a/surfsense_backend/app/agents/chat/runtime/references/chats.py +++ b/surfsense_backend/app/agents/chat/runtime/references/chats.py @@ -14,7 +14,7 @@ from app.agents.chat.runtime.referenced_chat_context.resolver import ( resolve_referenced_chats, ) -from .models import ReferenceKind, ResolvedReference +from .models import ChatReference async def resolve_chat_references( @@ -24,7 +24,7 @@ async def resolve_chat_references( requesting_user_id: str | None, current_chat_id: int, thread_ids: list[int], -) -> list[ResolvedReference]: +) -> list[ChatReference]: """Map ``@chat`` thread ids to access-checked pointers (titles only).""" if not thread_ids: return [] @@ -37,12 +37,7 @@ async def resolve_chat_references( mentioned_thread_ids=thread_ids, ) return [ - ResolvedReference( - kind=ReferenceKind.CHAT, - entity_id=chat.thread_id, - label=chat.title, - ) - for chat in chats + ChatReference(entity_id=chat.thread_id, label=chat.title) for chat in chats ] diff --git a/surfsense_backend/app/agents/chat/runtime/references/connectors.py b/surfsense_backend/app/agents/chat/runtime/references/connectors.py index efa13dcf1..8d5f36133 100644 --- a/surfsense_backend/app/agents/chat/runtime/references/connectors.py +++ b/surfsense_backend/app/agents/chat/runtime/references/connectors.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.db import SearchSourceConnector from app.schemas.new_chat import MentionedDocumentInfo -from .models import ReferenceKind, ResolvedReference +from .models import ConnectorReference def connector_pointer_fields( @@ -32,7 +32,7 @@ async def resolve_connector_references( search_space_id: int, connector_ids: list[int], chips: list[MentionedDocumentInfo] | None = None, -) -> list[ResolvedReference]: +) -> list[ConnectorReference]: """Map ``@connector`` ids to references; ids outside the space are dropped. The DB check only confirms the connector belongs to this search space; @@ -57,7 +57,7 @@ async def resolve_connector_references( chip.id: chip for chip in (chips or []) if chip.kind == "connector" } - references: list[ResolvedReference] = [] + references: list[ConnectorReference] = [] for connector_id in dict.fromkeys(connector_ids): row = accessible.get(connector_id) if row is None: @@ -71,8 +71,7 @@ async def resolve_connector_references( fallback_name=str(row.name or ""), ) references.append( - ResolvedReference( - kind=ReferenceKind.CONNECTOR, + ConnectorReference( entity_id=connector_id, label=label, provider=provider, diff --git a/surfsense_backend/app/agents/chat/runtime/references/documents.py b/surfsense_backend/app/agents/chat/runtime/references/documents.py index 03765b086..b2a3b1fe4 100644 --- a/surfsense_backend/app/agents/chat/runtime/references/documents.py +++ b/surfsense_backend/app/agents/chat/runtime/references/documents.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat.runtime.path_resolver import PathIndex, doc_to_virtual_path from app.db import Document -from .models import ReferenceKind, ResolvedReference +from .models import DocumentReference async def resolve_document_references( @@ -17,7 +17,7 @@ async def resolve_document_references( search_space_id: int, document_ids: list[int], index: PathIndex, -) -> list[ResolvedReference]: +) -> list[DocumentReference]: """Map document ids to references in input order; unknown ids are dropped. Best-effort and fail-closed: an id outside ``search_space_id`` (deleted or @@ -34,15 +34,14 @@ async def resolve_document_references( ) documents_by_id = {row.id: row for row in rows.scalars().all()} - references: list[ResolvedReference] = [] + references: list[DocumentReference] = [] for document_id in dict.fromkeys(document_ids): document = documents_by_id.get(document_id) if document is None: continue title = str(document.title or "untitled") references.append( - ResolvedReference( - kind=ReferenceKind.DOCUMENT, + DocumentReference( entity_id=document.id, label=title, path=doc_to_virtual_path( diff --git a/surfsense_backend/app/agents/chat/runtime/references/folders.py b/surfsense_backend/app/agents/chat/runtime/references/folders.py index 475f52d56..df0ec457b 100644 --- a/surfsense_backend/app/agents/chat/runtime/references/folders.py +++ b/surfsense_backend/app/agents/chat/runtime/references/folders.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat.runtime.path_resolver import DOCUMENTS_ROOT, PathIndex from app.db import Folder -from .models import ReferenceKind, ResolvedReference +from .models import FolderReference def folder_pointer_path(folder_id: int, folder_paths: dict[int, str]) -> str: @@ -23,7 +23,7 @@ async def resolve_folder_references( search_space_id: int, folder_ids: list[int], index: PathIndex, -) -> list[ResolvedReference]: +) -> list[FolderReference]: """Map folder ids to references in input order; unknown ids are dropped.""" if not folder_ids: return [] @@ -36,14 +36,13 @@ async def resolve_folder_references( ) folders_by_id = {row.id: row for row in rows.scalars().all()} - references: list[ResolvedReference] = [] + references: list[FolderReference] = [] for folder_id in dict.fromkeys(folder_ids): folder = folders_by_id.get(folder_id) if folder is None: continue references.append( - ResolvedReference( - kind=ReferenceKind.FOLDER, + FolderReference( entity_id=folder.id, label=str(folder.name or "untitled"), path=folder_pointer_path(folder.id, index.folder_paths), diff --git a/surfsense_backend/app/agents/chat/runtime/references/models.py b/surfsense_backend/app/agents/chat/runtime/references/models.py index c61198eca..8ae151772 100644 --- a/surfsense_backend/app/agents/chat/runtime/references/models.py +++ b/surfsense_backend/app/agents/chat/runtime/references/models.py @@ -1,9 +1,15 @@ -"""Data shapes for a resolved ``@``-reference.""" +"""Data shapes for resolved ``@``-references. + +One type per kind so each carries exactly the fields it needs: documents and +folders have a path, connectors have a provider, chats have neither. ``kind`` is +a class-level discriminator used by the renderer and scope builder. +""" from __future__ import annotations from dataclasses import dataclass from enum import Enum +from typing import ClassVar class ReferenceKind(str, Enum): @@ -16,14 +22,52 @@ class ReferenceKind(str, Enum): @dataclass(frozen=True) -class ResolvedReference: - """A resolved reference: identity plus the bits a pointer line needs.""" +class _Reference: + """Identity shared by every reference kind.""" - kind: ReferenceKind entity_id: int label: str - path: str | None = None # document/folder virtual path - provider: str | None = None # connector provider, e.g. "Gmail" -__all__ = ["ReferenceKind", "ResolvedReference"] +@dataclass(frozen=True) +class DocumentReference(_Reference): + """A referenced document, reachable by its virtual path.""" + + path: str + kind: ClassVar[ReferenceKind] = ReferenceKind.DOCUMENT + + +@dataclass(frozen=True) +class FolderReference(_Reference): + """A referenced folder, reachable by its virtual path.""" + + path: str + kind: ClassVar[ReferenceKind] = ReferenceKind.FOLDER + + +@dataclass(frozen=True) +class ConnectorReference(_Reference): + """A referenced connector account; ``provider`` is its type label.""" + + provider: str | None = None + kind: ClassVar[ReferenceKind] = ReferenceKind.CONNECTOR + + +@dataclass(frozen=True) +class ChatReference(_Reference): + """A referenced chat thread; its turns are read on demand, not here.""" + + kind: ClassVar[ReferenceKind] = ReferenceKind.CHAT + + +Reference = DocumentReference | FolderReference | ConnectorReference | ChatReference + + +__all__ = [ + "ChatReference", + "ConnectorReference", + "DocumentReference", + "FolderReference", + "Reference", + "ReferenceKind", +] diff --git a/surfsense_backend/app/agents/chat/runtime/references/reference_pointers.py b/surfsense_backend/app/agents/chat/runtime/references/reference_pointers.py index ce7966275..894d844b1 100644 --- a/surfsense_backend/app/agents/chat/runtime/references/reference_pointers.py +++ b/surfsense_backend/app/agents/chat/runtime/references/reference_pointers.py @@ -7,7 +7,13 @@ retrieve from. Actual content is pulled later via tools, never injected here. from __future__ import annotations -from .models import ReferenceKind, ResolvedReference +from .models import ( + ChatReference, + ConnectorReference, + DocumentReference, + FolderReference, + Reference, +) _HEADER = ( "The user pointed at these with @ this turn. They are scope, not content " @@ -15,7 +21,7 @@ _HEADER = ( ) -def render_reference_pointers(references: list[ResolvedReference]) -> str | None: +def render_reference_pointers(references: list[Reference]) -> str | None: """Render references as one read-only pointer block. Returns ``None`` when there is nothing to render so callers can skip the @@ -33,21 +39,23 @@ def render_reference_pointers(references: list[ResolvedReference]) -> str | None ) -def _render_pointer(reference: ResolvedReference) -> str: +def _render_pointer(reference: Reference) -> str: """One ``- {kind} {id} — {handle}`` line, shaped per kind.""" head = f"- {reference.kind.value} {reference.entity_id} — " return head + _handle(reference) -def _handle(reference: ResolvedReference) -> str: - """The human-reachable handle: connector provider, a path, or a title.""" +def _handle(reference: Reference) -> str: + """The human-reachable handle: a path, a connector provider, or a title.""" label = _clean(reference.label) - if reference.kind is ReferenceKind.CONNECTOR: - provider = _clean(reference.provider) if reference.provider else "" - return f"{provider} ({label})" if provider else label - if reference.path: - return f'"{label}" ({reference.path})' - return f'"{label}"' + match reference: + case DocumentReference() | FolderReference(): + return f'"{label}" ({reference.path})' + case ConnectorReference(): + provider = _clean(reference.provider) if reference.provider else "" + return f"{provider} ({label})" if provider else label + case ChatReference(): + return f'"{label}"' def _clean(text: str) -> str: diff --git a/surfsense_backend/tests/unit/agents/chat/runtime/references/test_reference_pointers.py b/surfsense_backend/tests/unit/agents/chat/runtime/references/test_reference_pointers.py index 5cbdd5f88..4ac23b616 100644 --- a/surfsense_backend/tests/unit/agents/chat/runtime/references/test_reference_pointers.py +++ b/surfsense_backend/tests/unit/agents/chat/runtime/references/test_reference_pointers.py @@ -5,8 +5,10 @@ from __future__ import annotations import pytest from app.agents.chat.runtime.references import ( - ReferenceKind, - ResolvedReference, + ChatReference, + ConnectorReference, + DocumentReference, + FolderReference, render_reference_pointers, ) @@ -20,10 +22,8 @@ def test_returns_none_when_no_references() -> None: def test_wraps_block_and_keeps_reference_order() -> None: block = render_reference_pointers( [ - ResolvedReference( - kind=ReferenceKind.DOCUMENT, entity_id=42, label="Q3 Notes" - ), - ResolvedReference(kind=ReferenceKind.CHAT, entity_id=5, label="Pricing"), + DocumentReference(entity_id=42, label="Q3 Notes", path="/documents/q3.xml"), + ChatReference(entity_id=5, label="Pricing"), ] ) @@ -36,8 +36,7 @@ def test_wraps_block_and_keeps_reference_order() -> None: def test_document_with_path_shows_title_and_path() -> None: block = render_reference_pointers( [ - ResolvedReference( - kind=ReferenceKind.DOCUMENT, + DocumentReference( entity_id=42, label="Q3 Launch Notes", path="/documents/Launch/Q3.xml", @@ -51,14 +50,7 @@ def test_document_with_path_shows_title_and_path() -> None: def test_folder_with_path_renders_with_folder_kind() -> None: block = render_reference_pointers( - [ - ResolvedReference( - kind=ReferenceKind.FOLDER, - entity_id=7, - label="Specs", - path="/documents/Specs/", - ) - ] + [FolderReference(entity_id=7, label="Specs", path="/documents/Specs/")] ) assert block is not None @@ -67,14 +59,7 @@ def test_folder_with_path_renders_with_folder_kind() -> None: def test_connector_shows_provider_and_account() -> None: block = render_reference_pointers( - [ - ResolvedReference( - kind=ReferenceKind.CONNECTOR, - entity_id=12, - label="work@acme.com", - provider="Gmail", - ) - ] + [ConnectorReference(entity_id=12, label="work@acme.com", provider="Gmail")] ) assert block is not None @@ -83,11 +68,7 @@ def test_connector_shows_provider_and_account() -> None: def test_connector_without_provider_falls_back_to_label() -> None: block = render_reference_pointers( - [ - ResolvedReference( - kind=ReferenceKind.CONNECTOR, entity_id=12, label="work@acme.com" - ) - ] + [ConnectorReference(entity_id=12, label="work@acme.com")] ) assert block is not None @@ -96,7 +77,7 @@ def test_connector_without_provider_falls_back_to_label() -> None: def test_chat_shows_quoted_title() -> None: block = render_reference_pointers( - [ResolvedReference(kind=ReferenceKind.CHAT, entity_id=5, label="Pricing debate")] + [ChatReference(entity_id=5, label="Pricing debate")] ) assert block is not None @@ -105,13 +86,7 @@ def test_chat_shows_quoted_title() -> None: def test_label_whitespace_is_collapsed_to_one_line() -> None: block = render_reference_pointers( - [ - ResolvedReference( - kind=ReferenceKind.DOCUMENT, - entity_id=1, - label="line one\nline two", - ) - ] + [DocumentReference(entity_id=1, label="line one\nline two", path="/d.xml")] ) assert block is not None