Add support for other providers - litellm, openrouter

This commit is contained in:
akhisud3195 2025-04-25 23:50:26 +05:30 committed by Ramnique Singh
parent 8c2c21a239
commit 14eee3e0c3
24 changed files with 398 additions and 95 deletions

View file

@ -7,6 +7,7 @@ import logging
from .helpers.access import (
get_agent_by_name,
get_external_tools,
get_prompt_by_type
)
from .helpers.state import (
construct_state_from_response
@ -14,7 +15,8 @@ from .helpers.state import (
from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name
from .swarm_wrapper import run as swarm_run, run_streamed as swarm_run_streamed, create_response, get_agents
from src.utils.common import common_logger as logger
import asyncio
from .types import PromptType
# Create a dedicated logger for swarm wrapper
logger.setLevel(logging.INFO)
@ -43,6 +45,26 @@ def order_messages(messages):
ordered_messages.append(ordered)
return ordered_messages
def set_sys_message(messages):
"""
If the system message is empty, set it to the default message: "You are a helplful assistant."
"""
if not any(msg.get("role") == "system" for msg in messages):
messages.insert(0, {
"role": "system",
"content": "You are a helpful assistant."
})
print("Inserted system message: ", messages[0])
logger.info("Inserted system message: ", messages[0])
elif messages[0].get("role") == "system" and messages[0].get("content") == "":
messages[0]["content"] = "You are a helpful assistant."
print("Updated system message: ", messages[0])
logger.info("Updated system message: ", messages[0])
print("Messages: ", messages)
# logger.info("Messages: ", messages)
return messages
def clean_up_history(agent_data):
"""
@ -197,7 +219,6 @@ async def run_turn(
logger.info(f"Completed run of agent: {last_new_agent.name}")
print(f"Completed run of agent: {last_new_agent.name}")
# Otherwise, duplicate the last response as external
logger.info("No post-processing agent found. Duplicating last response and setting to external.")
print("No post-processing agent found. Duplicating last response and setting to external.")
@ -236,13 +257,41 @@ async def run_turn_streamed(
start_agent_name,
agent_configs,
tool_configs,
prompt_configs,
start_turn_with_start_agent,
state={},
additional_tool_configs=[],
complete_request={}
):
messages = set_sys_message(messages)
is_greeting_turn = not any(msg.get("role") != "system" for msg in messages)
final_state = None # Initialize outside try block
try:
greeting_prompt = get_prompt_by_type(prompt_configs, PromptType.GREETING)
if is_greeting_turn:
if not greeting_prompt:
greeting_prompt = "How can I help you today?"
print("Greeting prompt not found. Using default: ", greeting_prompt)
message = {
'content': greeting_prompt,
'role': 'assistant',
'sender': start_agent_name,
'tool_calls': None,
'tool_call_id': None,
'tool_name': None,
'response_type': 'external'
}
print("Yielding greeting message: ", message)
yield ('message', message)
final_state = {
"last_agent_name": start_agent_name if start_agent_name else None,
"tokens": {"total": 0, "prompt": 0, "completion": 0}
}
print("Yielding done message")
yield ('done', {'state': final_state})
return
# Initialize agents and get external tools
new_agents = get_agents(agent_configs=agent_configs, tool_configs=tool_configs, complete_request=complete_request)
last_agent_name = get_last_agent_name(
@ -274,7 +323,7 @@ async def run_turn_streamed(
# Handle raw response events and accumulate tokens
if event.type == "raw_response_event":
if hasattr(event.data, 'type') and event.data.type == "response.completed":
if hasattr(event.data, 'type') and event.data.type == "response.completed" and event.data.response.usage:
if hasattr(event.data.response, 'usage'):
tokens_used["total"] += event.data.response.usage.total_tokens
tokens_used["prompt"] += event.data.response.usage.input_tokens
@ -616,4 +665,5 @@ async def run_turn_streamed(
except Exception as e:
print(traceback.format_exc())
print(f"Error in stream processing: {str(e)}")
print("Yielding error event:", {'error': str(e), 'state': final_state})
yield ('error', {'error': str(e), 'state': final_state}) # Include final_state in error response

View file

@ -3,6 +3,7 @@ import json
import aiohttp
import jwt
import hashlib
from agents import OpenAIChatCompletionsModel
# Import helper functions needed for get_agents
from .helpers.access import (
@ -31,6 +32,8 @@ MONGO_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/rowboat").s
mongo_client = MongoClient(MONGO_URI)
db = mongo_client["rowboat"]
from src.utils.client import client, PROVIDER_DEFAULT_MODEL
class NewResponse(BaseModel):
messages: List[Dict]
agent: Optional[Any] = None
@ -47,7 +50,9 @@ async def mock_tool(tool_name: str, args: str, description: str, mock_instructio
]
print(f"Generating simulated response for tool: {tool_name}")
response_content = generate_openai_output(messages, output_type='text', model="gpt-4o")
response_content = None
response_content = generate_openai_output(messages, output_type='text', model=PROVIDER_DEFAULT_MODEL)
print("Custom provider client not found, using default model: gpt-4o")
return response_content
except Exception as e:
logger.error(f"Error in mock_tool: {str(e)}")
@ -173,8 +178,6 @@ def get_rag_tool(config: dict, complete_request: dict) -> FunctionTool:
else:
return None
def get_agents(agent_configs, tool_configs, complete_request):
"""
Creates and initializes Agent objects based on their configurations and connections.
@ -246,12 +249,15 @@ def get_agents(agent_configs, tool_configs, complete_request):
# add the name and description to the agent instructions
agent_instructions = f"## Your Name\n{agent_config['name']}\n\n## Description\n{agent_config['description']}\n\n## Instructions\n{agent_config['instructions']}"
try:
model_name = agent_config["model"] if agent_config["model"] else PROVIDER_DEFAULT_MODEL
print(f"Using model: {model_name}")
model=OpenAIChatCompletionsModel(model=model_name, openai_client=client) if client else agent_config["model"]
new_agent = NewAgent(
name=agent_config["name"],
instructions=agent_instructions,
handoff_description=agent_config["description"],
tools=new_tools,
model=agent_config["model"],
model = model,
model_settings=ModelSettings(temperature=0.0)
)