diff --git a/apps/rowboat/app/actions/testing_actions.ts b/apps/rowboat/app/actions/testing_actions.ts index 45b525a1..88f01142 100644 --- a/apps/rowboat/app/actions/testing_actions.ts +++ b/apps/rowboat/app/actions/testing_actions.ts @@ -436,8 +436,8 @@ export async function createRun( startedAt: new Date().toISOString(), aggregateResults: { total: 0, - pass: 0, - fail: 0, + passCount: 0, + failCount: 0, }, }; const insertResult = await testRunsCollection.insertOne(doc); @@ -455,8 +455,8 @@ export async function updateRun( completedAt?: string; aggregateResults?: { total: number; - pass: number; - fail: number; + passCount: number; + failCount: number; }; } ): Promise { diff --git a/apps/rowboat/app/lib/types/testing_types.ts b/apps/rowboat/app/lib/types/testing_types.ts index 1d3036bd..4a1440bb 100644 --- a/apps/rowboat/app/lib/types/testing_types.ts +++ b/apps/rowboat/app/lib/types/testing_types.ts @@ -38,8 +38,8 @@ export const TestRun = z.object({ completedAt: z.string().optional(), aggregateResults: z.object({ total: z.number(), - pass: z.number(), - fail: z.number(), + passCount: z.number(), + failCount: z.number(), }).optional(), }); diff --git a/apps/rowboat/app/projects/[projectId]/test/[[...slug]]/runs_app.tsx b/apps/rowboat/app/projects/[projectId]/test/[[...slug]]/runs_app.tsx index 7110cbf9..104975ad 100644 --- a/apps/rowboat/app/projects/[projectId]/test/[[...slug]]/runs_app.tsx +++ b/apps/rowboat/app/projects/[projectId]/test/[[...slug]]/runs_app.tsx @@ -227,11 +227,11 @@ function ViewRun({
Passed
-
{run.aggregateResults?.pass || 0}
+
{run.aggregateResults?.passCount || 0}
Failed
-
{run.aggregateResults?.fail || 0}
+
{run.aggregateResults?.failCount || 0}
@@ -389,10 +389,10 @@ function RunList({ Total: {run.aggregateResults.total}
- Passed: {run.aggregateResults.pass} + Passed: {run.aggregateResults.passCount}
- Failed: {run.aggregateResults.fail} + Failed: {run.aggregateResults.failCount}
diff --git a/apps/simulation_runner/db.py b/apps/simulation_runner/db.py index 2dfd8f5e..c8a79937 100644 --- a/apps/simulation_runner/db.py +++ b/apps/simulation_runner/db.py @@ -2,15 +2,25 @@ from pymongo import MongoClient from bson import ObjectId import os from datetime import datetime, timedelta, timezone -from scenario_types import SimulationRun, Scenario, SimulationResult, SimulationAggregateResult +from typing import Optional +from scenario_types import ( + TestRun, + TestScenario, + TestProfile, + TestSimulation, + TestResult, + AggregateResults +) MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip() -SCENARIOS_COLLECTION_NAME = "scenarios" -API_KEYS_COLLECTION = "api_keys" -SIMULATIONS_COLLECTION_NAME = "simulation_runs" -SIMULATION_RESULT_COLLECTION_NAME = "simulation_result" -SIMULATION_AGGREGATE_RESULT_COLLECTION_NAME = "simulation_aggregate_result" +# New collection names +TEST_SCENARIOS_COLLECTION = "test_scenarios" +TEST_PROFILES_COLLECTION = "test_profiles" +TEST_SIMULATIONS_COLLECTION = "test_simulations" +TEST_RUNS_COLLECTION = "test_runs" +TEST_RESULTS_COLLECTION = "test_results" +API_KEYS_COLLECTION = "api_keys" # If still needed def get_db(): client = MongoClient(MONGO_URI) @@ -21,6 +31,9 @@ def get_collection(collection_name: str): return db[collection_name] def get_api_key(project_id: str): + """ + If you still use an API key pattern, adapt as needed. + """ collection = get_collection(API_KEYS_COLLECTION) doc = collection.find_one({"projectId": project_id}) if doc: @@ -28,71 +41,68 @@ def get_api_key(project_id: str): else: return None -def get_pending_simulation_run(): - collection = get_collection(SIMULATIONS_COLLECTION_NAME) +# +# TestRun helpers +# + +def get_pending_run() -> Optional[TestRun]: + """ + Finds a run with 'pending' status, marks it 'running', and returns it. + """ + collection = get_collection(TEST_RUNS_COLLECTION) doc = collection.find_one_and_update( {"status": "pending"}, {"$set": {"status": "running"}}, return_document=True ) if doc: - return SimulationRun( + return TestRun( id=str(doc["_id"]), projectId=doc["projectId"], - status="running", - scenarioIds=doc["scenarioIds"], + name=doc["name"], + simulationIds=doc["simulationIds"], workflowId=doc["workflowId"], + status="running", startedAt=doc["startedAt"], - completedAt=doc.get("completedAt") + completedAt=doc.get("completedAt"), + aggregateResults=doc.get("aggregateResults"), + lastHeartbeat=doc.get("lastHeartbeat") ) return None -def set_simulation_run_to_completed(simulation_run: SimulationRun, aggregate_result: SimulationAggregateResult): - collection = get_collection(SIMULATIONS_COLLECTION_NAME) - collection.update_one({"_id": ObjectId(simulation_run.id)}, {"$set": {"status": "completed", "aggregateResults": aggregate_result.model_dump(by_alias=True)}}) - -def get_scenarios_for_run(simulation_run: SimulationRun): - if simulation_run is None: - return [] - collection = get_collection(SCENARIOS_COLLECTION_NAME) - scenarios = [] - for doc in collection.find(): - if doc["_id"] in [ObjectId(sid) for sid in simulation_run.scenarioIds]: - scenarios.append(Scenario( - id=str(doc["_id"]), - projectId=doc["projectId"], - name=doc["name"], - description=doc["description"], - criteria=doc["criteria"], - context=doc["context"], - createdAt=doc["createdAt"], - lastUpdatedAt=doc["lastUpdatedAt"] - )) - return scenarios - -def write_simulation_result(result: SimulationResult): - collection = get_collection(SIMULATION_RESULT_COLLECTION_NAME) - collection.insert_one(result.model_dump()) - -def update_simulation_run_heartbeat(simulation_run_id: str): +def set_run_to_completed(test_run: TestRun, aggregate: AggregateResults): """ - Updates the 'last_heartbeat' timestamp for a SimulationRun. + Marks a test run 'completed' and sets the aggregate results. """ - - collection = get_collection(SIMULATIONS_COLLECTION_NAME) + collection = get_collection(TEST_RUNS_COLLECTION) collection.update_one( - {"_id": ObjectId(simulation_run_id)}, + {"_id": ObjectId(test_run.id)}, + { + "$set": { + "status": "completed", + "aggregateResults": aggregate.model_dump(by_alias=True), + "completedAt": datetime.now(timezone.utc) + } + } + ) + +def update_run_heartbeat(run_id: str): + """ + Updates the 'lastHeartbeat' timestamp for a TestRun. + """ + collection = get_collection(TEST_RUNS_COLLECTION) + collection.update_one( + {"_id": ObjectId(run_id)}, {"$set": {"lastHeartbeat": datetime.now(timezone.utc)}} ) -def mark_stale_jobs_as_failed(): +def mark_stale_jobs_as_failed(threshold_minutes: int = 20) -> int: """ - Finds any job in 'running' status whose last_heartbeat is older than 5 minutes, - and sets it to 'failed'. + Finds any run in 'running' status whose lastHeartbeat is older than + `threshold_minutes`, and sets it to 'failed'. Returns the count. """ - - collection = get_collection(SIMULATIONS_COLLECTION_NAME) - stale_threshold = datetime.now(timezone.utc) - timedelta(minutes=20) + collection = get_collection(TEST_RUNS_COLLECTION) + stale_threshold = datetime.now(timezone.utc) - timedelta(minutes=threshold_minutes) result = collection.update_many( { "status": "running", @@ -102,4 +112,46 @@ def mark_stale_jobs_as_failed(): "$set": {"status": "failed"} } ) - return result.modified_count # Number of jobs marked failed \ No newline at end of file + return result.modified_count + +# +# TestSimulation helpers +# + +def get_simulations_for_run(test_run: TestRun) -> list[TestSimulation]: + """ + Returns all simulations specified by a particular run. + """ + if test_run is None: + return [] + collection = get_collection(TEST_SIMULATIONS_COLLECTION) + simulation_docs = collection.find({ + "_id": {"$in": [ObjectId(sim_id) for sim_id in test_run.simulationIds]} + }) + + simulations = [] + for doc in simulation_docs: + simulations.append( + TestSimulation( + id=str(doc["_id"]), + projectId=doc["projectId"], + name=doc["name"], + scenarioId=doc["scenarioId"], + profileId=doc["profileId"], + passCriteria=doc["passCriteria"], + createdAt=doc["createdAt"], + lastUpdatedAt=doc["lastUpdatedAt"] + ) + ) + return simulations + +# +# TestResult helpers +# + +def write_test_result(result: TestResult): + """ + Writes a test result into the `test_results` collection. + """ + collection = get_collection(TEST_RESULTS_COLLECTION) + collection.insert_one(result.model_dump()) diff --git a/apps/simulation_runner/scenario_types.py b/apps/simulation_runner/scenario_types.py index fa8dc14f..981342eb 100644 --- a/apps/simulation_runner/scenario_types.py +++ b/apps/simulation_runner/scenario_types.py @@ -2,39 +2,62 @@ from datetime import datetime from typing import Optional, List, Literal from pydantic import BaseModel, Field -run_status = Literal["pending", "running", "completed", "cancelled", "failed"] +# Define run statuses to include the new "error" status +RunStatus = Literal["pending", "running", "completed", "cancelled", "failed", "error"] -class Scenario(BaseModel): +class TestScenario(BaseModel): + # `_id` in Mongo will be stored as ObjectId; we return it as a string id: str projectId: str - name: str = "" - description: str = "" - criteria: str = "" - context: str = "" + name: str + description: str createdAt: datetime lastUpdatedAt: datetime -class SimulationRun(BaseModel): +class TestProfile(BaseModel): id: str projectId: str - status: Literal["pending", "running", "completed", "cancelled", "failed"] - scenarioIds: List[str] + name: str + context: str + createdAt: datetime + lastUpdatedAt: datetime + mockTools: bool + mockPrompt: Optional[str] = None + +class TestSimulation(BaseModel): + id: str + projectId: str + name: str + scenarioId: str + profileId: str + passCriteria: str + createdAt: datetime + lastUpdatedAt: datetime + +class AggregateResults(BaseModel): + total: int + passCount: int + failCount: int + +class TestRun(BaseModel): + id: str + projectId: str + name: str + simulationIds: List[str] workflowId: str + status: RunStatus startedAt: datetime - lastHeartbeat: Optional[datetime] = None completedAt: Optional[datetime] = None - aggregateResults: Optional[dict] = None + # By default, store aggregate results as a dict or the typed AggregateResults + aggregateResults: Optional[AggregateResults] = None + # The new schema does not mention lastHeartbeat, + # but you can keep it if you still want to track stale runs + lastHeartbeat: Optional[datetime] = None -class SimulationResult(BaseModel): +class TestResult(BaseModel): projectId: str runId: str - scenarioId: str + simulationId: str result: Literal["pass", "fail"] details: str - transcript: str - -class SimulationAggregateResult(BaseModel): - total: int - pass_: int = Field(..., alias='pass') - fail: int \ No newline at end of file diff --git a/apps/simulation_runner/service.py b/apps/simulation_runner/service.py index 7ec8bbe5..385cfbaa 100644 --- a/apps/simulation_runner/service.py +++ b/apps/simulation_runner/service.py @@ -1,84 +1,104 @@ import asyncio import logging -from typing import List -from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key, mark_stale_jobs_as_failed, update_simulation_run_heartbeat -from scenario_types import SimulationRun, Scenario -from simulation import simulate_scenarios +from typing import List, Optional + +# Updated imports from your new db module and scenario_types +from db import ( + get_pending_run, + get_simulations_for_run, + set_run_to_completed, + get_api_key, + mark_stale_jobs_as_failed, + update_run_heartbeat +) +from scenario_types import TestRun, TestSimulation +# If you have a new simulation function, import it here. +# Otherwise, adapt the name as needed: +from simulation import simulate_simulations # or simulate_scenarios, if unchanged + logging.basicConfig(level=logging.INFO) class JobService: def __init__(self): self.poll_interval = 5 # seconds + # Control concurrency of run processing self.semaphore = asyncio.Semaphore(5) - async def poll_and_process_jobs(self, max_iterations: int = None): + async def poll_and_process_jobs(self, max_iterations: Optional[int] = None): """ - Periodically checks for new jobs in MongoDB and processes them. + Periodically checks for new runs in MongoDB and processes them. """ - - # Start the stale-job check in the background - asyncio.create_task(self.fail_stale_jobs_loop()) + # Start the stale-run check in the background + asyncio.create_task(self.fail_stale_runs_loop()) iterations = 0 while True: - job = get_pending_simulation_run() - if job: - logging.info(f"Found new job: {job}. Processing...") - asyncio.create_task(self.process_job(job)) + run = get_pending_run() # <--- changed to match new DB function + if run: + logging.info(f"Found new run: {run}. Processing...") + asyncio.create_task(self.process_run(run)) iterations += 1 if max_iterations is not None and iterations >= max_iterations: break + # Sleep for the polling interval await asyncio.sleep(self.poll_interval) - async def process_job(self, job: SimulationRun): + async def process_run(self, run: TestRun): """ - Calls the simulation function and updates job status upon completion. + Calls the simulation function and updates run status upon completion. """ - async with self.semaphore: # Start heartbeat in background stop_heartbeat_event = asyncio.Event() - heartbeat_task = asyncio.create_task(self.heartbeat_loop(job.id, stop_heartbeat_event)) + heartbeat_task = asyncio.create_task(self.heartbeat_loop(run.id, stop_heartbeat_event)) try: - scenarios = get_scenarios_for_run(job) - if not scenarios or len(scenarios) == 0: - logging.info(f"No scenarios found for job {job.id}") + # Fetch the simulations associated with this run + simulations = get_simulations_for_run(run) + if not simulations: + logging.info(f"No simulations found for run {run.id}") return - api_key = get_api_key(job.projectId) - result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key) + # Fetch API key if needed + api_key = get_api_key(run.projectId) + # Perform your simulation logic + # adapt this call to your actual simulation function’s signature + aggregate_result = await simulate_simulations( + simulations=simulations, + run_id=run.id, + workflow_id=run.workflowId, + api_key=api_key + ) - set_simulation_run_to_completed(job, result) - logging.info(f"Job {job.id} completed.") + # Mark run as completed with the aggregated result + set_run_to_completed(run, aggregate_result) + logging.info(f"Run {run.id} completed.") except Exception as exc: - logging.error(f"Job {job.id} failed: {exc}") + logging.error(f"Run {run.id} failed: {exc}") finally: stop_heartbeat_event.set() await heartbeat_task - async def fail_stale_jobs_loop(self): + async def fail_stale_runs_loop(self): """ - Periodically checks for stale jobs that haven't received a heartbeat in over 5 minutes, - and marks them as 'failed'. + Periodically checks for stale runs (no heartbeat) and marks them as 'failed'. """ while True: count = mark_stale_jobs_as_failed() if count > 0: - logging.warning(f"Marked {count} stale jobs as failed.") + logging.warning(f"Marked {count} stale runs as failed.") await asyncio.sleep(60) # Check every 60 seconds - async def heartbeat_loop(self, job_id: str, stop_event: asyncio.Event): + async def heartbeat_loop(self, run_id: str, stop_event: asyncio.Event): """ - Periodically updates 'last_heartbeat' for the given job until 'stop_event' is set. + Periodically updates 'lastHeartbeat' for the given run until 'stop_event' is set. """ - try: while not stop_event.is_set(): - update_simulation_run_heartbeat(job_id) + update_run_heartbeat(run_id) await asyncio.sleep(10) # Heartbeat interval in seconds except asyncio.CancelledError: pass diff --git a/apps/simulation_runner/simulation.py b/apps/simulation_runner/simulation.py index 35897aaf..67e6dfd2 100644 --- a/apps/simulation_runner/simulation.py +++ b/apps/simulation_runner/simulation.py @@ -1,46 +1,67 @@ -from rowboat import Client, StatefulChat +import asyncio +import logging from typing import List import json import os -import asyncio from openai import OpenAI -from scenario_types import Scenario, SimulationResult, SimulationAggregateResult -from db import write_simulation_result + +# Updated imports from your new schema/types +from scenario_types import TestSimulation, TestResult, AggregateResults + +# If your DB functions changed names, adapt here: +from db import write_test_result # replaced write_simulation_result + +from rowboat import Client, StatefulChat openai_client = OpenAI() MODEL_NAME = "gpt-4o" ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip() -async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> tuple[str, str, str]: +async def simulate_simulation( + simulation: TestSimulation, + rowboat_client: Client, + workflow_id: str, + max_iterations: int = 5 +) -> tuple[str, str, str]: """ - Runs a mock simulation for a given scenario asynchronously. + Runs a mock simulation for a given TestSimulation asynchronously. After simulating several turns of conversation, it evaluates the conversation. Returns a tuple of (evaluation_result, details, transcript_str). """ + loop = asyncio.get_running_loop() + # Optionally embed passCriteria in the system prompt, if it’s relevant to context: + pass_criteria = simulation.passCriteria or "" + # Or place it separately below if you prefer. + + # Prepare a Rowboat chat support_chat = StatefulChat( rowboat_client, - system_prompt=f"{f'Context: {scenario.context}' if scenario.context else ''}", + system_prompt=f"Context: {pass_criteria}" if pass_criteria else "", workflow_id=workflow_id ) + # You might want to describe the simulation or scenario more thoroughly. + # Here, we just embed simulation.name in the system message: messages = [ { "role": "system", - "content": f"Simulate the user based on the scenario: \n {scenario.description}" + "content": ( + f"Simulate the user based on this simulation:\n{simulation.name}" + ) } ] # ------------------------- - # 1) MAIN SIMULATION LOOP + # (1) MAIN SIMULATION LOOP # ------------------------- - for i in range(max_iterations): + for _ in range(max_iterations): openai_input = messages - # Run OpenAI API call in a separate thread + # Run OpenAI API call in a separate thread (non-blocking) simulated_user_response = await loop.run_in_executor( - None, # Use default thread pool + None, # default ThreadPool lambda: openai_client.chat.completions.create( model=MODEL_NAME, messages=openai_input, @@ -48,9 +69,9 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow ) ) - simulated_content = simulated_user_response.choices[0].message.content + simulated_content = simulated_user_response.choices[0].message.content.strip() - # Run support_chat.run in a thread if it's synchronous + # Run Rowboat chat in a thread if it's synchronous rowboat_response = await loop.run_in_executor( None, lambda: support_chat.run(simulated_content) @@ -59,7 +80,7 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow messages.append({"role": "assistant", "content": rowboat_response}) # ------------------------- - # 2) EVALUATION STEP + # (2) EVALUATION STEP # ------------------------- transcript_str = "" for m in messages: @@ -67,19 +88,24 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow content = m.get("content", "") transcript_str += f"{role.upper()}: {content}\n" + # We use passCriteria as the evaluation “criteria.” evaluation_prompt = [ { "role": "system", "content": ( - f"You are a neutral evaluator. Evaluate based on these criteria:\n{scenario.criteria}\n\nReturn ONLY a JSON object with format: " - '{"verdict": "pass", "details": } if the support bot answered correctly, or {"verdict": "fail", "details": } if not.' + f"You are a neutral evaluator. Evaluate based on these criteria:\n" + f"{simulation.passCriteria}\n\n" + "Return ONLY a JSON object in this format:\n" + '{"verdict": "pass", "details": } or ' + '{"verdict": "fail", "details": }.' ) }, { "role": "user", "content": ( f"Here is the conversation transcript:\n\n{transcript_str}\n\n" - "Did the support bot answer correctly or not? Return only 'pass' or 'fail' for verdict, and a brief 2 sentence explanation for details." + "Did the support bot answer correctly or not? " + "Return only 'pass' or 'fail' for verdict, and a brief explanation for details." ) } ] @@ -91,51 +117,82 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow model=MODEL_NAME, messages=evaluation_prompt, temperature=0.0, + # If your LLM supports a structured response format, you can specify it. + # Otherwise, remove or adapt 'response_format': response_format={"type": "json_object"} ) ) if not eval_response.choices: raise Exception("No evaluation response received from model") - else: - response_json = json.loads(eval_response.choices[0].message.content) - evaluation_result = response_json.get("verdict") - details = response_json.get("details") - if evaluation_result is None: - raise Exception("No verdict field found in evaluation response") + + response_json_str = eval_response.choices[0].message.content + # Attempt to parse the JSON + response_json = json.loads(response_json_str) + evaluation_result = response_json.get("verdict") + details = response_json.get("details") + + if evaluation_result is None: + raise Exception("No 'verdict' field found in evaluation response") return (evaluation_result, details, transcript_str) -async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5): +async def simulate_simulations( + simulations: List[TestSimulation], + run_id: str, + workflow_id: str, + api_key: str, + max_iterations: int = 5 +) -> AggregateResults: """ - Simulates a list of scenarios asynchronously and aggregates the results. + Simulates a list of TestSimulations asynchronously and aggregates the results. """ - project_id = scenarios[0].projectId + if not simulations: + # Return an empty result if there's nothing to simulate + return AggregateResults(total=0, pass_=0, fail=0) + + # We assume all simulations belong to the same project + project_id = simulations[0].projectId + + # Create a Rowboat client instance client = Client( host=ROWBOAT_API_HOST, project_id=project_id, api_key=api_key ) - results = [] - for scenario in scenarios: - # Await the asynchronous simulate_scenario - result, details, transcript = await simulate_scenario(scenario, client, workflow_id, max_iterations) + # Store results here + results: List[TestResult] = [] - simulation_result = SimulationResult( - projectId=project_id, - runId=runId, - scenarioId=scenario.id, - result=result, - details=details, - transcript=transcript + for simulation in simulations: + # Run each simulation + verdict, details, transcript = await simulate_simulation( + simulation=simulation, + rowboat_client=client, + workflow_id=workflow_id, + max_iterations=max_iterations ) - results.append(simulation_result) - write_simulation_result(simulation_result) - aggregate_result = SimulationAggregateResult(**{ - "total": len(scenarios), - "pass": sum(1 for result in results if result.result == "pass"), - "fail": sum(1 for result in results if result.result == "fail") - }) - return aggregate_result \ No newline at end of file + # Create a new TestResult + test_result = TestResult( + projectId=project_id, + runId=run_id, + simulationId=simulation.id, + result=verdict, + details=details + ) + results.append(test_result) + + # Persist the test result + write_test_result(test_result) + + # Aggregate pass/fail + total_count = len(results) + pass_count = sum(1 for r in results if r.result == "pass") + fail_count = sum(1 for r in results if r.result == "fail") + + return AggregateResults( + total=total_count, + passCount=pass_count, + failCount=fail_count + )