mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
roadmap(1.3): Update citation prompt to use new whole document structure
- Modified the document extraction and citation formatting to accommodate a new structure that includes a `chunks` list for each document. - Enhanced the citation format to reference `chunk_id` instead of `source_id`, ensuring accurate citations in the UI. - Updated various components, including the connector service and reranker service, to handle the new document format and maintain compatibility with existing functionalities. - Improved documentation and comments to reflect changes in the data structure and citation requirements.
This commit is contained in:
parent
ed6fc10133
commit
fea1837186
9 changed files with 1054 additions and 1122 deletions
|
|
@ -30,51 +30,59 @@ def extract_sources_from_documents(
|
|||
all_documents: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract sources from all_documents and group them by document type.
|
||||
Extract sources from **document-grouped** results and group them by document type.
|
||||
|
||||
Args:
|
||||
all_documents: List of document chunks from user-selected documents and connector-fetched documents
|
||||
all_documents: List of document-grouped results from user-selected documents and connector-fetched documents
|
||||
|
||||
Returns:
|
||||
List of source objects grouped by type for streaming
|
||||
"""
|
||||
# Group documents by their source type
|
||||
# Group sources by their source type
|
||||
documents_by_type = {}
|
||||
|
||||
for doc in all_documents:
|
||||
# Get source type from the document
|
||||
document_info = doc.get("document", {}) or {}
|
||||
source_type = doc.get("source", "UNKNOWN")
|
||||
document_info = doc.get("document", {})
|
||||
document_type = document_info.get("document_type", source_type)
|
||||
|
||||
# Use document_type if available, otherwise use source
|
||||
document_type = document_info.get("document_type", source_type) or source_type
|
||||
group_type = document_type if document_type != "UNKNOWN" else source_type
|
||||
|
||||
if group_type not in documents_by_type:
|
||||
documents_by_type[group_type] = []
|
||||
documents_by_type[group_type].append(doc)
|
||||
|
||||
# Create source objects for each document type
|
||||
source_objects = []
|
||||
source_id_counter = 1
|
||||
|
||||
for doc_type, docs in documents_by_type.items():
|
||||
sources_list = []
|
||||
|
||||
for doc in docs:
|
||||
document_info = doc.get("document", {})
|
||||
metadata = document_info.get("metadata", {})
|
||||
url = (
|
||||
metadata.get("url")
|
||||
or metadata.get("source")
|
||||
or metadata.get("page_url")
|
||||
or metadata.get("VisitedWebPageURL")
|
||||
or ""
|
||||
)
|
||||
|
||||
# Create source entry based on document structure
|
||||
source = {
|
||||
"id": doc.get("chunk_id", source_id_counter),
|
||||
"title": document_info.get("title", "Untitled Document"),
|
||||
"description": doc.get("content", "").strip(),
|
||||
"url": metadata.get("url", metadata.get("page_url", "")),
|
||||
}
|
||||
|
||||
source_id_counter += 1
|
||||
sources_list.append(source)
|
||||
# Each chunk becomes a source entry so citations like [citation:<chunk_id>] resolve in UI.
|
||||
for chunk in doc.get("chunks", []) or []:
|
||||
chunk_id = chunk.get("chunk_id")
|
||||
chunk_content = (chunk.get("content") or "").strip()
|
||||
description = (
|
||||
chunk_content
|
||||
if len(chunk_content) <= 240
|
||||
else chunk_content[:240] + "..."
|
||||
)
|
||||
sources_list.append(
|
||||
{
|
||||
"id": chunk_id,
|
||||
"title": document_info.get("title", "Untitled Document"),
|
||||
"description": description,
|
||||
"url": url,
|
||||
}
|
||||
)
|
||||
|
||||
# Create group object
|
||||
group_name = (
|
||||
|
|
@ -127,50 +135,40 @@ async def fetch_documents_by_ids(
|
|||
documents = result.scalars().all()
|
||||
|
||||
# Group documents by type for source object creation
|
||||
documents_by_type = {}
|
||||
formatted_documents = []
|
||||
documents_by_type: dict[str, list[Document]] = {}
|
||||
formatted_documents: list[dict[str, Any]] = []
|
||||
|
||||
from app.db import Chunk
|
||||
|
||||
for doc in documents:
|
||||
# Fetch associated chunks for this document (similar to DocumentHybridSearchRetriever)
|
||||
from app.db import Chunk
|
||||
|
||||
# Fetch associated chunks for this document
|
||||
chunks_query = (
|
||||
select(Chunk).where(Chunk.document_id == doc.id).order_by(Chunk.id)
|
||||
)
|
||||
chunks_result = await db_session.execute(chunks_query)
|
||||
chunks = chunks_result.scalars().all()
|
||||
|
||||
# Return individual chunks instead of concatenated content
|
||||
if chunks:
|
||||
for chunk in chunks:
|
||||
# Format each chunk to match connector service return format
|
||||
formatted_chunk = {
|
||||
"chunk_id": chunk.id,
|
||||
"content": chunk.content, # Use individual chunk content
|
||||
"score": 0.5, # High score since user explicitly selected these
|
||||
"document": {
|
||||
"id": chunk.id,
|
||||
"title": doc.title,
|
||||
"document_type": (
|
||||
doc.document_type.value
|
||||
if doc.document_type
|
||||
else "UNKNOWN"
|
||||
),
|
||||
"metadata": doc.document_metadata or {},
|
||||
},
|
||||
"source": doc.document_type.value
|
||||
if doc.document_type
|
||||
else "UNKNOWN",
|
||||
}
|
||||
formatted_documents.append(formatted_chunk)
|
||||
doc_type = doc.document_type.value if doc.document_type else "UNKNOWN"
|
||||
documents_by_type.setdefault(doc_type, []).append(doc)
|
||||
|
||||
# Group by document type for source objects
|
||||
doc_type = (
|
||||
doc.document_type.value if doc.document_type else "UNKNOWN"
|
||||
)
|
||||
if doc_type not in documents_by_type:
|
||||
documents_by_type[doc_type] = []
|
||||
documents_by_type[doc_type].append(doc)
|
||||
doc_group = {
|
||||
"document_id": doc.id,
|
||||
"content": "\n\n".join(c.content for c in chunks)
|
||||
if chunks
|
||||
else (doc.content or ""),
|
||||
"score": 0.5, # High score since user explicitly selected these
|
||||
"chunks": [{"chunk_id": c.id, "content": c.content} for c in chunks]
|
||||
if chunks
|
||||
else [],
|
||||
"document": {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"document_type": doc_type,
|
||||
"metadata": doc.document_metadata or {},
|
||||
},
|
||||
"source": doc_type,
|
||||
}
|
||||
formatted_documents.append(doc_group)
|
||||
|
||||
# Create source objects for each document type (similar to ConnectorService)
|
||||
source_objects = []
|
||||
|
|
@ -1265,25 +1263,22 @@ async def fetch_relevant_documents(
|
|||
}
|
||||
)
|
||||
|
||||
# Deduplicate raw documents based on chunk_id or content
|
||||
seen_chunk_ids = set()
|
||||
# Deduplicate raw documents based on document_id (preferred) or content hash
|
||||
seen_doc_ids = set()
|
||||
seen_content_hashes = set()
|
||||
deduplicated_docs = []
|
||||
deduplicated_docs: list[dict[str, Any]] = []
|
||||
|
||||
for doc in all_raw_documents:
|
||||
chunk_id = doc.get("chunk_id")
|
||||
content = doc.get("content", "")
|
||||
doc_id = (doc.get("document", {}) or {}).get("id")
|
||||
content = doc.get("content", "") or ""
|
||||
content_hash = hash(content)
|
||||
|
||||
# Skip if we've seen this chunk_id or content before
|
||||
if (
|
||||
chunk_id and chunk_id in seen_chunk_ids
|
||||
) or content_hash in seen_content_hashes:
|
||||
# Skip if we've seen this document_id or content before
|
||||
if (doc_id and doc_id in seen_doc_ids) or content_hash in seen_content_hashes:
|
||||
continue
|
||||
|
||||
# Add to our tracking sets and keep this document
|
||||
if chunk_id:
|
||||
seen_chunk_ids.add(chunk_id)
|
||||
if doc_id:
|
||||
seen_doc_ids.add(doc_id)
|
||||
seen_content_hashes.add(content_hash)
|
||||
deduplicated_docs.append(doc)
|
||||
|
||||
|
|
@ -1292,7 +1287,7 @@ async def fetch_relevant_documents(
|
|||
writer(
|
||||
{
|
||||
"yield_value": streaming_service.format_terminal_info_delta(
|
||||
f"🧹 Found {len(deduplicated_docs)} unique document chunks after removing duplicates"
|
||||
f"🧹 Found {len(deduplicated_docs)} unique documents after removing duplicates"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ You are an expert research assistant specializing in generating contextually rel
|
|||
|
||||
<input>
|
||||
- chat_history: Provided in XML format within <chat_history> tags, containing <user> and <assistant> message pairs that show the chronological conversation flow. This provides context about what has already been discussed.
|
||||
- available_documents: Provided in XML format within <documents> tags, containing individual <document> elements with <metadata> (source_id, source_type) and <content> sections. This helps understand what information is accessible for answering potential follow-up questions.
|
||||
- available_documents: Provided in XML format within <documents> tags, containing individual <document> elements with <document_metadata> and <document_content> sections. Each document contains multiple `<chunk id='...'>...</chunk>` blocks inside <document_content>. This helps understand what information is accessible for answering potential follow-up questions.
|
||||
</input>
|
||||
|
||||
<output_format>
|
||||
|
|
|
|||
|
|
@ -78,32 +78,53 @@ DEFAULT_QNA_CITATION_INSTRUCTIONS = """
|
|||
<citation_instructions>
|
||||
CRITICAL CITATION REQUIREMENTS:
|
||||
|
||||
1. For EVERY piece of information you include from the documents, add a citation in the format [citation:knowledge_source_id] where knowledge_source_id is the source_id from the document's metadata.
|
||||
1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `<chunk id='...'>` tag inside `<document_content>`.
|
||||
2. Make sure ALL factual statements from the documents have proper citations.
|
||||
3. If multiple documents support the same point, include all relevant citations [citation:source_id1], [citation:source_id2].
|
||||
4. You MUST use the exact source_id value from each document's metadata for citations. Do not create your own citation numbers.
|
||||
5. Every citation MUST be in the format [citation:knowledge_source_id] where knowledge_source_id is the exact source_id value.
|
||||
6. Never modify or change the source_id - always use the original values exactly as provided in the metadata.
|
||||
3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2].
|
||||
4. You MUST use the exact chunk_id values from the `<chunk id='...'>` attributes. Do not create your own citation numbers.
|
||||
5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value.
|
||||
6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags.
|
||||
7. Do not return citations as clickable links.
|
||||
8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only.
|
||||
9. Citations must ONLY appear as [citation:source_id] or [citation:source_id1], [citation:source_id2] format - never with parentheses, hyperlinks, or other formatting.
|
||||
10. Never make up source IDs. Only use source_id values that are explicitly provided in the document metadata.
|
||||
11. If you are unsure about a source_id, do not include a citation rather than guessing or making one up.
|
||||
9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting.
|
||||
10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `<chunk id='...'>` tags.
|
||||
11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up.
|
||||
|
||||
<document_structure_example>
|
||||
The documents you receive are structured like this:
|
||||
|
||||
<document>
|
||||
<document_metadata>
|
||||
<document_id>42</document_id>
|
||||
<document_type>GITHUB_CONNECTOR</document_type>
|
||||
<title><![CDATA[Some repo / file / issue title]]></title>
|
||||
<url><![CDATA[https://example.com]]></url>
|
||||
<metadata_json><![CDATA[{{"any":"other metadata"}}]]></metadata_json>
|
||||
</document_metadata>
|
||||
|
||||
<document_content>
|
||||
<chunk id='123'><![CDATA[First chunk text...]]></chunk>
|
||||
<chunk id='124'><![CDATA[Second chunk text...]]></chunk>
|
||||
</document_content>
|
||||
</document>
|
||||
|
||||
IMPORTANT: You MUST cite using the chunk ids (e.g. 123, 124). Do NOT cite document_id.
|
||||
</document_structure_example>
|
||||
|
||||
<citation_format>
|
||||
- Every fact from the documents must have a citation in the format [citation:knowledge_source_id] where knowledge_source_id is the EXACT source_id from the document's metadata
|
||||
- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `<chunk id='...'>` tag
|
||||
- Citations should appear at the end of the sentence containing the information they support
|
||||
- Multiple citations should be separated by commas: [citation:source_id1], [citation:source_id2], [citation:source_id3]
|
||||
- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
|
||||
- No need to return references section. Just citations in answer.
|
||||
- NEVER create your own citation format - use the exact source_id values from the documents in the [citation:source_id] format
|
||||
- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format
|
||||
- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only
|
||||
- NEVER make up source IDs if you are unsure about the source_id. It is better to omit the citation than to guess
|
||||
- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess
|
||||
</citation_format>
|
||||
|
||||
<citation_examples>
|
||||
CORRECT citation formats:
|
||||
- [citation:5]
|
||||
- [citation:source_id1], [citation:source_id2], [citation:source_id3]
|
||||
- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
|
||||
|
||||
INCORRECT citation formats (DO NOT use):
|
||||
- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense))
|
||||
|
|
|
|||
|
|
@ -71,6 +71,10 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
|
|||
reranks them using the reranker service based on the user's query,
|
||||
and updates the state with the reranked documents.
|
||||
|
||||
Documents are now document-grouped with a `chunks` list. Reranking is done
|
||||
using the concatenated `content` field, and the full structure (including
|
||||
`chunks`) is preserved for proper citation formatting.
|
||||
|
||||
If reranking is disabled, returns the original documents without processing.
|
||||
|
||||
Returns:
|
||||
|
|
@ -99,25 +103,12 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
|
|||
|
||||
# Perform reranking
|
||||
try:
|
||||
# Convert documents to format expected by reranker if needed
|
||||
reranker_input_docs = [
|
||||
{
|
||||
"chunk_id": doc.get("chunk_id", f"chunk_{i}"),
|
||||
"content": doc.get("content", ""),
|
||||
"score": doc.get("score", 0.0),
|
||||
"document": {
|
||||
"id": doc.get("document", {}).get("id", ""),
|
||||
"title": doc.get("document", {}).get("title", ""),
|
||||
"document_type": doc.get("document", {}).get("document_type", ""),
|
||||
"metadata": doc.get("document", {}).get("metadata", {}),
|
||||
},
|
||||
}
|
||||
for i, doc in enumerate(documents)
|
||||
]
|
||||
|
||||
# Rerank documents using the user's query
|
||||
# Pass documents directly to reranker - it will use:
|
||||
# - "content" (concatenated chunk text) for scoring
|
||||
# - "chunk_id" (primary chunk id) for matching
|
||||
# The full document structure including "chunks" is preserved
|
||||
reranked_docs = reranker_service.rerank_documents(
|
||||
user_query + "\n" + reformulated_query, reranker_input_docs
|
||||
user_query + "\n" + reformulated_query, documents
|
||||
)
|
||||
|
||||
# Sort by score in descending order
|
||||
|
|
@ -141,8 +132,8 @@ async def answer_question(
|
|||
|
||||
This node takes the relevant documents provided in the configuration and uses
|
||||
an LLM to generate a comprehensive answer to the user's question with
|
||||
proper citations. The citations follow [citation:source_id] format using source IDs from the
|
||||
documents. If no documents are provided, it will use chat history to generate
|
||||
proper citations. The citations follow [citation:chunk_id] format using chunk IDs from the
|
||||
`<chunk id='...'>` tags in the provided documents. If no documents are provided, it will use chat history to generate
|
||||
an answer.
|
||||
|
||||
The response is streamed token-by-token for real-time updates to the frontend.
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
|
@ -78,21 +79,59 @@ def convert_langchain_messages_to_dict(
|
|||
|
||||
|
||||
def format_document_for_citation(document: dict[str, Any]) -> str:
|
||||
"""Format a single document for citation in the standard XML format."""
|
||||
content = document.get("content", "")
|
||||
doc_info = document.get("document", {})
|
||||
document_id = document.get("chunk_id", "")
|
||||
"""Format a single document for citation in the new document+chunks XML format.
|
||||
|
||||
IMPORTANT:
|
||||
- Citations must reference real DB chunk IDs: `[citation:<chunk_id>]`
|
||||
- Document metadata is included under <document_metadata>, but citations are NOT document_id-based.
|
||||
"""
|
||||
|
||||
def _to_cdata(value: Any) -> str:
|
||||
text = "" if value is None else str(value)
|
||||
# Safely nest CDATA even if the content includes "]]>"
|
||||
return "<![CDATA[" + text.replace("]]>", "]]]]><![CDATA[>") + "]]>"
|
||||
|
||||
doc_info = document.get("document", {}) or {}
|
||||
metadata = doc_info.get("metadata", {}) or {}
|
||||
|
||||
doc_id = doc_info.get("id", "")
|
||||
title = doc_info.get("title", "")
|
||||
document_type = doc_info.get("document_type", "CRAWLED_URL")
|
||||
url = (
|
||||
metadata.get("url")
|
||||
or metadata.get("source")
|
||||
or metadata.get("page_url")
|
||||
or metadata.get("VisitedWebPageURL")
|
||||
or ""
|
||||
)
|
||||
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||
|
||||
chunks = document.get("chunks") or []
|
||||
if not chunks:
|
||||
# Fallback: treat `content` as a single chunk (no chunk_id available for citation)
|
||||
chunks = [{"chunk_id": "", "content": document.get("content", "")}]
|
||||
|
||||
chunks_xml = "\n".join(
|
||||
[
|
||||
f"<chunk id='{chunk.get('chunk_id', '')}'>{_to_cdata(chunk.get('content', ''))}</chunk>"
|
||||
for chunk in chunks
|
||||
]
|
||||
)
|
||||
|
||||
return f"""<document>
|
||||
<metadata>
|
||||
<source_id>{document_id}</source_id>
|
||||
<source_type>{document_type}</source_type>
|
||||
</metadata>
|
||||
<content>
|
||||
{content}
|
||||
</content>
|
||||
</document>"""
|
||||
<document_metadata>
|
||||
<document_id>{doc_id}</document_id>
|
||||
<document_type>{document_type}</document_type>
|
||||
<title>{_to_cdata(title)}</title>
|
||||
<url>{_to_cdata(url)}</url>
|
||||
<metadata_json>{_to_cdata(metadata_json)}</metadata_json>
|
||||
</document_metadata>
|
||||
|
||||
<document_content>
|
||||
{chunks_xml}
|
||||
</document_content>
|
||||
</document>"""
|
||||
|
||||
|
||||
def format_documents_section(
|
||||
|
|
|
|||
|
|
@ -131,18 +131,25 @@ class ChucksHybridSearchRetriever:
|
|||
end_date: datetime | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
|
||||
Hybrid search that returns **documents** (not individual chunks).
|
||||
|
||||
Each returned item is a document-grouped dict that preserves real DB chunk IDs so
|
||||
downstream agents can cite with `[citation:<chunk_id>]`.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
top_k: Number of documents to return
|
||||
search_space_id: The search space ID to search within
|
||||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing chunk data and relevance scores
|
||||
List of dictionaries containing document data and relevance scores. Each dict contains:
|
||||
- chunk_id: a "primary" chunk id for compatibility (best-ranked chunk for the doc)
|
||||
- content: concatenated chunk content (useful for reranking)
|
||||
- chunks: list[{chunk_id, content}] for citation-aware prompting
|
||||
- document: {id, title, document_type, metadata}
|
||||
"""
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
|
@ -154,9 +161,9 @@ class ChucksHybridSearchRetriever:
|
|||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
# Constants for RRF calculation
|
||||
k = 60 # Constant for RRF calculation
|
||||
n_results = top_k * 2 # Get more results for better fusion
|
||||
# RRF constants
|
||||
k = 60
|
||||
n_results = top_k * 5 # Fetch extra chunks for better document-level fusion
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector("english", Chunk.content)
|
||||
|
|
@ -255,10 +262,10 @@ class ChucksHybridSearchRetriever:
|
|||
if not chunks_with_scores:
|
||||
return []
|
||||
|
||||
# Convert to serializable dictionaries if no reranker is available or if reranking failed
|
||||
serialized_results = []
|
||||
# Convert to serializable dictionaries
|
||||
serialized_chunk_results: list[dict] = []
|
||||
for chunk, score in chunks_with_scores:
|
||||
serialized_results.append(
|
||||
serialized_chunk_results.append(
|
||||
{
|
||||
"chunk_id": chunk.id,
|
||||
"content": chunk.content,
|
||||
|
|
@ -274,4 +281,77 @@ class ChucksHybridSearchRetriever:
|
|||
}
|
||||
)
|
||||
|
||||
return serialized_results
|
||||
# Group by document, preserving ranking order by best chunk rank
|
||||
doc_scores: dict[int, float] = {}
|
||||
doc_order: list[int] = []
|
||||
for item in serialized_chunk_results:
|
||||
doc_id = item.get("document", {}).get("id")
|
||||
if doc_id is None:
|
||||
continue
|
||||
if doc_id not in doc_scores:
|
||||
doc_scores[doc_id] = item.get("score", 0.0)
|
||||
doc_order.append(doc_id)
|
||||
else:
|
||||
# Use the best score as doc score
|
||||
doc_scores[doc_id] = max(doc_scores[doc_id], item.get("score", 0.0))
|
||||
|
||||
# Keep only top_k documents by initial rank order.
|
||||
doc_ids = doc_order[:top_k]
|
||||
if not doc_ids:
|
||||
return []
|
||||
|
||||
# Fetch ALL chunks for selected documents in a single query so the final prompt can cite
|
||||
# any chunk from those documents.
|
||||
chunk_query = (
|
||||
select(Chunk)
|
||||
.options(joinedload(Chunk.document))
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.where(Document.id.in_(doc_ids))
|
||||
.where(*base_conditions)
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunks_result = await self.db_session.execute(chunk_query)
|
||||
all_chunks = chunks_result.scalars().all()
|
||||
|
||||
# Assemble final doc-grouped results in the same order as doc_ids
|
||||
doc_map: dict[int, dict] = {
|
||||
doc_id: {
|
||||
"document_id": doc_id,
|
||||
"content": "",
|
||||
"score": float(doc_scores.get(doc_id, 0.0)),
|
||||
"chunks": [],
|
||||
"document": {},
|
||||
"source": None,
|
||||
}
|
||||
for doc_id in doc_ids
|
||||
}
|
||||
|
||||
for chunk in all_chunks:
|
||||
doc = chunk.document
|
||||
doc_id = doc.id
|
||||
if doc_id not in doc_map:
|
||||
continue
|
||||
doc_entry = doc_map[doc_id]
|
||||
doc_entry["document"] = {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"document_type": doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None,
|
||||
"metadata": doc.document_metadata or {},
|
||||
}
|
||||
doc_entry["source"] = (
|
||||
doc.document_type.value if getattr(doc, "document_type", None) else None
|
||||
)
|
||||
doc_entry["chunks"].append({"chunk_id": chunk.id, "content": chunk.content})
|
||||
|
||||
# Fill concatenated content (useful for reranking)
|
||||
final_docs: list[dict] = []
|
||||
for doc_id in doc_ids:
|
||||
entry = doc_map[doc_id]
|
||||
entry["content"] = "\n\n".join(
|
||||
c["content"] for c in entry.get("chunks", []) if c.get("content")
|
||||
)
|
||||
final_docs.append(entry)
|
||||
|
||||
return final_docs
|
||||
|
|
|
|||
|
|
@ -131,11 +131,14 @@ class DocumentHybridSearchRetriever:
|
|||
end_date: datetime | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
|
||||
Hybrid search that returns **documents** (not individual chunks).
|
||||
|
||||
Each returned item is a document-grouped dict that preserves real DB chunk IDs so
|
||||
downstream agents can cite with `[citation:<chunk_id>]`.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
top_k: Number of documents to return
|
||||
search_space_id: The search space ID to search within
|
||||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
|
|
@ -146,15 +149,15 @@ class DocumentHybridSearchRetriever:
|
|||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.config import config
|
||||
from app.db import Document, DocumentType
|
||||
from app.db import Chunk, Document, DocumentType
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
# Constants for RRF calculation
|
||||
k = 60 # Constant for RRF calculation
|
||||
n_results = top_k * 2 # Get more results for better fusion
|
||||
# RRF constants
|
||||
k = 60
|
||||
n_results = top_k * 2 # Fetch extra documents for better fusion
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector("english", Document.content)
|
||||
|
|
@ -248,50 +251,56 @@ class DocumentHybridSearchRetriever:
|
|||
if not documents_with_scores:
|
||||
return []
|
||||
|
||||
# Convert to serializable dictionaries - return individual chunks
|
||||
serialized_results = []
|
||||
for document, score in documents_with_scores:
|
||||
# Fetch associated chunks for this document
|
||||
from sqlalchemy import select
|
||||
# Collect document IDs for chunk fetching
|
||||
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
|
||||
|
||||
from app.db import Chunk
|
||||
# Fetch ALL chunks for these documents in a single query
|
||||
chunks_query = (
|
||||
select(Chunk)
|
||||
.options(joinedload(Chunk.document))
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunks_result = await self.db_session.execute(chunks_query)
|
||||
chunks = chunks_result.scalars().all()
|
||||
|
||||
chunks_query = (
|
||||
select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
|
||||
# Assemble doc-grouped results
|
||||
doc_map: dict[int, dict] = {
|
||||
doc.id: {
|
||||
"document_id": doc.id,
|
||||
"content": "",
|
||||
"score": float(score),
|
||||
"chunks": [],
|
||||
"document": {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"document_type": doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None,
|
||||
"metadata": doc.document_metadata or {},
|
||||
},
|
||||
"source": doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None,
|
||||
}
|
||||
for doc, score in documents_with_scores
|
||||
}
|
||||
|
||||
for chunk in chunks:
|
||||
doc_id = chunk.document_id
|
||||
if doc_id not in doc_map:
|
||||
continue
|
||||
doc_map[doc_id]["chunks"].append(
|
||||
{"chunk_id": chunk.id, "content": chunk.content}
|
||||
)
|
||||
chunks_result = await self.db_session.execute(chunks_query)
|
||||
chunks = chunks_result.scalars().all()
|
||||
|
||||
# Return individual chunks instead of concatenated content
|
||||
if chunks:
|
||||
for chunk in chunks:
|
||||
serialized_results.append(
|
||||
{
|
||||
"document_id": chunk.id,
|
||||
"title": document.title,
|
||||
"content": chunk.content, # Use chunk content instead of document content
|
||||
"document_type": document.document_type.value
|
||||
if hasattr(document, "document_type")
|
||||
else None,
|
||||
"metadata": document.document_metadata,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
"search_space_id": document.search_space_id,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# If no chunks exist, return the document content as a single result
|
||||
serialized_results.append(
|
||||
{
|
||||
"document_id": document.id,
|
||||
"title": document.title,
|
||||
"content": document.content,
|
||||
"document_type": document.document_type.value
|
||||
if hasattr(document, "document_type")
|
||||
else None,
|
||||
"metadata": document.document_metadata,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
"search_space_id": document.search_space_id,
|
||||
}
|
||||
)
|
||||
# Fill concatenated content (useful for reranking)
|
||||
final_docs: list[dict] = []
|
||||
for doc_id in doc_ids:
|
||||
entry = doc_map[doc_id]
|
||||
entry["content"] = "\n\n".join(
|
||||
c["content"] for c in entry.get("chunks", []) if c.get("content")
|
||||
)
|
||||
final_docs.append(entry)
|
||||
|
||||
return serialized_results
|
||||
return final_docs
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -22,14 +22,18 @@ class RerankerService:
|
|||
self, query_text: str, documents: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using the configured reranker
|
||||
Rerank documents using the configured reranker.
|
||||
|
||||
Documents can be either:
|
||||
- Document-grouped (new format): Has `document_id`, `chunks` list, and `content` (concatenated)
|
||||
- Chunk-based (legacy format): Individual chunks with `chunk_id` and `content`
|
||||
|
||||
Args:
|
||||
query_text: The query text to use for reranking
|
||||
documents: List of document dictionaries to rerank
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Reranked documents
|
||||
List[Dict[str, Any]]: Reranked documents with preserved structure
|
||||
"""
|
||||
if not self.reranker_instance or not documents:
|
||||
return documents
|
||||
|
|
@ -38,7 +42,9 @@ class RerankerService:
|
|||
# Create Document objects for the rerankers library
|
||||
reranker_docs = []
|
||||
for i, doc in enumerate(documents):
|
||||
chunk_id = doc.get("chunk_id", f"chunk_{i}")
|
||||
# Use document_id for matching
|
||||
doc_id = doc.get("document_id") or f"doc_{i}"
|
||||
# Use concatenated content for reranking
|
||||
content = doc.get("content", "")
|
||||
score = doc.get("score", 0.0)
|
||||
document_info = doc.get("document", {})
|
||||
|
|
@ -46,12 +52,14 @@ class RerankerService:
|
|||
reranker_docs.append(
|
||||
RerankerDocument(
|
||||
text=content,
|
||||
doc_id=chunk_id,
|
||||
doc_id=doc_id,
|
||||
metadata={
|
||||
"document_id": document_info.get("id", ""),
|
||||
"document_title": document_info.get("title", ""),
|
||||
"document_type": document_info.get("document_type", ""),
|
||||
"rrf_score": score,
|
||||
# Track original index for fallback matching
|
||||
"original_index": i,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -62,21 +70,33 @@ class RerankerService:
|
|||
)
|
||||
|
||||
# Process the results from the reranker
|
||||
# Convert to serializable dictionaries
|
||||
# Convert to serializable dictionaries while preserving full structure
|
||||
serialized_results = []
|
||||
for result in reranking_results.results:
|
||||
# Find the original document by id
|
||||
original_doc = next(
|
||||
(
|
||||
doc
|
||||
for doc in documents
|
||||
if doc.get("chunk_id") == result.document.doc_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
result_doc_id = result.document.doc_id
|
||||
original_index = result.document.metadata.get("original_index")
|
||||
|
||||
# Find the original document by document_id
|
||||
original_doc = None
|
||||
for doc in documents:
|
||||
if doc.get("document_id") == result_doc_id:
|
||||
original_doc = doc
|
||||
break
|
||||
|
||||
# Fallback to original index if ID matching fails
|
||||
if (
|
||||
original_doc is None
|
||||
and original_index is not None
|
||||
and 0 <= original_index < len(documents)
|
||||
):
|
||||
original_doc = documents[original_index]
|
||||
|
||||
if original_doc:
|
||||
# Create a new document with the reranked score
|
||||
# Create a deep copy to preserve the full structure including chunks
|
||||
reranked_doc = original_doc.copy()
|
||||
# Preserve chunks list if present (important for citation formatting)
|
||||
if "chunks" in original_doc:
|
||||
reranked_doc["chunks"] = original_doc["chunks"]
|
||||
reranked_doc["score"] = float(result.score)
|
||||
reranked_doc["rank"] = result.rank
|
||||
serialized_results.append(reranked_doc)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue