mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 02:23:44 +02:00
API supports doc & text load (#167)
This commit is contained in:
parent
a1e0edd96f
commit
dc0f54f236
2 changed files with 143 additions and 17 deletions
|
|
@ -6,7 +6,6 @@ Loads a text document into TrustGraph processing.
|
||||||
|
|
||||||
import pulsar
|
import pulsar
|
||||||
from pulsar.schema import JsonSchema
|
from pulsar.schema import JsonSchema
|
||||||
import base64
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
import os
|
import os
|
||||||
|
import base64
|
||||||
|
|
||||||
import pulsar
|
import pulsar
|
||||||
from pulsar.asyncio import Client
|
from pulsar.asyncio import Client
|
||||||
|
|
@ -32,6 +33,8 @@ from ... log_level import LogLevel
|
||||||
from trustgraph.clients.llm_client import LlmClient
|
from trustgraph.clients.llm_client import LlmClient
|
||||||
from trustgraph.clients.prompt_client import PromptClient
|
from trustgraph.clients.prompt_client import PromptClient
|
||||||
|
|
||||||
|
from ... schema import Value, Metadata, Document, TextDocument, Triple
|
||||||
|
|
||||||
from ... schema import TextCompletionRequest, TextCompletionResponse
|
from ... schema import TextCompletionRequest, TextCompletionResponse
|
||||||
from ... schema import text_completion_request_queue
|
from ... schema import text_completion_request_queue
|
||||||
from ... schema import text_completion_response_queue
|
from ... schema import text_completion_response_queue
|
||||||
|
|
@ -44,7 +47,7 @@ from ... schema import GraphRagQuery, GraphRagResponse
|
||||||
from ... schema import graph_rag_request_queue
|
from ... schema import graph_rag_request_queue
|
||||||
from ... schema import graph_rag_response_queue
|
from ... schema import graph_rag_response_queue
|
||||||
|
|
||||||
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Value
|
from ... schema import TriplesQueryRequest, TriplesQueryResponse
|
||||||
from ... schema import triples_request_queue
|
from ... schema import triples_request_queue
|
||||||
from ... schema import triples_response_queue
|
from ... schema import triples_response_queue
|
||||||
|
|
||||||
|
|
@ -56,6 +59,8 @@ from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
||||||
from ... schema import embeddings_request_queue
|
from ... schema import embeddings_request_queue
|
||||||
from ... schema import embeddings_response_queue
|
from ... schema import embeddings_response_queue
|
||||||
|
|
||||||
|
from ... schema import document_ingest_queue, text_ingest_queue
|
||||||
|
|
||||||
logger = logging.getLogger("api")
|
logger = logging.getLogger("api")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
@ -63,13 +68,31 @@ default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
|
||||||
default_timeout = 600
|
default_timeout = 600
|
||||||
default_port = 8088
|
default_port = 8088
|
||||||
|
|
||||||
|
def to_value(x):
|
||||||
|
if x.startswith("http:") or x.startswith("https:"):
|
||||||
|
return Value(value=x, is_uri=True)
|
||||||
|
else:
|
||||||
|
return Value(value=x, is_uri=True)
|
||||||
|
|
||||||
|
def to_subgraph(x):
|
||||||
|
return [
|
||||||
|
Triple(
|
||||||
|
s=to_value(t["s"]),
|
||||||
|
p=to_value(t["p"]),
|
||||||
|
o=to_value(t["o"])
|
||||||
|
)
|
||||||
|
for t in x
|
||||||
|
]
|
||||||
|
|
||||||
class Publisher:
|
class Publisher:
|
||||||
|
|
||||||
def __init__(self, pulsar_host, topic, schema=None, max_size=10):
|
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
|
||||||
|
chunking_enabled=False):
|
||||||
self.pulsar_host = pulsar_host
|
self.pulsar_host = pulsar_host
|
||||||
self.topic = topic
|
self.topic = topic
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.q = asyncio.Queue(maxsize=max_size)
|
self.q = asyncio.Queue(maxsize=max_size)
|
||||||
|
self.chunking_enabled = chunking_enabled
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
|
||||||
|
|
@ -80,10 +103,16 @@ class Publisher:
|
||||||
async with client.create_producer(
|
async with client.create_producer(
|
||||||
topic=self.topic,
|
topic=self.topic,
|
||||||
schema=self.schema,
|
schema=self.schema,
|
||||||
|
chunking_enabled=self.chunking_enabled,
|
||||||
) as producer:
|
) as producer:
|
||||||
while True:
|
while True:
|
||||||
id, item = await self.q.get()
|
id, item = await self.q.get()
|
||||||
await producer.send(item, { "id": id })
|
|
||||||
|
if id:
|
||||||
|
await producer.send(item, { "id": id })
|
||||||
|
else:
|
||||||
|
await producer.send(item)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Exception:", e, flush=True)
|
print("Exception:", e, flush=True)
|
||||||
|
|
||||||
|
|
@ -139,7 +168,10 @@ class Api:
|
||||||
|
|
||||||
def __init__(self, **config):
|
def __init__(self, **config):
|
||||||
|
|
||||||
self.app = web.Application(middlewares=[])
|
self.app = web.Application(
|
||||||
|
middlewares=[],
|
||||||
|
client_max_size=256 * 1024 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
self.port = int(config.get("port", default_port))
|
self.port = int(config.get("port", default_port))
|
||||||
self.timeout = int(config.get("timeout", default_timeout))
|
self.timeout = int(config.get("timeout", default_timeout))
|
||||||
|
|
@ -211,6 +243,18 @@ class Api:
|
||||||
JsonSchema(EmbeddingsResponse)
|
JsonSchema(EmbeddingsResponse)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.document_out = Publisher(
|
||||||
|
self.pulsar_host, document_ingest_queue,
|
||||||
|
schema=JsonSchema(Document),
|
||||||
|
chunking_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.text_out = Publisher(
|
||||||
|
self.pulsar_host, text_ingest_queue,
|
||||||
|
schema=JsonSchema(TextDocument),
|
||||||
|
chunking_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
web.post("/api/v1/text-completion", self.llm),
|
web.post("/api/v1/text-completion", self.llm),
|
||||||
web.post("/api/v1/prompt", self.prompt),
|
web.post("/api/v1/prompt", self.prompt),
|
||||||
|
|
@ -218,6 +262,8 @@ class Api:
|
||||||
web.post("/api/v1/triples-query", self.triples_query),
|
web.post("/api/v1/triples-query", self.triples_query),
|
||||||
web.post("/api/v1/agent", self.agent),
|
web.post("/api/v1/agent", self.agent),
|
||||||
web.post("/api/v1/embeddings", self.embeddings),
|
web.post("/api/v1/embeddings", self.embeddings),
|
||||||
|
web.post("/api/v1/load/document", self.load_document),
|
||||||
|
web.post("/api/v1/load/text", self.load_text),
|
||||||
])
|
])
|
||||||
|
|
||||||
async def llm(self, request):
|
async def llm(self, request):
|
||||||
|
|
@ -368,26 +414,17 @@ class Api:
|
||||||
q = await self.triples_query_in.subscribe(id)
|
q = await self.triples_query_in.subscribe(id)
|
||||||
|
|
||||||
if "s" in data:
|
if "s" in data:
|
||||||
if data["s"].startswith("http:") or data["s"].startswith("https:"):
|
s = to_value(data["s"])
|
||||||
s = Value(value=data["s"], is_uri=True)
|
|
||||||
else:
|
|
||||||
s = Value(value=data["s"], is_uri=True)
|
|
||||||
else:
|
else:
|
||||||
s = None
|
s = None
|
||||||
|
|
||||||
if "p" in data:
|
if "p" in data:
|
||||||
if data["p"].startswith("http:") or data["p"].startswith("https:"):
|
p = to_value(data["p"])
|
||||||
p = Value(value=data["p"], is_uri=True)
|
|
||||||
else:
|
|
||||||
p = Value(value=data["p"], is_uri=True)
|
|
||||||
else:
|
else:
|
||||||
p = None
|
p = None
|
||||||
|
|
||||||
if "o" in data:
|
if "o" in data:
|
||||||
if data["o"].startswith("http:") or data["o"].startswith("https:"):
|
o = to_value(data["o"])
|
||||||
o = Value(value=data["o"], is_uri=True)
|
|
||||||
else:
|
|
||||||
o = Value(value=data["o"], is_uri=True)
|
|
||||||
else:
|
else:
|
||||||
o = None
|
o = None
|
||||||
|
|
||||||
|
|
@ -537,6 +574,92 @@ class Api:
|
||||||
finally:
|
finally:
|
||||||
await self.embeddings_in.unsubscribe(id)
|
await self.embeddings_in.unsubscribe(id)
|
||||||
|
|
||||||
|
async def load_document(self, request):
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
data = await request.json()
|
||||||
|
|
||||||
|
if "metadata" in data:
|
||||||
|
metadata = to_subgraph(data["metadata"])
|
||||||
|
else:
|
||||||
|
metadata = []
|
||||||
|
|
||||||
|
# Doing a base64 decode/encode here to make sure the
|
||||||
|
# content is valid base64
|
||||||
|
doc = base64.b64decode(data["data"])
|
||||||
|
|
||||||
|
resp = await self.document_out.send(
|
||||||
|
None,
|
||||||
|
Document(
|
||||||
|
metadata=Metadata(
|
||||||
|
id=data.get("id"),
|
||||||
|
metadata=metadata,
|
||||||
|
user=data.get("user", "trustgraph"),
|
||||||
|
collection=data.get("collection", "default"),
|
||||||
|
),
|
||||||
|
data=base64.b64encode(doc).decode("utf-8")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Document loaded.")
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{ }
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Exception: {e}")
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{ "error": str(e) }
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_text(self, request):
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
data = await request.json()
|
||||||
|
|
||||||
|
if "metadata" in data:
|
||||||
|
metadata = to_subgraph(data["metadata"])
|
||||||
|
else:
|
||||||
|
metadata = []
|
||||||
|
|
||||||
|
if "charset" in data:
|
||||||
|
charset = data["charset"]
|
||||||
|
else:
|
||||||
|
charset = "utf-8"
|
||||||
|
|
||||||
|
# Text is base64 encoded
|
||||||
|
text = base64.b64decode(data["text"]).decode(charset)
|
||||||
|
|
||||||
|
resp = await self.text_out.send(
|
||||||
|
None,
|
||||||
|
TextDocument(
|
||||||
|
metadata=Metadata(
|
||||||
|
id=data.get("id"),
|
||||||
|
metadata=metadata,
|
||||||
|
user=data.get("user", "trustgraph"),
|
||||||
|
collection=data.get("collection", "default"),
|
||||||
|
),
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Text document loaded.")
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{ }
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Exception: {e}")
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{ "error": str(e) }
|
||||||
|
)
|
||||||
|
|
||||||
async def app_factory(self):
|
async def app_factory(self):
|
||||||
|
|
||||||
self.llm_pub_task = asyncio.create_task(self.llm_in.run())
|
self.llm_pub_task = asyncio.create_task(self.llm_in.run())
|
||||||
|
|
@ -565,6 +688,10 @@ class Api:
|
||||||
self.embeddings_out.run()
|
self.embeddings_out.run()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.doc_ingest_pub_task = asyncio.create_task(self.document_out.run())
|
||||||
|
|
||||||
|
self.text_ingest_pub_task = asyncio.create_task(self.text_out.run())
|
||||||
|
|
||||||
return self.app
|
return self.app
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue