Feature/general websocket (#199)

* Split API endpoint into endpoint and requestor
* Service/endpoint separation
* Call out to multiple services working
* Add ID field
* Add mux service on websocket, calls out to all services
This commit is contained in:
cybermaggedon 2024-12-06 23:56:10 +00:00 committed by GitHub
parent fd3db3c925
commit 656dcb22a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 330 additions and 193 deletions

View file

@ -3,20 +3,19 @@ from .. schema import AgentRequest, AgentResponse
from .. schema import agent_request_queue
from .. schema import agent_response_queue
from . endpoint import MultiResponseServiceEndpoint
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class AgentEndpoint(MultiResponseServiceEndpoint):
class AgentRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(AgentEndpoint, self).__init__(
super(AgentRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=agent_request_queue,
response_queue=agent_response_queue,
request_schema=AgentRequest,
response_schema=AgentResponse,
endpoint_path="/api/v1/agent",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
@ -25,7 +24,19 @@ class AgentEndpoint(MultiResponseServiceEndpoint):
)
def from_response(self, message):
resp = {
}
if message.answer:
return { "answer": message.answer }, True
else:
return {}, False
resp["answer"] = message.answer
if message.thought:
resp["thought"] = message.thought
if message.observation:
resp["observation"] = message.observation
# The 2nd boolean expression indicates whether we're done responding
return resp, (message.answer is not None)

View file

@ -4,19 +4,18 @@ from .. schema import dbpedia_lookup_request_queue
from .. schema import dbpedia_lookup_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class DbpediaEndpoint(ServiceEndpoint):
class DbpediaRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(DbpediaEndpoint, self).__init__(
super(DbpediaRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=dbpedia_lookup_request_queue,
response_queue=dbpedia_lookup_response_queue,
request_schema=LookupRequest,
response_schema=LookupResponse,
endpoint_path="/api/v1/dbpedia",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
@ -26,5 +25,5 @@ class DbpediaEndpoint(ServiceEndpoint):
)
def from_response(self, message):
return { "text": message.text }
return { "text": message.text }, True

View file

@ -4,19 +4,18 @@ from .. schema import embeddings_request_queue
from .. schema import embeddings_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class EmbeddingsEndpoint(ServiceEndpoint):
class EmbeddingsRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(EmbeddingsEndpoint, self).__init__(
super(EmbeddingsRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=embeddings_request_queue,
response_queue=embeddings_response_queue,
request_schema=EmbeddingsRequest,
response_schema=EmbeddingsResponse,
endpoint_path="/api/v1/embeddings",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
@ -25,4 +24,6 @@ class EmbeddingsEndpoint(ServiceEndpoint):
)
def from_response(self, message):
return { "vectors": message.vectors }
return { "vectors": message.vectors }, True

View file

@ -4,19 +4,18 @@ from .. schema import encyclopedia_lookup_request_queue
from .. schema import encyclopedia_lookup_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class EncyclopediaEndpoint(ServiceEndpoint):
class EncyclopediaRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(EncyclopediaEndpoint, self).__init__(
super(EncyclopediaRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=encyclopedia_lookup_request_queue,
response_queue=encyclopedia_lookup_response_queue,
request_schema=LookupRequest,
response_schema=LookupResponse,
endpoint_path="/api/v1/encyclopedia",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
@ -26,5 +25,5 @@ class EncyclopediaEndpoint(ServiceEndpoint):
)
def from_response(self, message):
return { "text": message.text }
return { "text": message.text }, True

View file

@ -13,38 +13,17 @@ logger.setLevel(logging.INFO)
class ServiceEndpoint:
def __init__(
self,
pulsar_host,
request_queue, request_schema,
response_queue, response_schema,
endpoint_path,
auth,
subscription="api-gateway", consumer_name="api-gateway",
timeout=600,
):
self.pub = Publisher(
pulsar_host, request_queue,
schema=JsonSchema(request_schema)
)
self.sub = Subscriber(
pulsar_host, response_queue,
subscription, consumer_name,
JsonSchema(response_schema)
)
def __init__(self, endpoint_path, auth, requestor):
self.path = endpoint_path
self.timeout = timeout
self.auth = auth
self.auth = auth
self.operation = "service"
async def start(self):
self.requestor = requestor
self.pub.start()
self.sub.start()
async def start(self):
await self.requestor.start()
def add_routes(self, app):
@ -52,16 +31,8 @@ class ServiceEndpoint:
web.post(self.path, self.handle),
])
def to_request(self, request):
raise RuntimeError("Not defined")
def from_response(self, response):
raise RuntimeError("Not defined")
async def handle(self, request):
id = str(uuid.uuid4())
print(request.path, "...")
try:
@ -82,28 +53,12 @@ class ServiceEndpoint:
print(data)
q = self.sub.subscribe(id)
def responder(x, fin):
print(x)
await asyncio.to_thread(
self.pub.send, id, self.to_request(data)
)
resp, fin = await self.requestor.process(data, responder)
try:
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
except Exception as e:
raise RuntimeError("Timeout")
print(resp)
if resp.error:
print("Error")
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
self.from_response(resp)
)
return web.json_response(resp)
except Exception as e:
logging.error(f"Exception: {e}")
@ -112,55 +67,3 @@ class ServiceEndpoint:
{ "error": str(e) }
)
finally:
self.sub.unsubscribe(id)
class MultiResponseServiceEndpoint(ServiceEndpoint):
async def handle(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = self.sub.subscribe(id)
await asyncio.to_thread(
self.pub.send, id, self.to_request(data)
)
# Keeps looking at responses...
while True:
try:
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
except Exception as e:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
# Until from_response says we have a finished answer
resp, fin = self.from_response(resp)
if fin:
return web.json_response(resp)
# Not finished, so loop round and continue
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
self.sub.unsubscribe(id)

View file

@ -4,19 +4,18 @@ from .. schema import graph_rag_request_queue
from .. schema import graph_rag_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class GraphRagEndpoint(ServiceEndpoint):
class GraphRagRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(GraphRagEndpoint, self).__init__(
super(GraphRagRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=graph_rag_request_queue,
response_queue=graph_rag_response_queue,
request_schema=GraphRagQuery,
response_schema=GraphRagResponse,
endpoint_path="/api/v1/graph-rag",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
@ -27,5 +26,5 @@ class GraphRagEndpoint(ServiceEndpoint):
)
def from_response(self, message):
return { "response": message.response }
return { "response": message.response }, True

View file

@ -4,19 +4,18 @@ from .. schema import internet_search_request_queue
from .. schema import internet_search_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class InternetSearchEndpoint(ServiceEndpoint):
class InternetSearchRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(InternetSearchEndpoint, self).__init__(
super(InternetSearchRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=internet_search_request_queue,
response_queue=internet_search_response_queue,
request_schema=LookupRequest,
response_schema=LookupResponse,
endpoint_path="/api/v1/internet-search",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
@ -26,5 +25,5 @@ class InternetSearchEndpoint(ServiceEndpoint):
)
def from_response(self, message):
return { "text": message.text }
return { "text": message.text }, True

View file

@ -0,0 +1,94 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from aiohttp import web, WSMsgType
from . socket import SocketEndpoint
from . text_completion import TextCompletionRequestor
class MuxEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth,
services,
path="/api/v1/mux",
):
super(MuxEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.q = asyncio.Queue(maxsize=10)
self.services = services
async def start(self):
pass
async def async_thread(self, ws, running):
while running.get():
try:
id, svc, request = await asyncio.wait_for(self.q.get(), 1)
except TimeoutError:
continue
except Exception as e:
await ws.send_json({"id": id, "error": str(e)})
try:
print(svc, request)
requestor = self.services[svc]
async def responder(resp, fin):
await ws.send_json({
"id": id,
"response": resp,
"complete": fin,
})
resp = await requestor.process(request, responder)
except Exception as e:
await ws.send_json({"error": str(e)})
running.stop()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
try:
data = msg.json()
if data["service"] not in self.services:
raise RuntimeError("Bad service")
if "request" not in data:
raise RuntimeError("Bad message")
if "id" not in data:
raise RuntimeError("Bad message")
await self.q.put(
(data["id"], data["service"], data["request"])
)
except Exception as e:
await ws.send_json({"error": str(e)})
continue
running.stop()

View file

@ -6,19 +6,18 @@ from .. schema import prompt_request_queue
from .. schema import prompt_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class PromptEndpoint(ServiceEndpoint):
class PromptRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(PromptEndpoint, self).__init__(
super(PromptRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=prompt_request_queue,
response_queue=prompt_response_queue,
request_schema=PromptRequest,
response_schema=PromptResponse,
endpoint_path="/api/v1/prompt",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
@ -34,9 +33,9 @@ class PromptEndpoint(ServiceEndpoint):
if message.object:
return {
"object": message.object
}
}, True
else:
return {
"text": message.text
}
}, True

View file

@ -0,0 +1,88 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
import logging
from . publisher import Publisher
from . subscriber import Subscriber
logger = logging.getLogger("requestor")
logger.setLevel(logging.INFO)
class ServiceRequestor:
def __init__(
self,
pulsar_host,
request_queue, request_schema,
response_queue, response_schema,
subscription="api-gateway", consumer_name="api-gateway",
timeout=600,
):
self.pub = Publisher(
pulsar_host, request_queue,
schema=JsonSchema(request_schema)
)
self.sub = Subscriber(
pulsar_host, response_queue,
subscription, consumer_name,
JsonSchema(response_schema)
)
self.timeout = timeout
async def start(self):
self.pub.start()
self.sub.start()
def to_request(self, request):
raise RuntimeError("Not defined")
def from_response(self, response):
raise RuntimeError("Not defined")
async def process(self, request, responder=None):
id = str(uuid.uuid4())
try:
q = self.sub.subscribe(id)
await asyncio.to_thread(
self.pub.send, id, self.to_request(request)
)
while True:
try:
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
except Exception as e:
raise RuntimeError("Timeout")
if resp.error:
return { "error": resp.error.message }
resp, fin = self.from_response(resp)
print(resp, fin)
if responder:
await responder(resp, fin)
if fin:
return resp
except Exception as e:
logging.error(f"Exception: {e}")
return { "error": str(e) }
finally:
self.sub.unsubscribe(id)

View file

@ -31,20 +31,22 @@ from . serialize import to_subgraph
from . running import Running
from . publisher import Publisher
from . subscriber import Subscriber
from . endpoint import ServiceEndpoint, MultiResponseServiceEndpoint
from . text_completion import TextCompletionEndpoint
from . prompt import PromptEndpoint
from . graph_rag import GraphRagEndpoint
from . triples_query import TriplesQueryEndpoint
from . embeddings import EmbeddingsEndpoint
from . encyclopedia import EncyclopediaEndpoint
from . agent import AgentEndpoint
from . dbpedia import DbpediaEndpoint
from . internet_search import InternetSearchEndpoint
from . text_completion import TextCompletionRequestor
from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor
from . triples_query import TriplesQueryRequestor
from . embeddings import EmbeddingsRequestor
from . encyclopedia import EncyclopediaRequestor
from . agent import AgentRequestor
from . dbpedia import DbpediaRequestor
from . internet_search import InternetSearchRequestor
from . triples_stream import TriplesStreamEndpoint
from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
from . triples_load import TriplesLoadEndpoint
from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
from . mux import MuxEndpoint
from . endpoint import ServiceEndpoint
from . auth import Authenticator
logger = logging.getLogger("api")
@ -76,42 +78,81 @@ class Api:
else:
self.auth = Authenticator(allow_all=True)
self.services = {
"text-completion": TextCompletionRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"prompt": PromptRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"graph-rag": GraphRagRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"triples-query": TriplesQueryRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"embeddings": EmbeddingsRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"agent": AgentRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"encyclopedia": EncyclopediaRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"dbpedia": DbpediaRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"internet-search": InternetSearchRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
}
self.endpoints = [
TextCompletionEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
ServiceEndpoint(
endpoint_path = "/api/v1/text-completion", auth=self.auth,
requestor = self.services["text-completion"],
),
PromptEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
ServiceEndpoint(
endpoint_path = "/api/v1/prompt", auth=self.auth,
requestor = self.services["prompt"],
),
GraphRagEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
ServiceEndpoint(
endpoint_path = "/api/v1/graph-rag", auth=self.auth,
requestor = self.services["graph-rag"],
),
TriplesQueryEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
ServiceEndpoint(
endpoint_path = "/api/v1/triples-query", auth=self.auth,
requestor = self.services["triples-query"],
),
EmbeddingsEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
ServiceEndpoint(
endpoint_path = "/api/v1/embeddings", auth=self.auth,
requestor = self.services["embeddings"],
),
AgentEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
ServiceEndpoint(
endpoint_path = "/api/v1/agent", auth=self.auth,
requestor = self.services["agent"],
),
EncyclopediaEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
ServiceEndpoint(
endpoint_path = "/api/v1/encyclopedia", auth=self.auth,
requestor = self.services["encyclopedia"],
),
DbpediaEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
ServiceEndpoint(
endpoint_path = "/api/v1/dbpedia", auth=self.auth,
requestor = self.services["dbpedia"],
),
InternetSearchEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
ServiceEndpoint(
endpoint_path = "/api/v1/internet-search", auth=self.auth,
requestor = self.services["internet-search"],
),
TriplesStreamEndpoint(
pulsar_host=self.pulsar_host,
@ -129,6 +170,11 @@ class Api:
pulsar_host=self.pulsar_host,
auth = self.auth,
),
MuxEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
services = self.services,
),
]
self.document_out = Publisher(
@ -162,7 +208,7 @@ class Api:
else:
metadata = []
# Doing a base64 decode/encode here to make sure the
# Doing a base64 decoe/encode here to make sure the
# content is valid base64
doc = base64.b64decode(data["data"])

View file

@ -4,19 +4,18 @@ from .. schema import text_completion_request_queue
from .. schema import text_completion_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class TextCompletionEndpoint(ServiceEndpoint):
class TextCompletionRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(TextCompletionEndpoint, self).__init__(
super(TextCompletionRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=text_completion_request_queue,
response_queue=text_completion_response_queue,
request_schema=TextCompletionRequest,
response_schema=TextCompletionResponse,
endpoint_path="/api/v1/text-completion",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
@ -26,4 +25,5 @@ class TextCompletionEndpoint(ServiceEndpoint):
)
def from_response(self, message):
return { "response": message.response }
return { "response": message.response }, True

View file

@ -4,20 +4,19 @@ from .. schema import triples_request_queue
from .. schema import triples_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
from . serialize import to_value, serialize_subgraph
class TriplesQueryEndpoint(ServiceEndpoint):
class TriplesQueryRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(TriplesQueryEndpoint, self).__init__(
super(TriplesQueryRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=triples_request_queue,
response_queue=triples_response_queue,
request_schema=TriplesQueryRequest,
response_schema=TriplesQueryResponse,
endpoint_path="/api/v1/triples-query",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
@ -50,5 +49,5 @@ class TriplesQueryEndpoint(ServiceEndpoint):
print(message)
return {
"response": serialize_subgraph(message.triples)
}
}, True