API supports doc & text load (#167)

This commit is contained in:
cybermaggedon 2024-11-21 14:53:53 +00:00 committed by GitHub
parent a1e0edd96f
commit dc0f54f236
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 143 additions and 17 deletions

View file

@ -19,6 +19,7 @@ import json
import logging
import uuid
import os
import base64
import pulsar
from pulsar.asyncio import Client
@ -32,6 +33,8 @@ from ... log_level import LogLevel
from trustgraph.clients.llm_client import LlmClient
from trustgraph.clients.prompt_client import PromptClient
from ... schema import Value, Metadata, Document, TextDocument, Triple
from ... schema import TextCompletionRequest, TextCompletionResponse
from ... schema import text_completion_request_queue
from ... schema import text_completion_response_queue
@ -44,7 +47,7 @@ from ... schema import GraphRagQuery, GraphRagResponse
from ... schema import graph_rag_request_queue
from ... schema import graph_rag_response_queue
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Value
from ... schema import TriplesQueryRequest, TriplesQueryResponse
from ... schema import triples_request_queue
from ... schema import triples_response_queue
@ -56,6 +59,8 @@ from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue
from ... schema import document_ingest_queue, text_ingest_queue
logger = logging.getLogger("api")
logger.setLevel(logging.INFO)
@ -63,13 +68,31 @@ default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
default_timeout = 600
default_port = 8088
def to_value(x):
if x.startswith("http:") or x.startswith("https:"):
return Value(value=x, is_uri=True)
else:
return Value(value=x, is_uri=True)
def to_subgraph(x):
return [
Triple(
s=to_value(t["s"]),
p=to_value(t["p"]),
o=to_value(t["o"])
)
for t in x
]
class Publisher:
def __init__(self, pulsar_host, topic, schema=None, max_size=10):
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
chunking_enabled=False):
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = asyncio.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
async def run(self):
@ -80,10 +103,16 @@ class Publisher:
async with client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
) as producer:
while True:
id, item = await self.q.get()
await producer.send(item, { "id": id })
if id:
await producer.send(item, { "id": id })
else:
await producer.send(item)
except Exception as e:
print("Exception:", e, flush=True)
@ -139,7 +168,10 @@ class Api:
def __init__(self, **config):
self.app = web.Application(middlewares=[])
self.app = web.Application(
middlewares=[],
client_max_size=256 * 1024 * 1024
)
self.port = int(config.get("port", default_port))
self.timeout = int(config.get("timeout", default_timeout))
@ -211,6 +243,18 @@ class Api:
JsonSchema(EmbeddingsResponse)
)
self.document_out = Publisher(
self.pulsar_host, document_ingest_queue,
schema=JsonSchema(Document),
chunking_enabled=True,
)
self.text_out = Publisher(
self.pulsar_host, text_ingest_queue,
schema=JsonSchema(TextDocument),
chunking_enabled=True,
)
self.app.add_routes([
web.post("/api/v1/text-completion", self.llm),
web.post("/api/v1/prompt", self.prompt),
@ -218,6 +262,8 @@ class Api:
web.post("/api/v1/triples-query", self.triples_query),
web.post("/api/v1/agent", self.agent),
web.post("/api/v1/embeddings", self.embeddings),
web.post("/api/v1/load/document", self.load_document),
web.post("/api/v1/load/text", self.load_text),
])
async def llm(self, request):
@ -368,26 +414,17 @@ class Api:
q = await self.triples_query_in.subscribe(id)
if "s" in data:
if data["s"].startswith("http:") or data["s"].startswith("https:"):
s = Value(value=data["s"], is_uri=True)
else:
s = Value(value=data["s"], is_uri=True)
s = to_value(data["s"])
else:
s = None
if "p" in data:
if data["p"].startswith("http:") or data["p"].startswith("https:"):
p = Value(value=data["p"], is_uri=True)
else:
p = Value(value=data["p"], is_uri=True)
p = to_value(data["p"])
else:
p = None
if "o" in data:
if data["o"].startswith("http:") or data["o"].startswith("https:"):
o = Value(value=data["o"], is_uri=True)
else:
o = Value(value=data["o"], is_uri=True)
o = to_value(data["o"])
else:
o = None
@ -537,6 +574,92 @@ class Api:
finally:
await self.embeddings_in.unsubscribe(id)
async def load_document(self, request):
try:
data = await request.json()
if "metadata" in data:
metadata = to_subgraph(data["metadata"])
else:
metadata = []
# Doing a base64 decode/encode here to make sure the
# content is valid base64
doc = base64.b64decode(data["data"])
resp = await self.document_out.send(
None,
Document(
metadata=Metadata(
id=data.get("id"),
metadata=metadata,
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
data=base64.b64encode(doc).decode("utf-8")
)
)
print("Document loaded.")
return web.json_response(
{ }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
async def load_text(self, request):
try:
data = await request.json()
if "metadata" in data:
metadata = to_subgraph(data["metadata"])
else:
metadata = []
if "charset" in data:
charset = data["charset"]
else:
charset = "utf-8"
# Text is base64 encoded
text = base64.b64decode(data["text"]).decode(charset)
resp = await self.text_out.send(
None,
TextDocument(
metadata=Metadata(
id=data.get("id"),
metadata=metadata,
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
text=text,
)
)
print("Text document loaded.")
return web.json_response(
{ }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
async def app_factory(self):
self.llm_pub_task = asyncio.create_task(self.llm_in.run())
@ -565,6 +688,10 @@ class Api:
self.embeddings_out.run()
)
self.doc_ingest_pub_task = asyncio.create_task(self.document_out.run())
self.text_ingest_pub_task = asyncio.create_task(self.text_out.run())
return self.app
def run(self):