mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-18 03:45:12 +02:00
- Added load/save API for document embeddings (#269)
- tg-load-doc-embeds and tg-save-doc-embeds command line utils
This commit is contained in:
parent
acdd3efe51
commit
bed7423c26
7 changed files with 587 additions and 0 deletions
|
|
@ -0,0 +1,64 @@
|
|||
|
||||
import asyncio
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from .. schema import Metadata
|
||||
from .. schema import DocumentEmbeddings, ChunkEmbeddings
|
||||
from .. schema import document_embeddings_store_queue
|
||||
|
||||
from . publisher import Publisher
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import to_subgraph
|
||||
|
||||
class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(
|
||||
self, pulsar_host, auth, path="/api/v1/load/document-embeddings",
|
||||
):
|
||||
|
||||
super(DocumentEmbeddingsLoadEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_host=pulsar_host
|
||||
|
||||
self.publisher = Publisher(
|
||||
self.pulsar_host, document_embeddings_store_queue,
|
||||
schema=JsonSchema(DocumentEmbeddings)
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
self.publisher.start()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
async for msg in ws:
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
break
|
||||
else:
|
||||
|
||||
data = msg.json()
|
||||
|
||||
elt = DocumentEmbeddings(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
chunks=[
|
||||
ChunkEmbeddings(
|
||||
chunk=de["chunk"].encode("utf-8"),
|
||||
vectors=de["vectors"],
|
||||
)
|
||||
for de in data["chunks"]
|
||||
],
|
||||
)
|
||||
|
||||
self.publisher.send(None, elt)
|
||||
|
||||
running.stop()
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
|
||||
from .. schema import DocumentEmbeddings
|
||||
from .. schema import document_embeddings_store_queue
|
||||
|
||||
from . subscriber import Subscriber
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import serialize_document_embeddings
|
||||
|
||||
class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(
|
||||
self, pulsar_host, auth, path="/api/v1/stream/document-embeddings"
|
||||
):
|
||||
|
||||
super(DocumentEmbeddingsStreamEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_host=pulsar_host
|
||||
|
||||
self.subscriber = Subscriber(
|
||||
self.pulsar_host, document_embeddings_store_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
schema=JsonSchema(DocumentEmbeddings)
|
||||
)
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
worker = asyncio.create_task(
|
||||
self.async_thread(ws, running)
|
||||
)
|
||||
|
||||
await super(DocumentEmbeddingsStreamEndpoint, self).listener(
|
||||
ws, running
|
||||
)
|
||||
|
||||
await worker
|
||||
|
||||
async def start(self):
|
||||
|
||||
self.subscriber.start()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
q = self.subscriber.subscribe_all(id)
|
||||
|
||||
while running.get():
|
||||
try:
|
||||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_document_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
self.subscriber.unsubscribe_all(id)
|
||||
|
||||
running.stop()
|
||||
|
||||
|
|
@ -60,3 +60,20 @@ def serialize_graph_embeddings(message):
|
|||
],
|
||||
}
|
||||
|
||||
def serialize_document_embeddings(message):
|
||||
return {
|
||||
"metadata": {
|
||||
"id": message.metadata.id,
|
||||
"metadata": serialize_subgraph(message.metadata.metadata),
|
||||
"user": message.metadata.user,
|
||||
"collection": message.metadata.collection,
|
||||
},
|
||||
"chunks": [
|
||||
{
|
||||
"vectors": chunk.vectors,
|
||||
"chunk": chunk.chunk.decode("utf-8"),
|
||||
}
|
||||
for chunk in message.chunks
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -41,8 +41,10 @@ from . dbpedia import DbpediaRequestor
|
|||
from . internet_search import InternetSearchRequestor
|
||||
from . triples_stream import TriplesStreamEndpoint
|
||||
from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
|
||||
from . document_embeddings_stream import DocumentEmbeddingsStreamEndpoint
|
||||
from . triples_load import TriplesLoadEndpoint
|
||||
from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
|
||||
from . document_embeddings_load import DocumentEmbeddingsLoadEndpoint
|
||||
from . mux import MuxEndpoint
|
||||
from . document_load import DocumentLoadSender
|
||||
from . text_load import TextLoadSender
|
||||
|
|
@ -203,6 +205,10 @@ class Api:
|
|||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
),
|
||||
DocumentEmbeddingsStreamEndpoint(
|
||||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
),
|
||||
TriplesLoadEndpoint(
|
||||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
|
|
@ -211,6 +217,10 @@ class Api:
|
|||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
),
|
||||
DocumentEmbeddingsLoadEndpoint(
|
||||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
),
|
||||
MuxEndpoint(
|
||||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue