diff --git a/trustgraph-flow/trustgraph/gateway/agent.py b/trustgraph-flow/trustgraph/gateway/agent.py index e8fd0e72..c7af947b 100644 --- a/trustgraph-flow/trustgraph/gateway/agent.py +++ b/trustgraph-flow/trustgraph/gateway/agent.py @@ -3,20 +3,19 @@ from .. schema import AgentRequest, AgentResponse from .. schema import agent_request_queue from .. schema import agent_response_queue -from . endpoint import MultiResponseServiceEndpoint +from . endpoint import ServiceEndpoint +from . requestor import ServiceRequestor -class AgentEndpoint(MultiResponseServiceEndpoint): +class AgentRequestor(ServiceRequestor): def __init__(self, pulsar_host, timeout, auth): - super(AgentEndpoint, self).__init__( + super(AgentRequestor, self).__init__( pulsar_host=pulsar_host, request_queue=agent_request_queue, response_queue=agent_response_queue, request_schema=AgentRequest, response_schema=AgentResponse, - endpoint_path="/api/v1/agent", timeout=timeout, - auth=auth, ) def to_request(self, body): @@ -25,7 +24,19 @@ class AgentEndpoint(MultiResponseServiceEndpoint): ) def from_response(self, message): + resp = { + } + if message.answer: - return { "answer": message.answer }, True - else: - return {}, False + resp["answer"] = message.answer + + if message.thought: + resp["thought"] = message.thought + + if message.observation: + resp["observation"] = message.observation + + # The 2nd boolean expression indicates whether we're done responding + return resp, (message.answer is not None) + + diff --git a/trustgraph-flow/trustgraph/gateway/dbpedia.py b/trustgraph-flow/trustgraph/gateway/dbpedia.py index a61292a6..8ae4f695 100644 --- a/trustgraph-flow/trustgraph/gateway/dbpedia.py +++ b/trustgraph-flow/trustgraph/gateway/dbpedia.py @@ -4,19 +4,18 @@ from .. schema import dbpedia_lookup_request_queue from .. schema import dbpedia_lookup_response_queue from . endpoint import ServiceEndpoint +from . requestor import ServiceRequestor -class DbpediaEndpoint(ServiceEndpoint): +class DbpediaRequestor(ServiceRequestor): def __init__(self, pulsar_host, timeout, auth): - super(DbpediaEndpoint, self).__init__( + super(DbpediaRequestor, self).__init__( pulsar_host=pulsar_host, request_queue=dbpedia_lookup_request_queue, response_queue=dbpedia_lookup_response_queue, request_schema=LookupRequest, response_schema=LookupResponse, - endpoint_path="/api/v1/dbpedia", timeout=timeout, - auth=auth, ) def to_request(self, body): @@ -26,5 +25,5 @@ class DbpediaEndpoint(ServiceEndpoint): ) def from_response(self, message): - return { "text": message.text } + return { "text": message.text }, True diff --git a/trustgraph-flow/trustgraph/gateway/embeddings.py b/trustgraph-flow/trustgraph/gateway/embeddings.py index 6d3a9fe6..d0f3e1ef 100644 --- a/trustgraph-flow/trustgraph/gateway/embeddings.py +++ b/trustgraph-flow/trustgraph/gateway/embeddings.py @@ -4,19 +4,18 @@ from .. schema import embeddings_request_queue from .. schema import embeddings_response_queue from . endpoint import ServiceEndpoint +from . requestor import ServiceRequestor -class EmbeddingsEndpoint(ServiceEndpoint): +class EmbeddingsRequestor(ServiceRequestor): def __init__(self, pulsar_host, timeout, auth): - super(EmbeddingsEndpoint, self).__init__( + super(EmbeddingsRequestor, self).__init__( pulsar_host=pulsar_host, request_queue=embeddings_request_queue, response_queue=embeddings_response_queue, request_schema=EmbeddingsRequest, response_schema=EmbeddingsResponse, - endpoint_path="/api/v1/embeddings", timeout=timeout, - auth=auth, ) def to_request(self, body): @@ -25,4 +24,6 @@ class EmbeddingsEndpoint(ServiceEndpoint): ) def from_response(self, message): - return { "vectors": message.vectors } + return { "vectors": message.vectors }, True + + diff --git a/trustgraph-flow/trustgraph/gateway/encyclopedia.py b/trustgraph-flow/trustgraph/gateway/encyclopedia.py index 32eb5cd1..3f4dad79 100644 --- a/trustgraph-flow/trustgraph/gateway/encyclopedia.py +++ b/trustgraph-flow/trustgraph/gateway/encyclopedia.py @@ -4,19 +4,18 @@ from .. schema import encyclopedia_lookup_request_queue from .. schema import encyclopedia_lookup_response_queue from . endpoint import ServiceEndpoint +from . requestor import ServiceRequestor -class EncyclopediaEndpoint(ServiceEndpoint): +class EncyclopediaRequestor(ServiceRequestor): def __init__(self, pulsar_host, timeout, auth): - super(EncyclopediaEndpoint, self).__init__( + super(EncyclopediaRequestor, self).__init__( pulsar_host=pulsar_host, request_queue=encyclopedia_lookup_request_queue, response_queue=encyclopedia_lookup_response_queue, request_schema=LookupRequest, response_schema=LookupResponse, - endpoint_path="/api/v1/encyclopedia", timeout=timeout, - auth=auth, ) def to_request(self, body): @@ -26,5 +25,5 @@ class EncyclopediaEndpoint(ServiceEndpoint): ) def from_response(self, message): - return { "text": message.text } + return { "text": message.text }, True diff --git a/trustgraph-flow/trustgraph/gateway/endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint.py index 2b246361..6d6ca8d5 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint.py @@ -13,38 +13,17 @@ logger.setLevel(logging.INFO) class ServiceEndpoint: - def __init__( - self, - pulsar_host, - request_queue, request_schema, - response_queue, response_schema, - endpoint_path, - auth, - subscription="api-gateway", consumer_name="api-gateway", - timeout=600, - ): - - self.pub = Publisher( - pulsar_host, request_queue, - schema=JsonSchema(request_schema) - ) - - self.sub = Subscriber( - pulsar_host, response_queue, - subscription, consumer_name, - JsonSchema(response_schema) - ) + def __init__(self, endpoint_path, auth, requestor): self.path = endpoint_path - self.timeout = timeout - self.auth = auth + self.auth = auth self.operation = "service" - async def start(self): + self.requestor = requestor - self.pub.start() - self.sub.start() + async def start(self): + await self.requestor.start() def add_routes(self, app): @@ -52,16 +31,8 @@ class ServiceEndpoint: web.post(self.path, self.handle), ]) - def to_request(self, request): - raise RuntimeError("Not defined") - - def from_response(self, response): - raise RuntimeError("Not defined") - async def handle(self, request): - id = str(uuid.uuid4()) - print(request.path, "...") try: @@ -82,28 +53,12 @@ class ServiceEndpoint: print(data) - q = self.sub.subscribe(id) + def responder(x, fin): + print(x) - await asyncio.to_thread( - self.pub.send, id, self.to_request(data) - ) + resp, fin = await self.requestor.process(data, responder) - try: - resp = await asyncio.to_thread(q.get, timeout=self.timeout) - except Exception as e: - raise RuntimeError("Timeout") - - print(resp) - - if resp.error: - print("Error") - return web.json_response( - { "error": resp.error.message } - ) - - return web.json_response( - self.from_response(resp) - ) + return web.json_response(resp) except Exception as e: logging.error(f"Exception: {e}") @@ -112,55 +67,3 @@ class ServiceEndpoint: { "error": str(e) } ) - finally: - self.sub.unsubscribe(id) - - -class MultiResponseServiceEndpoint(ServiceEndpoint): - - async def handle(self, request): - - id = str(uuid.uuid4()) - - try: - - data = await request.json() - - q = self.sub.subscribe(id) - - await asyncio.to_thread( - self.pub.send, id, self.to_request(data) - ) - - # Keeps looking at responses... - - while True: - - try: - resp = await asyncio.to_thread(q.get, timeout=self.timeout) - except Exception as e: - raise RuntimeError("Timeout waiting for response") - - if resp.error: - return web.json_response( - { "error": resp.error.message } - ) - - # Until from_response says we have a finished answer - resp, fin = self.from_response(resp) - - - if fin: - return web.json_response(resp) - - # Not finished, so loop round and continue - - except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - - finally: - self.sub.unsubscribe(id) diff --git a/trustgraph-flow/trustgraph/gateway/graph_rag.py b/trustgraph-flow/trustgraph/gateway/graph_rag.py index 58679004..55fd5d2f 100644 --- a/trustgraph-flow/trustgraph/gateway/graph_rag.py +++ b/trustgraph-flow/trustgraph/gateway/graph_rag.py @@ -4,19 +4,18 @@ from .. schema import graph_rag_request_queue from .. schema import graph_rag_response_queue from . endpoint import ServiceEndpoint +from . requestor import ServiceRequestor -class GraphRagEndpoint(ServiceEndpoint): +class GraphRagRequestor(ServiceRequestor): def __init__(self, pulsar_host, timeout, auth): - super(GraphRagEndpoint, self).__init__( + super(GraphRagRequestor, self).__init__( pulsar_host=pulsar_host, request_queue=graph_rag_request_queue, response_queue=graph_rag_response_queue, request_schema=GraphRagQuery, response_schema=GraphRagResponse, - endpoint_path="/api/v1/graph-rag", timeout=timeout, - auth=auth, ) def to_request(self, body): @@ -27,5 +26,5 @@ class GraphRagEndpoint(ServiceEndpoint): ) def from_response(self, message): - return { "response": message.response } + return { "response": message.response }, True diff --git a/trustgraph-flow/trustgraph/gateway/internet_search.py b/trustgraph-flow/trustgraph/gateway/internet_search.py index 5a5dc948..127cd5d1 100644 --- a/trustgraph-flow/trustgraph/gateway/internet_search.py +++ b/trustgraph-flow/trustgraph/gateway/internet_search.py @@ -4,19 +4,18 @@ from .. schema import internet_search_request_queue from .. schema import internet_search_response_queue from . endpoint import ServiceEndpoint +from . requestor import ServiceRequestor -class InternetSearchEndpoint(ServiceEndpoint): +class InternetSearchRequestor(ServiceRequestor): def __init__(self, pulsar_host, timeout, auth): - super(InternetSearchEndpoint, self).__init__( + super(InternetSearchRequestor, self).__init__( pulsar_host=pulsar_host, request_queue=internet_search_request_queue, response_queue=internet_search_response_queue, request_schema=LookupRequest, response_schema=LookupResponse, - endpoint_path="/api/v1/internet-search", timeout=timeout, - auth=auth, ) def to_request(self, body): @@ -26,5 +25,5 @@ class InternetSearchEndpoint(ServiceEndpoint): ) def from_response(self, message): - return { "text": message.text } + return { "text": message.text }, True diff --git a/trustgraph-flow/trustgraph/gateway/mux.py b/trustgraph-flow/trustgraph/gateway/mux.py new file mode 100644 index 00000000..cd5ddfba --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/mux.py @@ -0,0 +1,94 @@ + +import asyncio +import queue +from pulsar.schema import JsonSchema +import uuid +from aiohttp import web, WSMsgType + +from . socket import SocketEndpoint +from . text_completion import TextCompletionRequestor + +class MuxEndpoint(SocketEndpoint): + + def __init__( + self, pulsar_host, auth, + services, + path="/api/v1/mux", + ): + + super(MuxEndpoint, self).__init__( + endpoint_path=path, auth=auth, + ) + + self.q = asyncio.Queue(maxsize=10) + + self.services = services + + async def start(self): + pass + + async def async_thread(self, ws, running): + + while running.get(): + + try: + id, svc, request = await asyncio.wait_for(self.q.get(), 1) + except TimeoutError: + continue + except Exception as e: + await ws.send_json({"id": id, "error": str(e)}) + + try: + + print(svc, request) + + requestor = self.services[svc] + + async def responder(resp, fin): + await ws.send_json({ + "id": id, + "response": resp, + "complete": fin, + }) + + resp = await requestor.process(request, responder) + + except Exception as e: + + await ws.send_json({"error": str(e)}) + + running.stop() + + async def listener(self, ws, running): + + async for msg in ws: + + # On error, finish + if msg.type == WSMsgType.ERROR: + break + else: + + try: + + data = msg.json() + + if data["service"] not in self.services: + raise RuntimeError("Bad service") + + if "request" not in data: + raise RuntimeError("Bad message") + + if "id" not in data: + raise RuntimeError("Bad message") + + await self.q.put( + (data["id"], data["service"], data["request"]) + ) + + except Exception as e: + + await ws.send_json({"error": str(e)}) + continue + + running.stop() + diff --git a/trustgraph-flow/trustgraph/gateway/prompt.py b/trustgraph-flow/trustgraph/gateway/prompt.py index f09a0e0e..080d5618 100644 --- a/trustgraph-flow/trustgraph/gateway/prompt.py +++ b/trustgraph-flow/trustgraph/gateway/prompt.py @@ -6,19 +6,18 @@ from .. schema import prompt_request_queue from .. schema import prompt_response_queue from . endpoint import ServiceEndpoint +from . requestor import ServiceRequestor -class PromptEndpoint(ServiceEndpoint): +class PromptRequestor(ServiceRequestor): def __init__(self, pulsar_host, timeout, auth): - super(PromptEndpoint, self).__init__( + super(PromptRequestor, self).__init__( pulsar_host=pulsar_host, request_queue=prompt_request_queue, response_queue=prompt_response_queue, request_schema=PromptRequest, response_schema=PromptResponse, - endpoint_path="/api/v1/prompt", timeout=timeout, - auth=auth, ) def to_request(self, body): @@ -34,9 +33,9 @@ class PromptEndpoint(ServiceEndpoint): if message.object: return { "object": message.object - } + }, True else: return { "text": message.text - } + }, True diff --git a/trustgraph-flow/trustgraph/gateway/requestor.py b/trustgraph-flow/trustgraph/gateway/requestor.py new file mode 100644 index 00000000..5f6e2692 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/requestor.py @@ -0,0 +1,88 @@ + +import asyncio +from pulsar.schema import JsonSchema +import uuid +import logging + +from . publisher import Publisher +from . subscriber import Subscriber + +logger = logging.getLogger("requestor") +logger.setLevel(logging.INFO) + +class ServiceRequestor: + + def __init__( + self, + pulsar_host, + request_queue, request_schema, + response_queue, response_schema, + subscription="api-gateway", consumer_name="api-gateway", + timeout=600, + ): + + self.pub = Publisher( + pulsar_host, request_queue, + schema=JsonSchema(request_schema) + ) + + self.sub = Subscriber( + pulsar_host, response_queue, + subscription, consumer_name, + JsonSchema(response_schema) + ) + + self.timeout = timeout + + async def start(self): + + self.pub.start() + self.sub.start() + + def to_request(self, request): + raise RuntimeError("Not defined") + + def from_response(self, response): + raise RuntimeError("Not defined") + + async def process(self, request, responder=None): + + id = str(uuid.uuid4()) + + try: + + q = self.sub.subscribe(id) + + await asyncio.to_thread( + self.pub.send, id, self.to_request(request) + ) + + while True: + + try: + resp = await asyncio.to_thread(q.get, timeout=self.timeout) + except Exception as e: + raise RuntimeError("Timeout") + + if resp.error: + return { "error": resp.error.message } + + resp, fin = self.from_response(resp) + + print(resp, fin) + + if responder: + await responder(resp, fin) + + if fin: + return resp + + except Exception as e: + + logging.error(f"Exception: {e}") + + return { "error": str(e) } + + finally: + self.sub.unsubscribe(id) + diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index e927ecf6..6a8a62eb 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -31,20 +31,22 @@ from . serialize import to_subgraph from . running import Running from . publisher import Publisher from . subscriber import Subscriber -from . endpoint import ServiceEndpoint, MultiResponseServiceEndpoint -from . text_completion import TextCompletionEndpoint -from . prompt import PromptEndpoint -from . graph_rag import GraphRagEndpoint -from . triples_query import TriplesQueryEndpoint -from . embeddings import EmbeddingsEndpoint -from . encyclopedia import EncyclopediaEndpoint -from . agent import AgentEndpoint -from . dbpedia import DbpediaEndpoint -from . internet_search import InternetSearchEndpoint +from . text_completion import TextCompletionRequestor +from . prompt import PromptRequestor +from . graph_rag import GraphRagRequestor +from . triples_query import TriplesQueryRequestor +from . embeddings import EmbeddingsRequestor +from . encyclopedia import EncyclopediaRequestor +from . agent import AgentRequestor +from . dbpedia import DbpediaRequestor +from . internet_search import InternetSearchRequestor from . triples_stream import TriplesStreamEndpoint from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint from . triples_load import TriplesLoadEndpoint from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint +from . mux import MuxEndpoint + +from . endpoint import ServiceEndpoint from . auth import Authenticator logger = logging.getLogger("api") @@ -76,42 +78,81 @@ class Api: else: self.auth = Authenticator(allow_all=True) + self.services = { + "text-completion": TextCompletionRequestor( + pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, + ), + "prompt": PromptRequestor( + pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, + ), + "graph-rag": GraphRagRequestor( + pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, + ), + "triples-query": TriplesQueryRequestor( + pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, + ), + "embeddings": EmbeddingsRequestor( + pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, + ), + "agent": AgentRequestor( + pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, + ), + "encyclopedia": EncyclopediaRequestor( + pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, + ), + "dbpedia": DbpediaRequestor( + pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, + ), + "internet-search": InternetSearchRequestor( + pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, + ), + } + self.endpoints = [ - TextCompletionEndpoint( - pulsar_host=self.pulsar_host, timeout=self.timeout, - auth = self.auth, + ServiceEndpoint( + endpoint_path = "/api/v1/text-completion", auth=self.auth, + requestor = self.services["text-completion"], ), - PromptEndpoint( - pulsar_host=self.pulsar_host, timeout=self.timeout, - auth = self.auth, + ServiceEndpoint( + endpoint_path = "/api/v1/prompt", auth=self.auth, + requestor = self.services["prompt"], ), - GraphRagEndpoint( - pulsar_host=self.pulsar_host, timeout=self.timeout, - auth = self.auth, + ServiceEndpoint( + endpoint_path = "/api/v1/graph-rag", auth=self.auth, + requestor = self.services["graph-rag"], ), - TriplesQueryEndpoint( - pulsar_host=self.pulsar_host, timeout=self.timeout, - auth = self.auth, + ServiceEndpoint( + endpoint_path = "/api/v1/triples-query", auth=self.auth, + requestor = self.services["triples-query"], ), - EmbeddingsEndpoint( - pulsar_host=self.pulsar_host, timeout=self.timeout, - auth = self.auth, + ServiceEndpoint( + endpoint_path = "/api/v1/embeddings", auth=self.auth, + requestor = self.services["embeddings"], ), - AgentEndpoint( - pulsar_host=self.pulsar_host, timeout=self.timeout, - auth = self.auth, + ServiceEndpoint( + endpoint_path = "/api/v1/agent", auth=self.auth, + requestor = self.services["agent"], ), - EncyclopediaEndpoint( - pulsar_host=self.pulsar_host, timeout=self.timeout, - auth = self.auth, + ServiceEndpoint( + endpoint_path = "/api/v1/encyclopedia", auth=self.auth, + requestor = self.services["encyclopedia"], ), - DbpediaEndpoint( - pulsar_host=self.pulsar_host, timeout=self.timeout, - auth = self.auth, + ServiceEndpoint( + endpoint_path = "/api/v1/dbpedia", auth=self.auth, + requestor = self.services["dbpedia"], ), - InternetSearchEndpoint( - pulsar_host=self.pulsar_host, timeout=self.timeout, - auth = self.auth, + ServiceEndpoint( + endpoint_path = "/api/v1/internet-search", auth=self.auth, + requestor = self.services["internet-search"], ), TriplesStreamEndpoint( pulsar_host=self.pulsar_host, @@ -129,6 +170,11 @@ class Api: pulsar_host=self.pulsar_host, auth = self.auth, ), + MuxEndpoint( + pulsar_host=self.pulsar_host, + auth = self.auth, + services = self.services, + ), ] self.document_out = Publisher( @@ -162,7 +208,7 @@ class Api: else: metadata = [] - # Doing a base64 decode/encode here to make sure the + # Doing a base64 decoe/encode here to make sure the # content is valid base64 doc = base64.b64decode(data["data"]) diff --git a/trustgraph-flow/trustgraph/gateway/text_completion.py b/trustgraph-flow/trustgraph/gateway/text_completion.py index d59737f0..7291fc88 100644 --- a/trustgraph-flow/trustgraph/gateway/text_completion.py +++ b/trustgraph-flow/trustgraph/gateway/text_completion.py @@ -4,19 +4,18 @@ from .. schema import text_completion_request_queue from .. schema import text_completion_response_queue from . endpoint import ServiceEndpoint +from . requestor import ServiceRequestor -class TextCompletionEndpoint(ServiceEndpoint): +class TextCompletionRequestor(ServiceRequestor): def __init__(self, pulsar_host, timeout, auth): - super(TextCompletionEndpoint, self).__init__( + super(TextCompletionRequestor, self).__init__( pulsar_host=pulsar_host, request_queue=text_completion_request_queue, response_queue=text_completion_response_queue, request_schema=TextCompletionRequest, response_schema=TextCompletionResponse, - endpoint_path="/api/v1/text-completion", timeout=timeout, - auth=auth, ) def to_request(self, body): @@ -26,4 +25,5 @@ class TextCompletionEndpoint(ServiceEndpoint): ) def from_response(self, message): - return { "response": message.response } + return { "response": message.response }, True + diff --git a/trustgraph-flow/trustgraph/gateway/triples_query.py b/trustgraph-flow/trustgraph/gateway/triples_query.py index 5a0cfff8..0ea7cd8d 100644 --- a/trustgraph-flow/trustgraph/gateway/triples_query.py +++ b/trustgraph-flow/trustgraph/gateway/triples_query.py @@ -4,20 +4,19 @@ from .. schema import triples_request_queue from .. schema import triples_response_queue from . endpoint import ServiceEndpoint +from . requestor import ServiceRequestor from . serialize import to_value, serialize_subgraph -class TriplesQueryEndpoint(ServiceEndpoint): +class TriplesQueryRequestor(ServiceRequestor): def __init__(self, pulsar_host, timeout, auth): - super(TriplesQueryEndpoint, self).__init__( + super(TriplesQueryRequestor, self).__init__( pulsar_host=pulsar_host, request_queue=triples_request_queue, response_queue=triples_response_queue, request_schema=TriplesQueryRequest, response_schema=TriplesQueryResponse, - endpoint_path="/api/v1/triples-query", timeout=timeout, - auth=auth, ) def to_request(self, body): @@ -50,5 +49,5 @@ class TriplesQueryEndpoint(ServiceEndpoint): print(message) return { "response": serialize_subgraph(message.triples) - } + }, True