From d75d2923fec2a9291c1621fe9aec4c84838209f4 Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Tue, 23 Jul 2024 23:02:09 +0100 Subject: [PATCH] Lock in subscriber names --- trustgraph/embeddings/vectorize/vectorize.py | 5 ++++- trustgraph/embeddings_client.py | 10 +++++----- trustgraph/graph_rag.py | 11 +++++++++-- trustgraph/graph_rag_client.py | 10 +++++----- trustgraph/kg/extract_definitions/extract.py | 5 ++++- trustgraph/kg/extract_relationships/extract.py | 5 ++++- trustgraph/llm_client.py | 10 +++++----- 7 files changed, 36 insertions(+), 20 deletions(-) diff --git a/trustgraph/embeddings/vectorize/vectorize.py b/trustgraph/embeddings/vectorize/vectorize.py index eaf9b05e..963017b6 100755 --- a/trustgraph/embeddings/vectorize/vectorize.py +++ b/trustgraph/embeddings/vectorize/vectorize.py @@ -34,7 +34,10 @@ class Processor(ConsumerProducer): } ) - self.embeddings = EmbeddingsClient(pulsar_host=self.pulsar_host) + self.embeddings = EmbeddingsClient( + pulsar_host=self.pulsar_host, + subscriber=module + "emb", + ) def emit(self, source, chunk, vectors): diff --git a/trustgraph/embeddings_client.py b/trustgraph/embeddings_client.py index fd69d602..9b7fa81a 100644 --- a/trustgraph/embeddings_client.py +++ b/trustgraph/embeddings_client.py @@ -17,14 +17,14 @@ DEBUG=_pulsar.LoggerLevel.Debug class EmbeddingsClient: def __init__( - self, log_level=ERROR, client_id=None, + self, log_level=ERROR, subscriber=None, pulsar_host="pulsar://pulsar:6650", ): self.client = None - if client_id == None: - client_id = str(uuid.uuid4()) + if subscriber == None: + subscriber = str(uuid.uuid4()) self.client = pulsar.Client( pulsar_host, @@ -38,7 +38,7 @@ class EmbeddingsClient: ) self.consumer = self.client.subscribe( - embeddings_response_queue, client_id, + embeddings_response_queue, subscriber, schema=JsonSchema(EmbeddingsResponse), ) @@ -68,7 +68,7 @@ class EmbeddingsClient: def __del__(self): if hasattr(self, "consumer"): - self.consumer.unsubscribe() +# self.consumer.unsubscribe() self.consumer.close() if hasattr(self, "producer"): diff --git a/trustgraph/graph_rag.py b/trustgraph/graph_rag.py index 220231f6..0b7ccbdc 100644 --- a/trustgraph/graph_rag.py +++ b/trustgraph/graph_rag.py @@ -19,6 +19,7 @@ class GraphRag: entity_limit=50, triple_limit=30, max_subgraph_size=3000, + module="test", ): self.verbose=verbose @@ -31,7 +32,10 @@ class GraphRag: self.graph = TrustGraph(graph_hosts) - self.embeddings = EmbeddingsClient(pulsar_host=pulsar_host) + self.embeddings = EmbeddingsClient( + pulsar_host=pulsar_host, + subscriber=module + "-emb", + ) self.vecstore = TripleVectors(vector_store) @@ -41,7 +45,10 @@ class GraphRag: self.label_cache = {} - self.llm = LlmClient(pulsar_host=pulsar_host) + self.llm = LlmClient( + pulsar_host=pulsar_host, + subscriber=module + "-llm", + ) if self.verbose: print("Initialised", flush=True) diff --git a/trustgraph/graph_rag_client.py b/trustgraph/graph_rag_client.py index 4a6fbc2f..fba37c5b 100644 --- a/trustgraph/graph_rag_client.py +++ b/trustgraph/graph_rag_client.py @@ -18,12 +18,12 @@ DEBUG=_pulsar.LoggerLevel.Debug class GraphRagClient: def __init__( - self, log_level=ERROR, client_id=None, + self, log_level=ERROR, subscriber=None, pulsar_host="pulsar://pulsar:6650", ): - if client_id == None: - client_id = str(uuid.uuid4()) + if subscriber == None: + subscriber = str(uuid.uuid4()) self.client = pulsar.Client( pulsar_host, @@ -37,7 +37,7 @@ class GraphRagClient: ) self.consumer = self.client.subscribe( - graph_rag_response_queue, client_id, + graph_rag_response_queue, subscriber, schema=JsonSchema(GraphRagResponse), ) @@ -67,7 +67,7 @@ class GraphRagClient: def __del__(self): if hasattr(self, "consumer"): - self.consumer.unsubscribe() +# self.consumer.unsubscribe() self.consumer.close() if hasattr(self, "producer"): diff --git a/trustgraph/kg/extract_definitions/extract.py b/trustgraph/kg/extract_definitions/extract.py index 22b428b9..46f18b62 100755 --- a/trustgraph/kg/extract_definitions/extract.py +++ b/trustgraph/kg/extract_definitions/extract.py @@ -41,7 +41,10 @@ class Processor(ConsumerProducer): } ) - self.llm = LlmClient(pulsar_host=self.pulsar_host) + self.llm = LlmClient( + pulsar_host=self.pulsar_host, + subscriber = module + "-llm", + ) def to_uri(self, text): diff --git a/trustgraph/kg/extract_relationships/extract.py b/trustgraph/kg/extract_relationships/extract.py index 45f63e8a..55f72491 100755 --- a/trustgraph/kg/extract_relationships/extract.py +++ b/trustgraph/kg/extract_relationships/extract.py @@ -61,7 +61,10 @@ class Processor(ConsumerProducer): "vector_schema": GraphEmbeddings.__name__, }) - self.llm = LlmClient(pulsar_host=self.pulsar_host) + self.llm = LlmClient( + pulsar_host = self.pulsar_host, + subscriber = module + "-llm", + ) def to_uri(self, text): diff --git a/trustgraph/llm_client.py b/trustgraph/llm_client.py index c6252bf0..3ed9a29a 100644 --- a/trustgraph/llm_client.py +++ b/trustgraph/llm_client.py @@ -19,12 +19,12 @@ DEBUG=_pulsar.LoggerLevel.Debug class LlmClient: def __init__( - self, log_level=ERROR, client_id=None, + self, log_level=ERROR, subscriber=None, pulsar_host="pulsar://pulsar:6650", ): - if client_id == None: - client_id = str(uuid.uuid4()) + if subscriber == None: + subscriber = str(uuid.uuid4()) self.client = pulsar.Client( pulsar_host, @@ -38,7 +38,7 @@ class LlmClient: ) self.consumer = self.client.subscribe( - text_completion_response_queue, client_id, + text_completion_response_queue, subscriber, schema=JsonSchema(TextCompletionResponse), ) @@ -69,7 +69,7 @@ class LlmClient: def __del__(self): if hasattr(self, "consumer"): - self.consumer.unsubscribe() +# self.consumer.unsubscribe() self.consumer.close() if hasattr(self, "producer"):