mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-24 21:38:09 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/api-key
This commit is contained in:
commit
3695e1d5c5
64 changed files with 1043 additions and 1852 deletions
|
|
@ -1,87 +0,0 @@
|
|||
"""Unit tests for search_knowledge_base hit rendering.
|
||||
|
||||
The tool must surface the passage that actually matched (the RRF-ranked
|
||||
chunk), not the top of the document, and annotate it with its line range
|
||||
when the chunk carries a char span.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.main_agent.tools.search_knowledge_base import (
|
||||
_format_hits,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_BODY = "Intro paragraph.\n\nMatched passage here.\n\nClosing paragraph."
|
||||
|
||||
|
||||
def _hit() -> dict:
|
||||
intro = "Intro paragraph."
|
||||
matched = "Matched passage here."
|
||||
matched_start = _BODY.index(matched)
|
||||
return {
|
||||
"document": {"id": 7, "title": "note.md", "document_type": "NOTE"},
|
||||
"score": 0.42,
|
||||
"content": _BODY.replace("\n\n", "\n\n"),
|
||||
"matched_chunk_ids": [102],
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_id": 101,
|
||||
"content": intro,
|
||||
"start_char": 0,
|
||||
"end_char": len(intro),
|
||||
},
|
||||
{
|
||||
"chunk_id": 102,
|
||||
"content": matched,
|
||||
"start_char": matched_start,
|
||||
"end_char": matched_start + len(matched),
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_renders_matched_passage_not_top_of_doc() -> None:
|
||||
out = _format_hits([_hit()], paths={7: "/documents/note.md"}, bodies={7: _BODY}, query="q")
|
||||
assert "Matched passage here." in out
|
||||
# The intro chunk was not matched, so it must not be shown as the snippet.
|
||||
assert "Intro paragraph." not in out
|
||||
|
||||
|
||||
def test_emits_copyable_line_citation_token_when_spans_present() -> None:
|
||||
out = _format_hits([_hit()], paths={7: "/documents/note.md"}, bodies={7: _BODY}, query="q")
|
||||
# "Matched passage here." sits on line 3 of the body; the hit must surface
|
||||
# a ready-to-copy token so the agent can cite without a separate read.
|
||||
assert "[citation:d7#L3-3]" in out
|
||||
|
||||
|
||||
def test_header_includes_document_id() -> None:
|
||||
out = _format_hits([_hit()], paths={7: "/documents/note.md"}, bodies={7: _BODY}, query="q")
|
||||
assert "id=7" in out
|
||||
|
||||
|
||||
def test_omits_citation_token_when_spans_absent() -> None:
|
||||
hit = _hit()
|
||||
for chunk in hit["chunks"]:
|
||||
chunk["start_char"] = None
|
||||
chunk["end_char"] = None
|
||||
out = _format_hits([hit], paths={7: "/documents/note.md"}, bodies={7: _BODY}, query="q")
|
||||
assert "Matched passage here." in out
|
||||
# No concrete, copyable token for this document without spans (the closing
|
||||
# instruction's placeholder template doesn't count).
|
||||
assert "[citation:d7#L" not in out
|
||||
|
||||
|
||||
def test_falls_back_to_content_when_no_matched_ids() -> None:
|
||||
hit = _hit()
|
||||
hit["matched_chunk_ids"] = []
|
||||
out = _format_hits([hit], paths={7: "/documents/note.md"}, bodies={7: _BODY}, query="q")
|
||||
assert "Intro paragraph." in out
|
||||
|
||||
|
||||
def test_no_results_message() -> None:
|
||||
out = _format_hits([], paths={}, bodies={}, query="missing")
|
||||
assert "No knowledge-base matches" in out
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
"""Span-aware chunking contract: slices form a lossless, contiguous partition
|
||||
of the markdown, and every slice's char span addresses its own text."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.indexing_pipeline.document_chunker import chunk_markdown_with_spans
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _assert_lossless_partition(md: str, slices) -> None:
|
||||
assert "".join(s.text for s in slices) == md
|
||||
|
||||
cursor = 0
|
||||
for s in slices:
|
||||
assert s.start_char == cursor, "slices must be contiguous"
|
||||
assert s.end_char >= s.start_char
|
||||
assert md[s.start_char : s.end_char] == s.text, "span must address slice text"
|
||||
cursor = s.end_char
|
||||
assert cursor == len(md)
|
||||
|
||||
|
||||
def test_prose_partition_and_spans():
|
||||
md = (
|
||||
"# Title\n\n"
|
||||
+ "First paragraph with several words here. " * 20
|
||||
+ "\n\nSecond section with more prose to force multiple chunks. " * 20
|
||||
)
|
||||
|
||||
slices = chunk_markdown_with_spans(md)
|
||||
|
||||
assert len(slices) > 1
|
||||
_assert_lossless_partition(md, slices)
|
||||
|
||||
|
||||
def test_table_kept_whole_with_exact_span():
|
||||
table = "| a | b |\n| - | - |\n| 1 | 2 |\n"
|
||||
md = f"Intro prose before the table.\n{table}\nClosing prose after."
|
||||
|
||||
slices = chunk_markdown_with_spans(md)
|
||||
|
||||
_assert_lossless_partition(md, slices)
|
||||
table_slices = [s for s in slices if s.text.lstrip().startswith("|")]
|
||||
assert any("| 1 | 2 |" in s.text for s in table_slices)
|
||||
for s in table_slices:
|
||||
assert "| a | b |" in s.text and "| 1 | 2 |" in s.text
|
||||
|
||||
|
||||
def test_table_at_eof_without_trailing_newline_stays_whole():
|
||||
md = "Intro.\n| a | b |\n| - | - |\n| 1 | 2 |"
|
||||
|
||||
slices = chunk_markdown_with_spans(md)
|
||||
|
||||
_assert_lossless_partition(md, slices)
|
||||
table_slices = [s for s in slices if "| 1 | 2 |" in s.text]
|
||||
assert len(table_slices) == 1
|
||||
assert "| a | b |" in table_slices[0].text
|
||||
|
||||
|
||||
def test_code_chunker_partition_and_spans():
|
||||
code = "\n\n".join(
|
||||
f"def func_{i}(x):\n total = x + {i}\n return total" for i in range(40)
|
||||
)
|
||||
|
||||
slices = chunk_markdown_with_spans(code, use_code_chunker=True)
|
||||
|
||||
assert len(slices) >= 1
|
||||
_assert_lossless_partition(code, slices)
|
||||
|
||||
|
||||
def test_empty_markdown_yields_no_slices():
|
||||
assert chunk_markdown_with_spans("") == []
|
||||
|
|
@ -37,9 +37,12 @@ def _make_orm_doc(connector_doc, doc_id):
|
|||
async def test_index_calls_embed_and_chunk_via_to_thread(
|
||||
pipeline, make_connector_document, monkeypatch
|
||||
):
|
||||
"""index() runs the chunker and embed_texts via asyncio.to_thread, not blocking the loop."""
|
||||
from app.indexing_pipeline.document_chunker import ChunkSlice
|
||||
"""index() runs the chunker and embed_texts via asyncio.to_thread, not blocking the loop.
|
||||
|
||||
Routing between ``chunk_text`` (code path) and ``chunk_text_hybrid`` (default
|
||||
path, see issue #1334) is verified separately in
|
||||
``test_non_code_documents_use_hybrid_chunker``.
|
||||
"""
|
||||
to_thread_calls = []
|
||||
original_to_thread = asyncio.to_thread
|
||||
|
||||
|
|
@ -48,11 +51,11 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
|
|||
return await original_to_thread(func, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(asyncio, "to_thread", tracking_to_thread)
|
||||
mock_chunker = MagicMock(return_value=[ChunkSlice("chunk1", 0, 6)])
|
||||
mock_chunker.__name__ = "chunk_markdown_with_spans"
|
||||
mock_chunk_hybrid = MagicMock(return_value=["chunk1"])
|
||||
mock_chunk_hybrid.__name__ = "chunk_text_hybrid"
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.cache.cached_indexing.chunk_markdown_with_spans",
|
||||
mock_chunker,
|
||||
"app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid",
|
||||
mock_chunk_hybrid,
|
||||
)
|
||||
mock_embed = MagicMock(
|
||||
side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]
|
||||
|
|
@ -87,25 +90,34 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
|
|||
|
||||
await pipeline.index(document, connector_doc)
|
||||
|
||||
assert "chunk_markdown_with_spans" in to_thread_calls
|
||||
# Either chunker entry point satisfies the "chunking runs off the event
|
||||
# loop" contract this test guards. Routing between the two is verified
|
||||
# in test_non_code_documents_use_hybrid_chunker.
|
||||
assert {"chunk_text", "chunk_text_hybrid"} & set(to_thread_calls)
|
||||
assert "embed_texts" in to_thread_calls
|
||||
assert document.status == DocumentStatus.ready()
|
||||
|
||||
|
||||
async def test_non_code_documents_use_prose_chunker(
|
||||
async def test_non_code_documents_use_hybrid_chunker(
|
||||
pipeline, make_connector_document, monkeypatch
|
||||
):
|
||||
"""Non-code documents chunk with use_code_chunker=False (issue #1334).
|
||||
"""Non-code documents route through ``chunk_text_hybrid`` (issue #1334).
|
||||
|
||||
The table-aware prose path keeps Markdown tables intact; only documents
|
||||
flagged with ``should_use_code_chunker=True`` request the code chunker.
|
||||
The hybrid chunker preserves Markdown table integrity by avoiding splits
|
||||
mid-row. Only documents flagged with ``should_use_code_chunker=True``
|
||||
should take the ``chunk_text`` path.
|
||||
"""
|
||||
from app.indexing_pipeline.document_chunker import ChunkSlice
|
||||
|
||||
mock_chunker = MagicMock(return_value=[ChunkSlice("chunk1", 0, 6)])
|
||||
mock_chunk_hybrid = MagicMock(return_value=["chunk1"])
|
||||
mock_chunk_hybrid.__name__ = "chunk_text_hybrid"
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.cache.cached_indexing.chunk_markdown_with_spans",
|
||||
mock_chunker,
|
||||
"app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid",
|
||||
mock_chunk_hybrid,
|
||||
)
|
||||
mock_chunk_code = MagicMock(return_value=["chunk1"])
|
||||
mock_chunk_code.__name__ = "chunk_text"
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.cache.cached_indexing.chunk_text",
|
||||
mock_chunk_code,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.cache.cached_indexing.embed_texts",
|
||||
|
|
@ -137,49 +149,8 @@ async def test_non_code_documents_use_prose_chunker(
|
|||
|
||||
await pipeline.index(document, connector_doc)
|
||||
|
||||
mock_chunker.assert_called_once()
|
||||
assert mock_chunker.call_args.args[1] is False
|
||||
|
||||
|
||||
async def test_code_documents_request_code_chunker(
|
||||
pipeline, make_connector_document, monkeypatch
|
||||
):
|
||||
"""Code-flagged documents forward use_code_chunker=True to the chunker."""
|
||||
from app.indexing_pipeline.document_chunker import ChunkSlice
|
||||
|
||||
mock_chunker = MagicMock(return_value=[ChunkSlice("chunk1", 0, 6)])
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.cache.cached_indexing.chunk_markdown_with_spans",
|
||||
mock_chunker,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.cache.cached_indexing.embed_texts",
|
||||
MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]),
|
||||
)
|
||||
monkeypatch.setattr(pipeline, "_load_existing_chunks", AsyncMock(return_value=[]))
|
||||
|
||||
async def _noop_persist(_session, doc, *_args, **_kwargs):
|
||||
doc.status = DocumentStatus.ready()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.persist_scratch_index",
|
||||
_noop_persist,
|
||||
)
|
||||
|
||||
connector_doc = make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id="repo-1",
|
||||
search_space_id=1,
|
||||
should_use_code_chunker=True,
|
||||
)
|
||||
document = MagicMock(spec=Document)
|
||||
document.id = 1
|
||||
document.status = DocumentStatus.pending()
|
||||
|
||||
await pipeline.index(document, connector_doc)
|
||||
|
||||
mock_chunker.assert_called_once()
|
||||
assert mock_chunker.call_args.args[1] is True
|
||||
mock_chunk_hybrid.assert_called_once()
|
||||
mock_chunk_code.assert_not_called()
|
||||
|
||||
|
||||
def _mock_session_factory(orm_docs_by_id):
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class _KBBackendStub(KBPostgresBackend):
|
|||
def __init__(self, *, children=None, file_data=None) -> None:
|
||||
self.als_info = AsyncMock(return_value=children or [])
|
||||
self._load_file_data = AsyncMock(
|
||||
return_value=(file_data, 17, None) if file_data is not None else None
|
||||
return_value=(file_data, 17) if file_data is not None else None
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -69,25 +69,13 @@ class _FakeSession:
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_embeddings_and_chunks(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Avoid loading the embedding model in unit tests.
|
||||
|
||||
Mirrors the legacy stub: one chunk spanning the whole content, with a
|
||||
zero summary/chunk vector, routed through the shared span builder.
|
||||
"""
|
||||
from app.indexing_pipeline.document_chunker import ChunkSlice
|
||||
|
||||
async def _fake_build_chunk_embeddings(content: str, *, use_code_chunker: bool):
|
||||
summary = np.zeros(8, dtype=np.float32)
|
||||
pairs = (
|
||||
[(ChunkSlice(content, 0, len(content)), np.zeros(8, dtype=np.float32))]
|
||||
if content
|
||||
else []
|
||||
)
|
||||
return summary, pairs
|
||||
|
||||
"""Avoid loading the embedding model in unit tests."""
|
||||
monkeypatch.setattr(
|
||||
kb_persistence, "build_chunk_embeddings", _fake_build_chunk_embeddings
|
||||
kb_persistence,
|
||||
"embed_texts",
|
||||
lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts],
|
||||
)
|
||||
monkeypatch.setattr(kb_persistence, "chunk_text", lambda content: [content])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -1,92 +0,0 @@
|
|||
"""Unit tests for the numbered-document read preamble."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.numbered_document import (
|
||||
build_read_preamble,
|
||||
compute_matched_line_ranges,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_BODY = "alpha\nbravo\ncharlie\ndelta"
|
||||
|
||||
|
||||
class TestComputeMatchedLineRanges:
|
||||
def test_maps_matched_chunk_spans_to_line_ranges(self):
|
||||
chunks = [(1, 0, 12), (2, 12, len(_BODY))]
|
||||
ranges = compute_matched_line_ranges(_BODY, chunks, {2})
|
||||
assert ranges == [(3, 4)]
|
||||
|
||||
def test_includes_only_matched_chunks(self):
|
||||
chunks = [(1, 0, 5), (2, 6, 11)]
|
||||
ranges = compute_matched_line_ranges(_BODY, chunks, {1})
|
||||
assert ranges == [(1, 1)]
|
||||
|
||||
def test_skips_chunks_without_spans(self):
|
||||
chunks = [(1, None, None)]
|
||||
ranges = compute_matched_line_ranges(_BODY, chunks, {1})
|
||||
assert ranges == []
|
||||
|
||||
def test_sorted_and_deduplicated(self):
|
||||
chunks = [(1, 12, len(_BODY)), (2, 0, 5), (3, 0, 5)]
|
||||
ranges = compute_matched_line_ranges(_BODY, chunks, {1, 2, 3})
|
||||
assert ranges == [(1, 1), (3, 4)]
|
||||
|
||||
|
||||
class TestBuildReadPreamble:
|
||||
def test_contains_document_metadata(self):
|
||||
preamble = build_read_preamble(
|
||||
document_id=42,
|
||||
document_type="FILE",
|
||||
title="Test Doc",
|
||||
url="https://example.com",
|
||||
matched_line_ranges=[],
|
||||
)
|
||||
assert "<document_id>42</document_id>" in preamble
|
||||
assert "<document_type>FILE</document_type>" in preamble
|
||||
assert "Test Doc" in preamble
|
||||
assert "https://example.com" in preamble
|
||||
|
||||
def test_citation_hint_uses_document_id(self):
|
||||
preamble = build_read_preamble(
|
||||
document_id=42,
|
||||
document_type="FILE",
|
||||
title="Test Doc",
|
||||
url="",
|
||||
matched_line_ranges=[],
|
||||
)
|
||||
assert "[citation:d42#L" in preamble
|
||||
|
||||
def test_lists_matched_line_ranges(self):
|
||||
preamble = build_read_preamble(
|
||||
document_id=7,
|
||||
document_type="NOTE",
|
||||
title="Notes",
|
||||
url="",
|
||||
matched_line_ranges=[(12, 18), (40, 40)],
|
||||
)
|
||||
assert "<matched_lines>" in preamble
|
||||
assert "12-18" in preamble
|
||||
assert "40" in preamble
|
||||
|
||||
def test_omits_matched_lines_block_when_empty(self):
|
||||
preamble = build_read_preamble(
|
||||
document_id=7,
|
||||
document_type="NOTE",
|
||||
title="Notes",
|
||||
url="",
|
||||
matched_line_ranges=[],
|
||||
)
|
||||
assert "<matched_lines>" not in preamble
|
||||
|
||||
def test_ends_with_trailing_newline_so_body_follows_cleanly(self):
|
||||
preamble = build_read_preamble(
|
||||
document_id=1,
|
||||
document_type="FILE",
|
||||
title="t",
|
||||
url="",
|
||||
matched_line_ranges=[],
|
||||
)
|
||||
assert preamble.endswith("\n")
|
||||
162
surfsense_backend/tests/unit/utils/test_async_retry.py
Normal file
162
surfsense_backend/tests/unit/utils/test_async_retry.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""Tests for async_retry utilities."""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.connectors.exceptions import (
|
||||
ConnectorAPIError,
|
||||
ConnectorAuthError,
|
||||
ConnectorError,
|
||||
ConnectorRateLimitError,
|
||||
ConnectorTimeoutError,
|
||||
)
|
||||
from app.utils.async_retry import _is_retryable, raise_for_status
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def make_response(
|
||||
status_code: int,
|
||||
*,
|
||||
headers: dict[str, str] | None = None,
|
||||
json_body=None,
|
||||
text_body: str = "",
|
||||
):
|
||||
kwargs = {
|
||||
"status_code": status_code,
|
||||
"headers": headers,
|
||||
"request": httpx.Request("GET", "https://x"),
|
||||
}
|
||||
|
||||
if json_body is not None:
|
||||
kwargs["json"] = json_body
|
||||
else:
|
||||
kwargs["text"] = text_body
|
||||
|
||||
return httpx.Response(**kwargs)
|
||||
|
||||
|
||||
def test_raise_for_status_does_not_raise_for_success():
|
||||
response = make_response(200)
|
||||
|
||||
raise_for_status(response)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("retry_after_header", "expected"),
|
||||
[
|
||||
("5", 5.0),
|
||||
(None, None),
|
||||
("abc", None),
|
||||
],
|
||||
)
|
||||
def test_raise_for_status_429(retry_after_header, expected):
|
||||
headers = {}
|
||||
if retry_after_header is not None:
|
||||
headers["Retry-After"] = retry_after_header
|
||||
|
||||
response = make_response(
|
||||
429,
|
||||
headers=headers,
|
||||
json_body={"detail": "rate limited"},
|
||||
)
|
||||
|
||||
with pytest.raises(ConnectorRateLimitError) as exc_info:
|
||||
raise_for_status(response)
|
||||
|
||||
exc = exc_info.value
|
||||
assert exc.retry_after == expected
|
||||
assert exc.response_body == {"detail": "rate limited"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status_code", [401, 403])
|
||||
def test_raise_for_status_auth_errors(status_code):
|
||||
response = make_response(
|
||||
status_code,
|
||||
json_body={"error": "unauthorized"},
|
||||
)
|
||||
|
||||
with pytest.raises(ConnectorAuthError) as exc_info:
|
||||
raise_for_status(response)
|
||||
|
||||
exc = exc_info.value
|
||||
assert exc.status_code == status_code
|
||||
assert exc.response_body == {"error": "unauthorized"}
|
||||
|
||||
|
||||
def test_raise_for_status_gateway_timeout():
|
||||
response = make_response(
|
||||
504,
|
||||
json_body={"error": "timeout"},
|
||||
)
|
||||
|
||||
with pytest.raises(ConnectorTimeoutError):
|
||||
raise_for_status(response)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status_code", [500, 502])
|
||||
def test_raise_for_status_server_errors(status_code):
|
||||
response = make_response(
|
||||
status_code,
|
||||
json_body={"error": "server"},
|
||||
)
|
||||
|
||||
with pytest.raises(ConnectorAPIError) as exc_info:
|
||||
raise_for_status(response)
|
||||
|
||||
assert exc_info.value.status_code == status_code
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status_code", [400, 404])
|
||||
def test_raise_for_status_client_errors(status_code):
|
||||
response = make_response(
|
||||
status_code,
|
||||
json_body={"error": "client"},
|
||||
)
|
||||
|
||||
with pytest.raises(ConnectorAPIError) as exc_info:
|
||||
raise_for_status(response)
|
||||
|
||||
assert exc_info.value.status_code == status_code
|
||||
|
||||
|
||||
def test_raise_for_status_uses_text_when_json_parsing_fails():
|
||||
response = make_response(
|
||||
500,
|
||||
text_body="Internal server error",
|
||||
)
|
||||
|
||||
with pytest.raises(ConnectorAPIError) as exc_info:
|
||||
raise_for_status(response)
|
||||
|
||||
assert exc_info.value.response_body == "Internal server error"
|
||||
|
||||
|
||||
def test_connector_error_retryable_false():
|
||||
exc = ConnectorError("boom")
|
||||
|
||||
assert _is_retryable(exc) is False
|
||||
|
||||
|
||||
def test_rate_limit_error_is_retryable():
|
||||
exc = ConnectorRateLimitError()
|
||||
|
||||
assert _is_retryable(exc) is True
|
||||
|
||||
|
||||
def test_timeout_exception_is_retryable():
|
||||
exc = httpx.TimeoutException("timeout")
|
||||
|
||||
assert _is_retryable(exc) is True
|
||||
|
||||
|
||||
def test_connect_error_is_retryable():
|
||||
exc = httpx.ConnectError("connection failed")
|
||||
|
||||
assert _is_retryable(exc) is True
|
||||
|
||||
|
||||
def test_unrelated_exception_is_not_retryable():
|
||||
exc = ValueError("boom")
|
||||
|
||||
assert _is_retryable(exc) is False
|
||||
293
surfsense_backend/tests/unit/utils/test_content_utils.py
Normal file
293
surfsense_backend/tests/unit/utils/test_content_utils.py
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
"""Tests for strip_markdown_fences() and extract_text_content() in
|
||||
app/utils/content_utils.py.
|
||||
|
||||
Out of scope: bootstrap_history_from_db() — async + DB, belongs in
|
||||
integration tests.
|
||||
|
||||
Run:
|
||||
uv run pytest -m unit tests/unit/utils/test_content_utils.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# strip_markdown_fences()
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStripMarkdownFences:
|
||||
"""Tests for strip_markdown_fences(text: str) -> str.
|
||||
|
||||
Regex: r"^```(?:\\w+)?\\s*\\n(.*?)```\\s*$" (re.DOTALL)
|
||||
Called on text.strip() — so surrounding whitespace is handled before
|
||||
the regex runs. The captured group is also .strip()-ped before return.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fenced with a language tag
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_json_fence_returns_inner_content(self):
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
text = '```json\n{"key": "value"}\n```'
|
||||
assert strip_markdown_fences(text) == '{"key": "value"}'
|
||||
|
||||
def test_python_fence_returns_inner_content(self):
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
text = "```python\ndef hello():\n return 'hi'\n```"
|
||||
assert strip_markdown_fences(text) == "def hello():\n return 'hi'"
|
||||
|
||||
def test_yaml_fence_returns_inner_content(self):
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
text = "```yaml\nkey: value\n```"
|
||||
assert strip_markdown_fences(text) == "key: value"
|
||||
|
||||
def test_sql_multiline_fence_returns_inner_content(self):
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
text = "```sql\nSELECT *\nFROM users\nWHERE id = 1;\n```"
|
||||
assert strip_markdown_fences(text) == "SELECT *\nFROM users\nWHERE id = 1;"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fenced without a language tag
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_no_lang_tag_single_line_returns_inner_content(self):
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
text = "```\nhello world\n```"
|
||||
assert strip_markdown_fences(text) == "hello world"
|
||||
|
||||
def test_no_lang_tag_multiline_returns_inner_content(self):
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
text = "```\nline one\nline two\n```"
|
||||
assert strip_markdown_fences(text) == "line one\nline two"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Plain text — no fences → returned unchanged
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_plain_text_returned_unchanged(self):
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
text = "just plain text with no fences"
|
||||
assert strip_markdown_fences(text) == text
|
||||
|
||||
def test_plain_text_with_newlines_returned_unchanged(self):
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
text = "line one\nline two\nline three"
|
||||
assert strip_markdown_fences(text) == text
|
||||
|
||||
def test_empty_string_returned_unchanged(self):
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
assert strip_markdown_fences("") == ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Surrounding whitespace handling
|
||||
# The function calls text.strip() before matching, so leading/trailing
|
||||
# whitespace outside the fence is consumed. The captured group is also
|
||||
# .strip()-ped, so whitespace between the fence markers and content is
|
||||
# removed too.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_leading_whitespace_around_fence_stripped(self):
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
text = " ```json\n{}\n```"
|
||||
assert strip_markdown_fences(text) == "{}"
|
||||
|
||||
def test_trailing_whitespace_around_fence_stripped(self):
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
text = "```json\n{}\n``` "
|
||||
assert strip_markdown_fences(text) == "{}"
|
||||
|
||||
def test_surrounding_newlines_stripped(self):
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
text = '\n\n```json\n{"a": 1}\n```\n\n'
|
||||
assert strip_markdown_fences(text) == '{"a": 1}'
|
||||
|
||||
def test_inner_indentation_preserved(self):
|
||||
"""The captured group is .strip()-ped, so leading whitespace on the
|
||||
*first* line is removed, but indentation on subsequent lines is kept."""
|
||||
from app.utils.content_utils import strip_markdown_fences
|
||||
|
||||
text = "```\n indented line\n deeper indent\n```"
|
||||
result = strip_markdown_fences(text)
|
||||
# .strip() removes the leading spaces from the first captured line
|
||||
assert "indented line" in result
|
||||
# indentation on the second line is preserved
|
||||
assert " deeper indent" in result
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# extract_text_content()
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestExtractTextContent:
|
||||
"""Tests for extract_text_content(content: str | dict | list) -> str."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# str input → returned as-is
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_str_input_returned_as_is(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
assert extract_text_content("hello world") == "hello world"
|
||||
|
||||
def test_str_empty_returned_as_is(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
assert extract_text_content("") == ""
|
||||
|
||||
def test_str_with_internal_whitespace_returned_as_is(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
assert extract_text_content(" spaced ") == " spaced "
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dict with "text" key → return content["text"]
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_dict_with_text_key_returns_its_value(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
assert extract_text_content({"text": "from dict"}) == "from dict"
|
||||
|
||||
def test_dict_with_text_key_empty_value(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
assert extract_text_content({"text": ""}) == ""
|
||||
|
||||
def test_dict_with_text_key_ignores_other_keys(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
d = {"text": "important", "role": "assistant", "extra": 99}
|
||||
assert extract_text_content(d) == "important"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dict without "text" key → str(dict)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_dict_without_text_key_returns_str_repr(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
d = {"role": "assistant", "value": 42}
|
||||
assert extract_text_content(d) == str(d)
|
||||
|
||||
def test_empty_dict_returns_str_repr(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
assert extract_text_content({}) == str({})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list of parts — text dicts and plain strings
|
||||
# Parts are joined with "\n" (per implementation: "\n".join(texts))
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_list_text_type_parts_joined_with_newline(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
parts = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "world"},
|
||||
]
|
||||
assert extract_text_content(parts) == "Hello\nworld"
|
||||
|
||||
def test_list_plain_strings_joined_with_newline(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
parts = ["foo", "bar"]
|
||||
assert extract_text_content(parts) == "foo\nbar"
|
||||
|
||||
def test_list_mixed_text_dicts_and_plain_strings(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
parts = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
"plain",
|
||||
{"type": "text", "text": "world"},
|
||||
]
|
||||
result = extract_text_content(parts)
|
||||
assert "Hello" in result
|
||||
assert "plain" in result
|
||||
assert "world" in result
|
||||
|
||||
def test_list_non_text_type_parts_ignored(self):
|
||||
"""tool_use, image, and other non-text blocks must not leak into output."""
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
parts = [
|
||||
{"type": "tool_use", "id": "abc", "name": "search_kb"},
|
||||
{"type": "text", "text": "visible text"},
|
||||
{"type": "image", "source": {"url": "https://example.com/img.png"}},
|
||||
]
|
||||
result = extract_text_content(parts)
|
||||
assert result == "visible text"
|
||||
assert "tool_use" not in result
|
||||
assert "search_kb" not in result
|
||||
assert "image" not in result
|
||||
|
||||
def test_list_only_non_text_parts_returns_empty_string(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
parts = [
|
||||
{"type": "tool_use", "id": "x"},
|
||||
{"type": "image", "source": {}},
|
||||
]
|
||||
assert extract_text_content(parts) == ""
|
||||
|
||||
def test_list_single_text_part(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
parts = [{"type": "text", "text": "only me"}]
|
||||
assert extract_text_content(parts) == "only me"
|
||||
|
||||
def test_list_text_part_missing_text_key_contributes_empty_string(self):
|
||||
"""part.get("text", "") — a text-typed dict with no "text" key gives ""."""
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
parts = [{"type": "text"}, {"type": "text", "text": "after"}]
|
||||
result = extract_text_content(parts)
|
||||
# both parts collected; joined → "\nafter" or "after" depending on strip
|
||||
assert "after" in result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Empty list → empty string
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_empty_list_returns_empty_string(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
assert extract_text_content([]) == ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unsupported types → empty string (the final bare `return ""`)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_none_returns_empty_string(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
assert extract_text_content(None) == ""
|
||||
|
||||
def test_integer_returns_empty_string(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
assert extract_text_content(42) == ""
|
||||
|
||||
def test_boolean_returns_empty_string(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
assert extract_text_content(True) == ""
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
"""Unit tests for char-span -> line-range conversion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.utils.text_spans import char_span_to_line_range
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_TEXT = "line1\nline2\nline3"
|
||||
|
||||
|
||||
def test_single_line_span() -> None:
|
||||
start = _TEXT.index("line2")
|
||||
assert char_span_to_line_range(_TEXT, start, start + len("line2")) == (2, 2)
|
||||
|
||||
|
||||
def test_first_line_span() -> None:
|
||||
assert char_span_to_line_range(_TEXT, 0, len("line1")) == (1, 1)
|
||||
|
||||
|
||||
def test_last_line_span() -> None:
|
||||
start = _TEXT.index("line3")
|
||||
assert char_span_to_line_range(_TEXT, start, len(_TEXT)) == (3, 3)
|
||||
|
||||
|
||||
def test_multi_line_span() -> None:
|
||||
# "line1\nline2" spans lines 1-2.
|
||||
assert char_span_to_line_range(_TEXT, 0, _TEXT.index("line2") + 5) == (1, 2)
|
||||
|
||||
|
||||
def test_empty_span_resolves_to_its_line() -> None:
|
||||
start = _TEXT.index("line2")
|
||||
assert char_span_to_line_range(_TEXT, start, start) == (2, 2)
|
||||
|
||||
|
||||
def test_offsets_clamped_to_text_bounds() -> None:
|
||||
assert char_span_to_line_range(_TEXT, -5, 10_000) == (1, 3)
|
||||
Loading…
Add table
Add a link
Reference in a new issue