mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-26 00:46:23 +02:00
simulation_runner: added failed job cleanup
This commit is contained in:
parent
ee20f7c6e3
commit
7bc3203ed2
4 changed files with 113 additions and 29 deletions
|
|
@ -2,20 +2,22 @@ from rowboat import Client, StatefulChat
|
|||
from typing import List
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from openai import OpenAI
|
||||
from scenario_types import Scenario, SimulationResult, SimulationAggregateResult
|
||||
from db import write_simulation_result
|
||||
|
||||
|
||||
openai_client = OpenAI()
|
||||
MODEL_NAME = "gpt-4o"
|
||||
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
|
||||
|
||||
def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> str:
|
||||
async def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> tuple[str, str, str]:
|
||||
"""
|
||||
Runs a mock simulation for a given scenario.
|
||||
Runs a mock simulation for a given scenario asynchronously.
|
||||
After simulating several turns of conversation, it evaluates the conversation.
|
||||
Returns a tuple of (evaluation_result, details, transcript_str).
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
support_chat = StatefulChat(
|
||||
rowboat_client,
|
||||
|
|
@ -36,18 +38,24 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
|
|||
for i in range(max_iterations):
|
||||
openai_input = messages
|
||||
|
||||
simulated_user_response = openai_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=openai_input,
|
||||
temperature=0.0,
|
||||
# Run OpenAI API call in a separate thread
|
||||
simulated_user_response = await loop.run_in_executor(
|
||||
None, # Use default thread pool
|
||||
lambda: openai_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=openai_input,
|
||||
temperature=0.0,
|
||||
)
|
||||
)
|
||||
|
||||
simulated_content = simulated_user_response.choices[0].message.content
|
||||
|
||||
# Feed the model-generated content back into Rowboat's stateful chat
|
||||
rowboat_response = support_chat.run(simulated_content)
|
||||
# Run support_chat.run in a thread if it's synchronous
|
||||
rowboat_response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: support_chat.run(simulated_content)
|
||||
)
|
||||
|
||||
# Store the user message back into `messages` so the conversation continues
|
||||
messages.append({"role": "assistant", "content": rowboat_response})
|
||||
|
||||
# -------------------------
|
||||
|
|
@ -76,11 +84,15 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
|
|||
}
|
||||
]
|
||||
|
||||
eval_response = openai_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=evaluation_prompt,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"}
|
||||
# Run evaluation in a separate thread
|
||||
eval_response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: openai_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=evaluation_prompt,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
)
|
||||
|
||||
if not eval_response.choices:
|
||||
|
|
@ -92,10 +104,12 @@ def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: s
|
|||
if evaluation_result is None:
|
||||
raise Exception("No verdict field found in evaluation response")
|
||||
|
||||
return(evaluation_result, details, transcript_str)
|
||||
|
||||
return (evaluation_result, details, transcript_str)
|
||||
|
||||
async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5):
|
||||
"""
|
||||
Simulates a list of scenarios asynchronously and aggregates the results.
|
||||
"""
|
||||
project_id = scenarios[0].projectId
|
||||
client = Client(
|
||||
host=ROWBOAT_API_HOST,
|
||||
|
|
@ -103,8 +117,10 @@ async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id:
|
|||
api_key=api_key
|
||||
)
|
||||
results = []
|
||||
|
||||
for scenario in scenarios:
|
||||
result, details, transcript = simulate_scenario(scenario, client, workflow_id, max_iterations)
|
||||
# Await the asynchronous simulate_scenario
|
||||
result, details, transcript = await simulate_scenario(scenario, client, workflow_id, max_iterations)
|
||||
|
||||
simulation_result = SimulationResult(
|
||||
projectId=project_id,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue