mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-25 00:16:29 +02:00
Add simulation runner
This commit is contained in:
parent
1edf2e0e18
commit
893f215f4c
8 changed files with 358 additions and 0 deletions
20
apps/simulation_runner/Dockerfile
Normal file
20
apps/simulation_runner/Dockerfile
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Use official Python runtime as base image
|
||||
FROM python:3.11-slim
|
||||
|
||||
# Set working directory in container
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements file
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy project files
|
||||
COPY . .
|
||||
|
||||
# Expose port if your app needs it (adjust as needed)
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Command to run the simulation service
|
||||
CMD ["python", "service.py"]
|
||||
0
apps/simulation_runner/__init__.py
Normal file
0
apps/simulation_runner/__init__.py
Normal file
74
apps/simulation_runner/db.py
Normal file
74
apps/simulation_runner/db.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
from pymongo import MongoClient
|
||||
from bson import ObjectId
|
||||
import os
|
||||
from scenario_types import SimulationRun, Scenario, SimulationResult, SimulationAggregateResult
|
||||
|
||||
MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip()
|
||||
|
||||
SCENARIOS_COLLECTION_NAME = "scenarios"
|
||||
API_KEYS_COLLECTION = "api_keys"
|
||||
SIMULATIONS_COLLECTION_NAME = "simulation_runs"
|
||||
SIMULATION_RESULT_COLLECTION_NAME = "simulation_result"
|
||||
SIMULATION_AGGREGATE_RESULT_COLLECTION_NAME = "simulation_aggregate_result"
|
||||
|
||||
def get_db():
|
||||
client = MongoClient(MONGO_URI)
|
||||
return client.get_default_database()
|
||||
|
||||
def get_collection(collection_name: str):
|
||||
db = get_db()
|
||||
return db[collection_name]
|
||||
|
||||
def get_api_key(project_id: str):
|
||||
collection = get_collection(API_KEYS_COLLECTION)
|
||||
doc = collection.find_one({"projectId": project_id})
|
||||
if doc:
|
||||
return doc["key"]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_pending_simulation_run():
|
||||
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
|
||||
doc = collection.find_one_and_update(
|
||||
{"status": "pending"},
|
||||
{"$set": {"status": "running"}},
|
||||
return_document=True
|
||||
)
|
||||
if doc:
|
||||
return SimulationRun(
|
||||
id=str(doc["_id"]),
|
||||
projectId=doc["projectId"],
|
||||
status="running",
|
||||
scenarioIds=doc["scenarioIds"],
|
||||
workflowId=doc["workflowId"],
|
||||
startedAt=doc["startedAt"],
|
||||
completedAt=doc.get("completedAt")
|
||||
)
|
||||
return None
|
||||
|
||||
def set_simulation_run_to_completed(simulation_run: SimulationRun, aggregate_result: SimulationAggregateResult):
|
||||
collection = get_collection(SIMULATIONS_COLLECTION_NAME)
|
||||
collection.update_one({"_id": ObjectId(simulation_run.id)}, {"$set": {"status": "completed", "aggregateResults": aggregate_result.model_dump(by_alias=True)}})
|
||||
|
||||
def get_scenarios_for_run(simulation_run: SimulationRun):
|
||||
if simulation_run is None:
|
||||
return []
|
||||
collection = get_collection(SCENARIOS_COLLECTION_NAME)
|
||||
scenarios = []
|
||||
for doc in collection.find():
|
||||
if doc["_id"] in [ObjectId(sid) for sid in simulation_run.scenarioIds]:
|
||||
scenarios.append(Scenario(
|
||||
id=str(doc["_id"]),
|
||||
projectId=doc["projectId"],
|
||||
name=doc["name"],
|
||||
description=doc["description"],
|
||||
criteria=doc["criteria"],
|
||||
context=doc["context"],
|
||||
createdAt=doc["createdAt"],
|
||||
lastUpdatedAt=doc["lastUpdatedAt"]
|
||||
))
|
||||
return scenarios
|
||||
|
||||
def write_simulation_result(result: SimulationResult):
|
||||
collection = get_collection(SIMULATION_RESULT_COLLECTION_NAME)
|
||||
collection.insert_one(result.model_dump())
|
||||
29
apps/simulation_runner/requirements.txt
Normal file
29
apps/simulation_runner/requirements.txt
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
annotated-types==0.7.0
|
||||
anyio==4.8.0
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
distro==1.9.0
|
||||
dnspython==2.7.0
|
||||
h11==0.14.0
|
||||
httpcore==1.0.7
|
||||
httpx==0.28.1
|
||||
idna==3.10
|
||||
iniconfig==2.0.0
|
||||
jiter==0.8.2
|
||||
motor==3.7.0
|
||||
openai==1.63.0
|
||||
packaging==24.2
|
||||
pluggy==1.5.0
|
||||
pydantic==2.10.6
|
||||
pydantic_core==2.27.2
|
||||
pymongo==4.11.1
|
||||
pytest==8.3.4
|
||||
pytest-asyncio==0.25.3
|
||||
python-dateutil==2.9.0.post0
|
||||
requests==2.32.3
|
||||
rowboat==1.0.4
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
tqdm==4.67.1
|
||||
typing_extensions==4.12.2
|
||||
urllib3==2.3.0
|
||||
38
apps/simulation_runner/scenario_types.py
Normal file
38
apps/simulation_runner/scenario_types.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional, List, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
run_status = Literal["pending", "running", "completed", "cancelled", "failed"]
|
||||
|
||||
class Scenario(BaseModel):
|
||||
id: str
|
||||
projectId: str
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
criteria: str = ""
|
||||
context: str = ""
|
||||
createdAt: datetime
|
||||
lastUpdatedAt: datetime
|
||||
|
||||
class SimulationRun(BaseModel):
|
||||
id: str
|
||||
projectId: str
|
||||
status: Literal["pending", "running", "completed", "cancelled", "failed"]
|
||||
scenarioIds: List[str]
|
||||
workflowId: str
|
||||
startedAt: datetime
|
||||
completedAt: Optional[datetime] = None
|
||||
aggregateResults: Optional[dict] = None
|
||||
|
||||
|
||||
class SimulationResult(BaseModel):
|
||||
projectId: str
|
||||
runId: str
|
||||
scenarioId: str
|
||||
result: Literal["pass", "fail"]
|
||||
details: str
|
||||
|
||||
class SimulationAggregateResult(BaseModel):
|
||||
total: int
|
||||
pass_: int = Field(..., alias='pass')
|
||||
fail: int
|
||||
64
apps/simulation_runner/service.py
Normal file
64
apps/simulation_runner/service.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
from db import get_pending_simulation_run, get_scenarios_for_run, set_simulation_run_to_completed, get_api_key
|
||||
from scenario_types import SimulationRun, Scenario
|
||||
from simulation import simulate_scenarios
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
class JobService:
|
||||
def __init__(self):
|
||||
self.poll_interval = 5 # seconds
|
||||
self.semaphore = asyncio.Semaphore(5)
|
||||
|
||||
async def poll_and_process_jobs(self, max_iterations: int = None):
|
||||
"""
|
||||
Periodically checks for new jobs in MongoDB and processes them.
|
||||
"""
|
||||
iterations = 0
|
||||
while True:
|
||||
job = get_pending_simulation_run()
|
||||
if job:
|
||||
logging.info(f"Found new job: {job}. Processing...")
|
||||
asyncio.create_task(self.process_job(job))
|
||||
else:
|
||||
logging.info("No new jobs found. Checking again in 5 seconds...")
|
||||
|
||||
iterations += 1
|
||||
if max_iterations is not None and iterations >= max_iterations:
|
||||
break
|
||||
# Sleep for the polling interval
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
async def process_job(self, job: SimulationRun):
|
||||
"""
|
||||
Calls the simulation function and updates job status upon completion.
|
||||
"""
|
||||
async with self.semaphore:
|
||||
scenarios = get_scenarios_for_run(job)
|
||||
if not scenarios or len(scenarios) == 0:
|
||||
logging.info(f"No scenarios found for job {job.id}")
|
||||
return
|
||||
|
||||
api_key = get_api_key(job.projectId)
|
||||
result = await simulate_scenarios(scenarios, job.id, job.workflowId, api_key)
|
||||
|
||||
|
||||
set_simulation_run_to_completed(job, result)
|
||||
logging.info(f"Job {job.id} completed.")
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Entry point to start the service event loop.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(self.poll_and_process_jobs())
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Service stopped by user.")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
service = JobService()
|
||||
service.start()
|
||||
123
apps/simulation_runner/simulation.py
Normal file
123
apps/simulation_runner/simulation.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
from rowboat import Client, StatefulChat
|
||||
from typing import List
|
||||
import json
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from scenario_types import Scenario, SimulationResult, SimulationAggregateResult
|
||||
from db import write_simulation_result, set_simulation_run_to_completed
|
||||
|
||||
|
||||
openai_client = OpenAI()
|
||||
MODEL_NAME = "gpt-4o"
|
||||
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
|
||||
|
||||
def simulate_scenario(scenario: Scenario, rowboat_client: Client, workflow_id: str, max_iterations: int = 5) -> str:
|
||||
"""
|
||||
Runs a mock simulation for a given scenario.
|
||||
After simulating several turns of conversation, it evaluates the conversation.
|
||||
"""
|
||||
|
||||
support_chat = StatefulChat(
|
||||
rowboat_client,
|
||||
system_prompt=f"{f'Context: {scenario.context}' if scenario.context else ''}",
|
||||
workflow_id=workflow_id
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Simulate the user based on the scenario: \n {scenario.description}"
|
||||
}
|
||||
]
|
||||
|
||||
# -------------------------
|
||||
# 1) MAIN SIMULATION LOOP
|
||||
# -------------------------
|
||||
for i in range(max_iterations):
|
||||
openai_input = messages
|
||||
|
||||
simulated_user_response = openai_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=openai_input,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
simulated_content = simulated_user_response.choices[0].message.content
|
||||
|
||||
# Feed the model-generated content back into Rowboat's stateful chat
|
||||
rowboat_response = support_chat.run(simulated_content)
|
||||
|
||||
# Store the user message back into `messages` so the conversation continues
|
||||
messages.append({"role": "assistant", "content": rowboat_response})
|
||||
|
||||
# -------------------------
|
||||
# 2) EVALUATION STEP
|
||||
# -------------------------
|
||||
transcript_str = ""
|
||||
for m in messages:
|
||||
role = m.get("role", "unknown")
|
||||
content = m.get("content", "")
|
||||
transcript_str += f"{role.upper()}: {content}\n"
|
||||
|
||||
evaluation_prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
f"You are a neutral evaluator. Evaluate based on these criteria:\n{scenario.criteria}\n\nReturn ONLY a JSON object with format: "
|
||||
'{"verdict": "pass"} if the support bot answered correctly, or {"verdict": "fail"} if not.'
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Here is the conversation transcript:\n\n{transcript_str}\n\n"
|
||||
"Did the support bot answer correctly or not? Return only 'pass' or 'fail'."
|
||||
)
|
||||
}
|
||||
]
|
||||
|
||||
eval_response = openai_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=evaluation_prompt,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
|
||||
if not eval_response.choices:
|
||||
raise Exception("No evaluation response received from model")
|
||||
else:
|
||||
response_json = json.loads(eval_response.choices[0].message.content)
|
||||
evaluation_result = response_json.get("verdict")
|
||||
if evaluation_result is None:
|
||||
raise Exception("No verdict field found in evaluation response")
|
||||
|
||||
return(evaluation_result, transcript_str)
|
||||
|
||||
|
||||
async def simulate_scenarios(scenarios: List[Scenario], runId: str, workflow_id: str, api_key: str, max_iterations: int = 5):
|
||||
project_id = scenarios[0].projectId
|
||||
client = Client(
|
||||
host=ROWBOAT_API_HOST,
|
||||
project_id=project_id,
|
||||
api_key=api_key
|
||||
)
|
||||
results = []
|
||||
for scenario in scenarios:
|
||||
result, transcript = simulate_scenario(scenario, client, workflow_id, max_iterations)
|
||||
|
||||
simulation_result = SimulationResult(
|
||||
projectId=project_id,
|
||||
runId=runId,
|
||||
scenarioId=scenario.id,
|
||||
result=result,
|
||||
details=transcript
|
||||
)
|
||||
results.append(simulation_result)
|
||||
write_simulation_result(simulation_result)
|
||||
|
||||
aggregate_result = SimulationAggregateResult(**{
|
||||
"total": len(scenarios),
|
||||
"pass": sum(1 for result in results if result.result == "pass"),
|
||||
"fail": sum(1 for result in results if result.result == "fail")
|
||||
})
|
||||
return aggregate_result
|
||||
|
|
@ -60,6 +60,16 @@ services:
|
|||
- SIGNING_SECRET=${SIGNING_SECRET}
|
||||
restart: unless-stopped
|
||||
|
||||
simulation_runner:
|
||||
build:
|
||||
context: ./apps/simulation_runner
|
||||
dockerfile: Dockerfile
|
||||
environment:
|
||||
- MONGODB_URI=${MONGODB_CONNECTION_STRING}
|
||||
- ROWBOAT_API_HOST=http://rowboat:3000
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
restart: unless-stopped
|
||||
|
||||
docs:
|
||||
build:
|
||||
context: ./apps/docs
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue