Feature/pulsar api key support (#308)

* Add pulsar API token check

* Added missing api_key references

---------

Co-authored-by: Tyler O <4535788+toliver38@users.noreply.github.com>
This commit is contained in:
cybermaggedon 2025-02-15 11:22:48 +00:00 committed by GitHub
parent f7df2df266
commit 617eb7efd5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 173 additions and 21 deletions

View file

@ -156,14 +156,16 @@ class Processor(ConsumerProducer):
subscriber=subscriber,
input_queue=prompt_request_queue,
output_queue=prompt_response_queue,
pulsar_host = self.pulsar_host
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
)
self.graph_rag = GraphRagClient(
subscriber=subscriber,
input_queue=graph_rag_request_queue,
output_queue=graph_rag_response_queue,
pulsar_host = self.pulsar_host
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
)
# Need to be able to feed requests to myself

View file

@ -59,6 +59,7 @@ class DocumentRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key=None,
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
@ -100,6 +101,7 @@ class DocumentRag:
subscriber=module + "-de",
input_queue=de_request_queue,
output_queue=de_response_queue,
pulsar_api_key=pulsar_api_key,
)
self.embeddings = EmbeddingsClient(
@ -107,6 +109,7 @@ class DocumentRag:
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
pulsar_api_key=pulsar_api_key,
)
self.lang = PromptClient(
@ -114,6 +117,7 @@ class DocumentRag:
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber=module + "-de-prompt",
pulsar_api_key=pulsar_api_key,
)
if self.verbose:

View file

@ -47,6 +47,7 @@ class Processor(ConsumerProducer):
self.embeddings = EmbeddingsClient(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",

View file

@ -79,6 +79,7 @@ class Processor(ConsumerProducer):
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",

View file

@ -54,6 +54,7 @@ class Processor(ConsumerProducer):
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",

View file

@ -52,6 +52,7 @@ class Processor(ConsumerProducer):
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",

View file

@ -112,6 +112,7 @@ class Processor(ConsumerProducer):
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",

View file

@ -7,10 +7,18 @@ from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class AgentRequestor(ServiceRequestor):
<<<<<<< HEAD
def __init__(self, pulsar_client, timeout, auth):
super(AgentRequestor, self).__init__(
pulsar_client=pulsar_client,
=======
def __init__(self, pulsar_host, timeout, auth, pulsar_api_key=None):
super(AgentRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
>>>>>>> a5d5b4c (Add pulsar API token check)
request_queue=agent_request_queue,
response_queue=agent_response_queue,
request_schema=AgentRequest,

View file

@ -27,7 +27,8 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber(
self.pulsar_client, document_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(DocumentEmbeddings)
schema=JsonSchema(DocumentEmbeddings),
pulsar_api_key=self.pulsar_api_key
)
async def listener(self, ws, running):

View file

@ -26,6 +26,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber(
self.pulsar_client, graph_embeddings_store_queue,
"api-gateway", "api-gateway",
pulsar_api_key=self.pulsar_api_key,
schema=JsonSchema(GraphEmbeddings)
)

View file

@ -21,6 +21,7 @@ class MuxEndpoint(SocketEndpoint):
self, pulsar_client, auth,
services,
path="/api/v1/socket",
pulsar_api_key=None
):
super(MuxEndpoint, self).__init__(

View file

@ -19,6 +19,7 @@ class ServiceRequestor:
response_queue, response_schema,
subscription="api-gateway", consumer_name="api-gateway",
timeout=600,
pulsar_api_key=None,
):
self.pub = Publisher(
@ -29,6 +30,7 @@ class ServiceRequestor:
self.sub = Subscriber(
pulsar_client, response_queue,
subscription, consumer_name,
pulsar_api_key,
JsonSchema(response_schema)
)

View file

@ -17,6 +17,7 @@ class ServiceSender:
self,
pulsar_client,
request_queue, request_schema,
pulsar_api_key=None,
):
self.pub = Publisher(

View file

@ -57,6 +57,7 @@ logger.setLevel(logging.INFO)
default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
default_prometheus_url = os.getenv("PROMETHEUS_URL", "http://prometheus:9090")
default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
default_timeout = 600
default_port = 8088
default_api_token = os.getenv("GATEWAY_SECRET", "")
@ -73,11 +74,20 @@ class Api:
self.port = int(config.get("port", default_port))
self.timeout = int(config.get("timeout", default_timeout))
self.pulsar_host = config.get("pulsar_host", default_pulsar_host)
self.pulsar_api_key = config.get(
"pulsar_api_key", default_pulsar_api_key
)
self.pulsar_listener = config.get("pulsar_listener", None)
self.pulsar_client = pulsar.Client(
self.pulsar_host, listener_name=self.pulsar_listener
)
if self.pulsar_api_key:
self.pulsar_client = pulsar.Client(
self.pulsar_host, listener_name=self.pulsar_listener,
authentication=pulsar.AuthenticationToken(self.pulsar_api_key)
)
else:
self.pulsar_client = pulsar.Client(
self.pulsar_host, listener_name=self.pulsar_listener,
)
self.prometheus_url = config.get(
"prometheus_url", default_prometheus_url,
@ -224,6 +234,7 @@ class Api:
TriplesLoadEndpoint(
pulsar_client=self.pulsar_client,
auth = self.auth,
pulsar_api_key=self.pulsar_api_key,
),
GraphEmbeddingsLoadEndpoint(
pulsar_client=self.pulsar_client,
@ -237,6 +248,7 @@ class Api:
pulsar_client=self.pulsar_client,
auth = self.auth,
services = self.services,
pulsar_api_key=self.pulsar_api_key,
),
MetricsEndpoint(
endpoint_path = "/api/v1/metrics",
@ -270,6 +282,12 @@ def run():
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'--pulsar-api-key',
default=default_pulsar_api_key,
help=f'Pulsar API key',
)
parser.add_argument(
'--pulsar-listener',

View file

@ -24,6 +24,7 @@ class TriplesStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber(
self.pulsar_client, triples_store_queue,
"api-gateway", "api-gateway",
pulsar_api_key=self.pulsar_api_key,
schema=JsonSchema(Triples)
)

View file

@ -161,6 +161,7 @@ class GraphRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key=None,
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
@ -207,6 +208,7 @@ class GraphRag:
self.ge_client = GraphEmbeddingsClient(
pulsar_host=pulsar_host,
pulsar_api_key=-pulsar_api_key,
subscriber=module + "-ge",
input_queue=ge_request_queue,
output_queue=ge_response_queue,
@ -214,6 +216,7 @@ class GraphRag:
self.triples_client = TriplesQueryClient(
pulsar_host=pulsar_host,
pulsar_api_key=-pulsar_api_key,
subscriber=module + "-tpl",
input_queue=tpl_request_queue,
output_queue=tpl_response_queue
@ -221,6 +224,7 @@ class GraphRag:
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
pulsar_api_key=-pulsar_api_key,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
@ -234,6 +238,7 @@ class GraphRag:
self.prompt = PromptClient(
pulsar_host=pulsar_host,
pulsar_api_key=-pulsar_api_key,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber=module + "-prompt",

View file

@ -63,7 +63,8 @@ class Processor(ConsumerProducer):
subscriber=subscriber,
input_queue=tc_request_queue,
output_queue=tc_response_queue,
pulsar_host = self.pulsar_host
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
)
def parse_json(self, text):

View file

@ -136,7 +136,8 @@ class Processor(ConsumerProducer):
subscriber=subscriber,
input_queue=tc_request_queue,
output_queue=tc_response_queue,
pulsar_host = self.pulsar_host
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
)
# System prompt hack

View file

@ -49,11 +49,12 @@ class Processing:
pulsar_host,
log_level,
file,
pulsar_api_key=None,
):
self.pulsar_host = pulsar_host
self.log_level = log_level
self.file = file
self.pulsar_api_key = pulsar_api_key
self.defs = load(open(file, "r"), Loader=Loader)
def run(self):
@ -68,6 +69,7 @@ class Processing:
params = {
"pulsar_host": self.pulsar_host,
"pulsar_api_key": self.pulsar_api_key,
"log_level": str(self.log_level),
}
@ -125,12 +127,19 @@ def run():
)
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'--pulsar-api-key',
default=default_pulsar_api_key,
help=f'Pulsar API key',
)
parser.add_argument(
'-l', '--log-level',
@ -153,6 +162,7 @@ def run():
try:
p = Processing(
pulsar_host=args.pulsar_host,
pulsar_api_key=args.pulsar_api_key,
file=args.file,
log_level=args.log_level,
)

View file

@ -68,6 +68,7 @@ class Processor(ConsumerProducer):
self.rag = DocumentRag(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
pr_request_queue=pr_request_queue,
pr_response_queue=pr_response_queue,
emb_request_queue=emb_request_queue,

View file

@ -82,6 +82,7 @@ class Processor(ConsumerProducer):
self.rag = GraphRag(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
pr_request_queue=pr_request_queue,
pr_response_queue=pr_response_queue,
emb_request_queue=emb_request_queue,