added transcript and fixed simulation bugs

This commit is contained in:
arkml 2025-03-11 13:31:26 +05:30 committed by ramnique
parent 9f2854a22c
commit 933a28ac28
3 changed files with 39 additions and 44 deletions

View file

@ -5,12 +5,9 @@ import json
import os
from openai import OpenAI
# Updated imports from your new schema/types
from scenario_types import TestSimulation, TestResult, AggregateResults
# If your DB functions changed names, adapt here:
from db import write_test_result # replaced write_simulation_result
from scenario_types import TestSimulation, TestResult, AggregateResults, TestScenario
from db import write_test_result, get_scenario_by_id
from rowboat import Client, StatefulChat
openai_client = OpenAI()
@ -18,7 +15,9 @@ MODEL_NAME = "gpt-4o"
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
async def simulate_simulation(
simulation: TestSimulation,
scenario: TestScenario,
profile_id: str,
pass_criteria: str,
rowboat_client: Client,
workflow_id: str,
max_iterations: int = 5
@ -30,25 +29,19 @@ async def simulate_simulation(
"""
loop = asyncio.get_running_loop()
pass_criteria = pass_criteria
# Optionally embed passCriteria in the system prompt, if its relevant to context:
pass_criteria = simulation.passCriteria or ""
# Or place it separately below if you prefer.
# Prepare a Rowboat chat
# Todo: add profile_id
support_chat = StatefulChat(
rowboat_client,
system_prompt=f"Context: {pass_criteria}" if pass_criteria else "",
workflow_id=workflow_id
)
# You might want to describe the simulation or scenario more thoroughly.
# Here, we just embed simulation.name in the system message:
messages = [
{
"role": "system",
"content": (
f"Simulate the user based on this simulation:\n{simulation.name}"
f"You are a customer talking to a chatbot. Have the following chat with the chatbot. Scenario:\n{scenario.description}. You are provided no other information. If the chatbot asks you for information that is not in context, go ahead and provide one unless stated otherwise in the scenario. Directly have the chat with the chatbot. Start now."
)
}
]
@ -70,7 +63,7 @@ async def simulate_simulation(
)
simulated_content = simulated_user_response.choices[0].message.content.strip()
messages.append({"role": "user", "content": simulated_content})
# Run Rowboat chat in a thread if it's synchronous
rowboat_response = await loop.run_in_executor(
None,
@ -88,13 +81,16 @@ async def simulate_simulation(
content = m.get("content", "")
transcript_str += f"{role.upper()}: {content}\n"
# We use passCriteria as the evaluation “criteria.”
# Store the transcript as a JSON string
transcript = json.dumps(messages)
# We use passCriteria as the evaluation "criteria."
evaluation_prompt = [
{
"role": "system",
"content": (
f"You are a neutral evaluator. Evaluate based on these criteria:\n"
f"{simulation.passCriteria}\n\n"
f"{pass_criteria}\n\n"
"Return ONLY a JSON object in this format:\n"
'{"verdict": "pass", "details": <reason>} or '
'{"verdict": "fail", "details": <reason>}.'
@ -117,8 +113,6 @@ async def simulate_simulation(
model=MODEL_NAME,
messages=evaluation_prompt,
temperature=0.0,
# If your LLM supports a structured response format, you can specify it.
# Otherwise, remove or adapt 'response_format':
response_format={"type": "json_object"}
)
)
@ -135,7 +129,7 @@ async def simulate_simulation(
if evaluation_result is None:
raise Exception("No 'verdict' field found in evaluation response")
return (evaluation_result, details, transcript_str)
return (evaluation_result, details, transcript)
async def simulate_simulations(
simulations: List[TestSimulation],
@ -151,10 +145,8 @@ async def simulate_simulations(
# Return an empty result if there's nothing to simulate
return AggregateResults(total=0, pass_=0, fail=0)
# We assume all simulations belong to the same project
project_id = simulations[0].projectId
# Create a Rowboat client instance
client = Client(
host=ROWBOAT_API_HOST,
project_id=project_id,
@ -165,9 +157,10 @@ async def simulate_simulations(
results: List[TestResult] = []
for simulation in simulations:
# Run each simulation
verdict, details, transcript = await simulate_simulation(
simulation=simulation,
scenario=get_scenario_by_id(simulation.scenarioId),
profile_id=simulation.profileId,
pass_criteria=simulation.passCriteria,
rowboat_client=client,
workflow_id=workflow_id,
max_iterations=max_iterations
@ -179,7 +172,8 @@ async def simulate_simulations(
runId=run_id,
simulationId=simulation.id,
result=verdict,
details=details
details=details,
transcript=transcript
)
results.append(test_result)