diff --git a/apps/simulation_runner/db.py b/apps/simulation_runner/db.py index c8a79937..06713720 100644 --- a/apps/simulation_runner/db.py +++ b/apps/simulation_runner/db.py @@ -6,7 +6,6 @@ from typing import Optional from scenario_types import ( TestRun, TestScenario, - TestProfile, TestSimulation, TestResult, AggregateResults @@ -14,13 +13,11 @@ from scenario_types import ( MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip() -# New collection names TEST_SCENARIOS_COLLECTION = "test_scenarios" -TEST_PROFILES_COLLECTION = "test_profiles" TEST_SIMULATIONS_COLLECTION = "test_simulations" TEST_RUNS_COLLECTION = "test_runs" TEST_RESULTS_COLLECTION = "test_results" -API_KEYS_COLLECTION = "api_keys" # If still needed +API_KEYS_COLLECTION = "api_keys" def get_db(): client = MongoClient(MONGO_URI) @@ -145,6 +142,23 @@ def get_simulations_for_run(test_run: TestRun) -> list[TestSimulation]: ) return simulations +def get_scenario_by_id(scenario_id: str) -> TestScenario: + """ + Returns a TestScenario by its ID. + """ + collection = get_collection(TEST_SCENARIOS_COLLECTION) + doc = collection.find_one({"_id": ObjectId(scenario_id)}) + if doc: + return TestScenario( + id=str(doc["_id"]), + projectId=doc["projectId"], + name=doc["name"], + description=doc["description"], + createdAt=doc["createdAt"], + lastUpdatedAt=doc["lastUpdatedAt"] + ) + return None + # # TestResult helpers # diff --git a/apps/simulation_runner/scenario_types.py b/apps/simulation_runner/scenario_types.py index 981342eb..3e61d894 100644 --- a/apps/simulation_runner/scenario_types.py +++ b/apps/simulation_runner/scenario_types.py @@ -14,16 +14,6 @@ class TestScenario(BaseModel): createdAt: datetime lastUpdatedAt: datetime -class TestProfile(BaseModel): - id: str - projectId: str - name: str - context: str - createdAt: datetime - lastUpdatedAt: datetime - mockTools: bool - mockPrompt: Optional[str] = None - class TestSimulation(BaseModel): id: str projectId: str @@ -48,11 +38,7 @@ class TestRun(BaseModel): status: RunStatus startedAt: datetime completedAt: Optional[datetime] = None - # By default, store aggregate results as a dict or the typed AggregateResults aggregateResults: Optional[AggregateResults] = None - - # The new schema does not mention lastHeartbeat, - # but you can keep it if you still want to track stale runs lastHeartbeat: Optional[datetime] = None class TestResult(BaseModel): @@ -61,3 +47,4 @@ class TestResult(BaseModel): simulationId: str result: Literal["pass", "fail"] details: str + transcript: str diff --git a/apps/simulation_runner/simulation.py b/apps/simulation_runner/simulation.py index 67e6dfd2..3ae2b41e 100644 --- a/apps/simulation_runner/simulation.py +++ b/apps/simulation_runner/simulation.py @@ -5,12 +5,9 @@ import json import os from openai import OpenAI -# Updated imports from your new schema/types -from scenario_types import TestSimulation, TestResult, AggregateResults - -# If your DB functions changed names, adapt here: -from db import write_test_result # replaced write_simulation_result +from scenario_types import TestSimulation, TestResult, AggregateResults, TestScenario +from db import write_test_result, get_scenario_by_id from rowboat import Client, StatefulChat openai_client = OpenAI() @@ -18,7 +15,9 @@ MODEL_NAME = "gpt-4o" ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip() async def simulate_simulation( - simulation: TestSimulation, + scenario: TestScenario, + profile_id: str, + pass_criteria: str, rowboat_client: Client, workflow_id: str, max_iterations: int = 5 @@ -30,25 +29,19 @@ async def simulate_simulation( """ loop = asyncio.get_running_loop() + pass_criteria = pass_criteria - # Optionally embed passCriteria in the system prompt, if it’s relevant to context: - pass_criteria = simulation.passCriteria or "" - # Or place it separately below if you prefer. - - # Prepare a Rowboat chat + # Todo: add profile_id support_chat = StatefulChat( rowboat_client, - system_prompt=f"Context: {pass_criteria}" if pass_criteria else "", workflow_id=workflow_id ) - # You might want to describe the simulation or scenario more thoroughly. - # Here, we just embed simulation.name in the system message: messages = [ { "role": "system", "content": ( - f"Simulate the user based on this simulation:\n{simulation.name}" + f"You are a customer talking to a chatbot. Have the following chat with the chatbot. Scenario:\n{scenario.description}. You are provided no other information. If the chatbot asks you for information that is not in context, go ahead and provide one unless stated otherwise in the scenario. Directly have the chat with the chatbot. Start now." ) } ] @@ -70,7 +63,7 @@ async def simulate_simulation( ) simulated_content = simulated_user_response.choices[0].message.content.strip() - + messages.append({"role": "user", "content": simulated_content}) # Run Rowboat chat in a thread if it's synchronous rowboat_response = await loop.run_in_executor( None, @@ -88,13 +81,16 @@ async def simulate_simulation( content = m.get("content", "") transcript_str += f"{role.upper()}: {content}\n" - # We use passCriteria as the evaluation “criteria.” + # Store the transcript as a JSON string + transcript = json.dumps(messages) + + # We use passCriteria as the evaluation "criteria." evaluation_prompt = [ { "role": "system", "content": ( f"You are a neutral evaluator. Evaluate based on these criteria:\n" - f"{simulation.passCriteria}\n\n" + f"{pass_criteria}\n\n" "Return ONLY a JSON object in this format:\n" '{"verdict": "pass", "details": } or ' '{"verdict": "fail", "details": }.' @@ -117,8 +113,6 @@ async def simulate_simulation( model=MODEL_NAME, messages=evaluation_prompt, temperature=0.0, - # If your LLM supports a structured response format, you can specify it. - # Otherwise, remove or adapt 'response_format': response_format={"type": "json_object"} ) ) @@ -135,7 +129,7 @@ async def simulate_simulation( if evaluation_result is None: raise Exception("No 'verdict' field found in evaluation response") - return (evaluation_result, details, transcript_str) + return (evaluation_result, details, transcript) async def simulate_simulations( simulations: List[TestSimulation], @@ -151,10 +145,8 @@ async def simulate_simulations( # Return an empty result if there's nothing to simulate return AggregateResults(total=0, pass_=0, fail=0) - # We assume all simulations belong to the same project project_id = simulations[0].projectId - # Create a Rowboat client instance client = Client( host=ROWBOAT_API_HOST, project_id=project_id, @@ -165,9 +157,10 @@ async def simulate_simulations( results: List[TestResult] = [] for simulation in simulations: - # Run each simulation verdict, details, transcript = await simulate_simulation( - simulation=simulation, + scenario=get_scenario_by_id(simulation.scenarioId), + profile_id=simulation.profileId, + pass_criteria=simulation.passCriteria, rowboat_client=client, workflow_id=workflow_id, max_iterations=max_iterations @@ -179,7 +172,8 @@ async def simulate_simulations( runId=run_id, simulationId=simulation.id, result=verdict, - details=details + details=details, + transcript=transcript ) results.append(test_result)