mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-25 00:16:29 +02:00
mv experimental apps
This commit is contained in:
parent
7f6ece90f8
commit
f722591ccd
53 changed files with 31 additions and 31 deletions
20
apps/experimental/simulation_runner/Dockerfile
Normal file
20
apps/experimental/simulation_runner/Dockerfile
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Use official Python runtime as base image
|
||||
FROM python:3.11-slim
|
||||
|
||||
# Set working directory in container
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements file
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy project files
|
||||
COPY . .
|
||||
|
||||
# Expose port if your app needs it (adjust as needed)
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Command to run the simulation service
|
||||
CMD ["python", "service.py"]
|
||||
0
apps/experimental/simulation_runner/__init__.py
Normal file
0
apps/experimental/simulation_runner/__init__.py
Normal file
171
apps/experimental/simulation_runner/db.py
Normal file
171
apps/experimental/simulation_runner/db.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
from pymongo import MongoClient
|
||||
from bson import ObjectId
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from scenario_types import (
|
||||
TestRun,
|
||||
TestScenario,
|
||||
TestSimulation,
|
||||
TestResult,
|
||||
AggregateResults
|
||||
)
|
||||
|
||||
MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").strip()
|
||||
|
||||
TEST_SCENARIOS_COLLECTION = "test_scenarios"
|
||||
TEST_SIMULATIONS_COLLECTION = "test_simulations"
|
||||
TEST_RUNS_COLLECTION = "test_runs"
|
||||
TEST_RESULTS_COLLECTION = "test_results"
|
||||
API_KEYS_COLLECTION = "api_keys"
|
||||
|
||||
def get_db():
|
||||
client = MongoClient(MONGO_URI)
|
||||
return client["rowboat"]
|
||||
|
||||
def get_collection(collection_name: str):
|
||||
db = get_db()
|
||||
return db[collection_name]
|
||||
|
||||
def get_api_key(project_id: str):
|
||||
"""
|
||||
If you still use an API key pattern, adapt as needed.
|
||||
"""
|
||||
collection = get_collection(API_KEYS_COLLECTION)
|
||||
doc = collection.find_one({"projectId": project_id})
|
||||
if doc:
|
||||
return doc["key"]
|
||||
else:
|
||||
return None
|
||||
|
||||
#
|
||||
# TestRun helpers
|
||||
#
|
||||
|
||||
def get_pending_run() -> Optional[TestRun]:
|
||||
"""
|
||||
Finds a run with 'pending' status, marks it 'running', and returns it.
|
||||
"""
|
||||
collection = get_collection(TEST_RUNS_COLLECTION)
|
||||
doc = collection.find_one_and_update(
|
||||
{"status": "pending"},
|
||||
{"$set": {"status": "running"}},
|
||||
return_document=True
|
||||
)
|
||||
if doc:
|
||||
return TestRun(
|
||||
id=str(doc["_id"]),
|
||||
projectId=doc["projectId"],
|
||||
name=doc["name"],
|
||||
simulationIds=doc["simulationIds"],
|
||||
workflowId=doc["workflowId"],
|
||||
status="running",
|
||||
startedAt=doc["startedAt"],
|
||||
completedAt=doc.get("completedAt"),
|
||||
aggregateResults=doc.get("aggregateResults"),
|
||||
lastHeartbeat=doc.get("lastHeartbeat")
|
||||
)
|
||||
return None
|
||||
|
||||
def set_run_to_completed(test_run: TestRun, aggregate: AggregateResults):
|
||||
"""
|
||||
Marks a test run 'completed' and sets the aggregate results.
|
||||
"""
|
||||
collection = get_collection(TEST_RUNS_COLLECTION)
|
||||
collection.update_one(
|
||||
{"_id": ObjectId(test_run.id)},
|
||||
{
|
||||
"$set": {
|
||||
"status": "completed",
|
||||
"aggregateResults": aggregate.model_dump(by_alias=True),
|
||||
"completedAt": datetime.now(timezone.utc)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def update_run_heartbeat(run_id: str):
|
||||
"""
|
||||
Updates the 'lastHeartbeat' timestamp for a TestRun.
|
||||
"""
|
||||
collection = get_collection(TEST_RUNS_COLLECTION)
|
||||
collection.update_one(
|
||||
{"_id": ObjectId(run_id)},
|
||||
{"$set": {"lastHeartbeat": datetime.now(timezone.utc)}}
|
||||
)
|
||||
|
||||
def mark_stale_jobs_as_failed(threshold_minutes: int = 20) -> int:
|
||||
"""
|
||||
Finds any run in 'running' status whose lastHeartbeat is older than
|
||||
`threshold_minutes`, and sets it to 'failed'. Returns the count.
|
||||
"""
|
||||
collection = get_collection(TEST_RUNS_COLLECTION)
|
||||
stale_threshold = datetime.now(timezone.utc) - timedelta(minutes=threshold_minutes)
|
||||
result = collection.update_many(
|
||||
{
|
||||
"status": "running",
|
||||
"lastHeartbeat": {"$lt": stale_threshold}
|
||||
},
|
||||
{
|
||||
"$set": {"status": "failed"}
|
||||
}
|
||||
)
|
||||
return result.modified_count
|
||||
|
||||
#
|
||||
# TestSimulation helpers
|
||||
#
|
||||
|
||||
def get_simulations_for_run(test_run: TestRun) -> list[TestSimulation]:
|
||||
"""
|
||||
Returns all simulations specified by a particular run.
|
||||
"""
|
||||
if test_run is None:
|
||||
return []
|
||||
collection = get_collection(TEST_SIMULATIONS_COLLECTION)
|
||||
simulation_docs = collection.find({
|
||||
"_id": {"$in": [ObjectId(sim_id) for sim_id in test_run.simulationIds]}
|
||||
})
|
||||
|
||||
simulations = []
|
||||
for doc in simulation_docs:
|
||||
simulations.append(
|
||||
TestSimulation(
|
||||
id=str(doc["_id"]),
|
||||
projectId=doc["projectId"],
|
||||
name=doc["name"],
|
||||
scenarioId=doc["scenarioId"],
|
||||
profileId=doc["profileId"],
|
||||
passCriteria=doc["passCriteria"],
|
||||
createdAt=doc["createdAt"],
|
||||
lastUpdatedAt=doc["lastUpdatedAt"]
|
||||
)
|
||||
)
|
||||
return simulations
|
||||
|
||||
def get_scenario_by_id(scenario_id: str) -> TestScenario:
|
||||
"""
|
||||
Returns a TestScenario by its ID.
|
||||
"""
|
||||
collection = get_collection(TEST_SCENARIOS_COLLECTION)
|
||||
doc = collection.find_one({"_id": ObjectId(scenario_id)})
|
||||
if doc:
|
||||
return TestScenario(
|
||||
id=str(doc["_id"]),
|
||||
projectId=doc["projectId"],
|
||||
name=doc["name"],
|
||||
description=doc["description"],
|
||||
createdAt=doc["createdAt"],
|
||||
lastUpdatedAt=doc["lastUpdatedAt"]
|
||||
)
|
||||
return None
|
||||
|
||||
#
|
||||
# TestResult helpers
|
||||
#
|
||||
|
||||
def write_test_result(result: TestResult):
|
||||
"""
|
||||
Writes a test result into the `test_results` collection.
|
||||
"""
|
||||
collection = get_collection(TEST_RESULTS_COLLECTION)
|
||||
collection.insert_one(result.model_dump())
|
||||
29
apps/experimental/simulation_runner/requirements.txt
Normal file
29
apps/experimental/simulation_runner/requirements.txt
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
annotated-types==0.7.0
|
||||
anyio==4.8.0
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
distro==1.9.0
|
||||
dnspython==2.7.0
|
||||
h11==0.14.0
|
||||
httpcore==1.0.7
|
||||
httpx==0.28.1
|
||||
idna==3.10
|
||||
iniconfig==2.0.0
|
||||
jiter==0.8.2
|
||||
motor==3.7.0
|
||||
openai==1.63.0
|
||||
packaging==24.2
|
||||
pluggy==1.5.0
|
||||
pydantic==2.10.6
|
||||
pydantic_core==2.27.2
|
||||
pymongo==4.11.1
|
||||
pytest==8.3.4
|
||||
pytest-asyncio==0.25.3
|
||||
python-dateutil==2.9.0.post0
|
||||
requests==2.32.3
|
||||
rowboat==2.1.0
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
tqdm==4.67.1
|
||||
typing_extensions==4.12.2
|
||||
urllib3==2.3.0
|
||||
50
apps/experimental/simulation_runner/scenario_types.py
Normal file
50
apps/experimental/simulation_runner/scenario_types.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional, List, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Define run statuses to include the new "error" status
|
||||
RunStatus = Literal["pending", "running", "completed", "cancelled", "failed", "error"]
|
||||
|
||||
class TestScenario(BaseModel):
|
||||
# `_id` in Mongo will be stored as ObjectId; we return it as a string
|
||||
id: str
|
||||
projectId: str
|
||||
name: str
|
||||
description: str
|
||||
createdAt: datetime
|
||||
lastUpdatedAt: datetime
|
||||
|
||||
class TestSimulation(BaseModel):
|
||||
id: str
|
||||
projectId: str
|
||||
name: str
|
||||
scenarioId: str
|
||||
profileId: str
|
||||
passCriteria: str
|
||||
createdAt: datetime
|
||||
lastUpdatedAt: datetime
|
||||
|
||||
class AggregateResults(BaseModel):
|
||||
total: int
|
||||
passCount: int
|
||||
failCount: int
|
||||
|
||||
class TestRun(BaseModel):
|
||||
id: str
|
||||
projectId: str
|
||||
name: str
|
||||
simulationIds: List[str]
|
||||
workflowId: str
|
||||
status: RunStatus
|
||||
startedAt: datetime
|
||||
completedAt: Optional[datetime] = None
|
||||
aggregateResults: Optional[AggregateResults] = None
|
||||
lastHeartbeat: Optional[datetime] = None
|
||||
|
||||
class TestResult(BaseModel):
|
||||
projectId: str
|
||||
runId: str
|
||||
simulationId: str
|
||||
result: Literal["pass", "fail"]
|
||||
details: str
|
||||
transcript: str
|
||||
120
apps/experimental/simulation_runner/service.py
Normal file
120
apps/experimental/simulation_runner/service.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
# Updated imports from your new db module and scenario_types
|
||||
from db import (
|
||||
get_pending_run,
|
||||
get_simulations_for_run,
|
||||
set_run_to_completed,
|
||||
get_api_key,
|
||||
mark_stale_jobs_as_failed,
|
||||
update_run_heartbeat
|
||||
)
|
||||
from scenario_types import TestRun, TestSimulation
|
||||
# If you have a new simulation function, import it here.
|
||||
# Otherwise, adapt the name as needed:
|
||||
from simulation import simulate_simulations # or simulate_scenarios, if unchanged
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
class JobService:
|
||||
def __init__(self):
|
||||
self.poll_interval = 5 # seconds
|
||||
# Control concurrency of run processing
|
||||
self.semaphore = asyncio.Semaphore(5)
|
||||
|
||||
async def poll_and_process_jobs(self, max_iterations: Optional[int] = None):
|
||||
"""
|
||||
Periodically checks for new runs in MongoDB and processes them.
|
||||
"""
|
||||
# Start the stale-run check in the background
|
||||
asyncio.create_task(self.fail_stale_runs_loop())
|
||||
|
||||
iterations = 0
|
||||
while True:
|
||||
run = get_pending_run() # <--- changed to match new DB function
|
||||
if run:
|
||||
logging.info(f"Found new run: {run}. Processing...")
|
||||
asyncio.create_task(self.process_run(run))
|
||||
|
||||
iterations += 1
|
||||
if max_iterations is not None and iterations >= max_iterations:
|
||||
break
|
||||
|
||||
# Sleep for the polling interval
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
async def process_run(self, run: TestRun):
|
||||
"""
|
||||
Calls the simulation function and updates run status upon completion.
|
||||
"""
|
||||
async with self.semaphore:
|
||||
# Start heartbeat in background
|
||||
stop_heartbeat_event = asyncio.Event()
|
||||
heartbeat_task = asyncio.create_task(self.heartbeat_loop(run.id, stop_heartbeat_event))
|
||||
|
||||
try:
|
||||
# Fetch the simulations associated with this run
|
||||
simulations = get_simulations_for_run(run)
|
||||
if not simulations:
|
||||
logging.info(f"No simulations found for run {run.id}")
|
||||
return
|
||||
|
||||
# Fetch API key if needed
|
||||
api_key = get_api_key(run.projectId)
|
||||
|
||||
# Perform your simulation logic
|
||||
# adapt this call to your actual simulation function’s signature
|
||||
aggregate_result = await simulate_simulations(
|
||||
simulations=simulations,
|
||||
run_id=run.id,
|
||||
workflow_id=run.workflowId,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
# Mark run as completed with the aggregated result
|
||||
set_run_to_completed(run, aggregate_result)
|
||||
logging.info(f"Run {run.id} completed.")
|
||||
except Exception as exc:
|
||||
logging.error(f"Run {run.id} failed: {exc}")
|
||||
finally:
|
||||
stop_heartbeat_event.set()
|
||||
await heartbeat_task
|
||||
|
||||
async def fail_stale_runs_loop(self):
|
||||
"""
|
||||
Periodically checks for stale runs (no heartbeat) and marks them as 'failed'.
|
||||
"""
|
||||
while True:
|
||||
count = mark_stale_jobs_as_failed()
|
||||
if count > 0:
|
||||
logging.warning(f"Marked {count} stale runs as failed.")
|
||||
await asyncio.sleep(60) # Check every 60 seconds
|
||||
|
||||
async def heartbeat_loop(self, run_id: str, stop_event: asyncio.Event):
|
||||
"""
|
||||
Periodically updates 'lastHeartbeat' for the given run until 'stop_event' is set.
|
||||
"""
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
update_run_heartbeat(run_id)
|
||||
await asyncio.sleep(10) # Heartbeat interval in seconds
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Entry point to start the service event loop.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(self.poll_and_process_jobs())
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Service stopped by user.")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
service = JobService()
|
||||
service.start()
|
||||
198
apps/experimental/simulation_runner/simulation.py
Normal file
198
apps/experimental/simulation_runner/simulation.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
import json
|
||||
import os
|
||||
from openai import OpenAI
|
||||
|
||||
from scenario_types import TestSimulation, TestResult, AggregateResults, TestScenario
|
||||
|
||||
from db import write_test_result, get_scenario_by_id
|
||||
from rowboat import Client, StatefulChat
|
||||
|
||||
openai_client = OpenAI()
|
||||
MODEL_NAME = "gpt-4o"
|
||||
ROWBOAT_API_HOST = os.environ.get("ROWBOAT_API_HOST", "http://127.0.0.1:3000").strip()
|
||||
|
||||
async def simulate_simulation(
|
||||
scenario: TestScenario,
|
||||
profile_id: str,
|
||||
pass_criteria: str,
|
||||
rowboat_client: Client,
|
||||
workflow_id: str,
|
||||
max_iterations: int = 5
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
Runs a mock simulation for a given TestSimulation asynchronously.
|
||||
After simulating several turns of conversation, it evaluates the conversation.
|
||||
Returns a tuple of (evaluation_result, details, transcript_str).
|
||||
"""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
pass_criteria = pass_criteria
|
||||
|
||||
# Todo: add profile_id
|
||||
support_chat = StatefulChat(
|
||||
rowboat_client,
|
||||
workflow_id=workflow_id,
|
||||
test_profile_id=profile_id
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
f"You are role playing a customer talking to a chatbot (the user is role playing the chatbot). Have the following chat with the chatbot. Scenario:\n{scenario.description}. You are provided no other information. If the chatbot asks you for information that is not in context, go ahead and provide one unless stated otherwise in the scenario. Directly have the chat with the chatbot. Start now with your first message."
|
||||
)
|
||||
}
|
||||
]
|
||||
|
||||
# -------------------------
|
||||
# (1) MAIN SIMULATION LOOP
|
||||
# -------------------------
|
||||
for _ in range(max_iterations):
|
||||
openai_input = messages
|
||||
|
||||
# Run OpenAI API call in a separate thread (non-blocking)
|
||||
simulated_user_response = await loop.run_in_executor(
|
||||
None, # default ThreadPool
|
||||
lambda: openai_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=openai_input,
|
||||
temperature=0.0,
|
||||
)
|
||||
)
|
||||
|
||||
simulated_content = simulated_user_response.choices[0].message.content.strip()
|
||||
messages.append({"role": "assistant", "content": simulated_content})
|
||||
# Run Rowboat chat in a thread if it's synchronous
|
||||
rowboat_response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: support_chat.run(simulated_content)
|
||||
)
|
||||
|
||||
messages.append({"role": "user", "content": rowboat_response})
|
||||
|
||||
# -------------------------
|
||||
# (2) EVALUATION STEP
|
||||
# -------------------------
|
||||
# swap the roles of the assistant and the user
|
||||
transcript_str = ""
|
||||
for m in messages:
|
||||
if m.get("role") == "assistant":
|
||||
m["role"] = "user"
|
||||
elif m.get("role") == "user":
|
||||
m["role"] = "assistant"
|
||||
role = m.get("role", "unknown")
|
||||
content = m.get("content", "")
|
||||
transcript_str += f"{role.upper()}: {content}\n"
|
||||
|
||||
# Store the transcript as a JSON string
|
||||
transcript = json.dumps(messages)
|
||||
|
||||
# We use passCriteria as the evaluation "criteria."
|
||||
evaluation_prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
f"You are a neutral evaluator. Evaluate based on these criteria:\n"
|
||||
f"{pass_criteria}\n\n"
|
||||
"Return ONLY a JSON object in this format:\n"
|
||||
'{"verdict": "pass", "details": <reason>} or '
|
||||
'{"verdict": "fail", "details": <reason>}.'
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Here is the conversation transcript:\n\n{transcript_str}\n\n"
|
||||
"Did the support bot answer correctly or not? "
|
||||
"Return only 'pass' or 'fail' for verdict, and a brief explanation for details."
|
||||
)
|
||||
}
|
||||
]
|
||||
|
||||
# Run evaluation in a separate thread
|
||||
eval_response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: openai_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=evaluation_prompt,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
)
|
||||
|
||||
if not eval_response.choices:
|
||||
raise Exception("No evaluation response received from model")
|
||||
|
||||
response_json_str = eval_response.choices[0].message.content
|
||||
# Attempt to parse the JSON
|
||||
response_json = json.loads(response_json_str)
|
||||
evaluation_result = response_json.get("verdict")
|
||||
details = response_json.get("details")
|
||||
|
||||
if evaluation_result is None:
|
||||
raise Exception("No 'verdict' field found in evaluation response")
|
||||
|
||||
return (evaluation_result, details, transcript)
|
||||
|
||||
async def simulate_simulations(
|
||||
simulations: List[TestSimulation],
|
||||
run_id: str,
|
||||
workflow_id: str,
|
||||
api_key: str,
|
||||
max_iterations: int = 5
|
||||
) -> AggregateResults:
|
||||
"""
|
||||
Simulates a list of TestSimulations asynchronously and aggregates the results.
|
||||
"""
|
||||
if not simulations:
|
||||
# Return an empty result if there's nothing to simulate
|
||||
return AggregateResults(total=0, pass_=0, fail=0)
|
||||
|
||||
project_id = simulations[0].projectId
|
||||
|
||||
client = Client(
|
||||
host=ROWBOAT_API_HOST,
|
||||
project_id=project_id,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
# Store results here
|
||||
results: List[TestResult] = []
|
||||
|
||||
for simulation in simulations:
|
||||
verdict, details, transcript = await simulate_simulation(
|
||||
scenario=get_scenario_by_id(simulation.scenarioId),
|
||||
profile_id=simulation.profileId,
|
||||
pass_criteria=simulation.passCriteria,
|
||||
rowboat_client=client,
|
||||
workflow_id=workflow_id,
|
||||
max_iterations=max_iterations
|
||||
)
|
||||
|
||||
# Create a new TestResult
|
||||
test_result = TestResult(
|
||||
projectId=project_id,
|
||||
runId=run_id,
|
||||
simulationId=simulation.id,
|
||||
result=verdict,
|
||||
details=details,
|
||||
transcript=transcript
|
||||
)
|
||||
results.append(test_result)
|
||||
|
||||
# Persist the test result
|
||||
write_test_result(test_result)
|
||||
|
||||
# Aggregate pass/fail
|
||||
total_count = len(results)
|
||||
pass_count = sum(1 for r in results if r.result == "pass")
|
||||
fail_count = sum(1 for r in results if r.result == "fail")
|
||||
|
||||
return AggregateResults(
|
||||
total=total_count,
|
||||
passCount=pass_count,
|
||||
failCount=fail_count
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue