Lock in subscriber names

This commit is contained in:
Cyber MacGeddon 2024-07-23 23:02:09 +01:00
parent de3525371c
commit d75d2923fe
7 changed files with 36 additions and 20 deletions

View file

@ -34,7 +34,10 @@ class Processor(ConsumerProducer):
} }
) )
self.embeddings = EmbeddingsClient(pulsar_host=self.pulsar_host) self.embeddings = EmbeddingsClient(
pulsar_host=self.pulsar_host,
subscriber=module + "emb",
)
def emit(self, source, chunk, vectors): def emit(self, source, chunk, vectors):

View file

@ -17,14 +17,14 @@ DEBUG=_pulsar.LoggerLevel.Debug
class EmbeddingsClient: class EmbeddingsClient:
def __init__( def __init__(
self, log_level=ERROR, client_id=None, self, log_level=ERROR, subscriber=None,
pulsar_host="pulsar://pulsar:6650", pulsar_host="pulsar://pulsar:6650",
): ):
self.client = None self.client = None
if client_id == None: if subscriber == None:
client_id = str(uuid.uuid4()) subscriber = str(uuid.uuid4())
self.client = pulsar.Client( self.client = pulsar.Client(
pulsar_host, pulsar_host,
@ -38,7 +38,7 @@ class EmbeddingsClient:
) )
self.consumer = self.client.subscribe( self.consumer = self.client.subscribe(
embeddings_response_queue, client_id, embeddings_response_queue, subscriber,
schema=JsonSchema(EmbeddingsResponse), schema=JsonSchema(EmbeddingsResponse),
) )
@ -68,7 +68,7 @@ class EmbeddingsClient:
def __del__(self): def __del__(self):
if hasattr(self, "consumer"): if hasattr(self, "consumer"):
self.consumer.unsubscribe() # self.consumer.unsubscribe()
self.consumer.close() self.consumer.close()
if hasattr(self, "producer"): if hasattr(self, "producer"):

View file

@ -19,6 +19,7 @@ class GraphRag:
entity_limit=50, entity_limit=50,
triple_limit=30, triple_limit=30,
max_subgraph_size=3000, max_subgraph_size=3000,
module="test",
): ):
self.verbose=verbose self.verbose=verbose
@ -31,7 +32,10 @@ class GraphRag:
self.graph = TrustGraph(graph_hosts) self.graph = TrustGraph(graph_hosts)
self.embeddings = EmbeddingsClient(pulsar_host=pulsar_host) self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
subscriber=module + "-emb",
)
self.vecstore = TripleVectors(vector_store) self.vecstore = TripleVectors(vector_store)
@ -41,7 +45,10 @@ class GraphRag:
self.label_cache = {} self.label_cache = {}
self.llm = LlmClient(pulsar_host=pulsar_host) self.llm = LlmClient(
pulsar_host=pulsar_host,
subscriber=module + "-llm",
)
if self.verbose: if self.verbose:
print("Initialised", flush=True) print("Initialised", flush=True)

View file

@ -18,12 +18,12 @@ DEBUG=_pulsar.LoggerLevel.Debug
class GraphRagClient: class GraphRagClient:
def __init__( def __init__(
self, log_level=ERROR, client_id=None, self, log_level=ERROR, subscriber=None,
pulsar_host="pulsar://pulsar:6650", pulsar_host="pulsar://pulsar:6650",
): ):
if client_id == None: if subscriber == None:
client_id = str(uuid.uuid4()) subscriber = str(uuid.uuid4())
self.client = pulsar.Client( self.client = pulsar.Client(
pulsar_host, pulsar_host,
@ -37,7 +37,7 @@ class GraphRagClient:
) )
self.consumer = self.client.subscribe( self.consumer = self.client.subscribe(
graph_rag_response_queue, client_id, graph_rag_response_queue, subscriber,
schema=JsonSchema(GraphRagResponse), schema=JsonSchema(GraphRagResponse),
) )
@ -67,7 +67,7 @@ class GraphRagClient:
def __del__(self): def __del__(self):
if hasattr(self, "consumer"): if hasattr(self, "consumer"):
self.consumer.unsubscribe() # self.consumer.unsubscribe()
self.consumer.close() self.consumer.close()
if hasattr(self, "producer"): if hasattr(self, "producer"):

View file

@ -41,7 +41,10 @@ class Processor(ConsumerProducer):
} }
) )
self.llm = LlmClient(pulsar_host=self.pulsar_host) self.llm = LlmClient(
pulsar_host=self.pulsar_host,
subscriber = module + "-llm",
)
def to_uri(self, text): def to_uri(self, text):

View file

@ -61,7 +61,10 @@ class Processor(ConsumerProducer):
"vector_schema": GraphEmbeddings.__name__, "vector_schema": GraphEmbeddings.__name__,
}) })
self.llm = LlmClient(pulsar_host=self.pulsar_host) self.llm = LlmClient(
pulsar_host = self.pulsar_host,
subscriber = module + "-llm",
)
def to_uri(self, text): def to_uri(self, text):

View file

@ -19,12 +19,12 @@ DEBUG=_pulsar.LoggerLevel.Debug
class LlmClient: class LlmClient:
def __init__( def __init__(
self, log_level=ERROR, client_id=None, self, log_level=ERROR, subscriber=None,
pulsar_host="pulsar://pulsar:6650", pulsar_host="pulsar://pulsar:6650",
): ):
if client_id == None: if subscriber == None:
client_id = str(uuid.uuid4()) subscriber = str(uuid.uuid4())
self.client = pulsar.Client( self.client = pulsar.Client(
pulsar_host, pulsar_host,
@ -38,7 +38,7 @@ class LlmClient:
) )
self.consumer = self.client.subscribe( self.consumer = self.client.subscribe(
text_completion_response_queue, client_id, text_completion_response_queue, subscriber,
schema=JsonSchema(TextCompletionResponse), schema=JsonSchema(TextCompletionResponse),
) )
@ -69,7 +69,7 @@ class LlmClient:
def __del__(self): def __del__(self):
if hasattr(self, "consumer"): if hasattr(self, "consumer"):
self.consumer.unsubscribe() # self.consumer.unsubscribe()
self.consumer.close() self.consumer.close()
if hasattr(self, "producer"): if hasattr(self, "producer"):