updated simulation runner to the new collections

This commit is contained in:
arkml 2025-03-01 12:56:08 +05:30
parent 33f30670f6
commit 2b1ef82c20
7 changed files with 311 additions and 159 deletions

View file

@ -1,84 +1,104 @@
import asyncio
import logging
from typing import List
from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key, mark_stale_jobs_as_failed, update_simulation_run_heartbeat
from scenario_types import SimulationRun, Scenario
from simulation import simulate_scenarios
from typing import List, Optional
# Updated imports from your new db module and scenario_types
from db import (
get_pending_run,
get_simulations_for_run,
set_run_to_completed,
get_api_key,
mark_stale_jobs_as_failed,
update_run_heartbeat
)
from scenario_types import TestRun, TestSimulation
# If you have a new simulation function, import it here.
# Otherwise, adapt the name as needed:
from simulation import simulate_simulations # or simulate_scenarios, if unchanged
logging.basicConfig(level=logging.INFO)
class JobService:
def __init__(self):
self.poll_interval = 5 # seconds
# Control concurrency of run processing
self.semaphore = asyncio.Semaphore(5)
async def poll_and_process_jobs(self, max_iterations: int = None):
async def poll_and_process_jobs(self, max_iterations: Optional[int] = None):
"""
Periodically checks for new jobs in MongoDB and processes them.
Periodically checks for new runs in MongoDB and processes them.
"""
# Start the stale-job check in the background
asyncio.create_task(self.fail_stale_jobs_loop())
# Start the stale-run check in the background
asyncio.create_task(self.fail_stale_runs_loop())
iterations = 0
while True:
job = get_pending_simulation_run()
if job:
logging.info(f"Found new job: {job}. Processing...")
asyncio.create_task(self.process_job(job))
run = get_pending_run() # <--- changed to match new DB function
if run:
logging.info(f"Found new run: {run}. Processing...")
asyncio.create_task(self.process_run(run))
iterations += 1
if max_iterations is not None and iterations >= max_iterations:
break
# Sleep for the polling interval
await asyncio.sleep(self.poll_interval)
async def process_job(self, job: SimulationRun):
async def process_run(self, run: TestRun):
"""
Calls the simulation function and updates job status upon completion.
Calls the simulation function and updates run status upon completion.
"""
async with self.semaphore:
# Start heartbeat in background
stop_heartbeat_event = asyncio.Event()
heartbeat_task = asyncio.create_task(self.heartbeat_loop(job.id, stop_heartbeat_event))
heartbeat_task = asyncio.create_task(self.heartbeat_loop(run.id, stop_heartbeat_event))
try:
scenarios = get_scenarios_for_run(job)
if not scenarios or len(scenarios) == 0:
logging.info(f"No scenarios found for job {job.id}")
# Fetch the simulations associated with this run
simulations = get_simulations_for_run(run)
if not simulations:
logging.info(f"No simulations found for run {run.id}")
return
api_key = get_api_key(job.projectId)
result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key)
# Fetch API key if needed
api_key = get_api_key(run.projectId)
# Perform your simulation logic
# adapt this call to your actual simulation functions signature
aggregate_result = await simulate_simulations(
simulations=simulations,
run_id=run.id,
workflow_id=run.workflowId,
api_key=api_key
)
set_simulation_run_to_completed(job, result)
logging.info(f"Job {job.id} completed.")
# Mark run as completed with the aggregated result
set_run_to_completed(run, aggregate_result)
logging.info(f"Run {run.id} completed.")
except Exception as exc:
logging.error(f"Job {job.id} failed: {exc}")
logging.error(f"Run {run.id} failed: {exc}")
finally:
stop_heartbeat_event.set()
await heartbeat_task
async def fail_stale_jobs_loop(self):
async def fail_stale_runs_loop(self):
"""
Periodically checks for stale jobs that haven't received a heartbeat in over 5 minutes,
and marks them as 'failed'.
Periodically checks for stale runs (no heartbeat) and marks them as 'failed'.
"""
while True:
count = mark_stale_jobs_as_failed()
if count > 0:
logging.warning(f"Marked {count} stale jobs as failed.")
logging.warning(f"Marked {count} stale runs as failed.")
await asyncio.sleep(60) # Check every 60 seconds
async def heartbeat_loop(self, job_id: str, stop_event: asyncio.Event):
async def heartbeat_loop(self, run_id: str, stop_event: asyncio.Event):
"""
Periodically updates 'last_heartbeat' for the given job until 'stop_event' is set.
Periodically updates 'lastHeartbeat' for the given run until 'stop_event' is set.
"""
try:
while not stop_event.is_set():
update_simulation_run_heartbeat(job_id)
update_run_heartbeat(run_id)
await asyncio.sleep(10) # Heartbeat interval in seconds
except asyncio.CancelledError:
pass