Feature/gateway auth (#186)

* Added auth module, just a simple token at this stage
* Pass auth token GATEWAY_SECRET through
* Auth token not mandatory, can be provided in env var
This commit is contained in:
cybermaggedon 2024-12-02 19:57:21 +00:00 committed by GitHub
parent 6d200c79c5
commit 1b9c6be4fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 126 additions and 33 deletions

View file

@ -6,7 +6,7 @@ from ... schema import agent_response_queue
from . endpoint import MultiResponseServiceEndpoint
class AgentEndpoint(MultiResponseServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):
super(AgentEndpoint, self).__init__(
pulsar_host=pulsar_host,
@ -16,6 +16,7 @@ class AgentEndpoint(MultiResponseServiceEndpoint):
response_schema=AgentResponse,
endpoint_path="/api/v1/agent",
timeout=timeout,
auth=auth,
)
def to_request(self, body):

View file

@ -0,0 +1,22 @@
class Authenticator:
def __init__(self, token=None, allow_all=False):
if not allow_all and token is None:
raise RuntimeError("Need a token")
if not allow_all and token == "":
raise RuntimeError("Need a token")
self.token = token
self.allow_all = allow_all
def permitted(self, token, roles):
if self.allow_all: return True
if self.token != token: return False
return True

View file

@ -6,7 +6,7 @@ from ... schema import dbpedia_lookup_response_queue
from . endpoint import ServiceEndpoint
class DbpediaEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):
super(DbpediaEndpoint, self).__init__(
pulsar_host=pulsar_host,
@ -16,6 +16,7 @@ class DbpediaEndpoint(ServiceEndpoint):
response_schema=LookupResponse,
endpoint_path="/api/v1/dbpedia",
timeout=timeout,
auth=auth,
)
def to_request(self, body):

View file

@ -6,7 +6,7 @@ from ... schema import embeddings_response_queue
from . endpoint import ServiceEndpoint
class EmbeddingsEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):
super(EmbeddingsEndpoint, self).__init__(
pulsar_host=pulsar_host,
@ -16,6 +16,7 @@ class EmbeddingsEndpoint(ServiceEndpoint):
response_schema=EmbeddingsResponse,
endpoint_path="/api/v1/embeddings",
timeout=timeout,
auth=auth,
)
def to_request(self, body):

View file

@ -6,7 +6,7 @@ from ... schema import encyclopedia_lookup_response_queue
from . endpoint import ServiceEndpoint
class EncyclopediaEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):
super(EncyclopediaEndpoint, self).__init__(
pulsar_host=pulsar_host,
@ -16,6 +16,7 @@ class EncyclopediaEndpoint(ServiceEndpoint):
response_schema=LookupResponse,
endpoint_path="/api/v1/encyclopedia",
timeout=timeout,
auth=auth,
)
def to_request(self, body):

View file

@ -19,6 +19,7 @@ class ServiceEndpoint:
request_queue, request_schema,
response_queue, response_schema,
endpoint_path,
auth,
subscription="api-gateway", consumer_name="api-gateway",
timeout=600,
):
@ -36,6 +37,9 @@ class ServiceEndpoint:
self.path = endpoint_path
self.timeout = timeout
self.auth = auth
self.operation = "service"
async def start(self):
@ -58,14 +62,24 @@ class ServiceEndpoint:
id = str(uuid.uuid4())
try:
ht = request.headers["Authorization"]
tokens = ht.split(" ", 2)
if tokens[0] != "Bearer":
return web.HTTPUnauthorized()
token = tokens[1]
except:
token = ""
if not self.auth.permitted(token, self.operation):
return web.HTTPUnauthorized()
try:
data = await request.json()
q = await self.sub.subscribe(id)
print(data)
await self.pub.send(
id,
self.to_request(data),
@ -76,8 +90,6 @@ class ServiceEndpoint:
except:
raise RuntimeError("Timeout waiting for response")
print(resp)
if resp.error:
return web.json_response(
{ "error": resp.error.message }
@ -110,8 +122,6 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
q = await self.sub.subscribe(id)
print(data)
await self.pub.send(
id,
self.to_request(data),
@ -126,8 +136,6 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
except:
raise RuntimeError("Timeout waiting for response")
print(resp)
if resp.error:
return web.json_response(
{ "error": resp.error.message }

View file

@ -14,10 +14,12 @@ from . serialize import to_subgraph, to_value
class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, path="/api/v1/load/graph-embeddings"):
def __init__(
self, pulsar_host, auth, path="/api/v1/load/graph-embeddings",
):
super(GraphEmbeddingsLoadEndpoint, self).__init__(
endpoint_path=path
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host

View file

@ -12,10 +12,12 @@ from . serialize import serialize_graph_embeddings
class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, path="/api/v1/stream/graph-embeddings"):
def __init__(
self, pulsar_host, auth, path="/api/v1/stream/graph-embeddings"
):
super(GraphEmbeddingsStreamEndpoint, self).__init__(
endpoint_path=path
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host

View file

@ -6,7 +6,7 @@ from ... schema import graph_rag_response_queue
from . endpoint import ServiceEndpoint
class GraphRagEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):
super(GraphRagEndpoint, self).__init__(
pulsar_host=pulsar_host,
@ -16,6 +16,7 @@ class GraphRagEndpoint(ServiceEndpoint):
response_schema=GraphRagResponse,
endpoint_path="/api/v1/graph-rag",
timeout=timeout,
auth=auth,
)
def to_request(self, body):

View file

@ -6,7 +6,7 @@ from ... schema import internet_search_response_queue
from . endpoint import ServiceEndpoint
class InternetSearchEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):
super(InternetSearchEndpoint, self).__init__(
pulsar_host=pulsar_host,
@ -16,6 +16,7 @@ class InternetSearchEndpoint(ServiceEndpoint):
response_schema=LookupResponse,
endpoint_path="/api/v1/internet-search",
timeout=timeout,
auth=auth,
)
def to_request(self, body):

View file

@ -8,7 +8,7 @@ from ... schema import prompt_response_queue
from . endpoint import ServiceEndpoint
class PromptEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):
super(PromptEndpoint, self).__init__(
pulsar_host=pulsar_host,
@ -18,6 +18,7 @@ class PromptEndpoint(ServiceEndpoint):
response_schema=PromptResponse,
endpoint_path="/api/v1/prompt",
timeout=timeout,
auth=auth,
)
def to_request(self, body):

View file

@ -45,6 +45,7 @@ from . triples_stream import TriplesStreamEndpoint
from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
from . triples_load import TriplesLoadEndpoint
from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
from . auth import Authenticator
logger = logging.getLogger("api")
logger.setLevel(logging.INFO)
@ -52,6 +53,7 @@ logger.setLevel(logging.INFO)
default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
default_timeout = 600
default_port = 8088
default_api_token = os.getenv("GATEWAY_SECRET", "")
class Api:
@ -66,45 +68,66 @@ class Api:
self.timeout = int(config.get("timeout", default_timeout))
self.pulsar_host = config.get("pulsar_host", default_pulsar_host)
api_token = config.get("api_token", default_api_token)
# Token not set, or token equal empty string means no auth
if api_token:
self.auth = Authenticator(token=api_token)
else:
self.auth = Authenticator(allow_all=True)
self.endpoints = [
TextCompletionEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
PromptEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
GraphRagEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
TriplesQueryEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
EmbeddingsEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
AgentEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
EncyclopediaEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
DbpediaEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
InternetSearchEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
TriplesStreamEndpoint(
pulsar_host=self.pulsar_host
pulsar_host=self.pulsar_host,
auth = self.auth,
),
GraphEmbeddingsStreamEndpoint(
pulsar_host=self.pulsar_host
pulsar_host=self.pulsar_host,
auth = self.auth,
),
TriplesLoadEndpoint(
pulsar_host=self.pulsar_host
pulsar_host=self.pulsar_host,
auth = self.auth,
),
GraphEmbeddingsLoadEndpoint(
pulsar_host=self.pulsar_host
pulsar_host=self.pulsar_host,
auth = self.auth,
),
]
@ -254,6 +277,12 @@ def run():
help=f'API request timeout in seconds (default: {default_timeout})',
)
parser.add_argument(
'--api-token',
default=default_api_token,
help=f'Secret API token (default: no auth)',
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,

View file

@ -11,11 +11,12 @@ logger.setLevel(logging.INFO)
class SocketEndpoint:
def __init__(
self,
endpoint_path="/api/v1/socket",
self, endpoint_path, auth,
):
self.path = endpoint_path
self.auth = auth
self.operation = "socket"
async def listener(self, ws, running):
@ -43,18 +44,33 @@ class SocketEndpoint:
async def handle(self, request):
try:
token = request.query['token']
except:
token = ""
if not self.auth.permitted(token, self.operation):
return web.HTTPUnauthorized()
running = Running()
ws = web.WebSocketResponse()
await ws.prepare(request)
task = asyncio.create_task(self.async_thread(ws, running))
await self.listener(ws, running)
try:
await task
await self.listener(ws, running)
except Exception as e:
print(e, flush=True)
running.stop()
await ws.close()
await task
return ws
async def start(self):

View file

@ -6,7 +6,7 @@ from ... schema import text_completion_response_queue
from . endpoint import ServiceEndpoint
class TextCompletionEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):
super(TextCompletionEndpoint, self).__init__(
pulsar_host=pulsar_host,
@ -16,6 +16,7 @@ class TextCompletionEndpoint(ServiceEndpoint):
response_schema=TextCompletionResponse,
endpoint_path="/api/v1/text-completion",
timeout=timeout,
auth=auth,
)
def to_request(self, body):

View file

@ -14,10 +14,10 @@ from . serialize import to_subgraph
class TriplesLoadEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, path="/api/v1/load/triples"):
def __init__(self, pulsar_host, auth, path="/api/v1/load/triples"):
super(TriplesLoadEndpoint, self).__init__(
endpoint_path=path
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host

View file

@ -7,7 +7,7 @@ from . endpoint import ServiceEndpoint
from . serialize import to_value, serialize_subgraph
class TriplesQueryEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):
super(TriplesQueryEndpoint, self).__init__(
pulsar_host=pulsar_host,
@ -17,6 +17,7 @@ class TriplesQueryEndpoint(ServiceEndpoint):
response_schema=TriplesQueryResponse,
endpoint_path="/api/v1/triples-query",
timeout=timeout,
auth=auth,
)
def to_request(self, body):

View file

@ -12,10 +12,10 @@ from . serialize import serialize_triples
class TriplesStreamEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, path="/api/v1/stream/triples"):
def __init__(self, pulsar_host, auth, path="/api/v1/stream/triples"):
super(TriplesStreamEndpoint, self).__init__(
endpoint_path=path
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host