mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-05-02 03:42:38 +02:00
simulation_runner: added failed job cleanup
This commit is contained in:
parent
ee20f7c6e3
commit
7bc3203ed2
4 changed files with 113 additions and 29 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
from scenario_types import SimulationRun, Scenario, SimulationResult, SimulationAggregateResult
|
from scenario_types import SimulationRun, Scenario, SimulationResult, SimulationAggregateResult
|
||||||
|
|
||||||
MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip()
|
MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip()
|
||||||
|
|
@ -72,3 +73,33 @@ def get_scenarios_for_run(simulation_run: SimulationRun):
|
||||||
def write_simulation_result(result: SimulationResult):
|
def write_simulation_result(result: SimulationResult):
|
||||||
collection = get_collection(SIMULATION_RESULT_COLLECTION_NAME)
|
collection = get_collection(SIMULATION_RESULT_COLLECTION_NAME)
|
||||||
collection.insert_one(result.model_dump())
|
collection.insert_one(result.model_dump())
|
||||||
|
|
||||||
|
def update_simulation_run_heartbeat(simulation_run_id: str):
|
||||||
|
"""
|
||||||
|
Updates the 'last_heartbeat' timestamp for a SimulationRun.
|
||||||
|
"""
|
||||||
|
|
||||||
|
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
|
||||||
|
collection.update_one(
|
||||||
|
{"_id": ObjectId(simulation_run_id)},
|
||||||
|
{"$set": {"lastHeartbeat": datetime.now(timezone.utc)}}
|
||||||
|
)
|
||||||
|
|
||||||
|
def mark_stale_jobs_as_failed():
|
||||||
|
"""
|
||||||
|
Finds any job in 'running' status whose last_heartbeat is older than 5 minutes,
|
||||||
|
and sets it to 'failed'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
|
||||||
|
stale_threshold = datetime.now(timezone.utc) - timedelta(minutes=20)
|
||||||
|
result = collection.update_many(
|
||||||
|
{
|
||||||
|
"status": "running",
|
||||||
|
"lastHeartbeat": {"$lt": stale_threshold}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$set": {"status": "failed"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result.modified_count # Number of jobs marked failed
|
||||||
|
|
@ -21,6 +21,7 @@ class SimulationRun(BaseModel):
|
||||||
scenarioIds: List[str]
|
scenarioIds: List[str]
|
||||||
workflowId: str
|
workflowId: str
|
||||||
startedAt: datetime
|
startedAt: datetime
|
||||||
|
lastHeartbeat: Optional[datetime] = None
|
||||||
completedAt: Optional[datetime] = None
|
completedAt: Optional[datetime] = None
|
||||||
aggregateResults: Optional[dict] = None
|
aggregateResults: Optional[dict] = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key
|
from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key, mark_stale_jobs_as_failed, update_simulation_run_heartbeat
|
||||||
from scenario_types import SimulationRun, Scenario
|
from scenario_types import SimulationRun, Scenario
|
||||||
from simulation import simulate_scenarios
|
from simulation import simulate_scenarios
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
@ -15,14 +15,16 @@ class JobService:
|
||||||
"""
|
"""
|
||||||
Periodically checks for new jobs in MongoDB and processes them.
|
Periodically checks for new jobs in MongoDB and processes them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Start the stale-job check in the background
|
||||||
|
asyncio.create_task(self.fail_stale_jobs_loop())
|
||||||
|
|
||||||
iterations = 0
|
iterations = 0
|
||||||
while True:
|
while True:
|
||||||
job = get_pending_simulation_run()
|
job = get_pending_simulation_run()
|
||||||
if job:
|
if job:
|
||||||
logging.info(f"Found new job: {job}. Processing...")
|
logging.info(f"Found new job: {job}. Processing...")
|
||||||
asyncio.create_task(self.process_job(job))
|
asyncio.create_task(self.process_job(job))
|
||||||
else:
|
|
||||||
logging.info("No new jobs found. Checking again in 5 seconds...")
|
|
||||||
|
|
||||||
iterations += 1
|
iterations += 1
|
||||||
if max_iterations is not None and iterations >= max_iterations:
|
if max_iterations is not None and iterations >= max_iterations:
|
||||||
|
|
@ -34,7 +36,13 @@ class JobService:
|
||||||
"""
|
"""
|
||||||
Calls the simulation function and updates job status upon completion.
|
Calls the simulation function and updates job status upon completion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
|
# Start heartbeat in background
|
||||||
|
stop_heartbeat_event = asyncio.Event()
|
||||||
|
heartbeat_task = asyncio.create_task(self.heartbeat_loop(job.id, stop_heartbeat_event))
|
||||||
|
|
||||||
|
try:
|
||||||
scenarios = get_scenarios_for_run(job)
|
scenarios = get_scenarios_for_run(job)
|
||||||
if not scenarios or len(scenarios) == 0:
|
if not scenarios or len(scenarios) == 0:
|
||||||
logging.info(f"No scenarios found for job {job.id}")
|
logging.info(f"No scenarios found for job {job.id}")
|
||||||
|
|
@ -46,6 +54,34 @@ class JobService:
|
||||||
|
|
||||||
set_simulation_run_to_completed(job, result)
|
set_simulation_run_to_completed(job, result)
|
||||||
logging.info(f"Job {job.id} completed.")
|
logging.info(f"Job {job.id} completed.")
|
||||||
|
except Exception as exc:
|
||||||
|
logging.error(f"Job {job.id} failed: {exc}")
|
||||||
|
finally:
|
||||||
|
stop_heartbeat_event.set()
|
||||||
|
await heartbeat_task
|
||||||
|
|
||||||
|
async def fail_stale_jobs_loop(self):
|
||||||
|
"""
|
||||||
|
Periodically checks for stale jobs that haven't received a heartbeat in over 5 minutes,
|
||||||
|
and marks them as 'failed'.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
count = mark_stale_jobs_as_failed()
|
||||||
|
if count > 0:
|
||||||
|
logging.warning(f"Marked {count} stale jobs as failed.")
|
||||||
|
await asyncio.sleep(60) # Check every 60 seconds
|
||||||
|
|
||||||
|
async def heartbeat_loop(self, job_id: str, stop_event: asyncio.Event):
|
||||||
|
"""
|
||||||
|
Periodically updates 'last_heartbeat' for the given job until 'stop_event' is set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not stop_event.is_set():
|
||||||
|
update_simulation_run_heartbeat(job_id)
|
||||||
|
await asyncio.sleep(10) # Heartbeat interval in seconds
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -2,20 +2,22 @@ from rowboat import Client, StatefulChat
|
||||||
from typing import List
|
from typing import List
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from scenario_types import Scenario, SimulationResult, SimulationAggregateResult
|
from scenario_types import Scenario, SimulationResult, SimulationAggregateResult
|
||||||
from db import write_simulation_result
|
from db import write_simulation_result
|
||||||
|
|
||||||
|
|
||||||
openai_client = OpenAI()
|
openai_client = OpenAI()
|
||||||
MODEL_NAME = "gpt-4o"
|
MODEL_NAME = "gpt-4o"
|
||||||
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
|
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
|
||||||
|
|
||||||
def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> str:
|
async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
Runs a mock simulation for a given scenario.
|
Runs a mock simulation for a given scenario asynchronously.
|
||||||
After simulating several turns of conversation, it evaluates the conversation.
|
After simulating several turns of conversation, it evaluates the conversation.
|
||||||
|
Returns a tuple of (evaluation_result, details, transcript_str).
|
||||||
"""
|
"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
support_chat = StatefulChat(
|
support_chat = StatefulChat(
|
||||||
rowboat_client,
|
rowboat_client,
|
||||||
|
|
@ -36,18 +38,24 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
|
||||||
for i in range(max_iterations):
|
for i in range(max_iterations):
|
||||||
openai_input = messages
|
openai_input = messages
|
||||||
|
|
||||||
simulated_user_response = openai_client.chat.completions.create(
|
# Run OpenAI API call in a separate thread
|
||||||
|
simulated_user_response = await loop.run_in_executor(
|
||||||
|
None, # Use default thread pool
|
||||||
|
lambda: openai_client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=openai_input,
|
messages=openai_input,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
simulated_content = simulated_user_response.choices[0].message.content
|
simulated_content = simulated_user_response.choices[0].message.content
|
||||||
|
|
||||||
# Feed the model-generated content back into Rowboat's stateful chat
|
# Run support_chat.run in a thread if it's synchronous
|
||||||
rowboat_response = support_chat.run(simulated_content)
|
rowboat_response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: support_chat.run(simulated_content)
|
||||||
|
)
|
||||||
|
|
||||||
# Store the user message back into `messages` so the conversation continues
|
|
||||||
messages.append({"role": "assistant", "content": rowboat_response})
|
messages.append({"role": "assistant", "content": rowboat_response})
|
||||||
|
|
||||||
# -------------------------
|
# -------------------------
|
||||||
|
|
@ -76,12 +84,16 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
eval_response = openai_client.chat.completions.create(
|
# Run evaluation in a separate thread
|
||||||
|
eval_response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: openai_client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=evaluation_prompt,
|
messages=evaluation_prompt,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
response_format={"type": "json_object"}
|
response_format={"type": "json_object"}
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if not eval_response.choices:
|
if not eval_response.choices:
|
||||||
raise Exception("No evaluation response received from model")
|
raise Exception("No evaluation response received from model")
|
||||||
|
|
@ -92,10 +104,12 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
|
||||||
if evaluation_result is None:
|
if evaluation_result is None:
|
||||||
raise Exception("No verdict field found in evaluation response")
|
raise Exception("No verdict field found in evaluation response")
|
||||||
|
|
||||||
return(evaluation_result, details, transcript_str)
|
return (evaluation_result, details, transcript_str)
|
||||||
|
|
||||||
|
|
||||||
async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5):
|
async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5):
|
||||||
|
"""
|
||||||
|
Simulates a list of scenarios asynchronously and aggregates the results.
|
||||||
|
"""
|
||||||
project_id = scenarios[0].projectId
|
project_id = scenarios[0].projectId
|
||||||
client = Client(
|
client = Client(
|
||||||
host=ROWBOAT_API_HOST,
|
host=ROWBOAT_API_HOST,
|
||||||
|
|
@ -103,8 +117,10 @@ async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id:
|
||||||
api_key=api_key
|
api_key=api_key
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for scenario in scenarios:
|
for scenario in scenarios:
|
||||||
result, details, transcript = simulate_scenario(scenario, client, workflow_id, max_iterations)
|
# Await the asynchronous simulate_scenario
|
||||||
|
result, details, transcript = await simulate_scenario(scenario, client, workflow_id, max_iterations)
|
||||||
|
|
||||||
simulation_result = SimulationResult(
|
simulation_result = SimulationResult(
|
||||||
projectId=project_id,
|
projectId=project_id,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue