mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 09:56:22 +02:00
API supports doc & text load (#167)
This commit is contained in:
parent
a1e0edd96f
commit
dc0f54f236
2 changed files with 143 additions and 17 deletions
|
|
@ -19,6 +19,7 @@ import json
|
|||
import logging
|
||||
import uuid
|
||||
import os
|
||||
import base64
|
||||
|
||||
import pulsar
|
||||
from pulsar.asyncio import Client
|
||||
|
|
@ -32,6 +33,8 @@ from ... log_level import LogLevel
|
|||
from trustgraph.clients.llm_client import LlmClient
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
from ... schema import Value, Metadata, Document, TextDocument, Triple
|
||||
|
||||
from ... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from ... schema import text_completion_request_queue
|
||||
from ... schema import text_completion_response_queue
|
||||
|
|
@ -44,7 +47,7 @@ from ... schema import GraphRagQuery, GraphRagResponse
|
|||
from ... schema import graph_rag_request_queue
|
||||
from ... schema import graph_rag_response_queue
|
||||
|
||||
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Value
|
||||
from ... schema import TriplesQueryRequest, TriplesQueryResponse
|
||||
from ... schema import triples_request_queue
|
||||
from ... schema import triples_response_queue
|
||||
|
||||
|
|
@ -56,6 +59,8 @@ from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
|||
from ... schema import embeddings_request_queue
|
||||
from ... schema import embeddings_response_queue
|
||||
|
||||
from ... schema import document_ingest_queue, text_ingest_queue
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
|
@ -63,13 +68,31 @@ default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
|
|||
default_timeout = 600
|
||||
default_port = 8088
|
||||
|
||||
def to_value(x):
|
||||
if x.startswith("http:") or x.startswith("https:"):
|
||||
return Value(value=x, is_uri=True)
|
||||
else:
|
||||
return Value(value=x, is_uri=True)
|
||||
|
||||
def to_subgraph(x):
|
||||
return [
|
||||
Triple(
|
||||
s=to_value(t["s"]),
|
||||
p=to_value(t["p"]),
|
||||
o=to_value(t["o"])
|
||||
)
|
||||
for t in x
|
||||
]
|
||||
|
||||
class Publisher:
|
||||
|
||||
def __init__(self, pulsar_host, topic, schema=None, max_size=10):
|
||||
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
|
||||
chunking_enabled=False):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.schema = schema
|
||||
self.q = asyncio.Queue(maxsize=max_size)
|
||||
self.chunking_enabled = chunking_enabled
|
||||
|
||||
async def run(self):
|
||||
|
||||
|
|
@ -80,10 +103,16 @@ class Publisher:
|
|||
async with client.create_producer(
|
||||
topic=self.topic,
|
||||
schema=self.schema,
|
||||
chunking_enabled=self.chunking_enabled,
|
||||
) as producer:
|
||||
while True:
|
||||
id, item = await self.q.get()
|
||||
await producer.send(item, { "id": id })
|
||||
|
||||
if id:
|
||||
await producer.send(item, { "id": id })
|
||||
else:
|
||||
await producer.send(item)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
|
|
@ -139,7 +168,10 @@ class Api:
|
|||
|
||||
def __init__(self, **config):
|
||||
|
||||
self.app = web.Application(middlewares=[])
|
||||
self.app = web.Application(
|
||||
middlewares=[],
|
||||
client_max_size=256 * 1024 * 1024
|
||||
)
|
||||
|
||||
self.port = int(config.get("port", default_port))
|
||||
self.timeout = int(config.get("timeout", default_timeout))
|
||||
|
|
@ -211,6 +243,18 @@ class Api:
|
|||
JsonSchema(EmbeddingsResponse)
|
||||
)
|
||||
|
||||
self.document_out = Publisher(
|
||||
self.pulsar_host, document_ingest_queue,
|
||||
schema=JsonSchema(Document),
|
||||
chunking_enabled=True,
|
||||
)
|
||||
|
||||
self.text_out = Publisher(
|
||||
self.pulsar_host, text_ingest_queue,
|
||||
schema=JsonSchema(TextDocument),
|
||||
chunking_enabled=True,
|
||||
)
|
||||
|
||||
self.app.add_routes([
|
||||
web.post("/api/v1/text-completion", self.llm),
|
||||
web.post("/api/v1/prompt", self.prompt),
|
||||
|
|
@ -218,6 +262,8 @@ class Api:
|
|||
web.post("/api/v1/triples-query", self.triples_query),
|
||||
web.post("/api/v1/agent", self.agent),
|
||||
web.post("/api/v1/embeddings", self.embeddings),
|
||||
web.post("/api/v1/load/document", self.load_document),
|
||||
web.post("/api/v1/load/text", self.load_text),
|
||||
])
|
||||
|
||||
async def llm(self, request):
|
||||
|
|
@ -368,26 +414,17 @@ class Api:
|
|||
q = await self.triples_query_in.subscribe(id)
|
||||
|
||||
if "s" in data:
|
||||
if data["s"].startswith("http:") or data["s"].startswith("https:"):
|
||||
s = Value(value=data["s"], is_uri=True)
|
||||
else:
|
||||
s = Value(value=data["s"], is_uri=True)
|
||||
s = to_value(data["s"])
|
||||
else:
|
||||
s = None
|
||||
|
||||
if "p" in data:
|
||||
if data["p"].startswith("http:") or data["p"].startswith("https:"):
|
||||
p = Value(value=data["p"], is_uri=True)
|
||||
else:
|
||||
p = Value(value=data["p"], is_uri=True)
|
||||
p = to_value(data["p"])
|
||||
else:
|
||||
p = None
|
||||
|
||||
if "o" in data:
|
||||
if data["o"].startswith("http:") or data["o"].startswith("https:"):
|
||||
o = Value(value=data["o"], is_uri=True)
|
||||
else:
|
||||
o = Value(value=data["o"], is_uri=True)
|
||||
o = to_value(data["o"])
|
||||
else:
|
||||
o = None
|
||||
|
||||
|
|
@ -537,6 +574,92 @@ class Api:
|
|||
finally:
|
||||
await self.embeddings_in.unsubscribe(id)
|
||||
|
||||
async def load_document(self, request):
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
if "metadata" in data:
|
||||
metadata = to_subgraph(data["metadata"])
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
# Doing a base64 decode/encode here to make sure the
|
||||
# content is valid base64
|
||||
doc = base64.b64decode(data["data"])
|
||||
|
||||
resp = await self.document_out.send(
|
||||
None,
|
||||
Document(
|
||||
metadata=Metadata(
|
||||
id=data.get("id"),
|
||||
metadata=metadata,
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
),
|
||||
data=base64.b64encode(doc).decode("utf-8")
|
||||
)
|
||||
)
|
||||
|
||||
print("Document loaded.")
|
||||
|
||||
return web.json_response(
|
||||
{ }
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
async def load_text(self, request):
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
if "metadata" in data:
|
||||
metadata = to_subgraph(data["metadata"])
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
if "charset" in data:
|
||||
charset = data["charset"]
|
||||
else:
|
||||
charset = "utf-8"
|
||||
|
||||
# Text is base64 encoded
|
||||
text = base64.b64decode(data["text"]).decode(charset)
|
||||
|
||||
resp = await self.text_out.send(
|
||||
None,
|
||||
TextDocument(
|
||||
metadata=Metadata(
|
||||
id=data.get("id"),
|
||||
metadata=metadata,
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
),
|
||||
text=text,
|
||||
)
|
||||
)
|
||||
|
||||
print("Text document loaded.")
|
||||
|
||||
return web.json_response(
|
||||
{ }
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
async def app_factory(self):
|
||||
|
||||
self.llm_pub_task = asyncio.create_task(self.llm_in.run())
|
||||
|
|
@ -565,6 +688,10 @@ class Api:
|
|||
self.embeddings_out.run()
|
||||
)
|
||||
|
||||
self.doc_ingest_pub_task = asyncio.create_task(self.document_out.run())
|
||||
|
||||
self.text_ingest_pub_task = asyncio.create_task(self.text_out.run())
|
||||
|
||||
return self.app
|
||||
|
||||
def run(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue