diff --git a/trustgraph/kg/extract_definitions/extract.py b/trustgraph/kg/extract_definitions/extract.py index 334fb32d..d090b919 100755 --- a/trustgraph/kg/extract_definitions/extract.py +++ b/trustgraph/kg/extract_definitions/extract.py @@ -9,6 +9,8 @@ import json from ... schema import ChunkEmbeddings, Triple, Source, Value from ... schema import chunk_embeddings_ingest_queue, triples_store_queue +from ... schema import text_completion_request_queue +from ... schema import text_completion_response_queue from ... log_level import LogLevel from ... llm_client import LlmClient from ... prompts import to_definitions @@ -30,6 +32,12 @@ class Processor(ConsumerProducer): input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) + tc_request_queue = params.get( + "text_completion_request_queue", text_completion_request_queue + ) + tc_response_queue = params.get( + "text_completion_response_queue", text_completion_response_queue + ) super(Processor, self).__init__( **params | { @@ -38,11 +46,15 @@ class Processor(ConsumerProducer): "subscriber": subscriber, "input_schema": ChunkEmbeddings, "output_schema": Triple, + "text_completion_request_queue": tc_request_queue, + "text_completion_response_queue": tc_response_queue, } ) self.llm = LlmClient( pulsar_host=self.pulsar_host, + input_queue=tc_request_queue, + output_queue=tc_response_queue, subscriber = module + "-llm", ) @@ -108,6 +120,18 @@ class Processor(ConsumerProducer): default_output_queue, ) + parser.add_argument( + '--text-completion-request-queue', + default=text_completion_request_queue, + help=f'Text completion request queue (default: {text_completion_request_queue})', + ) + + parser.add_argument( + '--text-completion-response-queue', + default=text_completion_response_queue, + help=f'Text completion response queue (default: {text_completion_response_queue})', + ) + def run(): Processor.start(module, __doc__) diff --git a/trustgraph/llm_client.py b/trustgraph/llm_client.py index 3ed9a29a..062fb323 100644 --- a/trustgraph/llm_client.py +++ b/trustgraph/llm_client.py @@ -19,10 +19,19 @@ DEBUG=_pulsar.LoggerLevel.Debug class LlmClient: def __init__( - self, log_level=ERROR, subscriber=None, + self, log_level=ERROR, + subscriber=None, + input_queue=None, + output_queue=None, pulsar_host="pulsar://pulsar:6650", ): + if input_queue == None: + input_queue = text_completion_request_queue + + if output_queue == None: + output_queue = text_completion_response_queue + if subscriber == None: subscriber = str(uuid.uuid4()) @@ -32,13 +41,13 @@ class LlmClient: ) self.producer = self.client.create_producer( - topic=text_completion_request_queue, + topic=input_queue, schema=JsonSchema(TextCompletionRequest), chunking_enabled=True, ) self.consumer = self.client.subscribe( - text_completion_response_queue, subscriber, + output_queue, subscriber, schema=JsonSchema(TextCompletionResponse), )