API supports doc & text load (#167)

This commit is contained in:
cybermaggedon 2024-11-21 14:53:53 +00:00 committed by GitHub
parent a1e0edd96f
commit dc0f54f236
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 143 additions and 17 deletions

View file

@ -6,7 +6,6 @@ Loads a text document into TrustGraph processing.
import pulsar import pulsar
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
import base64
import hashlib import hashlib
import argparse import argparse
import os import os

View file

@ -19,6 +19,7 @@ import json
import logging import logging
import uuid import uuid
import os import os
import base64
import pulsar import pulsar
from pulsar.asyncio import Client from pulsar.asyncio import Client
@ -32,6 +33,8 @@ from ... log_level import LogLevel
from trustgraph.clients.llm_client import LlmClient from trustgraph.clients.llm_client import LlmClient
from trustgraph.clients.prompt_client import PromptClient from trustgraph.clients.prompt_client import PromptClient
from ... schema import Value, Metadata, Document, TextDocument, Triple
from ... schema import TextCompletionRequest, TextCompletionResponse from ... schema import TextCompletionRequest, TextCompletionResponse
from ... schema import text_completion_request_queue from ... schema import text_completion_request_queue
from ... schema import text_completion_response_queue from ... schema import text_completion_response_queue
@ -44,7 +47,7 @@ from ... schema import GraphRagQuery, GraphRagResponse
from ... schema import graph_rag_request_queue from ... schema import graph_rag_request_queue
from ... schema import graph_rag_response_queue from ... schema import graph_rag_response_queue
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Value from ... schema import TriplesQueryRequest, TriplesQueryResponse
from ... schema import triples_request_queue from ... schema import triples_request_queue
from ... schema import triples_response_queue from ... schema import triples_response_queue
@ -56,6 +59,8 @@ from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import embeddings_request_queue from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue from ... schema import embeddings_response_queue
from ... schema import document_ingest_queue, text_ingest_queue
logger = logging.getLogger("api") logger = logging.getLogger("api")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -63,13 +68,31 @@ default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
default_timeout = 600 default_timeout = 600
default_port = 8088 default_port = 8088
def to_value(x):
if x.startswith("http:") or x.startswith("https:"):
return Value(value=x, is_uri=True)
else:
return Value(value=x, is_uri=True)
def to_subgraph(x):
return [
Triple(
s=to_value(t["s"]),
p=to_value(t["p"]),
o=to_value(t["o"])
)
for t in x
]
class Publisher: class Publisher:
def __init__(self, pulsar_host, topic, schema=None, max_size=10): def __init__(self, pulsar_host, topic, schema=None, max_size=10,
chunking_enabled=False):
self.pulsar_host = pulsar_host self.pulsar_host = pulsar_host
self.topic = topic self.topic = topic
self.schema = schema self.schema = schema
self.q = asyncio.Queue(maxsize=max_size) self.q = asyncio.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
async def run(self): async def run(self):
@ -80,10 +103,16 @@ class Publisher:
async with client.create_producer( async with client.create_producer(
topic=self.topic, topic=self.topic,
schema=self.schema, schema=self.schema,
chunking_enabled=self.chunking_enabled,
) as producer: ) as producer:
while True: while True:
id, item = await self.q.get() id, item = await self.q.get()
await producer.send(item, { "id": id })
if id:
await producer.send(item, { "id": id })
else:
await producer.send(item)
except Exception as e: except Exception as e:
print("Exception:", e, flush=True) print("Exception:", e, flush=True)
@ -139,7 +168,10 @@ class Api:
def __init__(self, **config): def __init__(self, **config):
self.app = web.Application(middlewares=[]) self.app = web.Application(
middlewares=[],
client_max_size=256 * 1024 * 1024
)
self.port = int(config.get("port", default_port)) self.port = int(config.get("port", default_port))
self.timeout = int(config.get("timeout", default_timeout)) self.timeout = int(config.get("timeout", default_timeout))
@ -211,6 +243,18 @@ class Api:
JsonSchema(EmbeddingsResponse) JsonSchema(EmbeddingsResponse)
) )
self.document_out = Publisher(
self.pulsar_host, document_ingest_queue,
schema=JsonSchema(Document),
chunking_enabled=True,
)
self.text_out = Publisher(
self.pulsar_host, text_ingest_queue,
schema=JsonSchema(TextDocument),
chunking_enabled=True,
)
self.app.add_routes([ self.app.add_routes([
web.post("/api/v1/text-completion", self.llm), web.post("/api/v1/text-completion", self.llm),
web.post("/api/v1/prompt", self.prompt), web.post("/api/v1/prompt", self.prompt),
@ -218,6 +262,8 @@ class Api:
web.post("/api/v1/triples-query", self.triples_query), web.post("/api/v1/triples-query", self.triples_query),
web.post("/api/v1/agent", self.agent), web.post("/api/v1/agent", self.agent),
web.post("/api/v1/embeddings", self.embeddings), web.post("/api/v1/embeddings", self.embeddings),
web.post("/api/v1/load/document", self.load_document),
web.post("/api/v1/load/text", self.load_text),
]) ])
async def llm(self, request): async def llm(self, request):
@ -368,26 +414,17 @@ class Api:
q = await self.triples_query_in.subscribe(id) q = await self.triples_query_in.subscribe(id)
if "s" in data: if "s" in data:
if data["s"].startswith("http:") or data["s"].startswith("https:"): s = to_value(data["s"])
s = Value(value=data["s"], is_uri=True)
else:
s = Value(value=data["s"], is_uri=True)
else: else:
s = None s = None
if "p" in data: if "p" in data:
if data["p"].startswith("http:") or data["p"].startswith("https:"): p = to_value(data["p"])
p = Value(value=data["p"], is_uri=True)
else:
p = Value(value=data["p"], is_uri=True)
else: else:
p = None p = None
if "o" in data: if "o" in data:
if data["o"].startswith("http:") or data["o"].startswith("https:"): o = to_value(data["o"])
o = Value(value=data["o"], is_uri=True)
else:
o = Value(value=data["o"], is_uri=True)
else: else:
o = None o = None
@ -537,6 +574,92 @@ class Api:
finally: finally:
await self.embeddings_in.unsubscribe(id) await self.embeddings_in.unsubscribe(id)
async def load_document(self, request):
try:
data = await request.json()
if "metadata" in data:
metadata = to_subgraph(data["metadata"])
else:
metadata = []
# Doing a base64 decode/encode here to make sure the
# content is valid base64
doc = base64.b64decode(data["data"])
resp = await self.document_out.send(
None,
Document(
metadata=Metadata(
id=data.get("id"),
metadata=metadata,
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
data=base64.b64encode(doc).decode("utf-8")
)
)
print("Document loaded.")
return web.json_response(
{ }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
async def load_text(self, request):
try:
data = await request.json()
if "metadata" in data:
metadata = to_subgraph(data["metadata"])
else:
metadata = []
if "charset" in data:
charset = data["charset"]
else:
charset = "utf-8"
# Text is base64 encoded
text = base64.b64decode(data["text"]).decode(charset)
resp = await self.text_out.send(
None,
TextDocument(
metadata=Metadata(
id=data.get("id"),
metadata=metadata,
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
text=text,
)
)
print("Text document loaded.")
return web.json_response(
{ }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
async def app_factory(self): async def app_factory(self):
self.llm_pub_task = asyncio.create_task(self.llm_in.run()) self.llm_pub_task = asyncio.create_task(self.llm_in.run())
@ -565,6 +688,10 @@ class Api:
self.embeddings_out.run() self.embeddings_out.run()
) )
self.doc_ingest_pub_task = asyncio.create_task(self.document_out.run())
self.text_ingest_pub_task = asyncio.create_task(self.text_out.run())
return self.app return self.app
def run(self): def run(self):