diff --git a/trustgraph-cli/scripts/tg-load-text b/trustgraph-cli/scripts/tg-load-text index 88dc8e17..e49ee7a9 100755 --- a/trustgraph-cli/scripts/tg-load-text +++ b/trustgraph-cli/scripts/tg-load-text @@ -6,7 +6,6 @@ Loads a text document into TrustGraph processing. import pulsar from pulsar.schema import JsonSchema -import base64 import hashlib import argparse import os diff --git a/trustgraph-flow/trustgraph/api/gateway/service.py b/trustgraph-flow/trustgraph/api/gateway/service.py index b955af1e..2ac22892 100755 --- a/trustgraph-flow/trustgraph/api/gateway/service.py +++ b/trustgraph-flow/trustgraph/api/gateway/service.py @@ -19,6 +19,7 @@ import json import logging import uuid import os +import base64 import pulsar from pulsar.asyncio import Client @@ -32,6 +33,8 @@ from ... log_level import LogLevel from trustgraph.clients.llm_client import LlmClient from trustgraph.clients.prompt_client import PromptClient +from ... schema import Value, Metadata, Document, TextDocument, Triple + from ... schema import TextCompletionRequest, TextCompletionResponse from ... schema import text_completion_request_queue from ... schema import text_completion_response_queue @@ -44,7 +47,7 @@ from ... schema import GraphRagQuery, GraphRagResponse from ... schema import graph_rag_request_queue from ... schema import graph_rag_response_queue -from ... schema import TriplesQueryRequest, TriplesQueryResponse, Value +from ... schema import TriplesQueryRequest, TriplesQueryResponse from ... schema import triples_request_queue from ... schema import triples_response_queue @@ -56,6 +59,8 @@ from ... schema import EmbeddingsRequest, EmbeddingsResponse from ... schema import embeddings_request_queue from ... schema import embeddings_response_queue +from ... schema import document_ingest_queue, text_ingest_queue + logger = logging.getLogger("api") logger.setLevel(logging.INFO) @@ -63,13 +68,31 @@ default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") default_timeout = 600 default_port = 8088 +def to_value(x): + if x.startswith("http:") or x.startswith("https:"): + return Value(value=x, is_uri=True) + else: + return Value(value=x, is_uri=True) + +def to_subgraph(x): + return [ + Triple( + s=to_value(t["s"]), + p=to_value(t["p"]), + o=to_value(t["o"]) + ) + for t in x + ] + class Publisher: - def __init__(self, pulsar_host, topic, schema=None, max_size=10): + def __init__(self, pulsar_host, topic, schema=None, max_size=10, + chunking_enabled=False): self.pulsar_host = pulsar_host self.topic = topic self.schema = schema self.q = asyncio.Queue(maxsize=max_size) + self.chunking_enabled = chunking_enabled async def run(self): @@ -80,10 +103,16 @@ class Publisher: async with client.create_producer( topic=self.topic, schema=self.schema, + chunking_enabled=self.chunking_enabled, ) as producer: while True: id, item = await self.q.get() - await producer.send(item, { "id": id }) + + if id: + await producer.send(item, { "id": id }) + else: + await producer.send(item) + except Exception as e: print("Exception:", e, flush=True) @@ -139,7 +168,10 @@ class Api: def __init__(self, **config): - self.app = web.Application(middlewares=[]) + self.app = web.Application( + middlewares=[], + client_max_size=256 * 1024 * 1024 + ) self.port = int(config.get("port", default_port)) self.timeout = int(config.get("timeout", default_timeout)) @@ -211,6 +243,18 @@ class Api: JsonSchema(EmbeddingsResponse) ) + self.document_out = Publisher( + self.pulsar_host, document_ingest_queue, + schema=JsonSchema(Document), + chunking_enabled=True, + ) + + self.text_out = Publisher( + self.pulsar_host, text_ingest_queue, + schema=JsonSchema(TextDocument), + chunking_enabled=True, + ) + self.app.add_routes([ web.post("/api/v1/text-completion", self.llm), web.post("/api/v1/prompt", self.prompt), @@ -218,6 +262,8 @@ class Api: web.post("/api/v1/triples-query", self.triples_query), web.post("/api/v1/agent", self.agent), web.post("/api/v1/embeddings", self.embeddings), + web.post("/api/v1/load/document", self.load_document), + web.post("/api/v1/load/text", self.load_text), ]) async def llm(self, request): @@ -368,26 +414,17 @@ class Api: q = await self.triples_query_in.subscribe(id) if "s" in data: - if data["s"].startswith("http:") or data["s"].startswith("https:"): - s = Value(value=data["s"], is_uri=True) - else: - s = Value(value=data["s"], is_uri=True) + s = to_value(data["s"]) else: s = None if "p" in data: - if data["p"].startswith("http:") or data["p"].startswith("https:"): - p = Value(value=data["p"], is_uri=True) - else: - p = Value(value=data["p"], is_uri=True) + p = to_value(data["p"]) else: p = None if "o" in data: - if data["o"].startswith("http:") or data["o"].startswith("https:"): - o = Value(value=data["o"], is_uri=True) - else: - o = Value(value=data["o"], is_uri=True) + o = to_value(data["o"]) else: o = None @@ -537,6 +574,92 @@ class Api: finally: await self.embeddings_in.unsubscribe(id) + async def load_document(self, request): + + try: + + data = await request.json() + + if "metadata" in data: + metadata = to_subgraph(data["metadata"]) + else: + metadata = [] + + # Doing a base64 decode/encode here to make sure the + # content is valid base64 + doc = base64.b64decode(data["data"]) + + resp = await self.document_out.send( + None, + Document( + metadata=Metadata( + id=data.get("id"), + metadata=metadata, + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default"), + ), + data=base64.b64encode(doc).decode("utf-8") + ) + ) + + print("Document loaded.") + + return web.json_response( + { } + ) + + except Exception as e: + logging.error(f"Exception: {e}") + + return web.json_response( + { "error": str(e) } + ) + + async def load_text(self, request): + + try: + + data = await request.json() + + if "metadata" in data: + metadata = to_subgraph(data["metadata"]) + else: + metadata = [] + + if "charset" in data: + charset = data["charset"] + else: + charset = "utf-8" + + # Text is base64 encoded + text = base64.b64decode(data["text"]).decode(charset) + + resp = await self.text_out.send( + None, + TextDocument( + metadata=Metadata( + id=data.get("id"), + metadata=metadata, + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default"), + ), + text=text, + ) + ) + + print("Text document loaded.") + + return web.json_response( + { } + ) + + except Exception as e: + logging.error(f"Exception: {e}") + + return web.json_response( + { "error": str(e) } + ) + async def app_factory(self): self.llm_pub_task = asyncio.create_task(self.llm_in.run()) @@ -565,6 +688,10 @@ class Api: self.embeddings_out.run() ) + self.doc_ingest_pub_task = asyncio.create_task(self.document_out.run()) + + self.text_ingest_pub_task = asyncio.create_task(self.text_out.run()) + return self.app def run(self):