- Added load/save API for document embeddings (#269)

- tg-load-doc-embeds and tg-save-doc-embeds command line utils
This commit is contained in:
cybermaggedon 2025-01-16 00:00:59 +00:00 committed by GitHub
parent acdd3efe51
commit bed7423c26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 587 additions and 0 deletions

View file

@ -0,0 +1,64 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import DocumentEmbeddings, ChunkEmbeddings
from .. schema import document_embeddings_store_queue
from . publisher import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph
class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth, path="/api/v1/load/document-embeddings",
):
super(DocumentEmbeddingsLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.publisher = Publisher(
self.pulsar_host, document_embeddings_store_queue,
schema=JsonSchema(DocumentEmbeddings)
)
async def start(self):
self.publisher.start()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = DocumentEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
chunks=[
ChunkEmbeddings(
chunk=de["chunk"].encode("utf-8"),
vectors=de["vectors"],
)
for de in data["chunks"]
],
)
self.publisher.send(None, elt)
running.stop()

View file

@ -0,0 +1,72 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from .. schema import DocumentEmbeddings
from .. schema import document_embeddings_store_queue
from . subscriber import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_document_embeddings
class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth, path="/api/v1/stream/document-embeddings"
):
super(DocumentEmbeddingsStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.subscriber = Subscriber(
self.pulsar_host, document_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(DocumentEmbeddings)
)
async def listener(self, ws, running):
worker = asyncio.create_task(
self.async_thread(ws, running)
)
await super(DocumentEmbeddingsStreamEndpoint, self).listener(
ws, running
)
await worker
async def start(self):
self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_document_embeddings(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -60,3 +60,20 @@ def serialize_graph_embeddings(message):
],
}
def serialize_document_embeddings(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"chunks": [
{
"vectors": chunk.vectors,
"chunk": chunk.chunk.decode("utf-8"),
}
for chunk in message.chunks
],
}

View file

@ -41,8 +41,10 @@ from . dbpedia import DbpediaRequestor
from . internet_search import InternetSearchRequestor
from . triples_stream import TriplesStreamEndpoint
from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
from . document_embeddings_stream import DocumentEmbeddingsStreamEndpoint
from . triples_load import TriplesLoadEndpoint
from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
from . document_embeddings_load import DocumentEmbeddingsLoadEndpoint
from . mux import MuxEndpoint
from . document_load import DocumentLoadSender
from . text_load import TextLoadSender
@ -203,6 +205,10 @@ class Api:
pulsar_host=self.pulsar_host,
auth = self.auth,
),
DocumentEmbeddingsStreamEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
),
TriplesLoadEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
@ -211,6 +217,10 @@ class Api:
pulsar_host=self.pulsar_host,
auth = self.auth,
),
DocumentEmbeddingsLoadEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
),
MuxEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,