diff --git a/docs/tech-specs/document-embeddings-chunk-id.md b/docs/tech-specs/document-embeddings-chunk-id.md new file mode 100644 index 00000000..157ecba0 --- /dev/null +++ b/docs/tech-specs/document-embeddings-chunk-id.md @@ -0,0 +1,136 @@ +# Document Embeddings Chunk ID + +## Overview + +Document embeddings storage currently stores chunk text directly in the vector store payload, duplicating data that exists in Garage. This spec replaces chunk text storage with `chunk_id` references. + +## Current State + +```python +@dataclass +class ChunkEmbeddings: + chunk: bytes = b"" + vectors: list[list[float]] = field(default_factory=list) + +@dataclass +class DocumentEmbeddingsResponse: + error: Error | None = None + chunks: list[str] = field(default_factory=list) +``` + +Vector store payload: +```python +payload={"doc": chunk} # Duplicates Garage content +``` + +## Design + +### Schema Changes + +**ChunkEmbeddings** - replace chunk with chunk_id: +```python +@dataclass +class ChunkEmbeddings: + chunk_id: str = "" + vectors: list[list[float]] = field(default_factory=list) +``` + +**DocumentEmbeddingsResponse** - return chunk_ids instead of chunks: +```python +@dataclass +class DocumentEmbeddingsResponse: + error: Error | None = None + chunk_ids: list[str] = field(default_factory=list) +``` + +### Vector Store Payload + +All stores (Qdrant, Milvus, Pinecone): +```python +payload={"chunk_id": chunk_id} +``` + +### Document RAG Changes + +The document RAG processor fetches chunk content from Garage: + +```python +# Get chunk_ids from embeddings store +chunk_ids = await self.rag.doc_embeddings_client.query(...) + +# Fetch chunk content from Garage +docs = [] +for chunk_id in chunk_ids: + content = await self.rag.librarian_client.get_document_content( + chunk_id, self.user + ) + docs.append(content) +``` + +### API/SDK Changes + +**DocumentEmbeddingsClient** returns chunk_ids: +```python +return resp.chunk_ids # Changed from resp.chunks +``` + +**Wire format** (DocumentEmbeddingsResponseTranslator): +```python +result["chunk_ids"] = obj.chunk_ids # Changed from chunks +``` + +### CLI Changes + +CLI tool displays chunk_ids (callers can fetch content separately if needed). + +## Files to Modify + +### Schema +- `trustgraph-base/trustgraph/schema/knowledge/embeddings.py` - ChunkEmbeddings +- `trustgraph-base/trustgraph/schema/services/query.py` - DocumentEmbeddingsResponse + +### Messaging/Translators +- `trustgraph-base/trustgraph/messaging/translators/embeddings_query.py` - DocumentEmbeddingsResponseTranslator + +### Client +- `trustgraph-base/trustgraph/base/document_embeddings_client.py` - return chunk_ids + +### Python SDK/API +- `trustgraph-base/trustgraph/api/flow.py` - document_embeddings_query +- `trustgraph-base/trustgraph/api/socket_client.py` - document_embeddings_query +- `trustgraph-base/trustgraph/api/async_flow.py` - if applicable +- `trustgraph-base/trustgraph/api/bulk_client.py` - import/export document embeddings +- `trustgraph-base/trustgraph/api/async_bulk_client.py` - import/export document embeddings + +### Embeddings Service +- `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py` - pass chunk_id + +### Storage Writers +- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py` +- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py` +- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py` + +### Query Services +- `trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py` +- `trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py` +- `trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py` + +### Gateway +- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py` +- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py` +- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py` + +### Document RAG +- `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` - add librarian client +- `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` - fetch from Garage + +### CLI +- `trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py` +- `trustgraph-cli/trustgraph/cli/save_doc_embeds.py` +- `trustgraph-cli/trustgraph/cli/load_doc_embeds.py` + +## Benefits + +1. Single source of truth - chunk text only in Garage +2. Reduced vector store storage +3. Enables query-time provenance via chunk_id diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py index 3dfb0fba..75999550 100644 --- a/trustgraph-base/trustgraph/api/bulk_client.py +++ b/trustgraph-base/trustgraph/api/bulk_client.py @@ -322,8 +322,8 @@ class BulkClient: # Generate document embeddings to import def doc_embedding_generator(): - yield {"id": "doc1-chunk1", "embedding": [0.1, 0.2, ...]} - yield {"id": "doc1-chunk2", "embedding": [0.3, 0.4, ...]} + yield {"chunk_id": "doc1/p0/c0", "embedding": [0.1, 0.2, ...]} + yield {"chunk_id": "doc1/p0/c1", "embedding": [0.3, 0.4, ...]} # ... more embeddings bulk.import_document_embeddings( @@ -363,9 +363,9 @@ class BulkClient: # Export and process document embeddings for embedding in bulk.export_document_embeddings(flow="default"): - doc_id = embedding.get("id") + chunk_id = embedding.get("chunk_id") vector = embedding.get("embedding") - print(f"{doc_id}: {len(vector)} dimensions") + print(f"{chunk_id}: {len(vector)} dimensions") ``` """ async_gen = self._export_document_embeddings_async(flow) diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index 0d34104c..c50bf9c4 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -634,7 +634,7 @@ class FlowInstance: limit: Maximum number of results (default: 10) Returns: - dict: Query results with similar document chunks + dict: Query results with chunk_ids of matching document chunks Example: ```python @@ -645,6 +645,7 @@ class FlowInstance: collection="research-papers", limit=5 ) + # results contains {"chunk_ids": ["doc1/p0/c0", "doc2/p1/c3", ...]} ``` """ diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index e8de442a..b471b535 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -682,7 +682,7 @@ class SocketFlowInstance: **kwargs: Additional parameters passed to the service Returns: - dict: Query results with similar document chunks + dict: Query results with chunk_ids of matching document chunks Example: ```python @@ -695,6 +695,7 @@ class SocketFlowInstance: collection="research-papers", limit=5 ) + # results contains {"chunk_ids": ["doc1/p0/c0", ...]} ``` """ # First convert text to embeddings vectors diff --git a/trustgraph-base/trustgraph/base/document_embeddings_client.py b/trustgraph-base/trustgraph/base/document_embeddings_client.py index e76a6da6..d403ff21 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_client.py @@ -27,7 +27,7 @@ class DocumentEmbeddingsClient(RequestResponse): if resp.error: raise RuntimeError(resp.error.message) - return resp.chunks + return resp.chunk_ids class DocumentEmbeddingsClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py index f04f2c60..013847d4 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py @@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor): docs = await self.query_document_embeddings(request) logger.debug("Sending document embeddings query response...") - r = DocumentEmbeddingsResponse(chunks=docs, error=None) + r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None) await flow("response").send(r, properties={"id": id}) logger.debug("Document embeddings query request completed") @@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor): type = "document-embeddings-query-error", message = str(e), ), - chunks=None, + chunk_ids=[], ) await flow("response").send(r, properties={"id": id}) diff --git a/trustgraph-base/trustgraph/messaging/translators/document_loading.py b/trustgraph-base/trustgraph/messaging/translators/document_loading.py index 3dfef718..1aaea6ac 100644 --- a/trustgraph-base/trustgraph/messaging/translators/document_loading.py +++ b/trustgraph-base/trustgraph/messaging/translators/document_loading.py @@ -144,15 +144,15 @@ class DocumentEmbeddingsTranslator(SendTranslator): def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddings: metadata = data.get("metadata", {}) - + chunks = [ ChunkEmbeddings( - chunk=chunk["chunk"].encode("utf-8") if isinstance(chunk["chunk"], str) else chunk["chunk"], + chunk_id=chunk["chunk_id"], vectors=chunk["vectors"] ) for chunk in data.get("chunks", []) ] - + from ...schema import Metadata return DocumentEmbeddings( metadata=Metadata( @@ -168,7 +168,7 @@ class DocumentEmbeddingsTranslator(SendTranslator): result = { "chunks": [ { - "chunk": chunk.chunk.decode("utf-8") if isinstance(chunk.chunk, bytes) else chunk.chunk, + "chunk_id": chunk.chunk_id, "vectors": chunk.vectors } for chunk in obj.chunks diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py index 141a7330..cc5f1534 100644 --- a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py @@ -36,13 +36,10 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator): def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]: result = {} - - if obj.chunks is not None: - result["chunks"] = [ - chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk - for chunk in obj.chunks - ] - + + if obj.chunk_ids is not None: + result["chunk_ids"] = list(obj.chunk_ids) + return result def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: diff --git a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py index b39bf6ea..c7d5b775 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py +++ b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py @@ -27,7 +27,7 @@ class GraphEmbeddings: @dataclass class ChunkEmbeddings: - chunk: bytes = b"" + chunk_id: str = "" vectors: list[list[float]] = field(default_factory=list) # This is a 'batching' mechanism for the above data diff --git a/trustgraph-base/trustgraph/schema/services/query.py b/trustgraph-base/trustgraph/schema/services/query.py index 50ec416a..68857e07 100644 --- a/trustgraph-base/trustgraph/schema/services/query.py +++ b/trustgraph-base/trustgraph/schema/services/query.py @@ -52,7 +52,7 @@ class DocumentEmbeddingsRequest: @dataclass class DocumentEmbeddingsResponse: error: Error | None = None - chunks: list[str] = field(default_factory=list) + chunk_ids: list[str] = field(default_factory=list) document_embeddings_request_queue = topic( "document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow' diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py index b14397cb..b3eef8a6 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py @@ -1,6 +1,6 @@ """ Queries document chunks by text similarity using vector embeddings. -Returns a list of matching document chunks, truncated to the specified length. +Returns a list of matching chunk IDs. """ import argparse @@ -10,13 +10,7 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def truncate_chunk(chunk, max_length): - """Truncate a chunk to max_length characters, adding ellipsis if needed.""" - if len(chunk) <= max_length: - return chunk - return chunk[:max_length] + "..." - -def query(url, flow_id, query_text, user, collection, limit, max_chunk_length, token=None): +def query(url, flow_id, query_text, user, collection, limit, token=None): # Create API client api = Api(url=url, token=token) @@ -32,10 +26,12 @@ def query(url, flow_id, query_text, user, collection, limit, max_chunk_length, t limit=limit ) - chunks = result.get("chunks", []) - for i, chunk in enumerate(chunks, 1): - truncated = truncate_chunk(chunk, max_chunk_length) - print(f"{i}. {truncated}") + chunk_ids = result.get("chunk_ids", []) + if not chunk_ids: + print("No matching chunks found.") + else: + for i, chunk_id in enumerate(chunk_ids, 1): + print(f"{i}. {chunk_id}") finally: # Clean up socket connection @@ -85,13 +81,6 @@ def main(): help='Maximum number of results (default: 10)', ) - parser.add_argument( - '--max-chunk-length', - type=int, - default=200, - help='Truncate chunks to N characters (default: 200)', - ) - parser.add_argument( 'query', nargs=1, @@ -109,7 +98,6 @@ def main(): user=args.user, collection=args.collection, limit=args.limit, - max_chunk_length=args.max_chunk_length, token=args.token, ) diff --git a/trustgraph-cli/trustgraph/cli/load_doc_embeds.py b/trustgraph-cli/trustgraph/cli/load_doc_embeds.py index 7e7f4865..20c78515 100644 --- a/trustgraph-cli/trustgraph/cli/load_doc_embeds.py +++ b/trustgraph-cli/trustgraph/cli/load_doc_embeds.py @@ -44,14 +44,14 @@ async def load_de(running, queue, url): msg = { "metadata": { - "id": msg["m"]["i"], + "id": msg["m"]["i"], "metadata": msg["m"]["m"], "user": msg["m"]["u"], "collection": msg["m"]["c"], }, "chunks": [ { - "chunk": chunk["c"], + "chunk_id": chunk["c"], "vectors": chunk["v"], } for chunk in msg["c"] diff --git a/trustgraph-cli/trustgraph/cli/save_doc_embeds.py b/trustgraph-cli/trustgraph/cli/save_doc_embeds.py index 8fdd335d..ca8d25de 100644 --- a/trustgraph-cli/trustgraph/cli/save_doc_embeds.py +++ b/trustgraph-cli/trustgraph/cli/save_doc_embeds.py @@ -50,14 +50,14 @@ async def fetch_de(running, queue, user, collection, url): "de", { "m": { - "i": data["metadata"]["id"], + "i": data["metadata"]["id"], "m": data["metadata"]["metadata"], "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "c": [ { - "c": chunk["chunk"], + "c": chunk["chunk_id"], "v": chunk["vectors"], } for chunk in data["chunks"] diff --git a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py index 4047a9e3..66bfe31f 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py @@ -84,14 +84,14 @@ class DocVectors: dim=dimension, ) - doc_field = FieldSchema( - name="doc", + chunk_id_field = FieldSchema( + name="chunk_id", dtype=DataType.VARCHAR, max_length=65535, ) schema = CollectionSchema( - fields = [pkey_field, vec_field, doc_field], + fields = [pkey_field, vec_field, chunk_id_field], description = "Document embedding schema", ) @@ -119,17 +119,17 @@ class DocVectors: self.collections[(dimension, user, collection)] = collection_name logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") - def insert(self, embeds, doc, user, collection): + def insert(self, embeds, chunk_id, user, collection): dim = len(embeds) if (dim, user, collection) not in self.collections: self.init_collection(dim, user, collection) - + data = [ { "vector": embeds, - "doc": doc, + "chunk_id": chunk_id, } ] @@ -138,7 +138,7 @@ class DocVectors: data=data ) - def search(self, embeds, user, collection, fields=["doc"], limit=10): + def search(self, embeds, user, collection, fields=["chunk_id"], limit=10): dim = len(embeds) diff --git a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py index 602f7bb8..032e15c4 100755 --- a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py @@ -70,7 +70,7 @@ class Processor(FlowProcessor): embeds = [ ChunkEmbeddings( - chunk=v.chunk, + chunk_id=v.document_id, vectors=vectors, ) ] diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py index 8f1cdece..f6e7c79b 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py @@ -89,7 +89,7 @@ def serialize_document_embeddings(message): "chunks": [ { "vectors": chunk.vectors, - "chunk": chunk.chunk.decode("utf-8"), + "chunk_id": chunk.chunk_id, } for chunk in message.chunks ], diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index 03c98ad3..6d897b71 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -1,7 +1,7 @@ """ Document embeddings query service. Input is vector, output is an array -of chunks +of chunk_ids """ import logging @@ -39,22 +39,22 @@ class Processor(DocumentEmbeddingsQueryService): if msg.limit <= 0: return [] - chunks = [] + chunk_ids = [] for vec in msg.vectors: resp = self.vecstore.search( - vec, - msg.user, - msg.collection, + vec, + msg.user, + msg.collection, limit=msg.limit ) for r in resp: - chunk = r["entity"]["doc"] - chunks.append(chunk) + chunk_id = r["entity"]["chunk_id"] + chunk_ids.append(chunk_id) - return chunks + return chunk_ids except Exception as e: diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index 1c3f8d1b..41857ab0 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -1,7 +1,7 @@ """ Document embeddings query service. Input is vector, output is an array -of chunks. Pinecone implementation. +of chunk_ids. Pinecone implementation. """ import logging @@ -55,7 +55,7 @@ class Processor(DocumentEmbeddingsQueryService): if msg.limit <= 0: return [] - chunks = [] + chunk_ids = [] for vec in msg.vectors: @@ -79,10 +79,10 @@ class Processor(DocumentEmbeddingsQueryService): ) for r in results.matches: - doc = r.metadata["doc"] - chunks.append(doc) + chunk_id = r.metadata["chunk_id"] + chunk_ids.append(chunk_id) - return chunks + return chunk_ids except Exception as e: diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index e84372cb..562023c7 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -1,7 +1,7 @@ """ Document embeddings query service. Input is vector, output is an array -of chunks +of chunk_ids """ import logging @@ -69,7 +69,7 @@ class Processor(DocumentEmbeddingsQueryService): try: - chunks = [] + chunk_ids = [] for vec in msg.vectors: @@ -90,10 +90,10 @@ class Processor(DocumentEmbeddingsQueryService): ).points for r in search_result: - ent = r.payload["doc"] - chunks.append(ent) + chunk_id = r.payload["chunk_id"] + chunk_ids.append(chunk_id) - return chunks + return chunk_ids except Exception as e: diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index 9f4ad0ff..f192bcf3 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -24,7 +24,7 @@ class Query: if self.verbose: logger.debug("Computing embeddings...") - qembeds = await self.rag.embeddings_client.embed(query) + qembeds = await self.rag.embeddings_client.embed(query) if self.verbose: logger.debug("Embeddings computed") @@ -36,17 +36,31 @@ class Query: vectors = await self.get_vector(query) if self.verbose: - logger.debug("Getting documents...") + logger.debug("Getting chunk_ids from embeddings store...") - docs = await self.rag.doc_embeddings_client.query( + # Get chunk_ids from embeddings store + chunk_ids = await self.rag.doc_embeddings_client.query( vectors, limit=self.doc_limit, user=self.user, collection=self.collection, ) if self.verbose: - logger.debug("Documents:") + logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...") + + # Fetch chunk content from Garage + docs = [] + for chunk_id in chunk_ids: + if chunk_id: + try: + content = await self.rag.fetch_chunk(chunk_id, self.user) + docs.append(content) + except Exception as e: + logger.warning(f"Failed to fetch chunk {chunk_id}: {e}") + + if self.verbose: + logger.debug("Documents fetched:") for doc in docs: - logger.debug(f" {doc}") + logger.debug(f" {doc[:100]}...") return docs @@ -54,6 +68,7 @@ class DocumentRag: def __init__( self, prompt_client, embeddings_client, doc_embeddings_client, + fetch_chunk, verbose=False, ): @@ -62,6 +77,7 @@ class DocumentRag: self.prompt_client = prompt_client self.embeddings_client = embeddings_client self.doc_embeddings_client = doc_embeddings_client + self.fetch_chunk = fetch_chunk if self.verbose: logger.debug("DocumentRag initialized") diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index 6490562a..3bc7113a 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -4,17 +4,26 @@ Simple RAG service, performs query using document RAG an LLM. Input is query, output is response. """ +import asyncio +import base64 import logging + from ... schema import DocumentRagQuery, DocumentRagResponse, Error +from ... schema import LibrarianRequest, LibrarianResponse +from ... schema import librarian_request_queue, librarian_response_queue from . document_rag import DocumentRag from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import DocumentEmbeddingsClientSpec +from ... base import Consumer, Producer +from ... base import ConsumerMetrics, ProducerMetrics # Module logger logger = logging.getLogger(__name__) default_ident = "document-rag" +default_librarian_request_queue = librarian_request_queue +default_librarian_response_queue = librarian_response_queue class Processor(FlowProcessor): @@ -69,6 +78,98 @@ class Processor(FlowProcessor): ) ) + # Librarian client for fetching chunk content from Garage + librarian_request_q = params.get( + "librarian_request_queue", default_librarian_request_queue + ) + librarian_response_q = params.get( + "librarian_response_queue", default_librarian_response_queue + ) + + librarian_request_metrics = ProducerMetrics( + processor=id, flow=None, name="librarian-request" + ) + + self.librarian_request_producer = Producer( + backend=self.pubsub, + topic=librarian_request_q, + schema=LibrarianRequest, + metrics=librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor=id, flow=None, name="librarian-response" + ) + + self.librarian_response_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=librarian_response_q, + subscriber=f"{id}-librarian", + schema=LibrarianResponse, + handler=self.on_librarian_response, + metrics=librarian_response_metrics, + ) + + # Pending librarian requests: request_id -> asyncio.Future + self.pending_requests = {} + + async def start(self): + await super(Processor, self).start() + await self.librarian_request_producer.start() + await self.librarian_response_consumer.start() + + async def on_librarian_response(self, msg, consumer, flow): + """Handle responses from the librarian service.""" + response = msg.value() + request_id = msg.properties().get("id") + + if request_id in self.pending_requests: + future = self.pending_requests.pop(request_id) + future.set_result(response) + else: + logger.warning(f"Received unexpected librarian response: {request_id}") + + async def fetch_chunk_content(self, chunk_id, user, timeout=120): + """Fetch chunk content from librarian/Garage.""" + import uuid + request_id = str(uuid.uuid4()) + + request = LibrarianRequest( + operation="get-document-content", + document_id=chunk_id, + user=user, + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error: {response.error.type}: {response.error.message}" + ) + + # Content is base64 encoded + content = response.content + if isinstance(content, str): + content = content.encode('utf-8') + return base64.b64decode(content).decode("utf-8") + + except asyncio.TimeoutError: + self.pending_requests.pop(request_id, None) + raise RuntimeError(f"Timeout fetching chunk {chunk_id}") + async def on_request(self, msg, consumer, flow): try: @@ -77,6 +178,7 @@ class Processor(FlowProcessor): embeddings_client = flow("embeddings-request"), doc_embeddings_client = flow("document-embeddings-request"), prompt_client = flow("prompt-request"), + fetch_chunk = self.fetch_chunk_content, verbose=True, ) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index ae869413..a4ff0838 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -37,14 +37,13 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): for emb in message.chunks: - if emb.chunk is None or emb.chunk == b"": continue - - chunk = emb.chunk.decode("utf-8") - if chunk == "": continue + chunk_id = emb.chunk_id + if chunk_id == "": + continue for vec in emb.vectors: self.vecstore.insert( - vec, chunk, + vec, chunk_id, message.metadata.user, message.metadata.collection ) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index a0e52253..f6393053 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -101,10 +101,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): for emb in message.chunks: - if emb.chunk is None or emb.chunk == b"": continue - - chunk = emb.chunk.decode("utf-8") - if chunk == "": continue + chunk_id = emb.chunk_id + if chunk_id == "": + continue for vec in emb.vectors: @@ -128,7 +127,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): { "id": vector_id, "values": vec, - "metadata": { "doc": chunk }, + "metadata": { "chunk_id": chunk_id }, } ] diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index cb978048..21ea9a98 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -52,8 +52,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): for emb in message.chunks: - chunk = emb.chunk.decode("utf-8") - if chunk == "": return + chunk_id = emb.chunk_id + if chunk_id == "": + continue for vec in emb.vectors: @@ -81,7 +82,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): id=str(uuid.uuid4()), vector=vec, payload={ - "doc": chunk, + "chunk_id": chunk_id, } ) ]