mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-30 02:46:25 +02:00
updated simulation runner to the new collections
This commit is contained in:
parent
33f30670f6
commit
2b1ef82c20
7 changed files with 311 additions and 159 deletions
|
|
@ -436,8 +436,8 @@ export async function createRun(
|
||||||
startedAt: new Date().toISOString(),
|
startedAt: new Date().toISOString(),
|
||||||
aggregateResults: {
|
aggregateResults: {
|
||||||
total: 0,
|
total: 0,
|
||||||
pass: 0,
|
passCount: 0,
|
||||||
fail: 0,
|
failCount: 0,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const insertResult = await testRunsCollection.insertOne(doc);
|
const insertResult = await testRunsCollection.insertOne(doc);
|
||||||
|
|
@ -455,8 +455,8 @@ export async function updateRun(
|
||||||
completedAt?: string;
|
completedAt?: string;
|
||||||
aggregateResults?: {
|
aggregateResults?: {
|
||||||
total: number;
|
total: number;
|
||||||
pass: number;
|
passCount: number;
|
||||||
fail: number;
|
failCount: number;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
|
|
|
||||||
|
|
@ -38,8 +38,8 @@ export const TestRun = z.object({
|
||||||
completedAt: z.string().optional(),
|
completedAt: z.string().optional(),
|
||||||
aggregateResults: z.object({
|
aggregateResults: z.object({
|
||||||
total: z.number(),
|
total: z.number(),
|
||||||
pass: z.number(),
|
passCount: z.number(),
|
||||||
fail: z.number(),
|
failCount: z.number(),
|
||||||
}).optional(),
|
}).optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -227,11 +227,11 @@ function ViewRun({
|
||||||
</div>
|
</div>
|
||||||
<div className="p-4 rounded-lg bg-green-50 dark:bg-green-900/20">
|
<div className="p-4 rounded-lg bg-green-50 dark:bg-green-900/20">
|
||||||
<div className="text-sm text-green-600 dark:text-green-400">Passed</div>
|
<div className="text-sm text-green-600 dark:text-green-400">Passed</div>
|
||||||
<div className="text-2xl font-semibold text-green-700 dark:text-green-400">{run.aggregateResults?.pass || 0}</div>
|
<div className="text-2xl font-semibold text-green-700 dark:text-green-400">{run.aggregateResults?.passCount || 0}</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="p-4 rounded-lg bg-red-50 dark:bg-red-900/20">
|
<div className="p-4 rounded-lg bg-red-50 dark:bg-red-900/20">
|
||||||
<div className="text-sm text-red-600 dark:text-red-400">Failed</div>
|
<div className="text-sm text-red-600 dark:text-red-400">Failed</div>
|
||||||
<div className="text-2xl font-semibold text-red-700 dark:text-red-400">{run.aggregateResults?.fail || 0}</div>
|
<div className="text-2xl font-semibold text-red-700 dark:text-red-400">{run.aggregateResults?.failCount || 0}</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
@ -389,10 +389,10 @@ function RunList({
|
||||||
Total: {run.aggregateResults.total}
|
Total: {run.aggregateResults.total}
|
||||||
</div>
|
</div>
|
||||||
<div className="text-green-600 dark:text-green-400">
|
<div className="text-green-600 dark:text-green-400">
|
||||||
Passed: {run.aggregateResults.pass}
|
Passed: {run.aggregateResults.passCount}
|
||||||
</div>
|
</div>
|
||||||
<div className="text-red-600 dark:text-red-400">
|
<div className="text-red-600 dark:text-red-400">
|
||||||
Failed: {run.aggregateResults.fail}
|
Failed: {run.aggregateResults.failCount}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,25 @@ from pymongo import MongoClient
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from scenario_types import SimulationRun, Scenario, SimulationResult, SimulationAggregateResult
|
from typing import Optional
|
||||||
|
from scenario_types import (
|
||||||
|
TestRun,
|
||||||
|
TestScenario,
|
||||||
|
TestProfile,
|
||||||
|
TestSimulation,
|
||||||
|
TestResult,
|
||||||
|
AggregateResults
|
||||||
|
)
|
||||||
|
|
||||||
MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip()
|
MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip()
|
||||||
|
|
||||||
SCENARIOS_COLLECTION_NAME = "scenarios"
|
# New collection names
|
||||||
API_KEYS_COLLECTION = "api_keys"
|
TEST_SCENARIOS_COLLECTION = "test_scenarios"
|
||||||
SIMULATIONS_COLLECTION_NAME = "simulation_runs"
|
TEST_PROFILES_COLLECTION = "test_profiles"
|
||||||
SIMULATION_RESULT_COLLECTION_NAME = "simulation_result"
|
TEST_SIMULATIONS_COLLECTION = "test_simulations"
|
||||||
SIMULATION_AGGREGATE_RESULT_COLLECTION_NAME = "simulation_aggregate_result"
|
TEST_RUNS_COLLECTION = "test_runs"
|
||||||
|
TEST_RESULTS_COLLECTION = "test_results"
|
||||||
|
API_KEYS_COLLECTION = "api_keys" # If still needed
|
||||||
|
|
||||||
def get_db():
|
def get_db():
|
||||||
client = MongoClient(MONGO_URI)
|
client = MongoClient(MONGO_URI)
|
||||||
|
|
@ -21,6 +31,9 @@ def get_collection(collection_name: str):
|
||||||
return db[collection_name]
|
return db[collection_name]
|
||||||
|
|
||||||
def get_api_key(project_id: str):
|
def get_api_key(project_id: str):
|
||||||
|
"""
|
||||||
|
If you still use an API key pattern, adapt as needed.
|
||||||
|
"""
|
||||||
collection = get_collection(API_KEYS_COLLECTION)
|
collection = get_collection(API_KEYS_COLLECTION)
|
||||||
doc = collection.find_one({"projectId": project_id})
|
doc = collection.find_one({"projectId": project_id})
|
||||||
if doc:
|
if doc:
|
||||||
|
|
@ -28,71 +41,68 @@ def get_api_key(project_id: str):
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_pending_simulation_run():
|
#
|
||||||
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
|
# TestRun helpers
|
||||||
|
#
|
||||||
|
|
||||||
|
def get_pending_run() -> Optional[TestRun]:
|
||||||
|
"""
|
||||||
|
Finds a run with 'pending' status, marks it 'running', and returns it.
|
||||||
|
"""
|
||||||
|
collection = get_collection(TEST_RUNS_COLLECTION)
|
||||||
doc = collection.find_one_and_update(
|
doc = collection.find_one_and_update(
|
||||||
{"status": "pending"},
|
{"status": "pending"},
|
||||||
{"$set": {"status": "running"}},
|
{"$set": {"status": "running"}},
|
||||||
return_document=True
|
return_document=True
|
||||||
)
|
)
|
||||||
if doc:
|
if doc:
|
||||||
return SimulationRun(
|
return TestRun(
|
||||||
id=str(doc["_id"]),
|
id=str(doc["_id"]),
|
||||||
projectId=doc["projectId"],
|
projectId=doc["projectId"],
|
||||||
status="running",
|
name=doc["name"],
|
||||||
scenarioIds=doc["scenarioIds"],
|
simulationIds=doc["simulationIds"],
|
||||||
workflowId=doc["workflowId"],
|
workflowId=doc["workflowId"],
|
||||||
|
status="running",
|
||||||
startedAt=doc["startedAt"],
|
startedAt=doc["startedAt"],
|
||||||
completedAt=doc.get("completedAt")
|
completedAt=doc.get("completedAt"),
|
||||||
|
aggregateResults=doc.get("aggregateResults"),
|
||||||
|
lastHeartbeat=doc.get("lastHeartbeat")
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def set_simulation_run_to_completed(simulation_run: SimulationRun, aggregate_result: SimulationAggregateResult):
|
def set_run_to_completed(test_run: TestRun, aggregate: AggregateResults):
|
||||||
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
|
|
||||||
collection.update_one({"_id": ObjectId(simulation_run.id)}, {"$set": {"status": "completed", "aggregateResults": aggregate_result.model_dump(by_alias=True)}})
|
|
||||||
|
|
||||||
def get_scenarios_for_run(simulation_run: SimulationRun):
|
|
||||||
if simulation_run is None:
|
|
||||||
return []
|
|
||||||
collection = get_collection(SCENARIOS_COLLECTION_NAME)
|
|
||||||
scenarios = []
|
|
||||||
for doc in collection.find():
|
|
||||||
if doc["_id"] in [ObjectId(sid) for sid in simulation_run.scenarioIds]:
|
|
||||||
scenarios.append(Scenario(
|
|
||||||
id=str(doc["_id"]),
|
|
||||||
projectId=doc["projectId"],
|
|
||||||
name=doc["name"],
|
|
||||||
description=doc["description"],
|
|
||||||
criteria=doc["criteria"],
|
|
||||||
context=doc["context"],
|
|
||||||
createdAt=doc["createdAt"],
|
|
||||||
lastUpdatedAt=doc["lastUpdatedAt"]
|
|
||||||
))
|
|
||||||
return scenarios
|
|
||||||
|
|
||||||
def write_simulation_result(result: SimulationResult):
|
|
||||||
collection = get_collection(SIMULATION_RESULT_COLLECTION_NAME)
|
|
||||||
collection.insert_one(result.model_dump())
|
|
||||||
|
|
||||||
def update_simulation_run_heartbeat(simulation_run_id: str):
|
|
||||||
"""
|
"""
|
||||||
Updates the 'last_heartbeat' timestamp for a SimulationRun.
|
Marks a test run 'completed' and sets the aggregate results.
|
||||||
"""
|
"""
|
||||||
|
collection = get_collection(TEST_RUNS_COLLECTION)
|
||||||
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
|
|
||||||
collection.update_one(
|
collection.update_one(
|
||||||
{"_id": ObjectId(simulation_run_id)},
|
{"_id": ObjectId(test_run.id)},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
"status": "completed",
|
||||||
|
"aggregateResults": aggregate.model_dump(by_alias=True),
|
||||||
|
"completedAt": datetime.now(timezone.utc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_run_heartbeat(run_id: str):
|
||||||
|
"""
|
||||||
|
Updates the 'lastHeartbeat' timestamp for a TestRun.
|
||||||
|
"""
|
||||||
|
collection = get_collection(TEST_RUNS_COLLECTION)
|
||||||
|
collection.update_one(
|
||||||
|
{"_id": ObjectId(run_id)},
|
||||||
{"$set": {"lastHeartbeat": datetime.now(timezone.utc)}}
|
{"$set": {"lastHeartbeat": datetime.now(timezone.utc)}}
|
||||||
)
|
)
|
||||||
|
|
||||||
def mark_stale_jobs_as_failed():
|
def mark_stale_jobs_as_failed(threshold_minutes: int = 20) -> int:
|
||||||
"""
|
"""
|
||||||
Finds any job in 'running' status whose last_heartbeat is older than 5 minutes,
|
Finds any run in 'running' status whose lastHeartbeat is older than
|
||||||
and sets it to 'failed'.
|
`threshold_minutes`, and sets it to 'failed'. Returns the count.
|
||||||
"""
|
"""
|
||||||
|
collection = get_collection(TEST_RUNS_COLLECTION)
|
||||||
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
|
stale_threshold = datetime.now(timezone.utc) - timedelta(minutes=threshold_minutes)
|
||||||
stale_threshold = datetime.now(timezone.utc) - timedelta(minutes=20)
|
|
||||||
result = collection.update_many(
|
result = collection.update_many(
|
||||||
{
|
{
|
||||||
"status": "running",
|
"status": "running",
|
||||||
|
|
@ -102,4 +112,46 @@ def mark_stale_jobs_as_failed():
|
||||||
"$set": {"status": "failed"}
|
"$set": {"status": "failed"}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return result.modified_count # Number of jobs marked failed
|
return result.modified_count
|
||||||
|
|
||||||
|
#
|
||||||
|
# TestSimulation helpers
|
||||||
|
#
|
||||||
|
|
||||||
|
def get_simulations_for_run(test_run: TestRun) -> list[TestSimulation]:
|
||||||
|
"""
|
||||||
|
Returns all simulations specified by a particular run.
|
||||||
|
"""
|
||||||
|
if test_run is None:
|
||||||
|
return []
|
||||||
|
collection = get_collection(TEST_SIMULATIONS_COLLECTION)
|
||||||
|
simulation_docs = collection.find({
|
||||||
|
"_id": {"$in": [ObjectId(sim_id) for sim_id in test_run.simulationIds]}
|
||||||
|
})
|
||||||
|
|
||||||
|
simulations = []
|
||||||
|
for doc in simulation_docs:
|
||||||
|
simulations.append(
|
||||||
|
TestSimulation(
|
||||||
|
id=str(doc["_id"]),
|
||||||
|
projectId=doc["projectId"],
|
||||||
|
name=doc["name"],
|
||||||
|
scenarioId=doc["scenarioId"],
|
||||||
|
profileId=doc["profileId"],
|
||||||
|
passCriteria=doc["passCriteria"],
|
||||||
|
createdAt=doc["createdAt"],
|
||||||
|
lastUpdatedAt=doc["lastUpdatedAt"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return simulations
|
||||||
|
|
||||||
|
#
|
||||||
|
# TestResult helpers
|
||||||
|
#
|
||||||
|
|
||||||
|
def write_test_result(result: TestResult):
|
||||||
|
"""
|
||||||
|
Writes a test result into the `test_results` collection.
|
||||||
|
"""
|
||||||
|
collection = get_collection(TEST_RESULTS_COLLECTION)
|
||||||
|
collection.insert_one(result.model_dump())
|
||||||
|
|
|
||||||
|
|
@ -2,39 +2,62 @@ from datetime import datetime
|
||||||
from typing import Optional, List, Literal
|
from typing import Optional, List, Literal
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
run_status = Literal["pending", "running", "completed", "cancelled", "failed"]
|
# Define run statuses to include the new "error" status
|
||||||
|
RunStatus = Literal["pending", "running", "completed", "cancelled", "failed", "error"]
|
||||||
|
|
||||||
class Scenario(BaseModel):
|
class TestScenario(BaseModel):
|
||||||
|
# `_id` in Mongo will be stored as ObjectId; we return it as a string
|
||||||
id: str
|
id: str
|
||||||
projectId: str
|
projectId: str
|
||||||
name: str = ""
|
name: str
|
||||||
description: str = ""
|
description: str
|
||||||
criteria: str = ""
|
|
||||||
context: str = ""
|
|
||||||
createdAt: datetime
|
createdAt: datetime
|
||||||
lastUpdatedAt: datetime
|
lastUpdatedAt: datetime
|
||||||
|
|
||||||
class SimulationRun(BaseModel):
|
class TestProfile(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
projectId: str
|
projectId: str
|
||||||
status: Literal["pending", "running", "completed", "cancelled", "failed"]
|
name: str
|
||||||
scenarioIds: List[str]
|
context: str
|
||||||
|
createdAt: datetime
|
||||||
|
lastUpdatedAt: datetime
|
||||||
|
mockTools: bool
|
||||||
|
mockPrompt: Optional[str] = None
|
||||||
|
|
||||||
|
class TestSimulation(BaseModel):
|
||||||
|
id: str
|
||||||
|
projectId: str
|
||||||
|
name: str
|
||||||
|
scenarioId: str
|
||||||
|
profileId: str
|
||||||
|
passCriteria: str
|
||||||
|
createdAt: datetime
|
||||||
|
lastUpdatedAt: datetime
|
||||||
|
|
||||||
|
class AggregateResults(BaseModel):
|
||||||
|
total: int
|
||||||
|
passCount: int
|
||||||
|
failCount: int
|
||||||
|
|
||||||
|
class TestRun(BaseModel):
|
||||||
|
id: str
|
||||||
|
projectId: str
|
||||||
|
name: str
|
||||||
|
simulationIds: List[str]
|
||||||
workflowId: str
|
workflowId: str
|
||||||
|
status: RunStatus
|
||||||
startedAt: datetime
|
startedAt: datetime
|
||||||
lastHeartbeat: Optional[datetime] = None
|
|
||||||
completedAt: Optional[datetime] = None
|
completedAt: Optional[datetime] = None
|
||||||
aggregateResults: Optional[dict] = None
|
# By default, store aggregate results as a dict or the typed AggregateResults
|
||||||
|
aggregateResults: Optional[AggregateResults] = None
|
||||||
|
|
||||||
|
# The new schema does not mention lastHeartbeat,
|
||||||
|
# but you can keep it if you still want to track stale runs
|
||||||
|
lastHeartbeat: Optional[datetime] = None
|
||||||
|
|
||||||
class SimulationResult(BaseModel):
|
class TestResult(BaseModel):
|
||||||
projectId: str
|
projectId: str
|
||||||
runId: str
|
runId: str
|
||||||
scenarioId: str
|
simulationId: str
|
||||||
result: Literal["pass", "fail"]
|
result: Literal["pass", "fail"]
|
||||||
details: str
|
details: str
|
||||||
transcript: str
|
|
||||||
|
|
||||||
class SimulationAggregateResult(BaseModel):
|
|
||||||
total: int
|
|
||||||
pass_: int = Field(..., alias='pass')
|
|
||||||
fail: int
|
|
||||||
|
|
@ -1,84 +1,104 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key, mark_stale_jobs_as_failed, update_simulation_run_heartbeat
|
|
||||||
from scenario_types import SimulationRun, Scenario
|
# Updated imports from your new db module and scenario_types
|
||||||
from simulation import simulate_scenarios
|
from db import (
|
||||||
|
get_pending_run,
|
||||||
|
get_simulations_for_run,
|
||||||
|
set_run_to_completed,
|
||||||
|
get_api_key,
|
||||||
|
mark_stale_jobs_as_failed,
|
||||||
|
update_run_heartbeat
|
||||||
|
)
|
||||||
|
from scenario_types import TestRun, TestSimulation
|
||||||
|
# If you have a new simulation function, import it here.
|
||||||
|
# Otherwise, adapt the name as needed:
|
||||||
|
from simulation import simulate_simulations # or simulate_scenarios, if unchanged
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
class JobService:
|
class JobService:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.poll_interval = 5 # seconds
|
self.poll_interval = 5 # seconds
|
||||||
|
# Control concurrency of run processing
|
||||||
self.semaphore = asyncio.Semaphore(5)
|
self.semaphore = asyncio.Semaphore(5)
|
||||||
|
|
||||||
async def poll_and_process_jobs(self, max_iterations: int = None):
|
async def poll_and_process_jobs(self, max_iterations: Optional[int] = None):
|
||||||
"""
|
"""
|
||||||
Periodically checks for new jobs in MongoDB and processes them.
|
Periodically checks for new runs in MongoDB and processes them.
|
||||||
"""
|
"""
|
||||||
|
# Start the stale-run check in the background
|
||||||
# Start the stale-job check in the background
|
asyncio.create_task(self.fail_stale_runs_loop())
|
||||||
asyncio.create_task(self.fail_stale_jobs_loop())
|
|
||||||
|
|
||||||
iterations = 0
|
iterations = 0
|
||||||
while True:
|
while True:
|
||||||
job = get_pending_simulation_run()
|
run = get_pending_run() # <--- changed to match new DB function
|
||||||
if job:
|
if run:
|
||||||
logging.info(f"Found new job: {job}. Processing...")
|
logging.info(f"Found new run: {run}. Processing...")
|
||||||
asyncio.create_task(self.process_job(job))
|
asyncio.create_task(self.process_run(run))
|
||||||
|
|
||||||
iterations += 1
|
iterations += 1
|
||||||
if max_iterations is not None and iterations >= max_iterations:
|
if max_iterations is not None and iterations >= max_iterations:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Sleep for the polling interval
|
# Sleep for the polling interval
|
||||||
await asyncio.sleep(self.poll_interval)
|
await asyncio.sleep(self.poll_interval)
|
||||||
|
|
||||||
async def process_job(self, job: SimulationRun):
|
async def process_run(self, run: TestRun):
|
||||||
"""
|
"""
|
||||||
Calls the simulation function and updates job status upon completion.
|
Calls the simulation function and updates run status upon completion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
# Start heartbeat in background
|
# Start heartbeat in background
|
||||||
stop_heartbeat_event = asyncio.Event()
|
stop_heartbeat_event = asyncio.Event()
|
||||||
heartbeat_task = asyncio.create_task(self.heartbeat_loop(job.id, stop_heartbeat_event))
|
heartbeat_task = asyncio.create_task(self.heartbeat_loop(run.id, stop_heartbeat_event))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
scenarios = get_scenarios_for_run(job)
|
# Fetch the simulations associated with this run
|
||||||
if not scenarios or len(scenarios) == 0:
|
simulations = get_simulations_for_run(run)
|
||||||
logging.info(f"No scenarios found for job {job.id}")
|
if not simulations:
|
||||||
|
logging.info(f"No simulations found for run {run.id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
api_key = get_api_key(job.projectId)
|
# Fetch API key if needed
|
||||||
result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key)
|
api_key = get_api_key(run.projectId)
|
||||||
|
|
||||||
|
# Perform your simulation logic
|
||||||
|
# adapt this call to your actual simulation function’s signature
|
||||||
|
aggregate_result = await simulate_simulations(
|
||||||
|
simulations=simulations,
|
||||||
|
run_id=run.id,
|
||||||
|
workflow_id=run.workflowId,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
|
||||||
set_simulation_run_to_completed(job, result)
|
# Mark run as completed with the aggregated result
|
||||||
logging.info(f"Job {job.id} completed.")
|
set_run_to_completed(run, aggregate_result)
|
||||||
|
logging.info(f"Run {run.id} completed.")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logging.error(f"Job {job.id} failed: {exc}")
|
logging.error(f"Run {run.id} failed: {exc}")
|
||||||
finally:
|
finally:
|
||||||
stop_heartbeat_event.set()
|
stop_heartbeat_event.set()
|
||||||
await heartbeat_task
|
await heartbeat_task
|
||||||
|
|
||||||
async def fail_stale_jobs_loop(self):
|
async def fail_stale_runs_loop(self):
|
||||||
"""
|
"""
|
||||||
Periodically checks for stale jobs that haven't received a heartbeat in over 5 minutes,
|
Periodically checks for stale runs (no heartbeat) and marks them as 'failed'.
|
||||||
and marks them as 'failed'.
|
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
count = mark_stale_jobs_as_failed()
|
count = mark_stale_jobs_as_failed()
|
||||||
if count > 0:
|
if count > 0:
|
||||||
logging.warning(f"Marked {count} stale jobs as failed.")
|
logging.warning(f"Marked {count} stale runs as failed.")
|
||||||
await asyncio.sleep(60) # Check every 60 seconds
|
await asyncio.sleep(60) # Check every 60 seconds
|
||||||
|
|
||||||
async def heartbeat_loop(self, job_id: str, stop_event: asyncio.Event):
|
async def heartbeat_loop(self, run_id: str, stop_event: asyncio.Event):
|
||||||
"""
|
"""
|
||||||
Periodically updates 'last_heartbeat' for the given job until 'stop_event' is set.
|
Periodically updates 'lastHeartbeat' for the given run until 'stop_event' is set.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
update_simulation_run_heartbeat(job_id)
|
update_run_heartbeat(run_id)
|
||||||
await asyncio.sleep(10) # Heartbeat interval in seconds
|
await asyncio.sleep(10) # Heartbeat interval in seconds
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,46 +1,67 @@
|
||||||
from rowboat import Client, StatefulChat
|
import asyncio
|
||||||
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import asyncio
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from scenario_types import Scenario, SimulationResult, SimulationAggregateResult
|
|
||||||
from db import write_simulation_result
|
# Updated imports from your new schema/types
|
||||||
|
from scenario_types import TestSimulation, TestResult, AggregateResults
|
||||||
|
|
||||||
|
# If your DB functions changed names, adapt here:
|
||||||
|
from db import write_test_result # replaced write_simulation_result
|
||||||
|
|
||||||
|
from rowboat import Client, StatefulChat
|
||||||
|
|
||||||
openai_client = OpenAI()
|
openai_client = OpenAI()
|
||||||
MODEL_NAME = "gpt-4o"
|
MODEL_NAME = "gpt-4o"
|
||||||
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
|
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
|
||||||
|
|
||||||
async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> tuple[str, str, str]:
|
async def simulate_simulation(
|
||||||
|
simulation: TestSimulation,
|
||||||
|
rowboat_client: Client,
|
||||||
|
workflow_id: str,
|
||||||
|
max_iterations: int = 5
|
||||||
|
) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
Runs a mock simulation for a given scenario asynchronously.
|
Runs a mock simulation for a given TestSimulation asynchronously.
|
||||||
After simulating several turns of conversation, it evaluates the conversation.
|
After simulating several turns of conversation, it evaluates the conversation.
|
||||||
Returns a tuple of (evaluation_result, details, transcript_str).
|
Returns a tuple of (evaluation_result, details, transcript_str).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
# Optionally embed passCriteria in the system prompt, if it’s relevant to context:
|
||||||
|
pass_criteria = simulation.passCriteria or ""
|
||||||
|
# Or place it separately below if you prefer.
|
||||||
|
|
||||||
|
# Prepare a Rowboat chat
|
||||||
support_chat = StatefulChat(
|
support_chat = StatefulChat(
|
||||||
rowboat_client,
|
rowboat_client,
|
||||||
system_prompt=f"{f'Context: {scenario.context}' if scenario.context else ''}",
|
system_prompt=f"Context: {pass_criteria}" if pass_criteria else "",
|
||||||
workflow_id=workflow_id
|
workflow_id=workflow_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# You might want to describe the simulation or scenario more thoroughly.
|
||||||
|
# Here, we just embed simulation.name in the system message:
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": f"Simulate the user based on the scenario: \n {scenario.description}"
|
"content": (
|
||||||
|
f"Simulate the user based on this simulation:\n{simulation.name}"
|
||||||
|
)
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
# -------------------------
|
# -------------------------
|
||||||
# 1) MAIN SIMULATION LOOP
|
# (1) MAIN SIMULATION LOOP
|
||||||
# -------------------------
|
# -------------------------
|
||||||
for i in range(max_iterations):
|
for _ in range(max_iterations):
|
||||||
openai_input = messages
|
openai_input = messages
|
||||||
|
|
||||||
# Run OpenAI API call in a separate thread
|
# Run OpenAI API call in a separate thread (non-blocking)
|
||||||
simulated_user_response = await loop.run_in_executor(
|
simulated_user_response = await loop.run_in_executor(
|
||||||
None, # Use default thread pool
|
None, # default ThreadPool
|
||||||
lambda: openai_client.chat.completions.create(
|
lambda: openai_client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=openai_input,
|
messages=openai_input,
|
||||||
|
|
@ -48,9 +69,9 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
simulated_content = simulated_user_response.choices[0].message.content
|
simulated_content = simulated_user_response.choices[0].message.content.strip()
|
||||||
|
|
||||||
# Run support_chat.run in a thread if it's synchronous
|
# Run Rowboat chat in a thread if it's synchronous
|
||||||
rowboat_response = await loop.run_in_executor(
|
rowboat_response = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: support_chat.run(simulated_content)
|
lambda: support_chat.run(simulated_content)
|
||||||
|
|
@ -59,7 +80,7 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow
|
||||||
messages.append({"role": "assistant", "content": rowboat_response})
|
messages.append({"role": "assistant", "content": rowboat_response})
|
||||||
|
|
||||||
# -------------------------
|
# -------------------------
|
||||||
# 2) EVALUATION STEP
|
# (2) EVALUATION STEP
|
||||||
# -------------------------
|
# -------------------------
|
||||||
transcript_str = ""
|
transcript_str = ""
|
||||||
for m in messages:
|
for m in messages:
|
||||||
|
|
@ -67,19 +88,24 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow
|
||||||
content = m.get("content", "")
|
content = m.get("content", "")
|
||||||
transcript_str += f"{role.upper()}: {content}\n"
|
transcript_str += f"{role.upper()}: {content}\n"
|
||||||
|
|
||||||
|
# We use passCriteria as the evaluation “criteria.”
|
||||||
evaluation_prompt = [
|
evaluation_prompt = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": (
|
"content": (
|
||||||
f"You are a neutral evaluator. Evaluate based on these criteria:\n{scenario.criteria}\n\nReturn ONLY a JSON object with format: "
|
f"You are a neutral evaluator. Evaluate based on these criteria:\n"
|
||||||
'{"verdict": "pass", "details": <the reason for pass in 2 sentences>} if the support bot answered correctly, or {"verdict": "fail", "details": <the reason for fail in 2 sentences>} if not.'
|
f"{simulation.passCriteria}\n\n"
|
||||||
|
"Return ONLY a JSON object in this format:\n"
|
||||||
|
'{"verdict": "pass", "details": <reason>} or '
|
||||||
|
'{"verdict": "fail", "details": <reason>}.'
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": (
|
"content": (
|
||||||
f"Here is the conversation transcript:\n\n{transcript_str}\n\n"
|
f"Here is the conversation transcript:\n\n{transcript_str}\n\n"
|
||||||
"Did the support bot answer correctly or not? Return only 'pass' or 'fail' for verdict, and a brief 2 sentence explanation for details."
|
"Did the support bot answer correctly or not? "
|
||||||
|
"Return only 'pass' or 'fail' for verdict, and a brief explanation for details."
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
@ -91,51 +117,82 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=evaluation_prompt,
|
messages=evaluation_prompt,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
|
# If your LLM supports a structured response format, you can specify it.
|
||||||
|
# Otherwise, remove or adapt 'response_format':
|
||||||
response_format={"type": "json_object"}
|
response_format={"type": "json_object"}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not eval_response.choices:
|
if not eval_response.choices:
|
||||||
raise Exception("No evaluation response received from model")
|
raise Exception("No evaluation response received from model")
|
||||||
else:
|
|
||||||
response_json = json.loads(eval_response.choices[0].message.content)
|
response_json_str = eval_response.choices[0].message.content
|
||||||
evaluation_result = response_json.get("verdict")
|
# Attempt to parse the JSON
|
||||||
details = response_json.get("details")
|
response_json = json.loads(response_json_str)
|
||||||
if evaluation_result is None:
|
evaluation_result = response_json.get("verdict")
|
||||||
raise Exception("No verdict field found in evaluation response")
|
details = response_json.get("details")
|
||||||
|
|
||||||
|
if evaluation_result is None:
|
||||||
|
raise Exception("No 'verdict' field found in evaluation response")
|
||||||
|
|
||||||
return (evaluation_result, details, transcript_str)
|
return (evaluation_result, details, transcript_str)
|
||||||
|
|
||||||
async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5):
|
async def simulate_simulations(
|
||||||
|
simulations: List[TestSimulation],
|
||||||
|
run_id: str,
|
||||||
|
workflow_id: str,
|
||||||
|
api_key: str,
|
||||||
|
max_iterations: int = 5
|
||||||
|
) -> AggregateResults:
|
||||||
"""
|
"""
|
||||||
Simulates a list of scenarios asynchronously and aggregates the results.
|
Simulates a list of TestSimulations asynchronously and aggregates the results.
|
||||||
"""
|
"""
|
||||||
project_id = scenarios[0].projectId
|
if not simulations:
|
||||||
|
# Return an empty result if there's nothing to simulate
|
||||||
|
return AggregateResults(total=0, pass_=0, fail=0)
|
||||||
|
|
||||||
|
# We assume all simulations belong to the same project
|
||||||
|
project_id = simulations[0].projectId
|
||||||
|
|
||||||
|
# Create a Rowboat client instance
|
||||||
client = Client(
|
client = Client(
|
||||||
host=ROWBOAT_API_HOST,
|
host=ROWBOAT_API_HOST,
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
api_key=api_key
|
api_key=api_key
|
||||||
)
|
)
|
||||||
results = []
|
|
||||||
|
|
||||||
for scenario in scenarios:
|
# Store results here
|
||||||
# Await the asynchronous simulate_scenario
|
results: List[TestResult] = []
|
||||||
result, details, transcript = await simulate_scenario(scenario, client, workflow_id, max_iterations)
|
|
||||||
|
|
||||||
simulation_result = SimulationResult(
|
for simulation in simulations:
|
||||||
projectId=project_id,
|
# Run each simulation
|
||||||
runId=runId,
|
verdict, details, transcript = await simulate_simulation(
|
||||||
scenarioId=scenario.id,
|
simulation=simulation,
|
||||||
result=result,
|
rowboat_client=client,
|
||||||
details=details,
|
workflow_id=workflow_id,
|
||||||
transcript=transcript
|
max_iterations=max_iterations
|
||||||
)
|
)
|
||||||
results.append(simulation_result)
|
|
||||||
write_simulation_result(simulation_result)
|
|
||||||
|
|
||||||
aggregate_result = SimulationAggregateResult(**{
|
# Create a new TestResult
|
||||||
"total": len(scenarios),
|
test_result = TestResult(
|
||||||
"pass": sum(1 for result in results if result.result == "pass"),
|
projectId=project_id,
|
||||||
"fail": sum(1 for result in results if result.result == "fail")
|
runId=run_id,
|
||||||
})
|
simulationId=simulation.id,
|
||||||
return aggregate_result
|
result=verdict,
|
||||||
|
details=details
|
||||||
|
)
|
||||||
|
results.append(test_result)
|
||||||
|
|
||||||
|
# Persist the test result
|
||||||
|
write_test_result(test_result)
|
||||||
|
|
||||||
|
# Aggregate pass/fail
|
||||||
|
total_count = len(results)
|
||||||
|
pass_count = sum(1 for r in results if r.result == "pass")
|
||||||
|
fail_count = sum(1 for r in results if r.result == "fail")
|
||||||
|
|
||||||
|
return AggregateResults(
|
||||||
|
total=total_count,
|
||||||
|
passCount=pass_count,
|
||||||
|
failCount=fail_count
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue