mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-09 07:12:37 +02:00
Feature/gateway auth (#186)
* Added auth module, just a simple token at this stage * Pass auth token GATEWAY_SECRET through * Auth token not mandatory, can be provided in env var
This commit is contained in:
parent
6d200c79c5
commit
1b9c6be4fc
18 changed files with 126 additions and 33 deletions
|
|
@ -6,7 +6,7 @@ from ... schema import agent_response_queue
|
|||
from . endpoint import MultiResponseServiceEndpoint
|
||||
|
||||
class AgentEndpoint(MultiResponseServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(AgentEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
|
|
@ -16,6 +16,7 @@ class AgentEndpoint(MultiResponseServiceEndpoint):
|
|||
response_schema=AgentResponse,
|
||||
endpoint_path="/api/v1/agent",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
|
|
|||
22
trustgraph-flow/trustgraph/api/gateway/auth.py
Normal file
22
trustgraph-flow/trustgraph/api/gateway/auth.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
class Authenticator:
|
||||
|
||||
def __init__(self, token=None, allow_all=False):
|
||||
|
||||
if not allow_all and token is None:
|
||||
raise RuntimeError("Need a token")
|
||||
|
||||
if not allow_all and token == "":
|
||||
raise RuntimeError("Need a token")
|
||||
|
||||
self.token = token
|
||||
self.allow_all = allow_all
|
||||
|
||||
def permitted(self, token, roles):
|
||||
|
||||
if self.allow_all: return True
|
||||
|
||||
if self.token != token: return False
|
||||
|
||||
return True
|
||||
|
||||
|
|
@ -6,7 +6,7 @@ from ... schema import dbpedia_lookup_response_queue
|
|||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class DbpediaEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(DbpediaEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
|
|
@ -16,6 +16,7 @@ class DbpediaEndpoint(ServiceEndpoint):
|
|||
response_schema=LookupResponse,
|
||||
endpoint_path="/api/v1/dbpedia",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from ... schema import embeddings_response_queue
|
|||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class EmbeddingsEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(EmbeddingsEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
|
|
@ -16,6 +16,7 @@ class EmbeddingsEndpoint(ServiceEndpoint):
|
|||
response_schema=EmbeddingsResponse,
|
||||
endpoint_path="/api/v1/embeddings",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from ... schema import encyclopedia_lookup_response_queue
|
|||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class EncyclopediaEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(EncyclopediaEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
|
|
@ -16,6 +16,7 @@ class EncyclopediaEndpoint(ServiceEndpoint):
|
|||
response_schema=LookupResponse,
|
||||
endpoint_path="/api/v1/encyclopedia",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class ServiceEndpoint:
|
|||
request_queue, request_schema,
|
||||
response_queue, response_schema,
|
||||
endpoint_path,
|
||||
auth,
|
||||
subscription="api-gateway", consumer_name="api-gateway",
|
||||
timeout=600,
|
||||
):
|
||||
|
|
@ -36,6 +37,9 @@ class ServiceEndpoint:
|
|||
|
||||
self.path = endpoint_path
|
||||
self.timeout = timeout
|
||||
self.auth = auth
|
||||
|
||||
self.operation = "service"
|
||||
|
||||
async def start(self):
|
||||
|
||||
|
|
@ -58,14 +62,24 @@ class ServiceEndpoint:
|
|||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
ht = request.headers["Authorization"]
|
||||
tokens = ht.split(" ", 2)
|
||||
if tokens[0] != "Bearer":
|
||||
return web.HTTPUnauthorized()
|
||||
token = tokens[1]
|
||||
except:
|
||||
token = ""
|
||||
|
||||
if not self.auth.permitted(token, self.operation):
|
||||
return web.HTTPUnauthorized()
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
q = await self.sub.subscribe(id)
|
||||
|
||||
print(data)
|
||||
|
||||
await self.pub.send(
|
||||
id,
|
||||
self.to_request(data),
|
||||
|
|
@ -76,8 +90,6 @@ class ServiceEndpoint:
|
|||
except:
|
||||
raise RuntimeError("Timeout waiting for response")
|
||||
|
||||
print(resp)
|
||||
|
||||
if resp.error:
|
||||
return web.json_response(
|
||||
{ "error": resp.error.message }
|
||||
|
|
@ -110,8 +122,6 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
|
|||
|
||||
q = await self.sub.subscribe(id)
|
||||
|
||||
print(data)
|
||||
|
||||
await self.pub.send(
|
||||
id,
|
||||
self.to_request(data),
|
||||
|
|
@ -126,8 +136,6 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
|
|||
except:
|
||||
raise RuntimeError("Timeout waiting for response")
|
||||
|
||||
print(resp)
|
||||
|
||||
if resp.error:
|
||||
return web.json_response(
|
||||
{ "error": resp.error.message }
|
||||
|
|
|
|||
|
|
@ -14,10 +14,12 @@ from . serialize import to_subgraph, to_value
|
|||
|
||||
class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(self, pulsar_host, path="/api/v1/load/graph-embeddings"):
|
||||
def __init__(
|
||||
self, pulsar_host, auth, path="/api/v1/load/graph-embeddings",
|
||||
):
|
||||
|
||||
super(GraphEmbeddingsLoadEndpoint, self).__init__(
|
||||
endpoint_path=path
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_host=pulsar_host
|
||||
|
|
|
|||
|
|
@ -12,10 +12,12 @@ from . serialize import serialize_graph_embeddings
|
|||
|
||||
class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(self, pulsar_host, path="/api/v1/stream/graph-embeddings"):
|
||||
def __init__(
|
||||
self, pulsar_host, auth, path="/api/v1/stream/graph-embeddings"
|
||||
):
|
||||
|
||||
super(GraphEmbeddingsStreamEndpoint, self).__init__(
|
||||
endpoint_path=path
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_host=pulsar_host
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from ... schema import graph_rag_response_queue
|
|||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class GraphRagEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(GraphRagEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
|
|
@ -16,6 +16,7 @@ class GraphRagEndpoint(ServiceEndpoint):
|
|||
response_schema=GraphRagResponse,
|
||||
endpoint_path="/api/v1/graph-rag",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from ... schema import internet_search_response_queue
|
|||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class InternetSearchEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(InternetSearchEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
|
|
@ -16,6 +16,7 @@ class InternetSearchEndpoint(ServiceEndpoint):
|
|||
response_schema=LookupResponse,
|
||||
endpoint_path="/api/v1/internet-search",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from ... schema import prompt_response_queue
|
|||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class PromptEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(PromptEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
|
|
@ -18,6 +18,7 @@ class PromptEndpoint(ServiceEndpoint):
|
|||
response_schema=PromptResponse,
|
||||
endpoint_path="/api/v1/prompt",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ from . triples_stream import TriplesStreamEndpoint
|
|||
from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
|
||||
from . triples_load import TriplesLoadEndpoint
|
||||
from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
|
||||
from . auth import Authenticator
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -52,6 +53,7 @@ logger.setLevel(logging.INFO)
|
|||
default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
|
||||
default_timeout = 600
|
||||
default_port = 8088
|
||||
default_api_token = os.getenv("GATEWAY_SECRET", "")
|
||||
|
||||
class Api:
|
||||
|
||||
|
|
@ -66,45 +68,66 @@ class Api:
|
|||
self.timeout = int(config.get("timeout", default_timeout))
|
||||
self.pulsar_host = config.get("pulsar_host", default_pulsar_host)
|
||||
|
||||
api_token = config.get("api_token", default_api_token)
|
||||
|
||||
# Token not set, or token equal empty string means no auth
|
||||
if api_token:
|
||||
self.auth = Authenticator(token=api_token)
|
||||
else:
|
||||
self.auth = Authenticator(allow_all=True)
|
||||
|
||||
self.endpoints = [
|
||||
TextCompletionEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
PromptEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
GraphRagEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
TriplesQueryEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
EmbeddingsEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
AgentEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
EncyclopediaEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
DbpediaEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
InternetSearchEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
TriplesStreamEndpoint(
|
||||
pulsar_host=self.pulsar_host
|
||||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
),
|
||||
GraphEmbeddingsStreamEndpoint(
|
||||
pulsar_host=self.pulsar_host
|
||||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
),
|
||||
TriplesLoadEndpoint(
|
||||
pulsar_host=self.pulsar_host
|
||||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
),
|
||||
GraphEmbeddingsLoadEndpoint(
|
||||
pulsar_host=self.pulsar_host
|
||||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -254,6 +277,12 @@ def run():
|
|||
help=f'API request timeout in seconds (default: {default_timeout})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--api-token',
|
||||
default=default_api_token,
|
||||
help=f'Secret API token (default: no auth)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--log-level',
|
||||
type=LogLevel,
|
||||
|
|
|
|||
|
|
@ -11,11 +11,12 @@ logger.setLevel(logging.INFO)
|
|||
class SocketEndpoint:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_path="/api/v1/socket",
|
||||
self, endpoint_path, auth,
|
||||
):
|
||||
|
||||
self.path = endpoint_path
|
||||
self.auth = auth
|
||||
self.operation = "socket"
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
|
|
@ -43,18 +44,33 @@ class SocketEndpoint:
|
|||
|
||||
async def handle(self, request):
|
||||
|
||||
try:
|
||||
token = request.query['token']
|
||||
except:
|
||||
token = ""
|
||||
|
||||
if not self.auth.permitted(token, self.operation):
|
||||
return web.HTTPUnauthorized()
|
||||
|
||||
running = Running()
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
task = asyncio.create_task(self.async_thread(ws, running))
|
||||
|
||||
await self.listener(ws, running)
|
||||
try:
|
||||
|
||||
await task
|
||||
await self.listener(ws, running)
|
||||
|
||||
except Exception as e:
|
||||
print(e, flush=True)
|
||||
|
||||
running.stop()
|
||||
|
||||
await ws.close()
|
||||
|
||||
await task
|
||||
|
||||
return ws
|
||||
|
||||
async def start(self):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from ... schema import text_completion_response_queue
|
|||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class TextCompletionEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(TextCompletionEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
|
|
@ -16,6 +16,7 @@ class TextCompletionEndpoint(ServiceEndpoint):
|
|||
response_schema=TextCompletionResponse,
|
||||
endpoint_path="/api/v1/text-completion",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
|
|
|||
|
|
@ -14,10 +14,10 @@ from . serialize import to_subgraph
|
|||
|
||||
class TriplesLoadEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(self, pulsar_host, path="/api/v1/load/triples"):
|
||||
def __init__(self, pulsar_host, auth, path="/api/v1/load/triples"):
|
||||
|
||||
super(TriplesLoadEndpoint, self).__init__(
|
||||
endpoint_path=path
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_host=pulsar_host
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from . endpoint import ServiceEndpoint
|
|||
from . serialize import to_value, serialize_subgraph
|
||||
|
||||
class TriplesQueryEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(TriplesQueryEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
|
|
@ -17,6 +17,7 @@ class TriplesQueryEndpoint(ServiceEndpoint):
|
|||
response_schema=TriplesQueryResponse,
|
||||
endpoint_path="/api/v1/triples-query",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ from . serialize import serialize_triples
|
|||
|
||||
class TriplesStreamEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(self, pulsar_host, path="/api/v1/stream/triples"):
|
||||
def __init__(self, pulsar_host, auth, path="/api/v1/stream/triples"):
|
||||
|
||||
super(TriplesStreamEndpoint, self).__init__(
|
||||
endpoint_path=path
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_host=pulsar_host
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue