Reduc pulsar connections (#189)

This commit is contained in:
cybermaggedon 2024-12-03 14:13:40 +00:00 committed by GitHub
parent df23e29971
commit 7e78aa6d91
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 82 additions and 52 deletions

View file

@ -41,10 +41,15 @@ class ServiceEndpoint:
self.operation = "service" self.operation = "service"
async def start(self): async def start(self, client):
self.pub_task = asyncio.create_task(self.pub.run()) self.pub_task = asyncio.create_task(self.pub.run(client))
self.sub_task = asyncio.create_task(self.sub.run()) self.sub_task = asyncio.create_task(self.sub.run(client))
async def join(self):
await self.pub_task
await self.sub_task
def add_routes(self, app): def add_routes(self, app):

View file

@ -29,10 +29,10 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
schema=JsonSchema(GraphEmbeddings) schema=JsonSchema(GraphEmbeddings)
) )
async def start(self): async def start(self, client):
self.task = asyncio.create_task( self.task = asyncio.create_task(
self.publisher.run() self.publisher.run(client)
) )
async def listener(self, ws, running): async def listener(self, ws, running):

View file

@ -28,10 +28,10 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
schema=JsonSchema(GraphEmbeddings) schema=JsonSchema(GraphEmbeddings)
) )
async def start(self): async def start(self, client):
self.task = asyncio.create_task( self.task = asyncio.create_task(
self.subscriber.run() self.subscriber.run(client)
) )
async def async_thread(self, ws, running): async def async_thread(self, ws, running):

View file

@ -1,6 +1,5 @@
import asyncio import asyncio
import aiopulsar
class Publisher: class Publisher:
@ -12,12 +11,11 @@ class Publisher:
self.q = asyncio.Queue(maxsize=max_size) self.q = asyncio.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled self.chunking_enabled = chunking_enabled
async def run(self): async def run(self, client):
while True: while True:
try: try:
async with aiopulsar.connect(self.pulsar_host) as client:
async with client.create_producer( async with client.create_producer(
topic=self.topic, topic=self.topic,
schema=self.schema, schema=self.schema,

View file

@ -17,6 +17,7 @@ from aiohttp import web
import logging import logging
import os import os
import base64 import base64
import aiopulsar
import pulsar import pulsar
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
@ -237,13 +238,35 @@ class Api:
{ "error": str(e) } { "error": str(e) }
) )
async def app_factory(self): async def run_endpoints(self):
async with aiopulsar.connect(self.pulsar_host) as client:
for ep in self.endpoints: for ep in self.endpoints:
await ep.start() await ep.start(client)
self.doc_ingest_pub_task = asyncio.create_task(self.document_out.run()) self.doc_ingest_pub_task = asyncio.create_task(
self.text_ingest_pub_task = asyncio.create_task(self.text_out.run()) self.document_out.run(client)
)
self.text_ingest_pub_task = asyncio.create_task(
self.text_out.run(client)
)
print("Endpoints are running...")
# They never exit
for ep in self.endpoints:
await ep.join()
await self.doc_ingest_pub_task
await self.text_ingest_pub_task
print("Endpoints are stopped.")
async def app_factory(self):
self.endpoint_task = asyncio.create_task(self.run_endpoints())
return self.app return self.app

View file

@ -76,6 +76,12 @@ class SocketEndpoint:
async def start(self): async def start(self):
pass pass
async def join(self):
# Nothing to wait for
while True:
await asyncio.sleep(100)
def add_routes(self, app): def add_routes(self, app):
app.add_routes([ app.add_routes([

View file

@ -1,6 +1,5 @@
import asyncio import asyncio
import aiopulsar
class Subscriber: class Subscriber:
@ -14,10 +13,9 @@ class Subscriber:
self.q = {} self.q = {}
self.full = {} self.full = {}
async def run(self): async def run(self, client):
while True: while True:
try: try:
async with aiopulsar.connect(self.pulsar_host) as client:
async with client.subscribe( async with client.subscribe(
topic=self.topic, topic=self.topic,
subscription_name=self.subscription, subscription_name=self.subscription,

View file

@ -27,10 +27,10 @@ class TriplesLoadEndpoint(SocketEndpoint):
schema=JsonSchema(Triples) schema=JsonSchema(Triples)
) )
async def start(self): async def start(self, client):
self.task = asyncio.create_task( self.task = asyncio.create_task(
self.publisher.run() self.publisher.run(client)
) )
async def listener(self, ws, running): async def listener(self, ws, running):

View file

@ -26,10 +26,10 @@ class TriplesStreamEndpoint(SocketEndpoint):
schema=JsonSchema(Triples) schema=JsonSchema(Triples)
) )
async def start(self): async def start(self, client):
self.task = asyncio.create_task( self.task = asyncio.create_task(
self.subscriber.run() self.subscriber.run(client)
) )
async def async_thread(self, ws, running): async def async_thread(self, ws, running):