Document chunks not stored in vector store (#665)

- Schema - ChunkEmbeddings now uses chunk_id: str instead of chunk: bytes
- Schema - DocumentEmbeddingsResponse now returns chunk_ids: list[str]
  instead of chunks
- Translators - Updated to serialize/deserialize chunk_id
- Clients - DocumentEmbeddingsClient.query() returns chunk_ids
- SDK/API - flow.py, socket_client.py, bulk_client.py updated
- Document embeddings service - Stores chunk_id (document ID) instead
  of chunk text
- Storage writers - Qdrant, Milvus, Pinecone store chunk_id in payload
- Query services - Return chunk_id from vector store searches
- Gateway dispatchers - Serialize chunk_id in API responses
- Document RAG - Added librarian client to fetch chunk content from
  Garage using chunk_ids
- CLI tools - Updated all three tools:
  - invoke_document_embeddings.py - displays chunk_ids, removed
    max_chunk_length
  - save_doc_embeds.py - exports chunk_id
  - load_doc_embeds.py - imports chunk_id
This commit is contained in:
cybermaggedon 2026-03-07 23:10:45 +00:00 committed by GitHub
parent be358efe67
commit 24bbe94136
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 331 additions and 91 deletions

View file

@ -0,0 +1,136 @@
# Document Embeddings Chunk ID
## Overview
Document embeddings storage currently stores chunk text directly in the vector store payload, duplicating data that exists in Garage. This spec replaces chunk text storage with `chunk_id` references.
## Current State
```python
@dataclass
class ChunkEmbeddings:
chunk: bytes = b""
vectors: list[list[float]] = field(default_factory=list)
@dataclass
class DocumentEmbeddingsResponse:
error: Error | None = None
chunks: list[str] = field(default_factory=list)
```
Vector store payload:
```python
payload={"doc": chunk} # Duplicates Garage content
```
## Design
### Schema Changes
**ChunkEmbeddings** - replace chunk with chunk_id:
```python
@dataclass
class ChunkEmbeddings:
chunk_id: str = ""
vectors: list[list[float]] = field(default_factory=list)
```
**DocumentEmbeddingsResponse** - return chunk_ids instead of chunks:
```python
@dataclass
class DocumentEmbeddingsResponse:
error: Error | None = None
chunk_ids: list[str] = field(default_factory=list)
```
### Vector Store Payload
All stores (Qdrant, Milvus, Pinecone):
```python
payload={"chunk_id": chunk_id}
```
### Document RAG Changes
The document RAG processor fetches chunk content from Garage:
```python
# Get chunk_ids from embeddings store
chunk_ids = await self.rag.doc_embeddings_client.query(...)
# Fetch chunk content from Garage
docs = []
for chunk_id in chunk_ids:
content = await self.rag.librarian_client.get_document_content(
chunk_id, self.user
)
docs.append(content)
```
### API/SDK Changes
**DocumentEmbeddingsClient** returns chunk_ids:
```python
return resp.chunk_ids # Changed from resp.chunks
```
**Wire format** (DocumentEmbeddingsResponseTranslator):
```python
result["chunk_ids"] = obj.chunk_ids # Changed from chunks
```
### CLI Changes
CLI tool displays chunk_ids (callers can fetch content separately if needed).
## Files to Modify
### Schema
- `trustgraph-base/trustgraph/schema/knowledge/embeddings.py` - ChunkEmbeddings
- `trustgraph-base/trustgraph/schema/services/query.py` - DocumentEmbeddingsResponse
### Messaging/Translators
- `trustgraph-base/trustgraph/messaging/translators/embeddings_query.py` - DocumentEmbeddingsResponseTranslator
### Client
- `trustgraph-base/trustgraph/base/document_embeddings_client.py` - return chunk_ids
### Python SDK/API
- `trustgraph-base/trustgraph/api/flow.py` - document_embeddings_query
- `trustgraph-base/trustgraph/api/socket_client.py` - document_embeddings_query
- `trustgraph-base/trustgraph/api/async_flow.py` - if applicable
- `trustgraph-base/trustgraph/api/bulk_client.py` - import/export document embeddings
- `trustgraph-base/trustgraph/api/async_bulk_client.py` - import/export document embeddings
### Embeddings Service
- `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py` - pass chunk_id
### Storage Writers
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
### Query Services
- `trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py`
- `trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py`
- `trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py`
### Gateway
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py`
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py`
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py`
### Document RAG
- `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` - add librarian client
- `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` - fetch from Garage
### CLI
- `trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py`
- `trustgraph-cli/trustgraph/cli/save_doc_embeds.py`
- `trustgraph-cli/trustgraph/cli/load_doc_embeds.py`
## Benefits
1. Single source of truth - chunk text only in Garage
2. Reduced vector store storage
3. Enables query-time provenance via chunk_id

View file

@ -322,8 +322,8 @@ class BulkClient:
# Generate document embeddings to import # Generate document embeddings to import
def doc_embedding_generator(): def doc_embedding_generator():
yield {"id": "doc1-chunk1", "embedding": [0.1, 0.2, ...]} yield {"chunk_id": "doc1/p0/c0", "embedding": [0.1, 0.2, ...]}
yield {"id": "doc1-chunk2", "embedding": [0.3, 0.4, ...]} yield {"chunk_id": "doc1/p0/c1", "embedding": [0.3, 0.4, ...]}
# ... more embeddings # ... more embeddings
bulk.import_document_embeddings( bulk.import_document_embeddings(
@ -363,9 +363,9 @@ class BulkClient:
# Export and process document embeddings # Export and process document embeddings
for embedding in bulk.export_document_embeddings(flow="default"): for embedding in bulk.export_document_embeddings(flow="default"):
doc_id = embedding.get("id") chunk_id = embedding.get("chunk_id")
vector = embedding.get("embedding") vector = embedding.get("embedding")
print(f"{doc_id}: {len(vector)} dimensions") print(f"{chunk_id}: {len(vector)} dimensions")
``` ```
""" """
async_gen = self._export_document_embeddings_async(flow) async_gen = self._export_document_embeddings_async(flow)

View file

@ -634,7 +634,7 @@ class FlowInstance:
limit: Maximum number of results (default: 10) limit: Maximum number of results (default: 10)
Returns: Returns:
dict: Query results with similar document chunks dict: Query results with chunk_ids of matching document chunks
Example: Example:
```python ```python
@ -645,6 +645,7 @@ class FlowInstance:
collection="research-papers", collection="research-papers",
limit=5 limit=5
) )
# results contains {"chunk_ids": ["doc1/p0/c0", "doc2/p1/c3", ...]}
``` ```
""" """

View file

@ -682,7 +682,7 @@ class SocketFlowInstance:
**kwargs: Additional parameters passed to the service **kwargs: Additional parameters passed to the service
Returns: Returns:
dict: Query results with similar document chunks dict: Query results with chunk_ids of matching document chunks
Example: Example:
```python ```python
@ -695,6 +695,7 @@ class SocketFlowInstance:
collection="research-papers", collection="research-papers",
limit=5 limit=5
) )
# results contains {"chunk_ids": ["doc1/p0/c0", ...]}
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embeddings vectors

View file

@ -27,7 +27,7 @@ class DocumentEmbeddingsClient(RequestResponse):
if resp.error: if resp.error:
raise RuntimeError(resp.error.message) raise RuntimeError(resp.error.message)
return resp.chunks return resp.chunk_ids
class DocumentEmbeddingsClientSpec(RequestResponseSpec): class DocumentEmbeddingsClientSpec(RequestResponseSpec):
def __init__( def __init__(

View file

@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
docs = await self.query_document_embeddings(request) docs = await self.query_document_embeddings(request)
logger.debug("Sending document embeddings query response...") logger.debug("Sending document embeddings query response...")
r = DocumentEmbeddingsResponse(chunks=docs, error=None) r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None)
await flow("response").send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})
logger.debug("Document embeddings query request completed") logger.debug("Document embeddings query request completed")
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
type = "document-embeddings-query-error", type = "document-embeddings-query-error",
message = str(e), message = str(e),
), ),
chunks=None, chunk_ids=[],
) )
await flow("response").send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})

View file

@ -144,15 +144,15 @@ class DocumentEmbeddingsTranslator(SendTranslator):
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddings: def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddings:
metadata = data.get("metadata", {}) metadata = data.get("metadata", {})
chunks = [ chunks = [
ChunkEmbeddings( ChunkEmbeddings(
chunk=chunk["chunk"].encode("utf-8") if isinstance(chunk["chunk"], str) else chunk["chunk"], chunk_id=chunk["chunk_id"],
vectors=chunk["vectors"] vectors=chunk["vectors"]
) )
for chunk in data.get("chunks", []) for chunk in data.get("chunks", [])
] ]
from ...schema import Metadata from ...schema import Metadata
return DocumentEmbeddings( return DocumentEmbeddings(
metadata=Metadata( metadata=Metadata(
@ -168,7 +168,7 @@ class DocumentEmbeddingsTranslator(SendTranslator):
result = { result = {
"chunks": [ "chunks": [
{ {
"chunk": chunk.chunk.decode("utf-8") if isinstance(chunk.chunk, bytes) else chunk.chunk, "chunk_id": chunk.chunk_id,
"vectors": chunk.vectors "vectors": chunk.vectors
} }
for chunk in obj.chunks for chunk in obj.chunks

View file

@ -36,13 +36,10 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]: def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
result = {} result = {}
if obj.chunks is not None: if obj.chunk_ids is not None:
result["chunks"] = [ result["chunk_ids"] = list(obj.chunk_ids)
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
for chunk in obj.chunks
]
return result return result
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -27,7 +27,7 @@ class GraphEmbeddings:
@dataclass @dataclass
class ChunkEmbeddings: class ChunkEmbeddings:
chunk: bytes = b"" chunk_id: str = ""
vectors: list[list[float]] = field(default_factory=list) vectors: list[list[float]] = field(default_factory=list)
# This is a 'batching' mechanism for the above data # This is a 'batching' mechanism for the above data

View file

@ -52,7 +52,7 @@ class DocumentEmbeddingsRequest:
@dataclass @dataclass
class DocumentEmbeddingsResponse: class DocumentEmbeddingsResponse:
error: Error | None = None error: Error | None = None
chunks: list[str] = field(default_factory=list) chunk_ids: list[str] = field(default_factory=list)
document_embeddings_request_queue = topic( document_embeddings_request_queue = topic(
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow' "document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'

View file

@ -1,6 +1,6 @@
""" """
Queries document chunks by text similarity using vector embeddings. Queries document chunks by text similarity using vector embeddings.
Returns a list of matching document chunks, truncated to the specified length. Returns a list of matching chunk IDs.
""" """
import argparse import argparse
@ -10,13 +10,7 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def truncate_chunk(chunk, max_length): def query(url, flow_id, query_text, user, collection, limit, token=None):
"""Truncate a chunk to max_length characters, adding ellipsis if needed."""
if len(chunk) <= max_length:
return chunk
return chunk[:max_length] + "..."
def query(url, flow_id, query_text, user, collection, limit, max_chunk_length, token=None):
# Create API client # Create API client
api = Api(url=url, token=token) api = Api(url=url, token=token)
@ -32,10 +26,12 @@ def query(url, flow_id, query_text, user, collection, limit, max_chunk_length, t
limit=limit limit=limit
) )
chunks = result.get("chunks", []) chunk_ids = result.get("chunk_ids", [])
for i, chunk in enumerate(chunks, 1): if not chunk_ids:
truncated = truncate_chunk(chunk, max_chunk_length) print("No matching chunks found.")
print(f"{i}. {truncated}") else:
for i, chunk_id in enumerate(chunk_ids, 1):
print(f"{i}. {chunk_id}")
finally: finally:
# Clean up socket connection # Clean up socket connection
@ -85,13 +81,6 @@ def main():
help='Maximum number of results (default: 10)', help='Maximum number of results (default: 10)',
) )
parser.add_argument(
'--max-chunk-length',
type=int,
default=200,
help='Truncate chunks to N characters (default: 200)',
)
parser.add_argument( parser.add_argument(
'query', 'query',
nargs=1, nargs=1,
@ -109,7 +98,6 @@ def main():
user=args.user, user=args.user,
collection=args.collection, collection=args.collection,
limit=args.limit, limit=args.limit,
max_chunk_length=args.max_chunk_length,
token=args.token, token=args.token,
) )

View file

@ -44,14 +44,14 @@ async def load_de(running, queue, url):
msg = { msg = {
"metadata": { "metadata": {
"id": msg["m"]["i"], "id": msg["m"]["i"],
"metadata": msg["m"]["m"], "metadata": msg["m"]["m"],
"user": msg["m"]["u"], "user": msg["m"]["u"],
"collection": msg["m"]["c"], "collection": msg["m"]["c"],
}, },
"chunks": [ "chunks": [
{ {
"chunk": chunk["c"], "chunk_id": chunk["c"],
"vectors": chunk["v"], "vectors": chunk["v"],
} }
for chunk in msg["c"] for chunk in msg["c"]

View file

@ -50,14 +50,14 @@ async def fetch_de(running, queue, user, collection, url):
"de", "de",
{ {
"m": { "m": {
"i": data["metadata"]["id"], "i": data["metadata"]["id"],
"m": data["metadata"]["metadata"], "m": data["metadata"]["metadata"],
"u": data["metadata"]["user"], "u": data["metadata"]["user"],
"c": data["metadata"]["collection"], "c": data["metadata"]["collection"],
}, },
"c": [ "c": [
{ {
"c": chunk["chunk"], "c": chunk["chunk_id"],
"v": chunk["vectors"], "v": chunk["vectors"],
} }
for chunk in data["chunks"] for chunk in data["chunks"]

View file

@ -84,14 +84,14 @@ class DocVectors:
dim=dimension, dim=dimension,
) )
doc_field = FieldSchema( chunk_id_field = FieldSchema(
name="doc", name="chunk_id",
dtype=DataType.VARCHAR, dtype=DataType.VARCHAR,
max_length=65535, max_length=65535,
) )
schema = CollectionSchema( schema = CollectionSchema(
fields = [pkey_field, vec_field, doc_field], fields = [pkey_field, vec_field, chunk_id_field],
description = "Document embedding schema", description = "Document embedding schema",
) )
@ -119,17 +119,17 @@ class DocVectors:
self.collections[(dimension, user, collection)] = collection_name self.collections[(dimension, user, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, doc, user, collection): def insert(self, embeds, chunk_id, user, collection):
dim = len(embeds) dim = len(embeds)
if (dim, user, collection) not in self.collections: if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection) self.init_collection(dim, user, collection)
data = [ data = [
{ {
"vector": embeds, "vector": embeds,
"doc": doc, "chunk_id": chunk_id,
} }
] ]
@ -138,7 +138,7 @@ class DocVectors:
data=data data=data
) )
def search(self, embeds, user, collection, fields=["doc"], limit=10): def search(self, embeds, user, collection, fields=["chunk_id"], limit=10):
dim = len(embeds) dim = len(embeds)

View file

@ -70,7 +70,7 @@ class Processor(FlowProcessor):
embeds = [ embeds = [
ChunkEmbeddings( ChunkEmbeddings(
chunk=v.chunk, chunk_id=v.document_id,
vectors=vectors, vectors=vectors,
) )
] ]

View file

@ -89,7 +89,7 @@ def serialize_document_embeddings(message):
"chunks": [ "chunks": [
{ {
"vectors": chunk.vectors, "vectors": chunk.vectors,
"chunk": chunk.chunk.decode("utf-8"), "chunk_id": chunk.chunk_id,
} }
for chunk in message.chunks for chunk in message.chunks
], ],

View file

@ -1,7 +1,7 @@
""" """
Document embeddings query service. Input is vector, output is an array Document embeddings query service. Input is vector, output is an array
of chunks of chunk_ids
""" """
import logging import logging
@ -39,22 +39,22 @@ class Processor(DocumentEmbeddingsQueryService):
if msg.limit <= 0: if msg.limit <= 0:
return [] return []
chunks = [] chunk_ids = []
for vec in msg.vectors: for vec in msg.vectors:
resp = self.vecstore.search( resp = self.vecstore.search(
vec, vec,
msg.user, msg.user,
msg.collection, msg.collection,
limit=msg.limit limit=msg.limit
) )
for r in resp: for r in resp:
chunk = r["entity"]["doc"] chunk_id = r["entity"]["chunk_id"]
chunks.append(chunk) chunk_ids.append(chunk_id)
return chunks return chunk_ids
except Exception as e: except Exception as e:

View file

@ -1,7 +1,7 @@
""" """
Document embeddings query service. Input is vector, output is an array Document embeddings query service. Input is vector, output is an array
of chunks. Pinecone implementation. of chunk_ids. Pinecone implementation.
""" """
import logging import logging
@ -55,7 +55,7 @@ class Processor(DocumentEmbeddingsQueryService):
if msg.limit <= 0: if msg.limit <= 0:
return [] return []
chunks = [] chunk_ids = []
for vec in msg.vectors: for vec in msg.vectors:
@ -79,10 +79,10 @@ class Processor(DocumentEmbeddingsQueryService):
) )
for r in results.matches: for r in results.matches:
doc = r.metadata["doc"] chunk_id = r.metadata["chunk_id"]
chunks.append(doc) chunk_ids.append(chunk_id)
return chunks return chunk_ids
except Exception as e: except Exception as e:

View file

@ -1,7 +1,7 @@
""" """
Document embeddings query service. Input is vector, output is an array Document embeddings query service. Input is vector, output is an array
of chunks of chunk_ids
""" """
import logging import logging
@ -69,7 +69,7 @@ class Processor(DocumentEmbeddingsQueryService):
try: try:
chunks = [] chunk_ids = []
for vec in msg.vectors: for vec in msg.vectors:
@ -90,10 +90,10 @@ class Processor(DocumentEmbeddingsQueryService):
).points ).points
for r in search_result: for r in search_result:
ent = r.payload["doc"] chunk_id = r.payload["chunk_id"]
chunks.append(ent) chunk_ids.append(chunk_id)
return chunks return chunk_ids
except Exception as e: except Exception as e:

View file

@ -24,7 +24,7 @@ class Query:
if self.verbose: if self.verbose:
logger.debug("Computing embeddings...") logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed(query) qembeds = await self.rag.embeddings_client.embed(query)
if self.verbose: if self.verbose:
logger.debug("Embeddings computed") logger.debug("Embeddings computed")
@ -36,17 +36,31 @@ class Query:
vectors = await self.get_vector(query) vectors = await self.get_vector(query)
if self.verbose: if self.verbose:
logger.debug("Getting documents...") logger.debug("Getting chunk_ids from embeddings store...")
docs = await self.rag.doc_embeddings_client.query( # Get chunk_ids from embeddings store
chunk_ids = await self.rag.doc_embeddings_client.query(
vectors, limit=self.doc_limit, vectors, limit=self.doc_limit,
user=self.user, collection=self.collection, user=self.user, collection=self.collection,
) )
if self.verbose: if self.verbose:
logger.debug("Documents:") logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...")
# Fetch chunk content from Garage
docs = []
for chunk_id in chunk_ids:
if chunk_id:
try:
content = await self.rag.fetch_chunk(chunk_id, self.user)
docs.append(content)
except Exception as e:
logger.warning(f"Failed to fetch chunk {chunk_id}: {e}")
if self.verbose:
logger.debug("Documents fetched:")
for doc in docs: for doc in docs:
logger.debug(f" {doc}") logger.debug(f" {doc[:100]}...")
return docs return docs
@ -54,6 +68,7 @@ class DocumentRag:
def __init__( def __init__(
self, prompt_client, embeddings_client, doc_embeddings_client, self, prompt_client, embeddings_client, doc_embeddings_client,
fetch_chunk,
verbose=False, verbose=False,
): ):
@ -62,6 +77,7 @@ class DocumentRag:
self.prompt_client = prompt_client self.prompt_client = prompt_client
self.embeddings_client = embeddings_client self.embeddings_client = embeddings_client
self.doc_embeddings_client = doc_embeddings_client self.doc_embeddings_client = doc_embeddings_client
self.fetch_chunk = fetch_chunk
if self.verbose: if self.verbose:
logger.debug("DocumentRag initialized") logger.debug("DocumentRag initialized")

View file

@ -4,17 +4,26 @@ Simple RAG service, performs query using document RAG an LLM.
Input is query, output is response. Input is query, output is response.
""" """
import asyncio
import base64
import logging import logging
from ... schema import DocumentRagQuery, DocumentRagResponse, Error from ... schema import DocumentRagQuery, DocumentRagResponse, Error
from ... schema import LibrarianRequest, LibrarianResponse
from ... schema import librarian_request_queue, librarian_response_queue
from . document_rag import DocumentRag from . document_rag import DocumentRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import DocumentEmbeddingsClientSpec from ... base import DocumentEmbeddingsClientSpec
from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
default_ident = "document-rag" default_ident = "document-rag"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor): class Processor(FlowProcessor):
@ -69,6 +78,98 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client for fetching chunk content from Garage
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def fetch_chunk_content(self, chunk_id, user, timeout=120):
"""Fetch chunk content from librarian/Garage."""
import uuid
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=chunk_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
# Content is base64 encoded
content = response.content
if isinstance(content, str):
content = content.encode('utf-8')
return base64.b64decode(content).decode("utf-8")
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching chunk {chunk_id}")
async def on_request(self, msg, consumer, flow): async def on_request(self, msg, consumer, flow):
try: try:
@ -77,6 +178,7 @@ class Processor(FlowProcessor):
embeddings_client = flow("embeddings-request"), embeddings_client = flow("embeddings-request"),
doc_embeddings_client = flow("document-embeddings-request"), doc_embeddings_client = flow("document-embeddings-request"),
prompt_client = flow("prompt-request"), prompt_client = flow("prompt-request"),
fetch_chunk = self.fetch_chunk_content,
verbose=True, verbose=True,
) )

View file

@ -37,14 +37,13 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for emb in message.chunks: for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue chunk_id = emb.chunk_id
if chunk_id == "":
chunk = emb.chunk.decode("utf-8") continue
if chunk == "": continue
for vec in emb.vectors: for vec in emb.vectors:
self.vecstore.insert( self.vecstore.insert(
vec, chunk, vec, chunk_id,
message.metadata.user, message.metadata.user,
message.metadata.collection message.metadata.collection
) )

View file

@ -101,10 +101,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for emb in message.chunks: for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue chunk_id = emb.chunk_id
if chunk_id == "":
chunk = emb.chunk.decode("utf-8") continue
if chunk == "": continue
for vec in emb.vectors: for vec in emb.vectors:
@ -128,7 +127,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
{ {
"id": vector_id, "id": vector_id,
"values": vec, "values": vec,
"metadata": { "doc": chunk }, "metadata": { "chunk_id": chunk_id },
} }
] ]

View file

@ -52,8 +52,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for emb in message.chunks: for emb in message.chunks:
chunk = emb.chunk.decode("utf-8") chunk_id = emb.chunk_id
if chunk == "": return if chunk_id == "":
continue
for vec in emb.vectors: for vec in emb.vectors:
@ -81,7 +82,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
vector=vec, vector=vec,
payload={ payload={
"doc": chunk, "chunk_id": chunk_id,
} }
) )
] ]