simulation_runner: added failed job cleanup

This commit is contained in:
arkml 2025-02-20 18:51:49 +05:30
parent ee20f7c6e3
commit 7bc3203ed2
4 changed files with 113 additions and 29 deletions

View file

@ -1,6 +1,7 @@
from pymongo import MongoClient from pymongo import MongoClient
from bson import ObjectId from bson import ObjectId
import os import os
from datetime import datetime, timedelta, timezone
from scenario_types import SimulationRun, Scenario, SimulationResult, SimulationAggregateResult from scenario_types import SimulationRun, Scenario, SimulationResult, SimulationAggregateResult
MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip() MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip()
@ -72,3 +73,33 @@ def get_scenarios_for_run(simulation_run: SimulationRun):
def write_simulation_result(result: SimulationResult): def write_simulation_result(result: SimulationResult):
collection = get_collection(SIMULATION_RESULT_COLLECTION_NAME) collection = get_collection(SIMULATION_RESULT_COLLECTION_NAME)
collection.insert_one(result.model_dump()) collection.insert_one(result.model_dump())
def update_simulation_run_heartbeat(simulation_run_id: str):
"""
Updates the 'last_heartbeat' timestamp for a SimulationRun.
"""
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
collection.update_one(
{"_id": ObjectId(simulation_run_id)},
{"$set": {"lastHeartbeat": datetime.now(timezone.utc)}}
)
def mark_stale_jobs_as_failed():
"""
Finds any job in 'running' status whose last_heartbeat is older than 5 minutes,
and sets it to 'failed'.
"""
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
stale_threshold = datetime.now(timezone.utc) - timedelta(minutes=20)
result = collection.update_many(
{
"status": "running",
"lastHeartbeat": {"$lt": stale_threshold}
},
{
"$set": {"status": "failed"}
}
)
return result.modified_count # Number of jobs marked failed

View file

@ -21,6 +21,7 @@ class SimulationRun(BaseModel):
scenarioIds: List[str] scenarioIds: List[str]
workflowId: str workflowId: str
startedAt: datetime startedAt: datetime
lastHeartbeat: Optional[datetime] = None
completedAt: Optional[datetime] = None completedAt: Optional[datetime] = None
aggregateResults: Optional[dict] = None aggregateResults: Optional[dict] = None

View file

@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from typing import List from typing import List
from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key, mark_stale_jobs_as_failed, update_simulation_run_heartbeat
from scenario_types import SimulationRun, Scenario from scenario_types import SimulationRun, Scenario
from simulation import simulate_scenarios from simulation import simulate_scenarios
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -15,14 +15,16 @@ class JobService:
""" """
Periodically checks for new jobs in MongoDB and processes them. Periodically checks for new jobs in MongoDB and processes them.
""" """
# Start the stale-job check in the background
asyncio.create_task(self.fail_stale_jobs_loop())
iterations = 0 iterations = 0
while True: while True:
job = get_pending_simulation_run() job = get_pending_simulation_run()
if job: if job:
logging.info(f"Found new job: {job}. Processing...") logging.info(f"Found new job: {job}. Processing...")
asyncio.create_task(self.process_job(job)) asyncio.create_task(self.process_job(job))
else:
logging.info("No new jobs found. Checking again in 5 seconds...")
iterations += 1 iterations += 1
if max_iterations is not None and iterations >= max_iterations: if max_iterations is not None and iterations >= max_iterations:
@ -34,18 +36,52 @@ class JobService:
""" """
Calls the simulation function and updates job status upon completion. Calls the simulation function and updates job status upon completion.
""" """
async with self.semaphore: async with self.semaphore:
scenarios = get_scenarios_for_run(job) # Start heartbeat in background
if not scenarios or len(scenarios) == 0: stop_heartbeat_event = asyncio.Event()
logging.info(f"No scenarios found for job {job.id}") heartbeat_task = asyncio.create_task(self.heartbeat_loop(job.id, stop_heartbeat_event))
return
api_key = get_api_key(job.projectId) try:
result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key) scenarios = get_scenarios_for_run(job)
if not scenarios or len(scenarios) == 0:
logging.info(f"No scenarios found for job {job.id}")
return
api_key = get_api_key(job.projectId)
result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key)
set_simulation_run_to_completed(job, result) set_simulation_run_to_completed(job, result)
logging.info(f"Job {job.id} completed.") logging.info(f"Job {job.id} completed.")
except Exception as exc:
logging.error(f"Job {job.id} failed: {exc}")
finally:
stop_heartbeat_event.set()
await heartbeat_task
async def fail_stale_jobs_loop(self):
"""
Periodically checks for stale jobs that haven't received a heartbeat in over 5 minutes,
and marks them as 'failed'.
"""
while True:
count = mark_stale_jobs_as_failed()
if count > 0:
logging.warning(f"Marked {count} stale jobs as failed.")
await asyncio.sleep(60) # Check every 60 seconds
async def heartbeat_loop(self, job_id: str, stop_event: asyncio.Event):
"""
Periodically updates 'last_heartbeat' for the given job until 'stop_event' is set.
"""
try:
while not stop_event.is_set():
update_simulation_run_heartbeat(job_id)
await asyncio.sleep(10) # Heartbeat interval in seconds
except asyncio.CancelledError:
pass
def start(self): def start(self):
""" """

View file

@ -2,20 +2,22 @@ from rowboat import Client, StatefulChat
from typing import List from typing import List
import json import json
import os import os
import asyncio
from openai import OpenAI from openai import OpenAI
from scenario_types import Scenario, SimulationResult, SimulationAggregateResult from scenario_types import Scenario, SimulationResult, SimulationAggregateResult
from db import write_simulation_result from db import write_simulation_result
openai_client = OpenAI() openai_client = OpenAI()
MODEL_NAME = "gpt-4o" MODEL_NAME = "gpt-4o"
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip() ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> str: async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> tuple[str, str, str]:
""" """
Runs a mock simulation for a given scenario. Runs a mock simulation for a given scenario asynchronously.
After simulating several turns of conversation, it evaluates the conversation. After simulating several turns of conversation, it evaluates the conversation.
Returns a tuple of (evaluation_result, details, transcript_str).
""" """
loop = asyncio.get_running_loop()
support_chat = StatefulChat( support_chat = StatefulChat(
rowboat_client, rowboat_client,
@ -36,18 +38,24 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
for i in range(max_iterations): for i in range(max_iterations):
openai_input = messages openai_input = messages
simulated_user_response = openai_client.chat.completions.create( # Run OpenAI API call in a separate thread
model=MODEL_NAME, simulated_user_response = await loop.run_in_executor(
messages=openai_input, None, # Use default thread pool
temperature=0.0, lambda: openai_client.chat.completions.create(
model=MODEL_NAME,
messages=openai_input,
temperature=0.0,
)
) )
simulated_content = simulated_user_response.choices[0].message.content simulated_content = simulated_user_response.choices[0].message.content
# Feed the model-generated content back into Rowboat's stateful chat # Run support_chat.run in a thread if it's synchronous
rowboat_response = support_chat.run(simulated_content) rowboat_response = await loop.run_in_executor(
None,
lambda: support_chat.run(simulated_content)
)
# Store the user message back into `messages` so the conversation continues
messages.append({"role": "assistant", "content": rowboat_response}) messages.append({"role": "assistant", "content": rowboat_response})
# ------------------------- # -------------------------
@ -76,11 +84,15 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
} }
] ]
eval_response = openai_client.chat.completions.create( # Run evaluation in a separate thread
model=MODEL_NAME, eval_response = await loop.run_in_executor(
messages=evaluation_prompt, None,
temperature=0.0, lambda: openai_client.chat.completions.create(
response_format={"type": "json_object"} model=MODEL_NAME,
messages=evaluation_prompt,
temperature=0.0,
response_format={"type": "json_object"}
)
) )
if not eval_response.choices: if not eval_response.choices:
@ -92,10 +104,12 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
if evaluation_result is None: if evaluation_result is None:
raise Exception("No verdict field found in evaluation response") raise Exception("No verdict field found in evaluation response")
return(evaluation_result, details, transcript_str) return (evaluation_result, details, transcript_str)
async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5): async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5):
"""
Simulates a list of scenarios asynchronously and aggregates the results.
"""
project_id = scenarios[0].projectId project_id = scenarios[0].projectId
client = Client( client = Client(
host=ROWBOAT_API_HOST, host=ROWBOAT_API_HOST,
@ -103,8 +117,10 @@ async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id:
api_key=api_key api_key=api_key
) )
results = [] results = []
for scenario in scenarios: for scenario in scenarios:
result, details, transcript = simulate_scenario(scenario, client, workflow_id, max_iterations) # Await the asynchronous simulate_scenario
result, details, transcript = await simulate_scenario(scenario, client, workflow_id, max_iterations)
simulation_result = SimulationResult( simulation_result = SimulationResult(
projectId=project_id, projectId=project_id,