mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-25 16:36:22 +02:00
simulation_runner: added failed job cleanup
This commit is contained in:
parent
ee20f7c6e3
commit
7bc3203ed2
4 changed files with 113 additions and 29 deletions
|
|
@ -1,6 +1,7 @@
|
|||
from pymongo import MongoClient
|
||||
from bson import ObjectId
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from scenario_types import SimulationRun, Scenario, SimulationResult, SimulationAggregateResult
|
||||
|
||||
MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip()
|
||||
|
|
@ -72,3 +73,33 @@ def get_scenarios_for_run(simulation_run: SimulationRun):
|
|||
def write_simulation_result(result: SimulationResult):
|
||||
collection = get_collection(SIMULATION_RESULT_COLLECTION_NAME)
|
||||
collection.insert_one(result.model_dump())
|
||||
|
||||
def update_simulation_run_heartbeat(simulation_run_id: str):
|
||||
"""
|
||||
Updates the 'last_heartbeat' timestamp for a SimulationRun.
|
||||
"""
|
||||
|
||||
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
|
||||
collection.update_one(
|
||||
{"_id": ObjectId(simulation_run_id)},
|
||||
{"$set": {"lastHeartbeat": datetime.now(timezone.utc)}}
|
||||
)
|
||||
|
||||
def mark_stale_jobs_as_failed():
|
||||
"""
|
||||
Finds any job in 'running' status whose last_heartbeat is older than 5 minutes,
|
||||
and sets it to 'failed'.
|
||||
"""
|
||||
|
||||
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
|
||||
stale_threshold = datetime.now(timezone.utc) - timedelta(minutes=20)
|
||||
result = collection.update_many(
|
||||
{
|
||||
"status": "running",
|
||||
"lastHeartbeat": {"$lt": stale_threshold}
|
||||
},
|
||||
{
|
||||
"$set": {"status": "failed"}
|
||||
}
|
||||
)
|
||||
return result.modified_count # Number of jobs marked failed
|
||||
|
|
@ -21,6 +21,7 @@ class SimulationRun(BaseModel):
|
|||
scenarioIds: List[str]
|
||||
workflowId: str
|
||||
startedAt: datetime
|
||||
lastHeartbeat: Optional[datetime] = None
|
||||
completedAt: Optional[datetime] = None
|
||||
aggregateResults: Optional[dict] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key
|
||||
from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key, mark_stale_jobs_as_failed, update_simulation_run_heartbeat
|
||||
from scenario_types import SimulationRun, Scenario
|
||||
from simulation import simulate_scenarios
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
|
@ -15,14 +15,16 @@ class JobService:
|
|||
"""
|
||||
Periodically checks for new jobs in MongoDB and processes them.
|
||||
"""
|
||||
|
||||
# Start the stale-job check in the background
|
||||
asyncio.create_task(self.fail_stale_jobs_loop())
|
||||
|
||||
iterations = 0
|
||||
while True:
|
||||
job = get_pending_simulation_run()
|
||||
if job:
|
||||
logging.info(f"Found new job: {job}. Processing...")
|
||||
asyncio.create_task(self.process_job(job))
|
||||
else:
|
||||
logging.info("No new jobs found. Checking again in 5 seconds...")
|
||||
|
||||
iterations += 1
|
||||
if max_iterations is not None and iterations >= max_iterations:
|
||||
|
|
@ -34,18 +36,52 @@ class JobService:
|
|||
"""
|
||||
Calls the simulation function and updates job status upon completion.
|
||||
"""
|
||||
|
||||
async with self.semaphore:
|
||||
scenarios = get_scenarios_for_run(job)
|
||||
if not scenarios or len(scenarios) == 0:
|
||||
logging.info(f"No scenarios found for job {job.id}")
|
||||
return
|
||||
# Start heartbeat in background
|
||||
stop_heartbeat_event = asyncio.Event()
|
||||
heartbeat_task = asyncio.create_task(self.heartbeat_loop(job.id, stop_heartbeat_event))
|
||||
|
||||
api_key = get_api_key(job.projectId)
|
||||
result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key)
|
||||
try:
|
||||
scenarios = get_scenarios_for_run(job)
|
||||
if not scenarios or len(scenarios) == 0:
|
||||
logging.info(f"No scenarios found for job {job.id}")
|
||||
return
|
||||
|
||||
api_key = get_api_key(job.projectId)
|
||||
result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key)
|
||||
|
||||
|
||||
set_simulation_run_to_completed(job, result)
|
||||
logging.info(f"Job {job.id} completed.")
|
||||
set_simulation_run_to_completed(job, result)
|
||||
logging.info(f"Job {job.id} completed.")
|
||||
except Exception as exc:
|
||||
logging.error(f"Job {job.id} failed: {exc}")
|
||||
finally:
|
||||
stop_heartbeat_event.set()
|
||||
await heartbeat_task
|
||||
|
||||
async def fail_stale_jobs_loop(self):
|
||||
"""
|
||||
Periodically checks for stale jobs that haven't received a heartbeat in over 5 minutes,
|
||||
and marks them as 'failed'.
|
||||
"""
|
||||
while True:
|
||||
count = mark_stale_jobs_as_failed()
|
||||
if count > 0:
|
||||
logging.warning(f"Marked {count} stale jobs as failed.")
|
||||
await asyncio.sleep(60) # Check every 60 seconds
|
||||
|
||||
async def heartbeat_loop(self, job_id: str, stop_event: asyncio.Event):
|
||||
"""
|
||||
Periodically updates 'last_heartbeat' for the given job until 'stop_event' is set.
|
||||
"""
|
||||
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
update_simulation_run_heartbeat(job_id)
|
||||
await asyncio.sleep(10) # Heartbeat interval in seconds
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2,20 +2,22 @@ from rowboat import Client, StatefulChat
|
|||
from typing import List
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from openai import OpenAI
|
||||
from scenario_types import Scenario, SimulationResult, SimulationAggregateResult
|
||||
from db import write_simulation_result
|
||||
|
||||
|
||||
openai_client = OpenAI()
|
||||
MODEL_NAME = "gpt-4o"
|
||||
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
|
||||
|
||||
def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> str:
|
||||
async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> tuple[str, str, str]:
|
||||
"""
|
||||
Runs a mock simulation for a given scenario.
|
||||
Runs a mock simulation for a given scenario asynchronously.
|
||||
After simulating several turns of conversation, it evaluates the conversation.
|
||||
Returns a tuple of (evaluation_result, details, transcript_str).
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
support_chat = StatefulChat(
|
||||
rowboat_client,
|
||||
|
|
@ -36,18 +38,24 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
|
|||
for i in range(max_iterations):
|
||||
openai_input = messages
|
||||
|
||||
simulated_user_response = openai_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=openai_input,
|
||||
temperature=0.0,
|
||||
# Run OpenAI API call in a separate thread
|
||||
simulated_user_response = await loop.run_in_executor(
|
||||
None, # Use default thread pool
|
||||
lambda: openai_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=openai_input,
|
||||
temperature=0.0,
|
||||
)
|
||||
)
|
||||
|
||||
simulated_content = simulated_user_response.choices[0].message.content
|
||||
|
||||
# Feed the model-generated content back into Rowboat's stateful chat
|
||||
rowboat_response = support_chat.run(simulated_content)
|
||||
# Run support_chat.run in a thread if it's synchronous
|
||||
rowboat_response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: support_chat.run(simulated_content)
|
||||
)
|
||||
|
||||
# Store the user message back into `messages` so the conversation continues
|
||||
messages.append({"role": "assistant", "content": rowboat_response})
|
||||
|
||||
# -------------------------
|
||||
|
|
@ -76,11 +84,15 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
|
|||
}
|
||||
]
|
||||
|
||||
eval_response = openai_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=evaluation_prompt,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"}
|
||||
# Run evaluation in a separate thread
|
||||
eval_response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: openai_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=evaluation_prompt,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
)
|
||||
|
||||
if not eval_response.choices:
|
||||
|
|
@ -92,10 +104,12 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
|
|||
if evaluation_result is None:
|
||||
raise Exception("No verdict field found in evaluation response")
|
||||
|
||||
return(evaluation_result, details, transcript_str)
|
||||
|
||||
return (evaluation_result, details, transcript_str)
|
||||
|
||||
async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5):
|
||||
"""
|
||||
Simulates a list of scenarios asynchronously and aggregates the results.
|
||||
"""
|
||||
project_id = scenarios[0].projectId
|
||||
client = Client(
|
||||
host=ROWBOAT_API_HOST,
|
||||
|
|
@ -103,8 +117,10 @@ async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id:
|
|||
api_key=api_key
|
||||
)
|
||||
results = []
|
||||
|
||||
for scenario in scenarios:
|
||||
result, details, transcript = simulate_scenario(scenario, client, workflow_id, max_iterations)
|
||||
# Await the asynchronous simulate_scenario
|
||||
result, details, transcript = await simulate_scenario(scenario, client, workflow_id, max_iterations)
|
||||
|
||||
simulation_result = SimulationResult(
|
||||
projectId=project_id,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue