simulation_runner: added failed job cleanup

This commit is contained in:
arkml 2025-02-20 18:51:49 +05:30
parent ee20f7c6e3
commit 7bc3203ed2
4 changed files with 113 additions and 29 deletions

View file

@ -1,6 +1,7 @@
from pymongo import MongoClient
from bson import ObjectId
import os
from datetime import datetime, timedelta, timezone
from scenario_types import SimulationRun, Scenario, SimulationResult, SimulationAggregateResult
MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip()
@ -72,3 +73,33 @@ def get_scenarios_for_run(simulation_run: SimulationRun):
def write_simulation_result(result: SimulationResult):
collection = get_collection(SIMULATION_RESULT_COLLECTION_NAME)
collection.insert_one(result.model_dump())
def update_simulation_run_heartbeat(simulation_run_id: str):
"""
Updates the 'last_heartbeat' timestamp for a SimulationRun.
"""
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
collection.update_one(
{"_id": ObjectId(simulation_run_id)},
{"$set": {"lastHeartbeat": datetime.now(timezone.utc)}}
)
def mark_stale_jobs_as_failed():
"""
Finds any job in 'running' status whose last_heartbeat is older than 5 minutes,
and sets it to 'failed'.
"""
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
stale_threshold = datetime.now(timezone.utc) - timedelta(minutes=20)
result = collection.update_many(
{
"status": "running",
"lastHeartbeat": {"$lt": stale_threshold}
},
{
"$set": {"status": "failed"}
}
)
return result.modified_count # Number of jobs marked failed

View file

@ -21,6 +21,7 @@ class SimulationRun(BaseModel):
scenarioIds: List[str]
workflowId: str
startedAt: datetime
lastHeartbeat: Optional[datetime] = None
completedAt: Optional[datetime] = None
aggregateResults: Optional[dict] = None

View file

@ -1,7 +1,7 @@
import asyncio
import logging
from typing import List
from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key
from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key, mark_stale_jobs_as_failed, update_simulation_run_heartbeat
from scenario_types import SimulationRun, Scenario
from simulation import simulate_scenarios
logging.basicConfig(level=logging.INFO)
@ -15,14 +15,16 @@ class JobService:
"""
Periodically checks for new jobs in MongoDB and processes them.
"""
# Start the stale-job check in the background
asyncio.create_task(self.fail_stale_jobs_loop())
iterations = 0
while True:
job = get_pending_simulation_run()
if job:
logging.info(f"Found new job: {job}. Processing...")
asyncio.create_task(self.process_job(job))
else:
logging.info("No new jobs found. Checking again in 5 seconds...")
iterations += 1
if max_iterations is not None and iterations >= max_iterations:
@ -34,18 +36,52 @@ class JobService:
"""
Calls the simulation function and updates job status upon completion.
"""
async with self.semaphore:
scenarios = get_scenarios_for_run(job)
if not scenarios or len(scenarios) == 0:
logging.info(f"No scenarios found for job {job.id}")
return
# Start heartbeat in background
stop_heartbeat_event = asyncio.Event()
heartbeat_task = asyncio.create_task(self.heartbeat_loop(job.id, stop_heartbeat_event))
api_key = get_api_key(job.projectId)
result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key)
try:
scenarios = get_scenarios_for_run(job)
if not scenarios or len(scenarios) == 0:
logging.info(f"No scenarios found for job {job.id}")
return
api_key = get_api_key(job.projectId)
result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key)
set_simulation_run_to_completed(job, result)
logging.info(f"Job {job.id} completed.")
set_simulation_run_to_completed(job, result)
logging.info(f"Job {job.id} completed.")
except Exception as exc:
logging.error(f"Job {job.id} failed: {exc}")
finally:
stop_heartbeat_event.set()
await heartbeat_task
async def fail_stale_jobs_loop(self):
"""
Periodically checks for stale jobs that haven't received a heartbeat in over 5 minutes,
and marks them as 'failed'.
"""
while True:
count = mark_stale_jobs_as_failed()
if count > 0:
logging.warning(f"Marked {count} stale jobs as failed.")
await asyncio.sleep(60) # Check every 60 seconds
async def heartbeat_loop(self, job_id: str, stop_event: asyncio.Event):
"""
Periodically updates 'last_heartbeat' for the given job until 'stop_event' is set.
"""
try:
while not stop_event.is_set():
update_simulation_run_heartbeat(job_id)
await asyncio.sleep(10) # Heartbeat interval in seconds
except asyncio.CancelledError:
pass
def start(self):
"""

View file

@ -2,20 +2,22 @@ from rowboat import Client, StatefulChat
from typing import List
import json
import os
import asyncio
from openai import OpenAI
from scenario_types import Scenario, SimulationResult, SimulationAggregateResult
from db import write_simulation_result
openai_client = OpenAI()
MODEL_NAME = "gpt-4o"
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> str:
async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> tuple[str, str, str]:
"""
Runs a mock simulation for a given scenario.
Runs a mock simulation for a given scenario asynchronously.
After simulating several turns of conversation, it evaluates the conversation.
Returns a tuple of (evaluation_result, details, transcript_str).
"""
loop = asyncio.get_running_loop()
support_chat = StatefulChat(
rowboat_client,
@ -36,18 +38,24 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
for i in range(max_iterations):
openai_input = messages
simulated_user_response = openai_client.chat.completions.create(
model=MODEL_NAME,
messages=openai_input,
temperature=0.0,
# Run OpenAI API call in a separate thread
simulated_user_response = await loop.run_in_executor(
None, # Use default thread pool
lambda: openai_client.chat.completions.create(
model=MODEL_NAME,
messages=openai_input,
temperature=0.0,
)
)
simulated_content = simulated_user_response.choices[0].message.content
# Feed the model-generated content back into Rowboat's stateful chat
rowboat_response = support_chat.run(simulated_content)
# Run support_chat.run in a thread if it's synchronous
rowboat_response = await loop.run_in_executor(
None,
lambda: support_chat.run(simulated_content)
)
# Store the user message back into `messages` so the conversation continues
messages.append({"role": "assistant", "content": rowboat_response})
# -------------------------
@ -76,11 +84,15 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
}
]
eval_response = openai_client.chat.completions.create(
model=MODEL_NAME,
messages=evaluation_prompt,
temperature=0.0,
response_format={"type": "json_object"}
# Run evaluation in a separate thread
eval_response = await loop.run_in_executor(
None,
lambda: openai_client.chat.completions.create(
model=MODEL_NAME,
messages=evaluation_prompt,
temperature=0.0,
response_format={"type": "json_object"}
)
)
if not eval_response.choices:
@ -92,10 +104,12 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
if evaluation_result is None:
raise Exception("No verdict field found in evaluation response")
return(evaluation_result, details, transcript_str)
return (evaluation_result, details, transcript_str)
async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5):
"""
Simulates a list of scenarios asynchronously and aggregates the results.
"""
project_id = scenarios[0].projectId
client = Client(
host=ROWBOAT_API_HOST,
@ -103,8 +117,10 @@ async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id:
api_key=api_key
)
results = []
for scenario in scenarios:
result, details, transcript = simulate_scenario(scenario, client, workflow_id, max_iterations)
# Await the asynchronous simulate_scenario
result, details, transcript = await simulate_scenario(scenario, client, workflow_id, max_iterations)
simulation_result = SimulationResult(
projectId=project_id,