mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-21 20:58:06 +02:00
Document chunks not stored in vector store (#665)
- Schema - ChunkEmbeddings now uses chunk_id: str instead of chunk: bytes
- Schema - DocumentEmbeddingsResponse now returns chunk_ids: list[str]
instead of chunks
- Translators - Updated to serialize/deserialize chunk_id
- Clients - DocumentEmbeddingsClient.query() returns chunk_ids
- SDK/API - flow.py, socket_client.py, bulk_client.py updated
- Document embeddings service - Stores chunk_id (document ID) instead
of chunk text
- Storage writers - Qdrant, Milvus, Pinecone store chunk_id in payload
- Query services - Return chunk_id from vector store searches
- Gateway dispatchers - Serialize chunk_id in API responses
- Document RAG - Added librarian client to fetch chunk content from
Garage using chunk_ids
- CLI tools - Updated all three tools:
- invoke_document_embeddings.py - displays chunk_ids, removed
max_chunk_length
- save_doc_embeds.py - exports chunk_id
- load_doc_embeds.py - imports chunk_id
This commit is contained in:
parent
be358efe67
commit
24bbe94136
24 changed files with 331 additions and 91 deletions
|
|
@ -84,14 +84,14 @@ class DocVectors:
|
|||
dim=dimension,
|
||||
)
|
||||
|
||||
doc_field = FieldSchema(
|
||||
name="doc",
|
||||
chunk_id_field = FieldSchema(
|
||||
name="chunk_id",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [pkey_field, vec_field, doc_field],
|
||||
fields = [pkey_field, vec_field, chunk_id_field],
|
||||
description = "Document embedding schema",
|
||||
)
|
||||
|
||||
|
|
@ -119,17 +119,17 @@ class DocVectors:
|
|||
self.collections[(dimension, user, collection)] = collection_name
|
||||
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
|
||||
|
||||
def insert(self, embeds, doc, user, collection):
|
||||
def insert(self, embeds, chunk_id, user, collection):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if (dim, user, collection) not in self.collections:
|
||||
self.init_collection(dim, user, collection)
|
||||
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"doc": doc,
|
||||
"chunk_id": chunk_id,
|
||||
}
|
||||
]
|
||||
|
||||
|
|
@ -138,7 +138,7 @@ class DocVectors:
|
|||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, user, collection, fields=["doc"], limit=10):
|
||||
def search(self, embeds, user, collection, fields=["chunk_id"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class Processor(FlowProcessor):
|
|||
|
||||
embeds = [
|
||||
ChunkEmbeddings(
|
||||
chunk=v.chunk,
|
||||
chunk_id=v.document_id,
|
||||
vectors=vectors,
|
||||
)
|
||||
]
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ def serialize_document_embeddings(message):
|
|||
"chunks": [
|
||||
{
|
||||
"vectors": chunk.vectors,
|
||||
"chunk": chunk.chunk.decode("utf-8"),
|
||||
"chunk_id": chunk.chunk_id,
|
||||
}
|
||||
for chunk in message.chunks
|
||||
],
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
"""
|
||||
Document embeddings query service. Input is vector, output is an array
|
||||
of chunks
|
||||
of chunk_ids
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -39,22 +39,22 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
chunk_ids = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
resp = self.vecstore.search(
|
||||
vec,
|
||||
msg.user,
|
||||
msg.collection,
|
||||
vec,
|
||||
msg.user,
|
||||
msg.collection,
|
||||
limit=msg.limit
|
||||
)
|
||||
|
||||
for r in resp:
|
||||
chunk = r["entity"]["doc"]
|
||||
chunks.append(chunk)
|
||||
chunk_id = r["entity"]["chunk_id"]
|
||||
chunk_ids.append(chunk_id)
|
||||
|
||||
return chunks
|
||||
return chunk_ids
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
"""
|
||||
Document embeddings query service. Input is vector, output is an array
|
||||
of chunks. Pinecone implementation.
|
||||
of chunk_ids. Pinecone implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -55,7 +55,7 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
chunk_ids = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
|
|
@ -79,10 +79,10 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
)
|
||||
|
||||
for r in results.matches:
|
||||
doc = r.metadata["doc"]
|
||||
chunks.append(doc)
|
||||
chunk_id = r.metadata["chunk_id"]
|
||||
chunk_ids.append(chunk_id)
|
||||
|
||||
return chunks
|
||||
return chunk_ids
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
"""
|
||||
Document embeddings query service. Input is vector, output is an array
|
||||
of chunks
|
||||
of chunk_ids
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -69,7 +69,7 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
|
||||
try:
|
||||
|
||||
chunks = []
|
||||
chunk_ids = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
|
|
@ -90,10 +90,10 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
).points
|
||||
|
||||
for r in search_result:
|
||||
ent = r.payload["doc"]
|
||||
chunks.append(ent)
|
||||
chunk_id = r.payload["chunk_id"]
|
||||
chunk_ids.append(chunk_id)
|
||||
|
||||
return chunks
|
||||
return chunk_ids
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class Query:
|
|||
if self.verbose:
|
||||
logger.debug("Computing embeddings...")
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed(query)
|
||||
qembeds = await self.rag.embeddings_client.embed(query)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Embeddings computed")
|
||||
|
|
@ -36,17 +36,31 @@ class Query:
|
|||
vectors = await self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting documents...")
|
||||
logger.debug("Getting chunk_ids from embeddings store...")
|
||||
|
||||
docs = await self.rag.doc_embeddings_client.query(
|
||||
# Get chunk_ids from embeddings store
|
||||
chunk_ids = await self.rag.doc_embeddings_client.query(
|
||||
vectors, limit=self.doc_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Documents:")
|
||||
logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...")
|
||||
|
||||
# Fetch chunk content from Garage
|
||||
docs = []
|
||||
for chunk_id in chunk_ids:
|
||||
if chunk_id:
|
||||
try:
|
||||
content = await self.rag.fetch_chunk(chunk_id, self.user)
|
||||
docs.append(content)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch chunk {chunk_id}: {e}")
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Documents fetched:")
|
||||
for doc in docs:
|
||||
logger.debug(f" {doc}")
|
||||
logger.debug(f" {doc[:100]}...")
|
||||
|
||||
return docs
|
||||
|
||||
|
|
@ -54,6 +68,7 @@ class DocumentRag:
|
|||
|
||||
def __init__(
|
||||
self, prompt_client, embeddings_client, doc_embeddings_client,
|
||||
fetch_chunk,
|
||||
verbose=False,
|
||||
):
|
||||
|
||||
|
|
@ -62,6 +77,7 @@ class DocumentRag:
|
|||
self.prompt_client = prompt_client
|
||||
self.embeddings_client = embeddings_client
|
||||
self.doc_embeddings_client = doc_embeddings_client
|
||||
self.fetch_chunk = fetch_chunk
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("DocumentRag initialized")
|
||||
|
|
|
|||
|
|
@ -4,17 +4,26 @@ Simple RAG service, performs query using document RAG an LLM.
|
|||
Input is query, output is response.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
|
||||
from ... schema import LibrarianRequest, LibrarianResponse
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from . document_rag import DocumentRag
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import DocumentEmbeddingsClientSpec
|
||||
from ... base import Consumer, Producer
|
||||
from ... base import ConsumerMetrics, ProducerMetrics
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "document-rag"
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
|
|
@ -69,6 +78,98 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
# Librarian client for fetching chunk content from Garage
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend=self.pubsub,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
backend=self.pubsub,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self.on_librarian_response,
|
||||
metrics=librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_requests = {}
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id in self.pending_requests:
|
||||
future = self.pending_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def fetch_chunk_content(self, chunk_id, user, timeout=120):
|
||||
"""Fetch chunk content from librarian/Garage."""
|
||||
import uuid
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-content",
|
||||
document_id=chunk_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
# Content is base64 encoded
|
||||
content = response.content
|
||||
if isinstance(content, str):
|
||||
content = content.encode('utf-8')
|
||||
return base64.b64decode(content).decode("utf-8")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout fetching chunk {chunk_id}")
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
|
@ -77,6 +178,7 @@ class Processor(FlowProcessor):
|
|||
embeddings_client = flow("embeddings-request"),
|
||||
doc_embeddings_client = flow("document-embeddings-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
fetch_chunk = self.fetch_chunk_content,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,14 +37,13 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
|
||||
for emb in message.chunks:
|
||||
|
||||
if emb.chunk is None or emb.chunk == b"": continue
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": continue
|
||||
chunk_id = emb.chunk_id
|
||||
if chunk_id == "":
|
||||
continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
self.vecstore.insert(
|
||||
vec, chunk,
|
||||
vec, chunk_id,
|
||||
message.metadata.user,
|
||||
message.metadata.collection
|
||||
)
|
||||
|
|
|
|||
|
|
@ -101,10 +101,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
|
||||
for emb in message.chunks:
|
||||
|
||||
if emb.chunk is None or emb.chunk == b"": continue
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": continue
|
||||
chunk_id = emb.chunk_id
|
||||
if chunk_id == "":
|
||||
continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
|
||||
|
|
@ -128,7 +127,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
{
|
||||
"id": vector_id,
|
||||
"values": vec,
|
||||
"metadata": { "doc": chunk },
|
||||
"metadata": { "chunk_id": chunk_id },
|
||||
}
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -52,8 +52,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
|
||||
for emb in message.chunks:
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": return
|
||||
chunk_id = emb.chunk_id
|
||||
if chunk_id == "":
|
||||
continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
|
||||
|
|
@ -81,7 +82,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
id=str(uuid.uuid4()),
|
||||
vector=vec,
|
||||
payload={
|
||||
"doc": chunk,
|
||||
"chunk_id": chunk_id,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue