mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Document chunks not stored in vector store (#665)
- Schema - ChunkEmbeddings now uses chunk_id: str instead of chunk: bytes
- Schema - DocumentEmbeddingsResponse now returns chunk_ids: list[str]
instead of chunks
- Translators - Updated to serialize/deserialize chunk_id
- Clients - DocumentEmbeddingsClient.query() returns chunk_ids
- SDK/API - flow.py, socket_client.py, bulk_client.py updated
- Document embeddings service - Stores chunk_id (document ID) instead
of chunk text
- Storage writers - Qdrant, Milvus, Pinecone store chunk_id in payload
- Query services - Return chunk_id from vector store searches
- Gateway dispatchers - Serialize chunk_id in API responses
- Document RAG - Added librarian client to fetch chunk content from
Garage using chunk_ids
- CLI tools - Updated all three tools:
- invoke_document_embeddings.py - displays chunk_ids, removed
max_chunk_length
- save_doc_embeds.py - exports chunk_id
- load_doc_embeds.py - imports chunk_id
This commit is contained in:
parent
be358efe67
commit
24bbe94136
24 changed files with 331 additions and 91 deletions
136
docs/tech-specs/document-embeddings-chunk-id.md
Normal file
136
docs/tech-specs/document-embeddings-chunk-id.md
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
# Document Embeddings Chunk ID
|
||||
|
||||
## Overview
|
||||
|
||||
Document embeddings storage currently stores chunk text directly in the vector store payload, duplicating data that exists in Garage. This spec replaces chunk text storage with `chunk_id` references.
|
||||
|
||||
## Current State
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ChunkEmbeddings:
|
||||
chunk: bytes = b""
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class DocumentEmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
chunks: list[str] = field(default_factory=list)
|
||||
```
|
||||
|
||||
Vector store payload:
|
||||
```python
|
||||
payload={"doc": chunk} # Duplicates Garage content
|
||||
```
|
||||
|
||||
## Design
|
||||
|
||||
### Schema Changes
|
||||
|
||||
**ChunkEmbeddings** - replace chunk with chunk_id:
|
||||
```python
|
||||
@dataclass
|
||||
class ChunkEmbeddings:
|
||||
chunk_id: str = ""
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
```
|
||||
|
||||
**DocumentEmbeddingsResponse** - return chunk_ids instead of chunks:
|
||||
```python
|
||||
@dataclass
|
||||
class DocumentEmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
chunk_ids: list[str] = field(default_factory=list)
|
||||
```
|
||||
|
||||
### Vector Store Payload
|
||||
|
||||
All stores (Qdrant, Milvus, Pinecone):
|
||||
```python
|
||||
payload={"chunk_id": chunk_id}
|
||||
```
|
||||
|
||||
### Document RAG Changes
|
||||
|
||||
The document RAG processor fetches chunk content from Garage:
|
||||
|
||||
```python
|
||||
# Get chunk_ids from embeddings store
|
||||
chunk_ids = await self.rag.doc_embeddings_client.query(...)
|
||||
|
||||
# Fetch chunk content from Garage
|
||||
docs = []
|
||||
for chunk_id in chunk_ids:
|
||||
content = await self.rag.librarian_client.get_document_content(
|
||||
chunk_id, self.user
|
||||
)
|
||||
docs.append(content)
|
||||
```
|
||||
|
||||
### API/SDK Changes
|
||||
|
||||
**DocumentEmbeddingsClient** returns chunk_ids:
|
||||
```python
|
||||
return resp.chunk_ids # Changed from resp.chunks
|
||||
```
|
||||
|
||||
**Wire format** (DocumentEmbeddingsResponseTranslator):
|
||||
```python
|
||||
result["chunk_ids"] = obj.chunk_ids # Changed from chunks
|
||||
```
|
||||
|
||||
### CLI Changes
|
||||
|
||||
CLI tool displays chunk_ids (callers can fetch content separately if needed).
|
||||
|
||||
## Files to Modify
|
||||
|
||||
### Schema
|
||||
- `trustgraph-base/trustgraph/schema/knowledge/embeddings.py` - ChunkEmbeddings
|
||||
- `trustgraph-base/trustgraph/schema/services/query.py` - DocumentEmbeddingsResponse
|
||||
|
||||
### Messaging/Translators
|
||||
- `trustgraph-base/trustgraph/messaging/translators/embeddings_query.py` - DocumentEmbeddingsResponseTranslator
|
||||
|
||||
### Client
|
||||
- `trustgraph-base/trustgraph/base/document_embeddings_client.py` - return chunk_ids
|
||||
|
||||
### Python SDK/API
|
||||
- `trustgraph-base/trustgraph/api/flow.py` - document_embeddings_query
|
||||
- `trustgraph-base/trustgraph/api/socket_client.py` - document_embeddings_query
|
||||
- `trustgraph-base/trustgraph/api/async_flow.py` - if applicable
|
||||
- `trustgraph-base/trustgraph/api/bulk_client.py` - import/export document embeddings
|
||||
- `trustgraph-base/trustgraph/api/async_bulk_client.py` - import/export document embeddings
|
||||
|
||||
### Embeddings Service
|
||||
- `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py` - pass chunk_id
|
||||
|
||||
### Storage Writers
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
|
||||
|
||||
### Query Services
|
||||
- `trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py`
|
||||
- `trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py`
|
||||
- `trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py`
|
||||
|
||||
### Gateway
|
||||
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py`
|
||||
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py`
|
||||
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py`
|
||||
|
||||
### Document RAG
|
||||
- `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` - add librarian client
|
||||
- `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` - fetch from Garage
|
||||
|
||||
### CLI
|
||||
- `trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py`
|
||||
- `trustgraph-cli/trustgraph/cli/save_doc_embeds.py`
|
||||
- `trustgraph-cli/trustgraph/cli/load_doc_embeds.py`
|
||||
|
||||
## Benefits
|
||||
|
||||
1. Single source of truth - chunk text only in Garage
|
||||
2. Reduced vector store storage
|
||||
3. Enables query-time provenance via chunk_id
|
||||
|
|
@ -322,8 +322,8 @@ class BulkClient:
|
|||
|
||||
# Generate document embeddings to import
|
||||
def doc_embedding_generator():
|
||||
yield {"id": "doc1-chunk1", "embedding": [0.1, 0.2, ...]}
|
||||
yield {"id": "doc1-chunk2", "embedding": [0.3, 0.4, ...]}
|
||||
yield {"chunk_id": "doc1/p0/c0", "embedding": [0.1, 0.2, ...]}
|
||||
yield {"chunk_id": "doc1/p0/c1", "embedding": [0.3, 0.4, ...]}
|
||||
# ... more embeddings
|
||||
|
||||
bulk.import_document_embeddings(
|
||||
|
|
@ -363,9 +363,9 @@ class BulkClient:
|
|||
|
||||
# Export and process document embeddings
|
||||
for embedding in bulk.export_document_embeddings(flow="default"):
|
||||
doc_id = embedding.get("id")
|
||||
chunk_id = embedding.get("chunk_id")
|
||||
vector = embedding.get("embedding")
|
||||
print(f"{doc_id}: {len(vector)} dimensions")
|
||||
print(f"{chunk_id}: {len(vector)} dimensions")
|
||||
```
|
||||
"""
|
||||
async_gen = self._export_document_embeddings_async(flow)
|
||||
|
|
|
|||
|
|
@ -634,7 +634,7 @@ class FlowInstance:
|
|||
limit: Maximum number of results (default: 10)
|
||||
|
||||
Returns:
|
||||
dict: Query results with similar document chunks
|
||||
dict: Query results with chunk_ids of matching document chunks
|
||||
|
||||
Example:
|
||||
```python
|
||||
|
|
@ -645,6 +645,7 @@ class FlowInstance:
|
|||
collection="research-papers",
|
||||
limit=5
|
||||
)
|
||||
# results contains {"chunk_ids": ["doc1/p0/c0", "doc2/p1/c3", ...]}
|
||||
```
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -682,7 +682,7 @@ class SocketFlowInstance:
|
|||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Returns:
|
||||
dict: Query results with similar document chunks
|
||||
dict: Query results with chunk_ids of matching document chunks
|
||||
|
||||
Example:
|
||||
```python
|
||||
|
|
@ -695,6 +695,7 @@ class SocketFlowInstance:
|
|||
collection="research-papers",
|
||||
limit=5
|
||||
)
|
||||
# results contains {"chunk_ids": ["doc1/p0/c0", ...]}
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class DocumentEmbeddingsClient(RequestResponse):
|
|||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.chunks
|
||||
return resp.chunk_ids
|
||||
|
||||
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
|||
docs = await self.query_document_embeddings(request)
|
||||
|
||||
logger.debug("Sending document embeddings query response...")
|
||||
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
|
||||
r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
logger.debug("Document embeddings query request completed")
|
||||
|
|
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
|||
type = "document-embeddings-query-error",
|
||||
message = str(e),
|
||||
),
|
||||
chunks=None,
|
||||
chunk_ids=[],
|
||||
)
|
||||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
|
|
|||
|
|
@ -144,15 +144,15 @@ class DocumentEmbeddingsTranslator(SendTranslator):
|
|||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddings:
|
||||
metadata = data.get("metadata", {})
|
||||
|
||||
|
||||
chunks = [
|
||||
ChunkEmbeddings(
|
||||
chunk=chunk["chunk"].encode("utf-8") if isinstance(chunk["chunk"], str) else chunk["chunk"],
|
||||
chunk_id=chunk["chunk_id"],
|
||||
vectors=chunk["vectors"]
|
||||
)
|
||||
for chunk in data.get("chunks", [])
|
||||
]
|
||||
|
||||
|
||||
from ...schema import Metadata
|
||||
return DocumentEmbeddings(
|
||||
metadata=Metadata(
|
||||
|
|
@ -168,7 +168,7 @@ class DocumentEmbeddingsTranslator(SendTranslator):
|
|||
result = {
|
||||
"chunks": [
|
||||
{
|
||||
"chunk": chunk.chunk.decode("utf-8") if isinstance(chunk.chunk, bytes) else chunk.chunk,
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"vectors": chunk.vectors
|
||||
}
|
||||
for chunk in obj.chunks
|
||||
|
|
|
|||
|
|
@ -36,13 +36,10 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.chunks is not None:
|
||||
result["chunks"] = [
|
||||
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
|
||||
for chunk in obj.chunks
|
||||
]
|
||||
|
||||
|
||||
if obj.chunk_ids is not None:
|
||||
result["chunk_ids"] = list(obj.chunk_ids)
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class GraphEmbeddings:
|
|||
|
||||
@dataclass
|
||||
class ChunkEmbeddings:
|
||||
chunk: bytes = b""
|
||||
chunk_id: str = ""
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
|
||||
# This is a 'batching' mechanism for the above data
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class DocumentEmbeddingsRequest:
|
|||
@dataclass
|
||||
class DocumentEmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
chunks: list[str] = field(default_factory=list)
|
||||
chunk_ids: list[str] = field(default_factory=list)
|
||||
|
||||
document_embeddings_request_queue = topic(
|
||||
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Queries document chunks by text similarity using vector embeddings.
|
||||
Returns a list of matching document chunks, truncated to the specified length.
|
||||
Returns a list of matching chunk IDs.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -10,13 +10,7 @@ from trustgraph.api import Api
|
|||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
||||
def truncate_chunk(chunk, max_length):
|
||||
"""Truncate a chunk to max_length characters, adding ellipsis if needed."""
|
||||
if len(chunk) <= max_length:
|
||||
return chunk
|
||||
return chunk[:max_length] + "..."
|
||||
|
||||
def query(url, flow_id, query_text, user, collection, limit, max_chunk_length, token=None):
|
||||
def query(url, flow_id, query_text, user, collection, limit, token=None):
|
||||
|
||||
# Create API client
|
||||
api = Api(url=url, token=token)
|
||||
|
|
@ -32,10 +26,12 @@ def query(url, flow_id, query_text, user, collection, limit, max_chunk_length, t
|
|||
limit=limit
|
||||
)
|
||||
|
||||
chunks = result.get("chunks", [])
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
truncated = truncate_chunk(chunk, max_chunk_length)
|
||||
print(f"{i}. {truncated}")
|
||||
chunk_ids = result.get("chunk_ids", [])
|
||||
if not chunk_ids:
|
||||
print("No matching chunks found.")
|
||||
else:
|
||||
for i, chunk_id in enumerate(chunk_ids, 1):
|
||||
print(f"{i}. {chunk_id}")
|
||||
|
||||
finally:
|
||||
# Clean up socket connection
|
||||
|
|
@ -85,13 +81,6 @@ def main():
|
|||
help='Maximum number of results (default: 10)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-chunk-length',
|
||||
type=int,
|
||||
default=200,
|
||||
help='Truncate chunks to N characters (default: 200)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'query',
|
||||
nargs=1,
|
||||
|
|
@ -109,7 +98,6 @@ def main():
|
|||
user=args.user,
|
||||
collection=args.collection,
|
||||
limit=args.limit,
|
||||
max_chunk_length=args.max_chunk_length,
|
||||
token=args.token,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -44,14 +44,14 @@ async def load_de(running, queue, url):
|
|||
|
||||
msg = {
|
||||
"metadata": {
|
||||
"id": msg["m"]["i"],
|
||||
"id": msg["m"]["i"],
|
||||
"metadata": msg["m"]["m"],
|
||||
"user": msg["m"]["u"],
|
||||
"collection": msg["m"]["c"],
|
||||
},
|
||||
"chunks": [
|
||||
{
|
||||
"chunk": chunk["c"],
|
||||
"chunk_id": chunk["c"],
|
||||
"vectors": chunk["v"],
|
||||
}
|
||||
for chunk in msg["c"]
|
||||
|
|
|
|||
|
|
@ -50,14 +50,14 @@ async def fetch_de(running, queue, user, collection, url):
|
|||
"de",
|
||||
{
|
||||
"m": {
|
||||
"i": data["metadata"]["id"],
|
||||
"i": data["metadata"]["id"],
|
||||
"m": data["metadata"]["metadata"],
|
||||
"u": data["metadata"]["user"],
|
||||
"c": data["metadata"]["collection"],
|
||||
},
|
||||
"c": [
|
||||
{
|
||||
"c": chunk["chunk"],
|
||||
"c": chunk["chunk_id"],
|
||||
"v": chunk["vectors"],
|
||||
}
|
||||
for chunk in data["chunks"]
|
||||
|
|
|
|||
|
|
@ -84,14 +84,14 @@ class DocVectors:
|
|||
dim=dimension,
|
||||
)
|
||||
|
||||
doc_field = FieldSchema(
|
||||
name="doc",
|
||||
chunk_id_field = FieldSchema(
|
||||
name="chunk_id",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [pkey_field, vec_field, doc_field],
|
||||
fields = [pkey_field, vec_field, chunk_id_field],
|
||||
description = "Document embedding schema",
|
||||
)
|
||||
|
||||
|
|
@ -119,17 +119,17 @@ class DocVectors:
|
|||
self.collections[(dimension, user, collection)] = collection_name
|
||||
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
|
||||
|
||||
def insert(self, embeds, doc, user, collection):
|
||||
def insert(self, embeds, chunk_id, user, collection):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if (dim, user, collection) not in self.collections:
|
||||
self.init_collection(dim, user, collection)
|
||||
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"doc": doc,
|
||||
"chunk_id": chunk_id,
|
||||
}
|
||||
]
|
||||
|
||||
|
|
@ -138,7 +138,7 @@ class DocVectors:
|
|||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, user, collection, fields=["doc"], limit=10):
|
||||
def search(self, embeds, user, collection, fields=["chunk_id"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class Processor(FlowProcessor):
|
|||
|
||||
embeds = [
|
||||
ChunkEmbeddings(
|
||||
chunk=v.chunk,
|
||||
chunk_id=v.document_id,
|
||||
vectors=vectors,
|
||||
)
|
||||
]
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ def serialize_document_embeddings(message):
|
|||
"chunks": [
|
||||
{
|
||||
"vectors": chunk.vectors,
|
||||
"chunk": chunk.chunk.decode("utf-8"),
|
||||
"chunk_id": chunk.chunk_id,
|
||||
}
|
||||
for chunk in message.chunks
|
||||
],
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
"""
|
||||
Document embeddings query service. Input is vector, output is an array
|
||||
of chunks
|
||||
of chunk_ids
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -39,22 +39,22 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
chunk_ids = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
resp = self.vecstore.search(
|
||||
vec,
|
||||
msg.user,
|
||||
msg.collection,
|
||||
vec,
|
||||
msg.user,
|
||||
msg.collection,
|
||||
limit=msg.limit
|
||||
)
|
||||
|
||||
for r in resp:
|
||||
chunk = r["entity"]["doc"]
|
||||
chunks.append(chunk)
|
||||
chunk_id = r["entity"]["chunk_id"]
|
||||
chunk_ids.append(chunk_id)
|
||||
|
||||
return chunks
|
||||
return chunk_ids
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
"""
|
||||
Document embeddings query service. Input is vector, output is an array
|
||||
of chunks. Pinecone implementation.
|
||||
of chunk_ids. Pinecone implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -55,7 +55,7 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
chunk_ids = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
|
|
@ -79,10 +79,10 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
)
|
||||
|
||||
for r in results.matches:
|
||||
doc = r.metadata["doc"]
|
||||
chunks.append(doc)
|
||||
chunk_id = r.metadata["chunk_id"]
|
||||
chunk_ids.append(chunk_id)
|
||||
|
||||
return chunks
|
||||
return chunk_ids
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
"""
|
||||
Document embeddings query service. Input is vector, output is an array
|
||||
of chunks
|
||||
of chunk_ids
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -69,7 +69,7 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
|
||||
try:
|
||||
|
||||
chunks = []
|
||||
chunk_ids = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
|
|
@ -90,10 +90,10 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
).points
|
||||
|
||||
for r in search_result:
|
||||
ent = r.payload["doc"]
|
||||
chunks.append(ent)
|
||||
chunk_id = r.payload["chunk_id"]
|
||||
chunk_ids.append(chunk_id)
|
||||
|
||||
return chunks
|
||||
return chunk_ids
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class Query:
|
|||
if self.verbose:
|
||||
logger.debug("Computing embeddings...")
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed(query)
|
||||
qembeds = await self.rag.embeddings_client.embed(query)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Embeddings computed")
|
||||
|
|
@ -36,17 +36,31 @@ class Query:
|
|||
vectors = await self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting documents...")
|
||||
logger.debug("Getting chunk_ids from embeddings store...")
|
||||
|
||||
docs = await self.rag.doc_embeddings_client.query(
|
||||
# Get chunk_ids from embeddings store
|
||||
chunk_ids = await self.rag.doc_embeddings_client.query(
|
||||
vectors, limit=self.doc_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Documents:")
|
||||
logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...")
|
||||
|
||||
# Fetch chunk content from Garage
|
||||
docs = []
|
||||
for chunk_id in chunk_ids:
|
||||
if chunk_id:
|
||||
try:
|
||||
content = await self.rag.fetch_chunk(chunk_id, self.user)
|
||||
docs.append(content)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch chunk {chunk_id}: {e}")
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Documents fetched:")
|
||||
for doc in docs:
|
||||
logger.debug(f" {doc}")
|
||||
logger.debug(f" {doc[:100]}...")
|
||||
|
||||
return docs
|
||||
|
||||
|
|
@ -54,6 +68,7 @@ class DocumentRag:
|
|||
|
||||
def __init__(
|
||||
self, prompt_client, embeddings_client, doc_embeddings_client,
|
||||
fetch_chunk,
|
||||
verbose=False,
|
||||
):
|
||||
|
||||
|
|
@ -62,6 +77,7 @@ class DocumentRag:
|
|||
self.prompt_client = prompt_client
|
||||
self.embeddings_client = embeddings_client
|
||||
self.doc_embeddings_client = doc_embeddings_client
|
||||
self.fetch_chunk = fetch_chunk
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("DocumentRag initialized")
|
||||
|
|
|
|||
|
|
@ -4,17 +4,26 @@ Simple RAG service, performs query using document RAG an LLM.
|
|||
Input is query, output is response.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
|
||||
from ... schema import LibrarianRequest, LibrarianResponse
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from . document_rag import DocumentRag
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import DocumentEmbeddingsClientSpec
|
||||
from ... base import Consumer, Producer
|
||||
from ... base import ConsumerMetrics, ProducerMetrics
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "document-rag"
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
|
|
@ -69,6 +78,98 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
# Librarian client for fetching chunk content from Garage
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend=self.pubsub,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
backend=self.pubsub,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self.on_librarian_response,
|
||||
metrics=librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_requests = {}
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id in self.pending_requests:
|
||||
future = self.pending_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def fetch_chunk_content(self, chunk_id, user, timeout=120):
|
||||
"""Fetch chunk content from librarian/Garage."""
|
||||
import uuid
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-content",
|
||||
document_id=chunk_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
# Content is base64 encoded
|
||||
content = response.content
|
||||
if isinstance(content, str):
|
||||
content = content.encode('utf-8')
|
||||
return base64.b64decode(content).decode("utf-8")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout fetching chunk {chunk_id}")
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
|
@ -77,6 +178,7 @@ class Processor(FlowProcessor):
|
|||
embeddings_client = flow("embeddings-request"),
|
||||
doc_embeddings_client = flow("document-embeddings-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
fetch_chunk = self.fetch_chunk_content,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,14 +37,13 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
|
||||
for emb in message.chunks:
|
||||
|
||||
if emb.chunk is None or emb.chunk == b"": continue
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": continue
|
||||
chunk_id = emb.chunk_id
|
||||
if chunk_id == "":
|
||||
continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
self.vecstore.insert(
|
||||
vec, chunk,
|
||||
vec, chunk_id,
|
||||
message.metadata.user,
|
||||
message.metadata.collection
|
||||
)
|
||||
|
|
|
|||
|
|
@ -101,10 +101,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
|
||||
for emb in message.chunks:
|
||||
|
||||
if emb.chunk is None or emb.chunk == b"": continue
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": continue
|
||||
chunk_id = emb.chunk_id
|
||||
if chunk_id == "":
|
||||
continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
|
||||
|
|
@ -128,7 +127,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
{
|
||||
"id": vector_id,
|
||||
"values": vec,
|
||||
"metadata": { "doc": chunk },
|
||||
"metadata": { "chunk_id": chunk_id },
|
||||
}
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -52,8 +52,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
|
||||
for emb in message.chunks:
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": return
|
||||
chunk_id = emb.chunk_id
|
||||
if chunk_id == "":
|
||||
continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
|
||||
|
|
@ -81,7 +82,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
id=str(uuid.uuid4()),
|
||||
vector=vec,
|
||||
payload={
|
||||
"doc": chunk,
|
||||
"chunk_id": chunk_id,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue