mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-26 08:56:22 +02:00
Add support for other providers - litellm, openrouter
This commit is contained in:
parent
8c2c21a239
commit
14eee3e0c3
24 changed files with 398 additions and 95 deletions
|
|
@ -91,6 +91,7 @@ async def chat():
|
|||
start_agent_name=data.get("startAgent", ""),
|
||||
agent_configs=data.get("agents", []),
|
||||
tool_configs=data.get("tools", []),
|
||||
prompt_configs=data.get("prompts", []),
|
||||
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
|
||||
state=data.get("state", {}),
|
||||
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
|
||||
|
|
@ -157,6 +158,7 @@ async def chat_stream():
|
|||
start_agent_name=request_data.get("startAgent", ""),
|
||||
agent_configs=request_data.get("agents", []),
|
||||
tool_configs=request_data.get("tools", []),
|
||||
prompt_configs=request_data.get("prompts", []),
|
||||
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
|
||||
state=request_data.get("state", {}),
|
||||
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
|
||||
|
|
@ -168,6 +170,9 @@ async def chat_stream():
|
|||
elif event_type == 'done':
|
||||
print("Yielding done:")
|
||||
yield format_sse(event_data, "done")
|
||||
elif event_type == 'error':
|
||||
print("Yielding error:")
|
||||
yield format_sse(event_data, "stream_error")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming error: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import logging
|
|||
from .helpers.access import (
|
||||
get_agent_by_name,
|
||||
get_external_tools,
|
||||
get_prompt_by_type
|
||||
)
|
||||
from .helpers.state import (
|
||||
construct_state_from_response
|
||||
|
|
@ -14,7 +15,8 @@ from .helpers.state import (
|
|||
from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name
|
||||
from .swarm_wrapper import run as swarm_run, run_streamed as swarm_run_streamed, create_response, get_agents
|
||||
from src.utils.common import common_logger as logger
|
||||
import asyncio
|
||||
|
||||
from .types import PromptType
|
||||
|
||||
# Create a dedicated logger for swarm wrapper
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -43,6 +45,26 @@ def order_messages(messages):
|
|||
ordered_messages.append(ordered)
|
||||
return ordered_messages
|
||||
|
||||
def set_sys_message(messages):
|
||||
"""
|
||||
If the system message is empty, set it to the default message: "You are a helplful assistant."
|
||||
"""
|
||||
if not any(msg.get("role") == "system" for msg in messages):
|
||||
messages.insert(0, {
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
})
|
||||
print("Inserted system message: ", messages[0])
|
||||
logger.info("Inserted system message: ", messages[0])
|
||||
|
||||
elif messages[0].get("role") == "system" and messages[0].get("content") == "":
|
||||
messages[0]["content"] = "You are a helpful assistant."
|
||||
print("Updated system message: ", messages[0])
|
||||
logger.info("Updated system message: ", messages[0])
|
||||
print("Messages: ", messages)
|
||||
# logger.info("Messages: ", messages)
|
||||
|
||||
return messages
|
||||
|
||||
def clean_up_history(agent_data):
|
||||
"""
|
||||
|
|
@ -197,7 +219,6 @@ async def run_turn(
|
|||
logger.info(f"Completed run of agent: {last_new_agent.name}")
|
||||
print(f"Completed run of agent: {last_new_agent.name}")
|
||||
|
||||
|
||||
# Otherwise, duplicate the last response as external
|
||||
logger.info("No post-processing agent found. Duplicating last response and setting to external.")
|
||||
print("No post-processing agent found. Duplicating last response and setting to external.")
|
||||
|
|
@ -236,13 +257,41 @@ async def run_turn_streamed(
|
|||
start_agent_name,
|
||||
agent_configs,
|
||||
tool_configs,
|
||||
prompt_configs,
|
||||
start_turn_with_start_agent,
|
||||
state={},
|
||||
additional_tool_configs=[],
|
||||
complete_request={}
|
||||
):
|
||||
messages = set_sys_message(messages)
|
||||
is_greeting_turn = not any(msg.get("role") != "system" for msg in messages)
|
||||
final_state = None # Initialize outside try block
|
||||
try:
|
||||
greeting_prompt = get_prompt_by_type(prompt_configs, PromptType.GREETING)
|
||||
if is_greeting_turn:
|
||||
if not greeting_prompt:
|
||||
greeting_prompt = "How can I help you today?"
|
||||
print("Greeting prompt not found. Using default: ", greeting_prompt)
|
||||
message = {
|
||||
'content': greeting_prompt,
|
||||
'role': 'assistant',
|
||||
'sender': start_agent_name,
|
||||
'tool_calls': None,
|
||||
'tool_call_id': None,
|
||||
'tool_name': None,
|
||||
'response_type': 'external'
|
||||
}
|
||||
print("Yielding greeting message: ", message)
|
||||
yield ('message', message)
|
||||
|
||||
final_state = {
|
||||
"last_agent_name": start_agent_name if start_agent_name else None,
|
||||
"tokens": {"total": 0, "prompt": 0, "completion": 0}
|
||||
}
|
||||
print("Yielding done message")
|
||||
yield ('done', {'state': final_state})
|
||||
return
|
||||
|
||||
# Initialize agents and get external tools
|
||||
new_agents = get_agents(agent_configs=agent_configs, tool_configs=tool_configs, complete_request=complete_request)
|
||||
last_agent_name = get_last_agent_name(
|
||||
|
|
@ -274,7 +323,7 @@ async def run_turn_streamed(
|
|||
|
||||
# Handle raw response events and accumulate tokens
|
||||
if event.type == "raw_response_event":
|
||||
if hasattr(event.data, 'type') and event.data.type == "response.completed":
|
||||
if hasattr(event.data, 'type') and event.data.type == "response.completed" and event.data.response.usage:
|
||||
if hasattr(event.data.response, 'usage'):
|
||||
tokens_used["total"] += event.data.response.usage.total_tokens
|
||||
tokens_used["prompt"] += event.data.response.usage.input_tokens
|
||||
|
|
@ -616,4 +665,5 @@ async def run_turn_streamed(
|
|||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print(f"Error in stream processing: {str(e)}")
|
||||
print("Yielding error event:", {'error': str(e), 'state': final_state})
|
||||
yield ('error', {'error': str(e), 'state': final_state}) # Include final_state in error response
|
||||
|
|
@ -3,6 +3,7 @@ import json
|
|||
import aiohttp
|
||||
import jwt
|
||||
import hashlib
|
||||
from agents import OpenAIChatCompletionsModel
|
||||
|
||||
# Import helper functions needed for get_agents
|
||||
from .helpers.access import (
|
||||
|
|
@ -31,6 +32,8 @@ MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").s
|
|||
mongo_client = MongoClient(MONGO_URI)
|
||||
db = mongo_client["rowboat"]
|
||||
|
||||
from src.utils.client import client, PROVIDER_DEFAULT_MODEL
|
||||
|
||||
class NewResponse(BaseModel):
|
||||
messages: List[Dict]
|
||||
agent: Optional[Any] = None
|
||||
|
|
@ -47,7 +50,9 @@ async def mock_tool(tool_name: str, args: str, description: str, mock_instructio
|
|||
]
|
||||
|
||||
print(f"Generating simulated response for tool: {tool_name}")
|
||||
response_content = generate_openai_output(messages, output_type='text', model="gpt-4o")
|
||||
response_content = None
|
||||
response_content = generate_openai_output(messages, output_type='text', model=PROVIDER_DEFAULT_MODEL)
|
||||
print("Custom provider client not found, using default model: gpt-4o")
|
||||
return response_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error in mock_tool: {str(e)}")
|
||||
|
|
@ -173,8 +178,6 @@ def get_rag_tool(config: dict, complete_request: dict) -> FunctionTool:
|
|||
else:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def get_agents(agent_configs, tool_configs, complete_request):
|
||||
"""
|
||||
Creates and initializes Agent objects based on their configurations and connections.
|
||||
|
|
@ -246,12 +249,15 @@ def get_agents(agent_configs, tool_configs, complete_request):
|
|||
# add the name and description to the agent instructions
|
||||
agent_instructions = f"## Your Name\n{agent_config['name']}\n\n## Description\n{agent_config['description']}\n\n## Instructions\n{agent_config['instructions']}"
|
||||
try:
|
||||
model_name = agent_config["model"] if agent_config["model"] else PROVIDER_DEFAULT_MODEL
|
||||
print(f"Using model: {model_name}")
|
||||
model=OpenAIChatCompletionsModel(model=model_name, openai_client=client) if client else agent_config["model"]
|
||||
new_agent = NewAgent(
|
||||
name=agent_config["name"],
|
||||
instructions=agent_instructions,
|
||||
handoff_description=agent_config["description"],
|
||||
tools=new_tools,
|
||||
model=agent_config["model"],
|
||||
model = model,
|
||||
model_settings=ModelSettings(temperature=0.0)
|
||||
)
|
||||
|
||||
|
|
|
|||
32
apps/rowboat_agents/src/utils/client.py
Normal file
32
apps/rowboat_agents/src/utils/client.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import os
|
||||
import logging
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
|
||||
PROVIDER_BASE_URL = os.getenv('PROVIDER_BASE_URL', '')
|
||||
PROVIDER_API_KEY = os.getenv('PROVIDER_API_KEY', os.getenv('OPENAI_API_KEY', ''))
|
||||
PROVIDER_DEFAULT_MODEL = os.getenv('PROVIDER_DEFAULT_MODEL', 'gpt-4.1')
|
||||
|
||||
client = None
|
||||
if not PROVIDER_API_KEY:
|
||||
raise ValueError("No LLM Provider API key found")
|
||||
|
||||
if PROVIDER_BASE_URL:
|
||||
print(f"Using provider {PROVIDER_BASE_URL} with API key {PROVIDER_API_KEY}")
|
||||
client = AsyncOpenAI(base_url=PROVIDER_BASE_URL, api_key=PROVIDER_API_KEY)
|
||||
else:
|
||||
print("No provider base URL configured, using OpenAI directly")
|
||||
|
||||
completions_client = None
|
||||
if PROVIDER_BASE_URL:
|
||||
print(f"Using provider {PROVIDER_BASE_URL} for completions")
|
||||
completions_client = OpenAI(
|
||||
base_url=PROVIDER_BASE_URL,
|
||||
api_key=PROVIDER_API_KEY
|
||||
)
|
||||
else:
|
||||
print(f"Using OpenAI directly for completions")
|
||||
completions_client = OpenAI(
|
||||
api_key=PROVIDER_API_KEY
|
||||
)
|
||||
|
|
@ -7,6 +7,7 @@ import time
|
|||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
|
||||
from src.utils.client import completions_client
|
||||
load_dotenv()
|
||||
|
||||
def setup_logger(name, log_file='./run.log', level=logging.INFO, log_to_file=False):
|
||||
|
|
@ -53,31 +54,28 @@ def get_api_key(key_name):
|
|||
raise ValueError(f"{key_name} not found. Did you set it in the .env file?")
|
||||
return api_key
|
||||
|
||||
openai_client = OpenAI(
|
||||
api_key=get_api_key("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
def generate_gpt4o_output_from_multi_turn_conv(messages, output_type='json', model="gpt-4o"):
|
||||
return generate_openai_output(messages, output_type, model)
|
||||
|
||||
def generate_openai_output(messages, output_type='not_json', model="gpt-4o", return_completion=False):
|
||||
print(f"In generate_openai_output, using client: {completions_client} and model: {model}")
|
||||
try:
|
||||
if output_type == 'json':
|
||||
chat_completion = openai_client.chat.completions.create(
|
||||
messages=messages,
|
||||
chat_completion = completions_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
else:
|
||||
chat_completion = openai_client.chat.completions.create(
|
||||
messages=messages,
|
||||
chat_completion = completions_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
if return_completion:
|
||||
return chat_completion
|
||||
return chat_completion.choices[0].message.content
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue