updated simulation runner to the new collections

This commit is contained in:
arkml 2025-03-01 12:56:08 +05:30
parent 33f30670f6
commit 2b1ef82c20
7 changed files with 311 additions and 159 deletions

View file

@ -1,46 +1,67 @@
from rowboat import Client, StatefulChat
import asyncio
import logging
from typing import List
import json
import os
import asyncio
from openai import OpenAI
from scenario_types import Scenario, SimulationResult, SimulationAggregateResult
from db import write_simulation_result
# Updated imports from your new schema/types
from scenario_types import TestSimulation, TestResult, AggregateResults
# If your DB functions changed names, adapt here:
from db import write_test_result # replaced write_simulation_result
from rowboat import Client, StatefulChat
openai_client = OpenAI()
MODEL_NAME = "gpt-4o"
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> tuple[str, str, str]:
async def simulate_simulation(
simulation: TestSimulation,
rowboat_client: Client,
workflow_id: str,
max_iterations: int = 5
) -> tuple[str, str, str]:
"""
Runs a mock simulation for a given scenario asynchronously.
Runs a mock simulation for a given TestSimulation asynchronously.
After simulating several turns of conversation, it evaluates the conversation.
Returns a tuple of (evaluation_result, details, transcript_str).
"""
loop = asyncio.get_running_loop()
# Optionally embed passCriteria in the system prompt, if its relevant to context:
pass_criteria = simulation.passCriteria or ""
# Or place it separately below if you prefer.
# Prepare a Rowboat chat
support_chat = StatefulChat(
rowboat_client,
system_prompt=f"{f'Context: {scenario.context}' if scenario.context else ''}",
system_prompt=f"Context: {pass_criteria}" if pass_criteria else "",
workflow_id=workflow_id
)
# You might want to describe the simulation or scenario more thoroughly.
# Here, we just embed simulation.name in the system message:
messages = [
{
"role": "system",
"content": f"Simulate the user based on the scenario: \n {scenario.description}"
"content": (
f"Simulate the user based on this simulation:\n{simulation.name}"
)
}
]
# -------------------------
# 1) MAIN SIMULATION LOOP
# (1) MAIN SIMULATION LOOP
# -------------------------
for i in range(max_iterations):
for _ in range(max_iterations):
openai_input = messages
# Run OpenAI API call in a separate thread
# Run OpenAI API call in a separate thread (non-blocking)
simulated_user_response = await loop.run_in_executor(
None, # Use default thread pool
None, # default ThreadPool
lambda: openai_client.chat.completions.create(
model=MODEL_NAME,
messages=openai_input,
@ -48,9 +69,9 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow
)
)
simulated_content = simulated_user_response.choices[0].message.content
simulated_content = simulated_user_response.choices[0].message.content.strip()
# Run support_chat.run in a thread if it's synchronous
# Run Rowboat chat in a thread if it's synchronous
rowboat_response = await loop.run_in_executor(
None,
lambda: support_chat.run(simulated_content)
@ -59,7 +80,7 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow
messages.append({"role": "assistant", "content": rowboat_response})
# -------------------------
# 2) EVALUATION STEP
# (2) EVALUATION STEP
# -------------------------
transcript_str = ""
for m in messages:
@ -67,19 +88,24 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow
content = m.get("content", "")
transcript_str += f"{role.upper()}: {content}\n"
# We use passCriteria as the evaluation “criteria.”
evaluation_prompt = [
{
"role": "system",
"content": (
f"You are a neutral evaluator. Evaluate based on these criteria:\n{scenario.criteria}\n\nReturn ONLY a JSON object with format: "
'{"verdict": "pass", "details": <the reason for pass in 2 sentences>} if the support bot answered correctly, or {"verdict": "fail", "details": <the reason for fail in 2 sentences>} if not.'
f"You are a neutral evaluator. Evaluate based on these criteria:\n"
f"{simulation.passCriteria}\n\n"
"Return ONLY a JSON object in this format:\n"
'{"verdict": "pass", "details": <reason>} or '
'{"verdict": "fail", "details": <reason>}.'
)
},
{
"role": "user",
"content": (
f"Here is the conversation transcript:\n\n{transcript_str}\n\n"
"Did the support bot answer correctly or not? Return only 'pass' or 'fail' for verdict, and a brief 2 sentence explanation for details."
"Did the support bot answer correctly or not? "
"Return only 'pass' or 'fail' for verdict, and a brief explanation for details."
)
}
]
@ -91,51 +117,82 @@ async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow
model=MODEL_NAME,
messages=evaluation_prompt,
temperature=0.0,
# If your LLM supports a structured response format, you can specify it.
# Otherwise, remove or adapt 'response_format':
response_format={"type": "json_object"}
)
)
if not eval_response.choices:
raise Exception("No evaluation response received from model")
else:
response_json = json.loads(eval_response.choices[0].message.content)
evaluation_result = response_json.get("verdict")
details = response_json.get("details")
if evaluation_result is None:
raise Exception("No verdict field found in evaluation response")
response_json_str = eval_response.choices[0].message.content
# Attempt to parse the JSON
response_json = json.loads(response_json_str)
evaluation_result = response_json.get("verdict")
details = response_json.get("details")
if evaluation_result is None:
raise Exception("No 'verdict' field found in evaluation response")
return (evaluation_result, details, transcript_str)
async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5):
async def simulate_simulations(
simulations: List[TestSimulation],
run_id: str,
workflow_id: str,
api_key: str,
max_iterations: int = 5
) -> AggregateResults:
"""
Simulates a list of scenarios asynchronously and aggregates the results.
Simulates a list of TestSimulations asynchronously and aggregates the results.
"""
project_id = scenarios[0].projectId
if not simulations:
# Return an empty result if there's nothing to simulate
return AggregateResults(total=0, pass_=0, fail=0)
# We assume all simulations belong to the same project
project_id = simulations[0].projectId
# Create a Rowboat client instance
client = Client(
host=ROWBOAT_API_HOST,
project_id=project_id,
api_key=api_key
)
results = []
for scenario in scenarios:
# Await the asynchronous simulate_scenario
result, details, transcript = await simulate_scenario(scenario, client, workflow_id, max_iterations)
# Store results here
results: List[TestResult] = []
simulation_result = SimulationResult(
projectId=project_id,
runId=runId,
scenarioId=scenario.id,
result=result,
details=details,
transcript=transcript
for simulation in simulations:
# Run each simulation
verdict, details, transcript = await simulate_simulation(
simulation=simulation,
rowboat_client=client,
workflow_id=workflow_id,
max_iterations=max_iterations
)
results.append(simulation_result)
write_simulation_result(simulation_result)
aggregate_result = SimulationAggregateResult(**{
"total": len(scenarios),
"pass": sum(1 for result in results if result.result == "pass"),
"fail": sum(1 for result in results if result.result == "fail")
})
return aggregate_result
# Create a new TestResult
test_result = TestResult(
projectId=project_id,
runId=run_id,
simulationId=simulation.id,
result=verdict,
details=details
)
results.append(test_result)
# Persist the test result
write_test_result(test_result)
# Aggregate pass/fail
total_count = len(results)
pass_count = sum(1 for r in results if r.result == "pass")
fail_count = sum(1 for r in results if r.result == "fail")
return AggregateResults(
total=total_count,
passCount=pass_count,
failCount=fail_count
)