updated simulation runner to the new collections

This commit is contained in:
arkml 2025-03-01 12:56:08 +05:30
parent 33f30670f6
commit 2b1ef82c20
7 changed files with 311 additions and 159 deletions

View file

@ -436,8 +436,8 @@ export async function createRun(
startedAt: new Date().toISOString(), startedAt: new Date().toISOString(),
aggregateResults: { aggregateResults: {
total: 0, total: 0,
pass: 0, passCount: 0,
fail: 0, failCount: 0,
}, },
}; };
const insertResult = await testRunsCollection.insertOne(doc); const insertResult = await testRunsCollection.insertOne(doc);
@ -455,8 +455,8 @@ export async function updateRun(
completedAt?: string; completedAt?: string;
aggregateResults?: { aggregateResults?: {
total: number; total: number;
pass: number; passCount: number;
fail: number; failCount: number;
}; };
} }
): Promise<void> { ): Promise<void> {

View file

@ -38,8 +38,8 @@ export const TestRun = z.object({
completedAt: z.string().optional(), completedAt: z.string().optional(),
aggregateResults: z.object({ aggregateResults: z.object({
total: z.number(), total: z.number(),
pass: z.number(), passCount: z.number(),
fail: z.number(), failCount: z.number(),
}).optional(), }).optional(),
}); });

View file

@ -227,11 +227,11 @@ function ViewRun({
</div> </div>
<div className="p-4 rounded-lg bg-green-50 dark:bg-green-900/20"> <div className="p-4 rounded-lg bg-green-50 dark:bg-green-900/20">
<div className="text-sm text-green-600 dark:text-green-400">Passed</div> <div className="text-sm text-green-600 dark:text-green-400">Passed</div>
<div className="text-2xl font-semibold text-green-700 dark:text-green-400">{run.aggregateResults?.pass || 0}</div> <div className="text-2xl font-semibold text-green-700 dark:text-green-400">{run.aggregateResults?.passCount || 0}</div>
</div> </div>
<div className="p-4 rounded-lg bg-red-50 dark:bg-red-900/20"> <div className="p-4 rounded-lg bg-red-50 dark:bg-red-900/20">
<div className="text-sm text-red-600 dark:text-red-400">Failed</div> <div className="text-sm text-red-600 dark:text-red-400">Failed</div>
<div className="text-2xl font-semibold text-red-700 dark:text-red-400">{run.aggregateResults?.fail || 0}</div> <div className="text-2xl font-semibold text-red-700 dark:text-red-400">{run.aggregateResults?.failCount || 0}</div>
</div> </div>
</div> </div>
@ -389,10 +389,10 @@ function RunList({
Total: {run.aggregateResults.total} Total: {run.aggregateResults.total}
</div> </div>
<div className="text-green-600 dark:text-green-400"> <div className="text-green-600 dark:text-green-400">
Passed: {run.aggregateResults.pass} Passed: {run.aggregateResults.passCount}
</div> </div>
<div className="text-red-600 dark:text-red-400"> <div className="text-red-600 dark:text-red-400">
Failed: {run.aggregateResults.fail} Failed: {run.aggregateResults.failCount}
</div> </div>
</div> </div>
</div> </div>

View file

@ -2,15 +2,25 @@ from pymongo import MongoClient
from bson import ObjectId from bson import ObjectId
import os import os
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from scenario_types import SimulationRun, Scenario, SimulationResult, SimulationAggregateResult from typing import Optional
from scenario_types import (
TestRun,
TestScenario,
TestProfile,
TestSimulation,
TestResult,
AggregateResults
)
MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip() MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip()
SCENARIOS_COLLECTION_NAME = "scenarios" # New collection names
API_KEYS_COLLECTION = "api_keys" TEST_SCENARIOS_COLLECTION = "test_scenarios"
SIMULATIONS_COLLECTION_NAME = "simulation_runs" TEST_PROFILES_COLLECTION = "test_profiles"
SIMULATION_RESULT_COLLECTION_NAME = "simulation_result" TEST_SIMULATIONS_COLLECTION = "test_simulations"
SIMULATION_AGGREGATE_RESULT_COLLECTION_NAME = "simulation_aggregate_result" TEST_RUNS_COLLECTION = "test_runs"
TEST_RESULTS_COLLECTION = "test_results"
API_KEYS_COLLECTION = "api_keys" # If still needed
def get_db(): def get_db():
client = MongoClient(MONGO_URI) client = MongoClient(MONGO_URI)
@ -21,6 +31,9 @@ def get_collection(collection_name: str):
return db[collection_name] return db[collection_name]
def get_api_key(project_id: str): def get_api_key(project_id: str):
"""
If you still use an API key pattern, adapt as needed.
"""
collection = get_collection(API_KEYS_COLLECTION) collection = get_collection(API_KEYS_COLLECTION)
doc = collection.find_one({"projectId": project_id}) doc = collection.find_one({"projectId": project_id})
if doc: if doc:
@ -28,71 +41,68 @@ def get_api_key(project_id: str):
else: else:
return None return None
def get_pending_simulation_run(): #
collection = get_collection(SIMULATIONS_COLLECTION_NAME) # TestRun helpers
#
def get_pending_run() -> Optional[TestRun]:
"""
Finds a run with 'pending' status, marks it 'running', and returns it.
"""
collection = get_collection(TEST_RUNS_COLLECTION)
doc = collection.find_one_and_update( doc = collection.find_one_and_update(
{"status": "pending"}, {"status": "pending"},
{"$set": {"status": "running"}}, {"$set": {"status": "running"}},
return_document=True return_document=True
) )
if doc: if doc:
return SimulationRun( return TestRun(
id=str(doc["_id"]),
projectId=doc["projectId"],
status="running",
scenarioIds=doc["scenarioIds"],
workflowId=doc["workflowId"],
startedAt=doc["startedAt"],
completedAt=doc.get("completedAt")
)
return None
def set_simulation_run_to_completed(simulation_run: SimulationRun, aggregate_result: SimulationAggregateResult):
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
collection.update_one({"_id": ObjectId(simulation_run.id)}, {"$set": {"status": "completed", "aggregateResults": aggregate_result.model_dump(by_alias=True)}})
def get_scenarios_for_run(simulation_run: SimulationRun):
if simulation_run is None:
return []
collection = get_collection(SCENARIOS_COLLECTION_NAME)
scenarios = []
for doc in collection.find():
if doc["_id"] in [ObjectId(sid) for sid in simulation_run.scenarioIds]:
scenarios.append(Scenario(
id=str(doc["_id"]), id=str(doc["_id"]),
projectId=doc["projectId"], projectId=doc["projectId"],
name=doc["name"], name=doc["name"],
description=doc["description"], simulationIds=doc["simulationIds"],
criteria=doc["criteria"], workflowId=doc["workflowId"],
context=doc["context"], status="running",
createdAt=doc["createdAt"], startedAt=doc["startedAt"],
lastUpdatedAt=doc["lastUpdatedAt"] completedAt=doc.get("completedAt"),
)) aggregateResults=doc.get("aggregateResults"),
return scenarios lastHeartbeat=doc.get("lastHeartbeat")
)
return None
def write_simulation_result(result: SimulationResult): def set_run_to_completed(test_run: TestRun, aggregate: AggregateResults):
collection = get_collection(SIMULATION_RESULT_COLLECTION_NAME)
collection.insert_one(result.model_dump())
def update_simulation_run_heartbeat(simulation_run_id: str):
""" """
Updates the 'last_heartbeat' timestamp for a SimulationRun. Marks a test run 'completed' and sets the aggregate results.
""" """
collection = get_collection(TEST_RUNS_COLLECTION)
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
collection.update_one( collection.update_one(
{"_id": ObjectId(simulation_run_id)}, {"_id": ObjectId(test_run.id)},
{
"$set": {
"status": "completed",
"aggregateResults": aggregate.model_dump(by_alias=True),
"completedAt": datetime.now(timezone.utc)
}
}
)
def update_run_heartbeat(run_id: str):
"""
Updates the 'lastHeartbeat' timestamp for a TestRun.
"""
collection = get_collection(TEST_RUNS_COLLECTION)
collection.update_one(
{"_id": ObjectId(run_id)},
{"$set": {"lastHeartbeat": datetime.now(timezone.utc)}} {"$set": {"lastHeartbeat": datetime.now(timezone.utc)}}
) )
def mark_stale_jobs_as_failed(): def mark_stale_jobs_as_failed(threshold_minutes: int = 20) -> int:
""" """
Finds any job in 'running' status whose last_heartbeat is older than 5 minutes, Finds any run in 'running' status whose lastHeartbeat is older than
and sets it to 'failed'. `threshold_minutes`, and sets it to 'failed'. Returns the count.
""" """
collection = get_collection(TEST_RUNS_COLLECTION)
collection = get_collection(SIMULATIONS_COLLECTION_NAME) stale_threshold = datetime.now(timezone.utc) - timedelta(minutes=threshold_minutes)
stale_threshold = datetime.now(timezone.utc) - timedelta(minutes=20)
result = collection.update_many( result = collection.update_many(
{ {
"status": "running", "status": "running",
@ -102,4 +112,46 @@ def mark_stale_jobs_as_failed():
"$set": {"status": "failed"} "$set": {"status": "failed"}
} }
) )
return result.modified_count # Number of jobs marked failed return result.modified_count
#
# TestSimulation helpers
#
def get_simulations_for_run(test_run: TestRun) -> list[TestSimulation]:
"""
Returns all simulations specified by a particular run.
"""
if test_run is None:
return []
collection = get_collection(TEST_SIMULATIONS_COLLECTION)
simulation_docs = collection.find({
"_id": {"$in": [ObjectId(sim_id) for sim_id in test_run.simulationIds]}
})
simulations = []
for doc in simulation_docs:
simulations.append(
TestSimulation(
id=str(doc["_id"]),
projectId=doc["projectId"],
name=doc["name"],
scenarioId=doc["scenarioId"],
profileId=doc["profileId"],
passCriteria=doc["passCriteria"],
createdAt=doc["createdAt"],
lastUpdatedAt=doc["lastUpdatedAt"]
)
)
return simulations
#
# TestResult helpers
#
def write_test_result(result: TestResult):
"""
Writes a test result into the `test_results` collection.
"""
collection = get_collection(TEST_RESULTS_COLLECTION)
collection.insert_one(result.model_dump())

View file

@ -2,39 +2,62 @@ from datetime import datetime
from typing import Optional, List, Literal from typing import Optional, List, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
run_status = Literal["pending", "running", "completed", "cancelled", "failed"] # Define run statuses to include the new "error" status
RunStatus = Literal["pending", "running", "completed", "cancelled", "failed", "error"]
class Scenario(BaseModel): class TestScenario(BaseModel):
# `_id` in Mongo will be stored as ObjectId; we return it as a string
id: str id: str
projectId: str projectId: str
name: str = "" name: str
description: str = "" description: str
criteria: str = ""
context: str = ""
createdAt: datetime createdAt: datetime
lastUpdatedAt: datetime lastUpdatedAt: datetime
class SimulationRun(BaseModel): class TestProfile(BaseModel):
id: str id: str
projectId: str projectId: str
status: Literal["pending", "running", "completed", "cancelled", "failed"] name: str
scenarioIds: List[str] context: str
createdAt: datetime
lastUpdatedAt: datetime
mockTools: bool
mockPrompt: Optional[str] = None
class TestSimulation(BaseModel):
id: str
projectId: str
name: str
scenarioId: str
profileId: str
passCriteria: str
createdAt: datetime
lastUpdatedAt: datetime
class AggregateResults(BaseModel):
total: int
passCount: int
failCount: int
class TestRun(BaseModel):
id: str
projectId: str
name: str
simulationIds: List[str]
workflowId: str workflowId: str
status: RunStatus
startedAt: datetime startedAt: datetime
lastHeartbeat: Optional[datetime] = None
completedAt: Optional[datetime] = None completedAt: Optional[datetime] = None
aggregateResults: Optional[dict] = None # By default, store aggregate results as a dict or the typed AggregateResults
aggregateResults: Optional[AggregateResults] = None
# The new schema does not mention lastHeartbeat,
# but you can keep it if you still want to track stale runs
lastHeartbeat: Optional[datetime] = None
class SimulationResult(BaseModel): class TestResult(BaseModel):
projectId: str projectId: str
runId: str runId: str
scenarioId: str simulationId: str
result: Literal["pass", "fail"] result: Literal["pass", "fail"]
details: str details: str
transcript: str
class SimulationAggregateResult(BaseModel):
total: int
pass_: int = Field(..., alias='pass')
fail: int

View file

@ -1,84 +1,104 @@
import asyncio import asyncio
import logging import logging
from typing import List from typing import List, Optional
from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key, mark_stale_jobs_as_failed, update_simulation_run_heartbeat
from scenario_types import SimulationRun, Scenario # Updated imports from your new db module and scenario_types
from simulation import simulate_scenarios from db import (
get_pending_run,
get_simulations_for_run,
set_run_to_completed,
get_api_key,
mark_stale_jobs_as_failed,
update_run_heartbeat
)
from scenario_types import TestRun, TestSimulation
# If you have a new simulation function, import it here.
# Otherwise, adapt the name as needed:
from simulation import simulate_simulations # or simulate_scenarios, if unchanged
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
class JobService: class JobService:
def __init__(self): def __init__(self):
self.poll_interval = 5 # seconds self.poll_interval = 5 # seconds
# Control concurrency of run processing
self.semaphore = asyncio.Semaphore(5) self.semaphore = asyncio.Semaphore(5)
async def poll_and_process_jobs(self, max_iterations: int = None): async def poll_and_process_jobs(self, max_iterations: Optional[int] = None):
""" """
Periodically checks for new jobs in MongoDB and processes them. Periodically checks for new runs in MongoDB and processes them.
""" """
# Start the stale-run check in the background
# Start the stale-job check in the background asyncio.create_task(self.fail_stale_runs_loop())
asyncio.create_task(self.fail_stale_jobs_loop())
iterations = 0 iterations = 0
while True: while True:
job = get_pending_simulation_run() run = get_pending_run() # <--- changed to match new DB function
if job: if run:
logging.info(f"Found new job: {job}. Processing...") logging.info(f"Found new run: {run}. Processing...")
asyncio.create_task(self.process_job(job)) asyncio.create_task(self.process_run(run))
iterations += 1 iterations += 1
if max_iterations is not None and iterations >= max_iterations: if max_iterations is not None and iterations >= max_iterations:
break break
# Sleep for the polling interval # Sleep for the polling interval
await asyncio.sleep(self.poll_interval) await asyncio.sleep(self.poll_interval)
async def process_job(self, job: SimulationRun): async def process_run(self, run: TestRun):
""" """
Calls the simulation function and updates job status upon completion. Calls the simulation function and updates run status upon completion.
""" """
async with self.semaphore: async with self.semaphore:
# Start heartbeat in background # Start heartbeat in background
stop_heartbeat_event = asyncio.Event() stop_heartbeat_event = asyncio.Event()
heartbeat_task = asyncio.create_task(self.heartbeat_loop(job.id, stop_heartbeat_event)) heartbeat_task = asyncio.create_task(self.heartbeat_loop(run.id, stop_heartbeat_event))
try: try:
scenarios = get_scenarios_for_run(job) # Fetch the simulations associated with this run
if not scenarios or len(scenarios) == 0: simulations = get_simulations_for_run(run)
logging.info(f"No scenarios found for job {job.id}") if not simulations:
logging.info(f"No simulations found for run {run.id}")
return return
api_key = get_api_key(job.projectId) # Fetch API key if needed
result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key) api_key = get_api_key(run.projectId)
# Perform your simulation logic
# adapt this call to your actual simulation functions signature
aggregate_result = await simulate_simulations(
simulations=simulations,
run_id=run.id,
workflow_id=run.workflowId,
api_key=api_key
)
set_simulation_run_to_completed(job, result) # Mark run as completed with the aggregated result
logging.info(f"Job {job.id} completed.") set_run_to_completed(run, aggregate_result)
logging.info(f"Run {run.id} completed.")
except Exception as exc: except Exception as exc:
logging.error(f"Job {job.id} failed: {exc}") logging.error(f"Run {run.id} failed: {exc}")
finally: finally:
stop_heartbeat_event.set() stop_heartbeat_event.set()
await heartbeat_task await heartbeat_task
async def fail_stale_jobs_loop(self): async def fail_stale_runs_loop(self):
""" """
Periodically checks for stale jobs that haven't received a heartbeat in over 5 minutes, Periodically checks for stale runs (no heartbeat) and marks them as 'failed'.
and marks them as 'failed'.
""" """
while True: while True:
count = mark_stale_jobs_as_failed() count = mark_stale_jobs_as_failed()
if count > 0: if count > 0:
logging.warning(f"Marked {count} stale jobs as failed.") logging.warning(f"Marked {count} stale runs as failed.")
await asyncio.sleep(60) # Check every 60 seconds await asyncio.sleep(60) # Check every 60 seconds
async def heartbeat_loop(self, job_id: str, stop_event: asyncio.Event): async def heartbeat_loop(self, run_id: str, stop_event: asyncio.Event):
""" """
Periodically updates 'last_heartbeat' for the given job until 'stop_event' is set. Periodically updates 'lastHeartbeat' for the given run until 'stop_event' is set.
""" """
try: try:
while not stop_event.is_set(): while not stop_event.is_set():
update_simulation_run_heartbeat(job_id) update_run_heartbeat(run_id)
await asyncio.sleep(10) # Heartbeat interval in seconds await asyncio.sleep(10) # Heartbeat interval in seconds
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass

View file

@ -1,46 +1,67 @@
from rowboat import Client, StatefulChat import asyncio
import logging
from typing import List from typing import List
import json import json
import os import os
import asyncio
from openai import OpenAI from openai import OpenAI
from scenario_types import Scenario, SimulationResult, SimulationAggregateResult
from db import write_simulation_result # Updated imports from your new schema/types
from scenario_types import TestSimulation, TestResult, AggregateResults
# If your DB functions changed names, adapt here:
from db import write_test_result # replaced write_simulation_result
from rowboat import Client, StatefulChat
openai_client = OpenAI() openai_client = OpenAI()
MODEL_NAME = "gpt-4o" MODEL_NAME = "gpt-4o"
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip() ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> tuple[str, str, str]: async def simulate_simulation(
simulation: TestSimulation,
rowboat_client: Client,
workflow_id: str,
max_iterations: int = 5
) -> tuple[str, str, str]:
""" """
Runs a mock simulation for a given scenario asynchronously. Runs a mock simulation for a given TestSimulation asynchronously.
After simulating several turns of conversation, it evaluates the conversation. After simulating several turns of conversation, it evaluates the conversation.
Returns a tuple of (evaluation_result, details, transcript_str). Returns a tuple of (evaluation_result, details, transcript_str).
""" """
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
# Optionally embed passCriteria in the system prompt, if its relevant to context:
pass_criteria = simulation.passCriteria or ""
# Or place it separately below if you prefer.
# Prepare a Rowboat chat
support_chat = StatefulChat( support_chat = StatefulChat(
rowboat_client, rowboat_client,
system_prompt=f"{f'Context: {scenario.context}' if scenario.context else ''}", system_prompt=f"Context: {pass_criteria}" if pass_criteria else "",
workflow_id=workflow_id workflow_id=workflow_id
) )
# You might want to describe the simulation or scenario more thoroughly.
# Here, we just embed simulation.name in the system message:
messages = [ messages = [
{ {
"role": "system", "role": "system",
"content": f"Simulate the user based on the scenario: \n {scenario.description}" "content": (
f"Simulate the user based on this simulation:\n{simulation.name}"
)
} }
] ]
# ------------------------- # -------------------------
# 1) MAIN SIMULATION LOOP # (1) MAIN SIMULATION LOOP
# ------------------------- # -------------------------
for i in range(max_iterations): for _ in range(max_iterations):
openai_input = messages openai_input = messages
# Run OpenAI API call in a separate thread # Run OpenAI API call in a separate thread (non-blocking)
simulated_user_response = await loop.run_in_executor( simulated_user_response = await loop.run_in_executor(
None, # Use default thread pool None, # default ThreadPool
lambda: openai_client.chat.completions.create( lambda: openai_client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
messages=openai_input, messages=openai_input,
@ -48,9 +69,9 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow
) )
) )
simulated_content = simulated_user_response.choices[0].message.content simulated_content = simulated_user_response.choices[0].message.content.strip()
# Run support_chat.run in a thread if it's synchronous # Run Rowboat chat in a thread if it's synchronous
rowboat_response = await loop.run_in_executor( rowboat_response = await loop.run_in_executor(
None, None,
lambda: support_chat.run(simulated_content) lambda: support_chat.run(simulated_content)
@ -59,7 +80,7 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow
messages.append({"role": "assistant", "content": rowboat_response}) messages.append({"role": "assistant", "content": rowboat_response})
# ------------------------- # -------------------------
# 2) EVALUATION STEP # (2) EVALUATION STEP
# ------------------------- # -------------------------
transcript_str = "" transcript_str = ""
for m in messages: for m in messages:
@ -67,19 +88,24 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow
content = m.get("content", "") content = m.get("content", "")
transcript_str += f"{role.upper()}: {content}\n" transcript_str += f"{role.upper()}: {content}\n"
# We use passCriteria as the evaluation “criteria.”
evaluation_prompt = [ evaluation_prompt = [
{ {
"role": "system", "role": "system",
"content": ( "content": (
f"You are a neutral evaluator. Evaluate based on these criteria:\n{scenario.criteria}\n\nReturn ONLY a JSON object with format: " f"You are a neutral evaluator. Evaluate based on these criteria:\n"
'{"verdict": "pass", "details": <the reason for pass in 2 sentences>} if the support bot answered correctly, or {"verdict": "fail", "details": <the reason for fail in 2 sentences>} if not.' f"{simulation.passCriteria}\n\n"
"Return ONLY a JSON object in this format:\n"
'{"verdict": "pass", "details": <reason>} or '
'{"verdict": "fail", "details": <reason>}.'
) )
}, },
{ {
"role": "user", "role": "user",
"content": ( "content": (
f"Here is the conversation transcript:\n\n{transcript_str}\n\n" f"Here is the conversation transcript:\n\n{transcript_str}\n\n"
"Did the support bot answer correctly or not? Return only 'pass' or 'fail' for verdict, and a brief 2 sentence explanation for details." "Did the support bot answer correctly or not? "
"Return only 'pass' or 'fail' for verdict, and a brief explanation for details."
) )
} }
] ]
@ -91,51 +117,82 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow
model=MODEL_NAME, model=MODEL_NAME,
messages=evaluation_prompt, messages=evaluation_prompt,
temperature=0.0, temperature=0.0,
# If your LLM supports a structured response format, you can specify it.
# Otherwise, remove or adapt 'response_format':
response_format={"type": "json_object"} response_format={"type": "json_object"}
) )
) )
if not eval_response.choices: if not eval_response.choices:
raise Exception("No evaluation response received from model") raise Exception("No evaluation response received from model")
else:
response_json = json.loads(eval_response.choices[0].message.content) response_json_str = eval_response.choices[0].message.content
# Attempt to parse the JSON
response_json = json.loads(response_json_str)
evaluation_result = response_json.get("verdict") evaluation_result = response_json.get("verdict")
details = response_json.get("details") details = response_json.get("details")
if evaluation_result is None: if evaluation_result is None:
raise Exception("No verdict field found in evaluation response") raise Exception("No 'verdict' field found in evaluation response")
return (evaluation_result, details, transcript_str) return (evaluation_result, details, transcript_str)
async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5): async def simulate_simulations(
simulations: List[TestSimulation],
run_id: str,
workflow_id: str,
api_key: str,
max_iterations: int = 5
) -> AggregateResults:
""" """
Simulates a list of scenarios asynchronously and aggregates the results. Simulates a list of TestSimulations asynchronously and aggregates the results.
""" """
project_id = scenarios[0].projectId if not simulations:
# Return an empty result if there's nothing to simulate
return AggregateResults(total=0, pass_=0, fail=0)
# We assume all simulations belong to the same project
project_id = simulations[0].projectId
# Create a Rowboat client instance
client = Client( client = Client(
host=ROWBOAT_API_HOST, host=ROWBOAT_API_HOST,
project_id=project_id, project_id=project_id,
api_key=api_key api_key=api_key
) )
results = []
for scenario in scenarios: # Store results here
# Await the asynchronous simulate_scenario results: List[TestResult] = []
result, details, transcript = await simulate_scenario(scenario, client, workflow_id, max_iterations)
simulation_result = SimulationResult( for simulation in simulations:
projectId=project_id, # Run each simulation
runId=runId, verdict, details, transcript = await simulate_simulation(
scenarioId=scenario.id, simulation=simulation,
result=result, rowboat_client=client,
details=details, workflow_id=workflow_id,
transcript=transcript max_iterations=max_iterations
) )
results.append(simulation_result)
write_simulation_result(simulation_result)
aggregate_result = SimulationAggregateResult(**{ # Create a new TestResult
"total": len(scenarios), test_result = TestResult(
"pass": sum(1 for result in results if result.result == "pass"), projectId=project_id,
"fail": sum(1 for result in results if result.result == "fail") runId=run_id,
}) simulationId=simulation.id,
return aggregate_result result=verdict,
details=details
)
results.append(test_result)
# Persist the test result
write_test_result(test_result)
# Aggregate pass/fail
total_count = len(results)
pass_count = sum(1 for r in results if r.result == "pass")
fail_count = sum(1 for r in results if r.result == "fail")
return AggregateResults(
total=total_count,
passCount=pass_count,
failCount=fail_count
)