diff --git a/apps/simulation_runner/db.py b/apps/simulation_runner/db.py index c0e26267..2dfd8f5e 100644 --- a/apps/simulation_runner/db.py +++ b/apps/simulation_runner/db.py @@ -1,6 +1,7 @@ from pymongo import MongoClient from bson import ObjectId import os +from datetime import datetime, timedelta, timezone from scenario_types import SimulationRun, Scenario, SimulationResult, SimulationAggregateResult MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip() @@ -72,3 +73,33 @@ def get_scenarios_for_run(simulation_run: SimulationRun): def write_simulation_result(result: SimulationResult): collection = get_collection(SIMULATION_RESULT_COLLECTION_NAME) collection.insert_one(result.model_dump()) + +def update_simulation_run_heartbeat(simulation_run_id: str): + """ + Updates the 'last_heartbeat' timestamp for a SimulationRun. + """ + + collection = get_collection(SIMULATIONS_COLLECTION_NAME) + collection.update_one( + {"_id": ObjectId(simulation_run_id)}, + {"$set": {"lastHeartbeat": datetime.now(timezone.utc)}} + ) + +def mark_stale_jobs_as_failed(): + """ + Finds any job in 'running' status whose last_heartbeat is older than 5 minutes, + and sets it to 'failed'. + """ + + collection = get_collection(SIMULATIONS_COLLECTION_NAME) + stale_threshold = datetime.now(timezone.utc) - timedelta(minutes=20) + result = collection.update_many( + { + "status": "running", + "lastHeartbeat": {"$lt": stale_threshold} + }, + { + "$set": {"status": "failed"} + } + ) + return result.modified_count # Number of jobs marked failed \ No newline at end of file diff --git a/apps/simulation_runner/scenario_types.py b/apps/simulation_runner/scenario_types.py index fdafb2d2..fa8dc14f 100644 --- a/apps/simulation_runner/scenario_types.py +++ b/apps/simulation_runner/scenario_types.py @@ -21,6 +21,7 @@ class SimulationRun(BaseModel): scenarioIds: List[str] workflowId: str startedAt: datetime + lastHeartbeat: Optional[datetime] = None completedAt: Optional[datetime] = None aggregateResults: Optional[dict] = None diff --git a/apps/simulation_runner/service.py b/apps/simulation_runner/service.py index 206a5eaa..7ec8bbe5 100644 --- a/apps/simulation_runner/service.py +++ b/apps/simulation_runner/service.py @@ -1,7 +1,7 @@ import asyncio import logging from typing import List -from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key +from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key, mark_stale_jobs_as_failed, update_simulation_run_heartbeat from scenario_types import SimulationRun, Scenario from simulation import simulate_scenarios logging.basicConfig(level=logging.INFO) @@ -15,14 +15,16 @@ class JobService: """ Periodically checks for new jobs in MongoDB and processes them. """ + + # Start the stale-job check in the background + asyncio.create_task(self.fail_stale_jobs_loop()) + iterations = 0 while True: job = get_pending_simulation_run() if job: logging.info(f"Found new job: {job}. Processing...") asyncio.create_task(self.process_job(job)) - else: - logging.info("No new jobs found. Checking again in 5 seconds...") iterations += 1 if max_iterations is not None and iterations >= max_iterations: @@ -34,18 +36,52 @@ class JobService: """ Calls the simulation function and updates job status upon completion. """ + async with self.semaphore: - scenarios = get_scenarios_for_run(job) - if not scenarios or len(scenarios) == 0: - logging.info(f"No scenarios found for job {job.id}") - return + # Start heartbeat in background + stop_heartbeat_event = asyncio.Event() + heartbeat_task = asyncio.create_task(self.heartbeat_loop(job.id, stop_heartbeat_event)) - api_key = get_api_key(job.projectId) - result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key) + try: + scenarios = get_scenarios_for_run(job) + if not scenarios or len(scenarios) == 0: + logging.info(f"No scenarios found for job {job.id}") + return + + api_key = get_api_key(job.projectId) + result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key) - set_simulation_run_to_completed(job, result) - logging.info(f"Job {job.id} completed.") + set_simulation_run_to_completed(job, result) + logging.info(f"Job {job.id} completed.") + except Exception as exc: + logging.error(f"Job {job.id} failed: {exc}") + finally: + stop_heartbeat_event.set() + await heartbeat_task + + async def fail_stale_jobs_loop(self): + """ + Periodically checks for stale jobs that haven't received a heartbeat in over 5 minutes, + and marks them as 'failed'. + """ + while True: + count = mark_stale_jobs_as_failed() + if count > 0: + logging.warning(f"Marked {count} stale jobs as failed.") + await asyncio.sleep(60) # Check every 60 seconds + + async def heartbeat_loop(self, job_id: str, stop_event: asyncio.Event): + """ + Periodically updates 'last_heartbeat' for the given job until 'stop_event' is set. + """ + + try: + while not stop_event.is_set(): + update_simulation_run_heartbeat(job_id) + await asyncio.sleep(10) # Heartbeat interval in seconds + except asyncio.CancelledError: + pass def start(self): """ diff --git a/apps/simulation_runner/simulation.py b/apps/simulation_runner/simulation.py index fad9d809..35897aaf 100644 --- a/apps/simulation_runner/simulation.py +++ b/apps/simulation_runner/simulation.py @@ -2,20 +2,22 @@ from rowboat import Client, StatefulChat from typing import List import json import os +import asyncio from openai import OpenAI from scenario_types import Scenario, SimulationResult, SimulationAggregateResult from db import write_simulation_result - openai_client = OpenAI() MODEL_NAME = "gpt-4o" ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip() -def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> str: +async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> tuple[str, str, str]: """ - Runs a mock simulation for a given scenario. + Runs a mock simulation for a given scenario asynchronously. After simulating several turns of conversation, it evaluates the conversation. + Returns a tuple of (evaluation_result, details, transcript_str). """ + loop = asyncio.get_running_loop() support_chat = StatefulChat( rowboat_client, @@ -36,18 +38,24 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s for i in range(max_iterations): openai_input = messages - simulated_user_response = openai_client.chat.completions.create( - model=MODEL_NAME, - messages=openai_input, - temperature=0.0, + # Run OpenAI API call in a separate thread + simulated_user_response = await loop.run_in_executor( + None, # Use default thread pool + lambda: openai_client.chat.completions.create( + model=MODEL_NAME, + messages=openai_input, + temperature=0.0, + ) ) simulated_content = simulated_user_response.choices[0].message.content - # Feed the model-generated content back into Rowboat's stateful chat - rowboat_response = support_chat.run(simulated_content) + # Run support_chat.run in a thread if it's synchronous + rowboat_response = await loop.run_in_executor( + None, + lambda: support_chat.run(simulated_content) + ) - # Store the user message back into `messages` so the conversation continues messages.append({"role": "assistant", "content": rowboat_response}) # ------------------------- @@ -76,11 +84,15 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s } ] - eval_response = openai_client.chat.completions.create( - model=MODEL_NAME, - messages=evaluation_prompt, - temperature=0.0, - response_format={"type": "json_object"} + # Run evaluation in a separate thread + eval_response = await loop.run_in_executor( + None, + lambda: openai_client.chat.completions.create( + model=MODEL_NAME, + messages=evaluation_prompt, + temperature=0.0, + response_format={"type": "json_object"} + ) ) if not eval_response.choices: @@ -92,10 +104,12 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s if evaluation_result is None: raise Exception("No verdict field found in evaluation response") - return(evaluation_result, details, transcript_str) - + return (evaluation_result, details, transcript_str) async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5): + """ + Simulates a list of scenarios asynchronously and aggregates the results. + """ project_id = scenarios[0].projectId client = Client( host=ROWBOAT_API_HOST, @@ -103,8 +117,10 @@ async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: api_key=api_key ) results = [] + for scenario in scenarios: - result, details, transcript = simulate_scenario(scenario, client, workflow_id, max_iterations) + # Await the asynchronous simulate_scenario + result, details, transcript = await simulate_scenario(scenario, client, workflow_id, max_iterations) simulation_result = SimulationResult( projectId=project_id,