Document chunks not stored in vector store (#665)

- Schema - ChunkEmbeddings now uses chunk_id: str instead of chunk: bytes
- Schema - DocumentEmbeddingsResponse now returns chunk_ids: list[str]
  instead of chunks
- Translators - Updated to serialize/deserialize chunk_id
- Clients - DocumentEmbeddingsClient.query() returns chunk_ids
- SDK/API - flow.py, socket_client.py, bulk_client.py updated
- Document embeddings service - Stores chunk_id (document ID) instead
  of chunk text
- Storage writers - Qdrant, Milvus, Pinecone store chunk_id in payload
- Query services - Return chunk_id from vector store searches
- Gateway dispatchers - Serialize chunk_id in API responses
- Document RAG - Added librarian client to fetch chunk content from
  Garage using chunk_ids
- CLI tools - Updated all three tools:
  - invoke_document_embeddings.py - displays chunk_ids, removed
    max_chunk_length
  - save_doc_embeds.py - exports chunk_id
  - load_doc_embeds.py - imports chunk_id
This commit is contained in:
cybermaggedon 2026-03-07 23:10:45 +00:00 committed by GitHub
parent be358efe67
commit 24bbe94136
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 331 additions and 91 deletions

View file

@ -0,0 +1,136 @@
# Document Embeddings Chunk ID
## Overview
Document embeddings storage currently stores chunk text directly in the vector store payload, duplicating data that exists in Garage. This spec replaces chunk text storage with `chunk_id` references.
## Current State
```python
@dataclass
class ChunkEmbeddings:
chunk: bytes = b""
vectors: list[list[float]] = field(default_factory=list)
@dataclass
class DocumentEmbeddingsResponse:
error: Error | None = None
chunks: list[str] = field(default_factory=list)
```
Vector store payload:
```python
payload={"doc": chunk} # Duplicates Garage content
```
## Design
### Schema Changes
**ChunkEmbeddings** - replace chunk with chunk_id:
```python
@dataclass
class ChunkEmbeddings:
chunk_id: str = ""
vectors: list[list[float]] = field(default_factory=list)
```
**DocumentEmbeddingsResponse** - return chunk_ids instead of chunks:
```python
@dataclass
class DocumentEmbeddingsResponse:
error: Error | None = None
chunk_ids: list[str] = field(default_factory=list)
```
### Vector Store Payload
All stores (Qdrant, Milvus, Pinecone):
```python
payload={"chunk_id": chunk_id}
```
### Document RAG Changes
The document RAG processor fetches chunk content from Garage:
```python
# Get chunk_ids from embeddings store
chunk_ids = await self.rag.doc_embeddings_client.query(...)
# Fetch chunk content from Garage
docs = []
for chunk_id in chunk_ids:
content = await self.rag.librarian_client.get_document_content(
chunk_id, self.user
)
docs.append(content)
```
### API/SDK Changes
**DocumentEmbeddingsClient** returns chunk_ids:
```python
return resp.chunk_ids # Changed from resp.chunks
```
**Wire format** (DocumentEmbeddingsResponseTranslator):
```python
result["chunk_ids"] = obj.chunk_ids # Changed from chunks
```
### CLI Changes
CLI tool displays chunk_ids (callers can fetch content separately if needed).
## Files to Modify
### Schema
- `trustgraph-base/trustgraph/schema/knowledge/embeddings.py` - ChunkEmbeddings
- `trustgraph-base/trustgraph/schema/services/query.py` - DocumentEmbeddingsResponse
### Messaging/Translators
- `trustgraph-base/trustgraph/messaging/translators/embeddings_query.py` - DocumentEmbeddingsResponseTranslator
### Client
- `trustgraph-base/trustgraph/base/document_embeddings_client.py` - return chunk_ids
### Python SDK/API
- `trustgraph-base/trustgraph/api/flow.py` - document_embeddings_query
- `trustgraph-base/trustgraph/api/socket_client.py` - document_embeddings_query
- `trustgraph-base/trustgraph/api/async_flow.py` - if applicable
- `trustgraph-base/trustgraph/api/bulk_client.py` - import/export document embeddings
- `trustgraph-base/trustgraph/api/async_bulk_client.py` - import/export document embeddings
### Embeddings Service
- `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py` - pass chunk_id
### Storage Writers
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
### Query Services
- `trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py`
- `trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py`
- `trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py`
### Gateway
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py`
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py`
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py`
### Document RAG
- `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` - add librarian client
- `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` - fetch from Garage
### CLI
- `trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py`
- `trustgraph-cli/trustgraph/cli/save_doc_embeds.py`
- `trustgraph-cli/trustgraph/cli/load_doc_embeds.py`
## Benefits
1. Single source of truth - chunk text only in Garage
2. Reduced vector store storage
3. Enables query-time provenance via chunk_id

View file

@ -322,8 +322,8 @@ class BulkClient:
# Generate document embeddings to import
def doc_embedding_generator():
yield {"id": "doc1-chunk1", "embedding": [0.1, 0.2, ...]}
yield {"id": "doc1-chunk2", "embedding": [0.3, 0.4, ...]}
yield {"chunk_id": "doc1/p0/c0", "embedding": [0.1, 0.2, ...]}
yield {"chunk_id": "doc1/p0/c1", "embedding": [0.3, 0.4, ...]}
# ... more embeddings
bulk.import_document_embeddings(
@ -363,9 +363,9 @@ class BulkClient:
# Export and process document embeddings
for embedding in bulk.export_document_embeddings(flow="default"):
doc_id = embedding.get("id")
chunk_id = embedding.get("chunk_id")
vector = embedding.get("embedding")
print(f"{doc_id}: {len(vector)} dimensions")
print(f"{chunk_id}: {len(vector)} dimensions")
```
"""
async_gen = self._export_document_embeddings_async(flow)

View file

@ -634,7 +634,7 @@ class FlowInstance:
limit: Maximum number of results (default: 10)
Returns:
dict: Query results with similar document chunks
dict: Query results with chunk_ids of matching document chunks
Example:
```python
@ -645,6 +645,7 @@ class FlowInstance:
collection="research-papers",
limit=5
)
# results contains {"chunk_ids": ["doc1/p0/c0", "doc2/p1/c3", ...]}
```
"""

View file

@ -682,7 +682,7 @@ class SocketFlowInstance:
**kwargs: Additional parameters passed to the service
Returns:
dict: Query results with similar document chunks
dict: Query results with chunk_ids of matching document chunks
Example:
```python
@ -695,6 +695,7 @@ class SocketFlowInstance:
collection="research-papers",
limit=5
)
# results contains {"chunk_ids": ["doc1/p0/c0", ...]}
```
"""
# First convert text to embeddings vectors

View file

@ -27,7 +27,7 @@ class DocumentEmbeddingsClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
return resp.chunks
return resp.chunk_ids
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
def __init__(

View file

@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
docs = await self.query_document_embeddings(request)
logger.debug("Sending document embeddings query response...")
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None)
await flow("response").send(r, properties={"id": id})
logger.debug("Document embeddings query request completed")
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
type = "document-embeddings-query-error",
message = str(e),
),
chunks=None,
chunk_ids=[],
)
await flow("response").send(r, properties={"id": id})

View file

@ -144,15 +144,15 @@ class DocumentEmbeddingsTranslator(SendTranslator):
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddings:
metadata = data.get("metadata", {})
chunks = [
ChunkEmbeddings(
chunk=chunk["chunk"].encode("utf-8") if isinstance(chunk["chunk"], str) else chunk["chunk"],
chunk_id=chunk["chunk_id"],
vectors=chunk["vectors"]
)
for chunk in data.get("chunks", [])
]
from ...schema import Metadata
return DocumentEmbeddings(
metadata=Metadata(
@ -168,7 +168,7 @@ class DocumentEmbeddingsTranslator(SendTranslator):
result = {
"chunks": [
{
"chunk": chunk.chunk.decode("utf-8") if isinstance(chunk.chunk, bytes) else chunk.chunk,
"chunk_id": chunk.chunk_id,
"vectors": chunk.vectors
}
for chunk in obj.chunks

View file

@ -36,13 +36,10 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
result = {}
if obj.chunks is not None:
result["chunks"] = [
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
for chunk in obj.chunks
]
if obj.chunk_ids is not None:
result["chunk_ids"] = list(obj.chunk_ids)
return result
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -27,7 +27,7 @@ class GraphEmbeddings:
@dataclass
class ChunkEmbeddings:
chunk: bytes = b""
chunk_id: str = ""
vectors: list[list[float]] = field(default_factory=list)
# This is a 'batching' mechanism for the above data

View file

@ -52,7 +52,7 @@ class DocumentEmbeddingsRequest:
@dataclass
class DocumentEmbeddingsResponse:
error: Error | None = None
chunks: list[str] = field(default_factory=list)
chunk_ids: list[str] = field(default_factory=list)
document_embeddings_request_queue = topic(
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'

View file

@ -1,6 +1,6 @@
"""
Queries document chunks by text similarity using vector embeddings.
Returns a list of matching document chunks, truncated to the specified length.
Returns a list of matching chunk IDs.
"""
import argparse
@ -10,13 +10,7 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def truncate_chunk(chunk, max_length):
"""Truncate a chunk to max_length characters, adding ellipsis if needed."""
if len(chunk) <= max_length:
return chunk
return chunk[:max_length] + "..."
def query(url, flow_id, query_text, user, collection, limit, max_chunk_length, token=None):
def query(url, flow_id, query_text, user, collection, limit, token=None):
# Create API client
api = Api(url=url, token=token)
@ -32,10 +26,12 @@ def query(url, flow_id, query_text, user, collection, limit, max_chunk_length, t
limit=limit
)
chunks = result.get("chunks", [])
for i, chunk in enumerate(chunks, 1):
truncated = truncate_chunk(chunk, max_chunk_length)
print(f"{i}. {truncated}")
chunk_ids = result.get("chunk_ids", [])
if not chunk_ids:
print("No matching chunks found.")
else:
for i, chunk_id in enumerate(chunk_ids, 1):
print(f"{i}. {chunk_id}")
finally:
# Clean up socket connection
@ -85,13 +81,6 @@ def main():
help='Maximum number of results (default: 10)',
)
parser.add_argument(
'--max-chunk-length',
type=int,
default=200,
help='Truncate chunks to N characters (default: 200)',
)
parser.add_argument(
'query',
nargs=1,
@ -109,7 +98,6 @@ def main():
user=args.user,
collection=args.collection,
limit=args.limit,
max_chunk_length=args.max_chunk_length,
token=args.token,
)

View file

@ -44,14 +44,14 @@ async def load_de(running, queue, url):
msg = {
"metadata": {
"id": msg["m"]["i"],
"id": msg["m"]["i"],
"metadata": msg["m"]["m"],
"user": msg["m"]["u"],
"collection": msg["m"]["c"],
},
"chunks": [
{
"chunk": chunk["c"],
"chunk_id": chunk["c"],
"vectors": chunk["v"],
}
for chunk in msg["c"]

View file

@ -50,14 +50,14 @@ async def fetch_de(running, queue, user, collection, url):
"de",
{
"m": {
"i": data["metadata"]["id"],
"i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"c": [
{
"c": chunk["chunk"],
"c": chunk["chunk_id"],
"v": chunk["vectors"],
}
for chunk in data["chunks"]

View file

@ -84,14 +84,14 @@ class DocVectors:
dim=dimension,
)
doc_field = FieldSchema(
name="doc",
chunk_id_field = FieldSchema(
name="chunk_id",
dtype=DataType.VARCHAR,
max_length=65535,
)
schema = CollectionSchema(
fields = [pkey_field, vec_field, doc_field],
fields = [pkey_field, vec_field, chunk_id_field],
description = "Document embedding schema",
)
@ -119,17 +119,17 @@ class DocVectors:
self.collections[(dimension, user, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, doc, user, collection):
def insert(self, embeds, chunk_id, user, collection):
dim = len(embeds)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
data = [
{
"vector": embeds,
"doc": doc,
"chunk_id": chunk_id,
}
]
@ -138,7 +138,7 @@ class DocVectors:
data=data
)
def search(self, embeds, user, collection, fields=["doc"], limit=10):
def search(self, embeds, user, collection, fields=["chunk_id"], limit=10):
dim = len(embeds)

View file

@ -70,7 +70,7 @@ class Processor(FlowProcessor):
embeds = [
ChunkEmbeddings(
chunk=v.chunk,
chunk_id=v.document_id,
vectors=vectors,
)
]

View file

@ -89,7 +89,7 @@ def serialize_document_embeddings(message):
"chunks": [
{
"vectors": chunk.vectors,
"chunk": chunk.chunk.decode("utf-8"),
"chunk_id": chunk.chunk_id,
}
for chunk in message.chunks
],

View file

@ -1,7 +1,7 @@
"""
Document embeddings query service. Input is vector, output is an array
of chunks
of chunk_ids
"""
import logging
@ -39,22 +39,22 @@ class Processor(DocumentEmbeddingsQueryService):
if msg.limit <= 0:
return []
chunks = []
chunk_ids = []
for vec in msg.vectors:
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
vec,
msg.user,
msg.collection,
limit=msg.limit
)
for r in resp:
chunk = r["entity"]["doc"]
chunks.append(chunk)
chunk_id = r["entity"]["chunk_id"]
chunk_ids.append(chunk_id)
return chunks
return chunk_ids
except Exception as e:

View file

@ -1,7 +1,7 @@
"""
Document embeddings query service. Input is vector, output is an array
of chunks. Pinecone implementation.
of chunk_ids. Pinecone implementation.
"""
import logging
@ -55,7 +55,7 @@ class Processor(DocumentEmbeddingsQueryService):
if msg.limit <= 0:
return []
chunks = []
chunk_ids = []
for vec in msg.vectors:
@ -79,10 +79,10 @@ class Processor(DocumentEmbeddingsQueryService):
)
for r in results.matches:
doc = r.metadata["doc"]
chunks.append(doc)
chunk_id = r.metadata["chunk_id"]
chunk_ids.append(chunk_id)
return chunks
return chunk_ids
except Exception as e:

View file

@ -1,7 +1,7 @@
"""
Document embeddings query service. Input is vector, output is an array
of chunks
of chunk_ids
"""
import logging
@ -69,7 +69,7 @@ class Processor(DocumentEmbeddingsQueryService):
try:
chunks = []
chunk_ids = []
for vec in msg.vectors:
@ -90,10 +90,10 @@ class Processor(DocumentEmbeddingsQueryService):
).points
for r in search_result:
ent = r.payload["doc"]
chunks.append(ent)
chunk_id = r.payload["chunk_id"]
chunk_ids.append(chunk_id)
return chunks
return chunk_ids
except Exception as e:

View file

@ -24,7 +24,7 @@ class Query:
if self.verbose:
logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed(query)
qembeds = await self.rag.embeddings_client.embed(query)
if self.verbose:
logger.debug("Embeddings computed")
@ -36,17 +36,31 @@ class Query:
vectors = await self.get_vector(query)
if self.verbose:
logger.debug("Getting documents...")
logger.debug("Getting chunk_ids from embeddings store...")
docs = await self.rag.doc_embeddings_client.query(
# Get chunk_ids from embeddings store
chunk_ids = await self.rag.doc_embeddings_client.query(
vectors, limit=self.doc_limit,
user=self.user, collection=self.collection,
)
if self.verbose:
logger.debug("Documents:")
logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...")
# Fetch chunk content from Garage
docs = []
for chunk_id in chunk_ids:
if chunk_id:
try:
content = await self.rag.fetch_chunk(chunk_id, self.user)
docs.append(content)
except Exception as e:
logger.warning(f"Failed to fetch chunk {chunk_id}: {e}")
if self.verbose:
logger.debug("Documents fetched:")
for doc in docs:
logger.debug(f" {doc}")
logger.debug(f" {doc[:100]}...")
return docs
@ -54,6 +68,7 @@ class DocumentRag:
def __init__(
self, prompt_client, embeddings_client, doc_embeddings_client,
fetch_chunk,
verbose=False,
):
@ -62,6 +77,7 @@ class DocumentRag:
self.prompt_client = prompt_client
self.embeddings_client = embeddings_client
self.doc_embeddings_client = doc_embeddings_client
self.fetch_chunk = fetch_chunk
if self.verbose:
logger.debug("DocumentRag initialized")

View file

@ -4,17 +4,26 @@ Simple RAG service, performs query using document RAG an LLM.
Input is query, output is response.
"""
import asyncio
import base64
import logging
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
from ... schema import LibrarianRequest, LibrarianResponse
from ... schema import librarian_request_queue, librarian_response_queue
from . document_rag import DocumentRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import DocumentEmbeddingsClientSpec
from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "document-rag"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor):
@ -69,6 +78,98 @@ class Processor(FlowProcessor):
)
)
# Librarian client for fetching chunk content from Garage
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def fetch_chunk_content(self, chunk_id, user, timeout=120):
"""Fetch chunk content from librarian/Garage."""
import uuid
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=chunk_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
# Content is base64 encoded
content = response.content
if isinstance(content, str):
content = content.encode('utf-8')
return base64.b64decode(content).decode("utf-8")
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching chunk {chunk_id}")
async def on_request(self, msg, consumer, flow):
try:
@ -77,6 +178,7 @@ class Processor(FlowProcessor):
embeddings_client = flow("embeddings-request"),
doc_embeddings_client = flow("document-embeddings-request"),
prompt_client = flow("prompt-request"),
fetch_chunk = self.fetch_chunk_content,
verbose=True,
)

View file

@ -37,14 +37,13 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue
chunk = emb.chunk.decode("utf-8")
if chunk == "": continue
chunk_id = emb.chunk_id
if chunk_id == "":
continue
for vec in emb.vectors:
self.vecstore.insert(
vec, chunk,
vec, chunk_id,
message.metadata.user,
message.metadata.collection
)

View file

@ -101,10 +101,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue
chunk = emb.chunk.decode("utf-8")
if chunk == "": continue
chunk_id = emb.chunk_id
if chunk_id == "":
continue
for vec in emb.vectors:
@ -128,7 +127,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
{
"id": vector_id,
"values": vec,
"metadata": { "doc": chunk },
"metadata": { "chunk_id": chunk_id },
}
]

View file

@ -52,8 +52,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for emb in message.chunks:
chunk = emb.chunk.decode("utf-8")
if chunk == "": return
chunk_id = emb.chunk_id
if chunk_id == "":
continue
for vec in emb.vectors:
@ -81,7 +82,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
id=str(uuid.uuid4()),
vector=vec,
payload={
"doc": chunk,
"chunk_id": chunk_id,
}
)
]