Merge changes v1

This commit is contained in:
akhisud3195 2025-03-25 15:37:51 +05:30 committed by Ramnique Singh
parent b2fd9bf877
commit 24efe0e887
45 changed files with 2940 additions and 294 deletions

View file

View file

View file

@ -0,0 +1,200 @@
from flask import Flask, request, jsonify, Response
from datetime import datetime
from functools import wraps
import os
import redis
import uuid
import json
import asyncio
from hypercorn.config import Config
from hypercorn.asyncio import serve
from src.graph.core import run_turn, run_turn_streamed
from src.graph.tools import RAG_TOOL, CLOSE_CHAT_TOOL
from src.utils.common import common_logger, read_json_from_file
from pprint import pprint
logger = common_logger
redis_client = redis.from_url(os.environ.get('REDIS_URL', 'redis://localhost:6379'))
app = Flask(__name__)
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "ok"})
@app.route("/")
def home():
return "Hello, World!"
def require_api_key(f):
@wraps(f)
def decorated(*args, **kwargs):
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({'error': 'Missing or invalid authorization header'}), 401
token = auth_header.split('Bearer ')[1]
actual = os.environ.get('API_KEY', '').strip()
if actual and token != actual:
return jsonify({'error': 'Invalid API key'}), 403
return f(*args, **kwargs)
return decorated
@app.route("/chat", methods=["POST"])
@require_api_key
def chat():
logger.info('='*100)
logger.info(f"{'*'*100}Running server mode{'*'*100}")
try:
data = request.get_json()
logger.info('Complete request:')
logger.info(data)
logger.info('-'*100)
start_time = datetime.now()
config = read_json_from_file("./configs/default_config.json")
logger.info('Beginning turn')
resp_messages, resp_tokens_used, resp_state = run_turn(
messages=data.get("messages", []),
start_agent_name=data.get("startAgent", ""),
agent_configs=data.get("agents", []),
tool_configs=data.get("tools", []),
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
state=data.get("state", {}),
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
complete_request=data
)
logger.info('-'*100)
logger.info('Raw output:')
logger.info((resp_messages, resp_tokens_used, resp_state))
out = {
"messages": resp_messages,
"tokens_used": resp_tokens_used,
"state": resp_state,
}
logger.info("Output:")
for k, v in out.items():
logger.info(f"{k}: {v}")
logger.info('*'*100)
logger.info('='*100)
logger.info(f"Processing time: {datetime.now() - start_time}")
return jsonify(out)
except Exception as e:
logger.error(f"Error: {e}")
return jsonify({"error": str(e)}), 500
@app.route("/chat_stream_init", methods=["POST"])
@require_api_key
def chat_stream_init():
# create a uuid for the stream
stream_id = str(uuid.uuid4())
# store the request data in redis with 10 minute TTL
data = request.get_json()
redis_client.setex(f"stream_request_{stream_id}", 600, json.dumps(data))
return jsonify({"stream_id": stream_id})
@app.route("/chat_stream/<stream_id>", methods=["GET"])
@require_api_key
def chat_stream(stream_id):
# get the request data from redis
request_data = redis_client.get(f"stream_request_{stream_id}")
if not request_data:
return jsonify({"error": "Stream not found"}), 404
request_data = json.loads(request_data)
config = read_json_from_file("./configs/default_config.json")
# Preprocess messages to handle null content and role issues
for msg in request_data["messages"]:
# Handle null content in assistant messages with tool calls
if (msg.get("role") == "assistant" and
msg.get("content") is None and
msg.get("tool_calls") is not None and
len(msg.get("tool_calls")) > 0):
msg["content"] = "Calling tool"
# Handle role issues
if msg.get("role") == "tool":
msg["role"] = "developer"
elif not msg.get("role"):
msg["role"] = "user"
print('*'*200)
print("Request:")
print('*'*200)
pprint(request_data)
print('='*200)
async def process_stream():
try:
async for event_type, event_data in run_turn_streamed(
messages=request_data.get("messages", []),
start_agent_name=request_data.get("startAgent", ""),
agent_configs=request_data.get("agents", []),
tool_configs=request_data.get("tools", []),
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
state=request_data.get("state", {}),
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
complete_request=request_data
):
if event_type == 'message':
print('*'*200)
print("Yielding message:")
print('*'*200)
to_yield = f"event: message\ndata: {json.dumps(event_data)}\n\n"
print(to_yield)
print('='*200)
yield to_yield
elif event_type == 'done':
print('*'*200)
print("Yielding done:")
print('*'*200)
to_yield = f"event: done\ndata: {json.dumps(event_data)}\n\n"
print(to_yield)
print('='*200)
yield to_yield
except Exception as e:
logger.error(f"Streaming error: {str(e)}")
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
def generate():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def get_all_chunks():
chunks = []
async for chunk in process_stream():
chunks.append(chunk)
return chunks
chunks = loop.run_until_complete(get_all_chunks())
for chunk in chunks:
yield chunk
except Exception as e:
logger.error(f"Error in generate: {e}")
raise
finally:
loop.close()
return Response(generate(), mimetype='text/event-stream')
if __name__ == "__main__":
print("Starting async server...")
config = Config()
config.bind = ["0.0.0.0:4040"]
asyncio.run(serve(app, config))

View file

@ -0,0 +1,367 @@
from copy import deepcopy
from datetime import datetime
import logging
from .helpers.access import (
get_agent_by_name,
get_external_tools,
)
from .helpers.state import (
construct_state_from_response
)
from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name
from .swarm_wrapper import run as swarm_run, run_streamed as swarm_run_streamed, create_response, get_agents
from src.utils.common import common_logger as logger
import asyncio
# Create a dedicated logger for swarm wrapper
logger.setLevel(logging.INFO)
print("Logger level set to INFO")
def order_messages(messages):
"""
Sorts each message's keys in a specified order and returns a new list of ordered messages.
"""
ordered_messages = []
for msg in messages:
# Filter out None values
msg = {k: v for k, v in msg.items() if v is not None}
# Specify the exact order
ordered = {}
for key in ['role', 'sender', 'content', 'created_at', 'timestamp']:
if key in msg:
ordered[key] = msg[key]
# Add remaining keys in alphabetical order
remaining_keys = sorted(k for k in msg if k not in ordered)
for key in remaining_keys:
ordered[key] = msg[key]
ordered_messages.append(ordered)
return ordered_messages
def clean_up_history(agent_data):
"""
Ensures each agent's history is sorted using order_messages.
"""
for data in agent_data:
data["history"] = order_messages(data["history"])
return agent_data
def create_final_response(response, turn_messages, tokens_used, all_agents):
"""
Constructs the final response data (messages, tokens_used, updated state) that a caller would need.
"""
# Ensure response has a messages attribute
if not hasattr(response, 'messages'):
response.messages = []
# Assign the appropriate messages to the response
response.messages = turn_messages
# Ensure tokens_used is a valid dictionary
if not isinstance(tokens_used, dict):
tokens_used = {"total": 100, "prompt": 50, "completion": 50} # Default values if not a dictionary
# Ensure response has a tokens_used attribute that's a dictionary
if not hasattr(response, 'tokens_used') or not isinstance(response.tokens_used, dict):
response.tokens_used = {}
response.tokens_used = tokens_used
# Ensure response has an agent attribute for state construction
if not hasattr(response, 'agent'):
if all_agents and len(all_agents) > 0:
response.agent = all_agents[0] # Set default agent if missing
new_state = construct_state_from_response(response, all_agents)
return response.messages, response.tokens_used, new_state
def run_turn(
messages, start_agent_name, agent_configs, tool_configs, start_turn_with_start_agent, state={}, additional_tool_configs=[], complete_request={}
):
"""
Coordinates a single 'turn' of conversation or processing among agents.
Includes validation, agent setup, optional greeting logic, error handling, and post-processing steps.
"""
logger.info("Running stateless turn")
print("Running stateless turn")
# Sort messages by the specified ordering
#messages = order_messages(messages)
# Merge any additional tool configs
tool_configs = tool_configs + additional_tool_configs
# Determine if this is a greeting turn
greeting_turn = not any(msg.get("role") != "system" for msg in messages)
turn_messages = []
# Initialize tokens_used as a dictionary
tokens_used = {"total": 0, "prompt": 0, "completion": 0}
agent_data = state.get("agent_data", [])
# If not a greeting turn, localize the last user or system messages
if not greeting_turn:
latest_assistant_msg = get_latest_assistant_msg(messages)
latest_non_assistant_msgs = get_latest_non_assistant_messages(messages)
msg_type = latest_non_assistant_msgs[-1]["role"]
# Determine the last agent from state/config
last_agent_name = get_last_agent_name(
state=state,
agent_configs=agent_configs,
start_agent_name=start_agent_name,
msg_type=msg_type,
latest_assistant_msg=latest_assistant_msg,
start_turn_with_start_agent=start_turn_with_start_agent
)
else:
# For a greeting turn, we assume the last agent is the start_agent_name
last_agent_name = start_agent_name
state["agent_data"] = agent_data
# Initialize all agents
logger.info("Initializing agents")
print("Initializing agents")
new_agents = get_agents(
agent_configs=agent_configs,
tool_configs=tool_configs,
complete_request=complete_request
)
# Prepare escalation agent
last_new_agent = get_agent_by_name(last_agent_name, new_agents)
# Gather external tools for Swarm
external_tools = get_external_tools(tool_configs)
logger.info(f"Found {len(external_tools)} external tools")
print(f"Found {len(external_tools)} external tools")
# If no validation error yet, proceed with the main run
logger.info("Running swarm run")
print("Running swarm run")
response = swarm_run(
agent=last_new_agent,
messages=messages,
external_tools=external_tools,
tokens_used=tokens_used
)
logger.info("Swarm run completed")
print("Swarm run completed")
# Initialize response.messages if it doesn't exist
if not hasattr(response, 'messages'):
response.messages = []
# Convert the ResponseOutputMessage to a standard message format
if hasattr(response, 'new_items') and response.new_items and hasattr(response.new_items[-1], 'raw_item'):
raw_item = response.new_items[-1].raw_item
# Extract text content from ResponseOutputText objects
content = ""
if hasattr(raw_item, 'content') and raw_item.content:
for content_item in raw_item.content:
if hasattr(content_item, 'text'):
content += content_item.text
# Create a standard message dictionary
standard_message = {
"role": raw_item.role if hasattr(raw_item, 'role') else "assistant",
"content": content,
"sender": last_new_agent.name,
"created_at": None,
"response_type": "internal"
}
# Add the converted message to response messages
response.messages.append(standard_message)
logger.info("Converted message added to response messages")
print("Converted message added to response messages")
# Use a dictionary for tokens_used instead of a hard-coded integer
tokens_used = {"total": 100, "prompt": 50, "completion": 50} # Dummy values as placeholders
# Ensure turn_messages can be extended with response.messages
if hasattr(response, 'messages') and isinstance(response.messages, list):
turn_messages.extend(response.messages)
logger.info(f"Completed run of agent: {last_new_agent.name}")
print(f"Completed run of agent: {last_new_agent.name}")
# Otherwise, duplicate the last response as external
logger.info("No post-processing agent found. Duplicating last response and setting to external.")
print("No post-processing agent found. Duplicating last response and setting to external.")
if turn_messages:
duplicate_msg = deepcopy(turn_messages[-1])
duplicate_msg["response_type"] = "external"
duplicate_msg["sender"] += " >> External"
# Ensure tokens_used remains a proper dictionary
if not isinstance(tokens_used, dict):
tokens_used = {"total": 100, "prompt": 50, "completion": 50} # Default values if not a dictionary
response = create_response(
messages=[duplicate_msg],
tokens_used=tokens_used,
agent=last_new_agent,
error_msg=''
)
# Ensure response has messages attribute
if hasattr(response, 'messages') and isinstance(response.messages, list):
turn_messages.extend(response.messages)
# Finalize the response
logger.info("Finalizing response")
print("Finalizing response")
return create_final_response(
response=response,
turn_messages=turn_messages,
tokens_used=tokens_used,
all_agents=new_agents
)
async def run_turn_streamed(
messages,
start_agent_name,
agent_configs,
tool_configs,
start_turn_with_start_agent,
state={},
additional_tool_configs=[],
complete_request={}
):
final_state = None # Initialize outside try block
try:
# Initialize agents and get external tools
new_agents = get_agents(agent_configs=agent_configs, tool_configs=tool_configs, complete_request=complete_request)
last_agent_name = get_last_agent_name(
state=state,
agent_configs=agent_configs,
start_agent_name=start_agent_name,
msg_type="user",
latest_assistant_msg=None,
start_turn_with_start_agent=start_turn_with_start_agent
)
last_new_agent = get_agent_by_name(last_agent_name, new_agents)
external_tools = get_external_tools(tool_configs)
current_agent = last_new_agent
tokens_used = {"total": 0, "prompt": 0, "completion": 0}
stream_result = await swarm_run_streamed(
agent=last_new_agent,
messages=messages,
external_tools=external_tools,
tokens_used=tokens_used
)
# Process streaming events
async for event in stream_result.stream_events():
# print('='*50)
# print("Received event: ", event)
# print('-'*50)
# Handle raw response events and accumulate tokens
if event.type == "raw_response_event":
if hasattr(event.data, 'type') and event.data.type == "response.completed":
if hasattr(event.data.response, 'usage'):
tokens_used["total"] += event.data.response.usage.total_tokens
tokens_used["prompt"] += event.data.response.usage.input_tokens
tokens_used["completion"] += event.data.response.usage.output_tokens
print('-'*50)
print(f"Found usage information. Updated cumulative tokens: {tokens_used}")
print('-'*50)
continue
# Update current agent when it changes
elif event.type == "agent_updated_stream_event":
current_agent = event.new_agent
message = {
'content': f"Agent changed to {current_agent.name}",
'role': 'assistant',
'sender': current_agent.name,
'tool_calls': None,
'tool_call_id': None,
'response_type': 'internal'
}
print("Yielding message: ", message)
yield ('message', message)
continue
# Handle run items (tools, messages, etc)
elif event.type == "run_item_stream_event":
if event.item.type == "tool_call_item":
message = {
'content': None,
'role': 'assistant',
'sender': current_agent.name if current_agent else None,
'tool_calls': [{
'function': {
'name': event.item.raw_item.name,
'arguments': event.item.raw_item.arguments
},
'id': event.item.raw_item.id,
'type': 'function'
}],
'tool_call_id': None,
'tool_name': None,
'response_type': 'internal'
}
print("Yielding message: ", message)
yield ('message', message)
elif event.item.type == "tool_call_output_item":
message = {
'content': str(event.item.output),
'role': 'tool',
'sender': None,
'tool_calls': None,
'tool_call_id': event.item.raw_item['call_id'],
'tool_name': event.item.raw_item.get('name', None),
'response_type': 'internal'
}
print("Yielding message: ", message)
yield ('message', message)
elif event.item.type == "message_output_item":
content = ""
if hasattr(event.item.raw_item, 'content'):
for content_item in event.item.raw_item.content:
if hasattr(content_item, 'text'):
content += content_item.text
message = {
'content': content,
'role': 'assistant',
'sender': current_agent.name,
'tool_calls': None,
'tool_call_id': None,
'tool_name': None,
'response_type': 'external'
}
print("Yielding message: ", message)
yield ('message', message)
print(f"\n{'='*50}\n")
# After all events are processed, set final state
final_state = {
"last_agent_name": current_agent.name if current_agent else None,
"tokens": tokens_used
}
yield ('done', {'state': final_state})
except Exception as e:
import traceback
print(traceback.format_exc())
print(f"Error in stream processing: {str(e)}")
yield ('error', {'error': str(e), 'state': final_state}) # Include final_state in error response

View file

@ -0,0 +1,218 @@
# Guardrails
from src.utils.common import generate_llm_output
import os
import copy
from .swarm_wrapper import Agent, Response, create_response
from src.utils.common import common_logger, generate_openai_output, update_tokens_used
logger = common_logger
def classify_hallucination(context: str, assistant_response: str, chat_history: list, model: str) -> str:
"""
Checks if an assistant's response contains hallucinations by comparing against provided context.
Args:
context (str): The context/knowledge base to check the response against
assistant_response (str): The response from the assistant to validate
chat_history (list): List of previous chat messages for context
Returns:
str: Verdict indicating level of hallucination:
'yes-absolute' - completely supported by context
'yes-common-sensical' - supported with common sense interpretation
'no-absolute' - not supported by context
'no-subtle' - not supported but difference is subtle
"""
chat_history_str = "\n".join([f"{message['role']}: {message['content']}" for message in chat_history])
prompt = f"""
You are a guardrail agent. Your job is to check if the response is hallucinating.
------------------------------------------------------------------------
Here is the context:
{context}
------------------------------------------------------------------------
Here is the chat history message:
{chat_history_str}
------------------------------------------------------------------------
Here is the response:
{assistant_response}
------------------------------------------------------------------------
As a hallucination guardrail, your job is to go through each line of the response and check if it is completely supported by the context. Even if a single line is not supported, the response is no.
Output a single verdict for the entire response. don't provide any reasoning. The output classes are
yes-absolute: completely supported by the context
yes-common-sensical: but with some common sense interpretation
no-absolute: not supported by the context
no-subtle: not supported by the context but the difference is subtle
Output of of the classes:
verdict : yes-absolute/yes-common-sensical/no-absolute/no-subtle
Example 1: The response is completely supported by the context.
User Input:
Context: "Our airline provides complimentary meals and beverages on all international flights. Passengers are allowed one carry-on bag and one personal item."
Chat History:
User: "Do international flights with your airline offer free meals?"
Response: "Yes, all international flights with our airline offer free meals and beverages."
Output: verdict: yes-absolute
Example 2: The response is generally true and could be deduced with common sense interpretation, though not explicitly stated in the context.
User Input:
Context: "Flights may experience delays due to weather conditions. In such cases, the airline staff will provide updates at the airport."
Chat History:
User: "Will there be announcements if my flight is delayed?"
Response: "Yes, if your flight is delayed, there will be announcements at the airport."
Output: verdict: yes-common-sensical
Example 3: The response is not supported by the context and contains glaring inaccuracies.
User Input:
Context: "You can cancel your ticket online up to 24 hours before the flight's departure time and receive a full refund."
Chat History:
User: "Can I get a refund if I cancel 12 hours before the flight?"
Response: "Yes, you can get a refund if you cancel 12 hours before the flight."
Output: verdict: no-absolute
Example 4: The response is not supported by the context but the difference is subtle.
User Input:
Context: "Our frequent flyer program offers discounts on checked bags for members who have achieved Gold status."
Chat History:
User: "As a member, do I get discounts on checked bags?"
Response: "Yes, members of our frequent flyer program get discounts on checked bags."
Output: verdict: no-subtle
"""
messages = [
{
"role": "system",
"content": prompt,
},
]
response = generate_llm_output(messages, model)
return response
def post_process_response(messages: list, post_processing_agent_name: str, post_process_instructions: str, style_prompt: str = None, context: str = None, model: str = "gpt-4o", tokens_used: dict = {}, last_agent: Agent = None) -> dict:
agent_instructions = last_agent.instructions
agent_history = last_agent.history
# agent_instructions = ''
# agent_history = []
pending_msg = copy.deepcopy(messages[-1])
logger.debug(f"Pending message keys: {pending_msg.keys()}")
skip = False
if pending_msg.get("tool_calls"):
logger.info("Last message is a tool call, skipping post processing and setting last message to external")
skip = True
elif not pending_msg['response_type'] == "internal":
logger.info("Last message is not internal, skipping post processing and setting last message to external")
skip = True
elif not pending_msg['content']:
logger.info("Last message has no content, skipping post processing and setting last message to external")
skip = True
elif not post_process_instructions:
logger.info("No post process instructions, skipping post processing and setting last message to external")
skip = True
if skip:
pending_msg['response_type'] = "external"
response = Response(
messages=[],
tokens_used=tokens_used,
agent=last_agent,
error_msg=''
)
return response
agent_history_str = f"\n{'*'*100}\n".join([f"Role: {message['role']} | Content: {message.get('content', 'None')} | Tool Calls: {message.get('tool_calls', 'None')}" for message in agent_history[:-1]])
logger.debug(f"Agent history: {agent_history_str}")
prompt = f"""
# ROLE
You are a post processing agent responsible for rewriting a response generated by an agent, according to instructions provided below. Ensure that the response you produce adheres to the instructions provided to you (if any).
------------------------------------------------------------------------
# ADDITIONAL INSTRUCTIONS
Here are additional instructions that the admin might have configured for you:
{post_process_instructions}
------------------------------------------------------------------------
# CHAT HISTORY
Here is the chat history:
{agent_history_str}
"""
if context:
context_prompt = f"""
------------------------------------------------------------------------
# CONTEXT
Here is the context:
{context}
"""
prompt += context_prompt
if style_prompt:
style_prompt = f"""
------------------------------------------------------------------------
# STYLE PROMPT
Here is the style prompt:
{style_prompt}
"""
prompt += style_prompt
agent_response_and_instructions = f"""
------------------------------------------------------------------------
# AGENT INSTRUCTIONS
Here are the instructions to the agent generating the response:
{agent_instructions}
------------------------------------------------------------------------
# AGENT RESPONSE
Here is the response that the agent has generated:
{pending_msg['content']}
"""
prompt += agent_response_and_instructions
logger.debug(f"Sanitizing response for style. Original response: {pending_msg['content']}")
completion = generate_openai_output(
messages=[
{"role": "system", "content": prompt}
],
model = model,
return_completion=True
)
content = completion.choices[0].message.content
if content:
content = content.strip().lstrip().rstrip()
tokens_used = update_tokens_used(provider="openai", model=model, tokens_used=tokens_used, completion=completion)
logger.debug(f"Response after style check: {content}, tokens used: {tokens_used}")
pending_msg['content'] = content if content else pending_msg['content']
pending_msg['response_type'] = "external"
pending_msg['sender'] = pending_msg['sender'] + f' >> {post_processing_agent_name}'
response = Response(
messages=[pending_msg],
tokens_used=tokens_used,
agent=last_agent,
error_msg=''
)
return response

View file

@ -0,0 +1,48 @@
from src.utils.common import common_logger
logger = common_logger
def get_external_tools(tool_configs):
logger.debug("Getting external tools")
tools = [tool["name"] for tool in tool_configs]
logger.debug(f"Found {len(tools)} external tools")
return tools
def get_agent_by_name(agent_name, agents):
agent = next((a for a in agents if getattr(a, "name", None) == agent_name), None)
if not agent:
logger.error(f"Agent with name {agent_name} not found")
raise ValueError(f"Agent with name {agent_name} not found")
return agent
def get_agent_config_by_name(agent_name, agent_configs):
agent_config = next((ac for ac in agent_configs if ac.get("name") == agent_name), None)
if not agent_config:
logger.error(f"Agent config with name {agent_name} not found")
raise ValueError(f"Agent config with name {agent_name} not found")
return agent_config
def pop_agent_config_by_type(agent_configs, agent_type):
agent_config = next((ac for ac in agent_configs if ac.get("type") == agent_type), None)
if agent_config:
agent_configs.remove(agent_config)
return agent_config, agent_configs
def get_agent_by_type(agents, agent_type):
return next((a for a in agents if a.type == agent_type), None)
def get_prompt_by_type(prompt_configs, prompt_type):
return next((pc.get("prompt") for pc in prompt_configs if pc.get("type") == prompt_type), None)
def get_agent_data_by_name(agent_name, agent_data):
for data in agent_data:
name = data.get("name", "")
if name == agent_name:
return data
return None
def get_tool_config_by_name(tool_configs, tool_name):
return next((tc for tc in tool_configs if tc.get("name", "") == tool_name), None)
def get_tool_config_by_type(tool_configs, tool_type):
return next((tc for tc in tool_configs if tc.get("type", "") == tool_type), None)

View file

@ -0,0 +1,50 @@
from .access import get_agent_config_by_name, get_agent_data_by_name
from src.graph.types import ControlType
from src.utils.common import common_logger
logger = common_logger
def get_last_agent_name(state, agent_configs, start_agent_name, msg_type, latest_assistant_msg, start_turn_with_start_agent):
default_last_agent_name = state.get("last_agent_name", '')
last_agent_config = get_agent_config_by_name(default_last_agent_name, agent_configs)
specific_agent_data = get_agent_data_by_name(default_last_agent_name, state.get("agent_data", []))
# Overrides for special cases
logger.info("Setting agent control based on last agent and control type")
if msg_type == "tool":
last_agent_name = default_last_agent_name
assert last_agent_name == latest_assistant_msg.get("sender", ''), "Last agent name does not match sender of latest assistant message during tool call handling"
elif start_turn_with_start_agent:
last_agent_name = start_agent_name
else:
control_type = last_agent_config.get("controlType", ControlType.RETAIN.value)
if control_type == ControlType.PARENT_AGENT.value:
last_agent_name = specific_agent_data.get("most_recent_parent_name", None) if specific_agent_data else None
if not last_agent_name:
logger.error("Most recent parent is empty, defaulting to same agent instead")
last_agent_name = default_last_agent_name
elif control_type == ControlType.START_AGENT.value:
last_agent_name = start_agent_name
else:
last_agent_name = default_last_agent_name
if default_last_agent_name != last_agent_name:
logger.info(f"Last agent name changed from {default_last_agent_name} to {last_agent_name} due to control settings")
return last_agent_name
def get_latest_assistant_msg(messages):
# Find the latest message with role assistant
for i in range(len(messages)-1, -1, -1):
if messages[i].get("role") == "assistant":
return messages[i]
return None
def get_latest_non_assistant_messages(messages):
# Find all messages after the last assistant message
for i in range(len(messages)-1, -1, -1):
if messages[i].get("role") == "assistant":
return messages[i+1:]
return messages

View file

@ -0,0 +1,39 @@
from src.graph.instructions import TRANSFER_CHILDREN_INSTRUCTIONS, TRANSFER_PARENT_AWARE_INSTRUCTIONS, RAG_INSTRUCTIONS, ERROR_ESCALATION_AGENT_INSTRUCTIONS, TRANSFER_GIVE_UP_CONTROL_INSTRUCTIONS, SYSTEM_MESSAGE
def add_transfer_instructions_to_parent_agents(agent, children, transfer_functions):
other_agent_name_descriptions_tools = f'\n{'-'*100}\n'.join([f"Name: {agent.name}\nDescription: {agent.description if agent.description else ''}\nTool for transfer: {transfer_functions[agent.name].__name__}" for agent in children.values()])
prompt = TRANSFER_CHILDREN_INSTRUCTIONS.format(other_agent_name_descriptions_tools=other_agent_name_descriptions_tools)
agent.instructions = agent.instructions + f'\n\n{'-'*100}\n\n' + prompt
return agent
def add_transfer_instructions_to_child_agents(child, children_aware_of_parent):
if children_aware_of_parent:
candidate_parents_name_description_tools = f'\n{'-'*100}\n'.join([f"Name: {parent_name}\nTool for transfer: {func.__name__}" for parent_name, func in child.candidate_parent_functions.items()])
prompt = TRANSFER_PARENT_AWARE_INSTRUCTIONS.format(candidate_parents_name_description_tools=candidate_parents_name_description_tools)
else:
candidate_parents_name_description_tools = f'\n{'-'*100}\n'.join(list(set([f"Tool for transfer: {func.__name__}" for _, func in child.candidate_parent_functions.items()])))
prompt = TRANSFER_GIVE_UP_CONTROL_INSTRUCTIONS.format(candidate_parents_name_description_tools=candidate_parents_name_description_tools)
child.instructions = child.instructions + f'\n\n{'-'*100}\n\n' + prompt
return child
def add_rag_instructions_to_agent(agent_config, rag_tool_name):
prompt = RAG_INSTRUCTIONS.format(rag_tool_name=rag_tool_name)
agent_config["instructions"] = agent_config["instructions"] + f'\n\n{'-'*100}\n\n' + prompt
return agent_config
def add_error_escalation_instructions(agent):
prompt = ERROR_ESCALATION_AGENT_INSTRUCTIONS
agent.instructions = agent.instructions + f'\n\n{'-'*100}\n\n' + prompt
return agent
def get_universal_system_message(messages):
if messages and messages[0].get("role") == "system":
return SYSTEM_MESSAGE.format(system_message=messages[0].get("content"))
return ""
def add_universal_system_message_to_agent(agent, universal_sys_msg):
agent.instructions = agent.instructions + f'\n\n{'-'*100}\n\n' + universal_sys_msg
return agent

View file

@ -0,0 +1,60 @@
from src.utils.common import common_logger
logger = common_logger
from .access import get_agent_data_by_name
def reset_current_turn(messages):
# Set all messages' current_turn to False
for msg in messages:
msg["current_turn"] = False
# Find most recent user message
messages[-1]["current_turn"] = True
return messages
def reset_current_turn_agent_history(agent_data, agent_names):
for name in agent_names:
data = get_agent_data_by_name(name, agent_data)
if data:
for msg in data["history"]:
msg["current_turn"] = False
return agent_data
def add_recent_messages_to_history(recent_messages, last_agent_name, agent_data, messages, parent_has_child_history):
last_msg = messages[-1]
specific_agent_data = get_agent_data_by_name(last_agent_name, agent_data)
if specific_agent_data:
specific_agent_data["history"].extend(recent_messages)
if parent_has_child_history:
current_agent_data = specific_agent_data
while current_agent_data.get("most_recent_parent_name"):
parent_name = current_agent_data.get("most_recent_parent_name")
parent_agent_data = get_agent_data_by_name(parent_name, agent_data)
if parent_agent_data:
parent_agent_data["history"].extend(recent_messages)
current_agent_data = parent_agent_data
else:
logger.error(f"Parent agent data for {current_agent_data['name']} not found in agent_data")
raise ValueError(f"Parent agent data for {current_agent_data['name']} not found in agent_data")
else:
agent_data.append({
"name": last_agent_name,
"history": [last_msg]
})
return agent_data
def construct_state_from_response(response, agents):
agent_data = []
for agent in agents:
agent_data.append({
"name": agent.name,
"instructions": agent.instructions
})
state = {
"last_agent_name": response.agent.name,
"agent_data": agent_data
}
return state

View file

@ -0,0 +1,44 @@
from src.utils.common import common_logger
logger = common_logger
def create_transfer_function_to_agent(agent):
agent_name = agent.name
fn_spec = {
"name": f"transfer_to_{agent_name.lower().replace(' ', '_')}",
"description": f"Function to transfer the chat to {agent_name}.",
"return_value": agent
}
def generated_function(*args, **kwargs):
logger.info(f"Transferring chat to {agent_name}")
return fn_spec.get('return_value', None)
generated_function.__name__ = fn_spec['name']
generated_function.__doc__ = fn_spec.get('description', '')
return generated_function
def create_transfer_function_to_parent_agent(parent_agent, children_aware_of_parent, transfer_functions):
if children_aware_of_parent:
name = f"{transfer_functions[parent_agent.name].__name__}_from_child"
description = f"Function to transfer the chat to your parent agent: {parent_agent.name}."
else:
name = "give_up_chat_control"
description = "Function to give up control of the chat when you are unable to handle it."
fn_spec = {
"name": name,
"description": description,
"return_value": parent_agent
}
def generated_function(*args, **kwargs):
logger.info(f"Transferring chat to parent agent: {parent_agent.name}")
return fn_spec.get('return_value', None)
generated_function.__name__ = fn_spec['name']
generated_function.__doc__ = fn_spec.get('description', '')
return generated_function

View file

@ -0,0 +1,70 @@
########################
# Instructions for agents that use RAG
########################
RAG_INSTRUCTIONS = f"""
# Instructions about using the article retrieval tool
- Where relevant, use the articles tool: {{rag_tool_name}} to fetch articles with knowledge relevant to the query and use its contents to respond to the user.
- Do not send a separate message first asking the user to wait while you look up information. Immediately fetch the articles and respond to the user with the answer to their query.
- Do not make up information. If the article's contents do not have the answer, give up control of the chat (or transfer to your parent agent, as per your transfer instructions). Do not say anything to the user.
"""
########################
# Instructions for child agents that are aware of parent agents
########################
TRANSFER_PARENT_AWARE_INSTRUCTIONS = f"""
# Instructions about using your parent agents
You have the following candidate parent agents that you can transfer the chat to, using the appropriate tool calls for the transfer:
{{candidate_parents_name_description_tools}}.
## Notes:
- During runtime, you will be provided with a tool call for exactly one of these parent agents that you can use. Use that tool call to transfer the chat to the parent agent in case you are unable to handle the chat (e.g. if it is not in your scope of instructions).
- Transfer the chat to the appropriate agent, based on the chat history and / or the user's request.
- When you transfer the chat to another agent, you should not provide any response to the user. For example, do not say 'Transferring chat to X agent' or anything like that. Just invoke the tool call to transfer to the other agent.
- Do NOT ever mention the existence of other agents. For example, do not say 'Please check with X agent for details regarding processing times.' or anything like that.
- If any other agent transfers the chat to you without responding to the user, it means that they don't know how to help. Do not transfer the chat to back to the same agent in this case. In such cases, you should transfer to the escalation agent using the appropriate tool call. Never ask the user to contact support.
"""
########################
# Instructions for child agents that give up control to parent agents
########################
TRANSFER_GIVE_UP_CONTROL_INSTRUCTIONS = f"""
# Instructions about giving up chat control
If you are unable to handle the chat (e.g. if it is not in your scope of instructions), you should use the tool call provided to give up control of the chat.
{{candidate_parents_name_description_tools}}
## Notes:
- When you give up control of the chat, you should not provide any response to the user. Just invoke the tool call to give up control.
"""
########################
# Instructions for parent agents that need to transfer the chat to other specialized (children) agents
########################
TRANSFER_CHILDREN_INSTRUCTIONS = f"""
# Instructions about using other specialized agents
You have the following specialized agents that you can transfer the chat to, using the appropriate tool calls for the transfer:
{{other_agent_name_descriptions_tools}}
## Notes:
- Transfer the chat to the appropriate agent, based on the chat history and / or the user's request.
- When you transfer the chat to another agent, you should not provide any response to the user. For example, do not say 'Transferring chat to X agent' or anything like that. Just invoke the tool call to transfer to the other agent.
- Do NOT ever mention the existence of other agents. For example, do not say 'Please check with X agent for details regarding processing times.' or anything like that.
- If any other agent transfers the chat to you without responding to the user, it means that they don't know how to help. Do not transfer the chat to back to the same agent in this case. In such cases, you should transfer to the escalation agent using the appropriate tool call. Never ask the user to contact support.
"""
########################
# Additional instruction for escalation agent when called due to an error
########################
ERROR_ESCALATION_AGENT_INSTRUCTIONS = f"""
# Context
The rest of the parts of the chatbot were unable to handle the chat. Hence, the chat has been escalated to you. In addition to your other instructions, tell the user that you are having trouble handling the chat - say "I'm having trouble helping with your request. Sorry about that.". Remember you are a part of the chatbot as well.
"""
########################
# Universal system message formatting
########################
SYSTEM_MESSAGE = f"""
# Additional System-Wide Context or Instructions:
{{system_message}}
"""

View file

@ -0,0 +1,391 @@
import logging
import json
import aiohttp
# Import helper functions needed for get_agents
from .helpers.access import (
get_tool_config_by_name,
get_tool_config_by_type
)
from .helpers.instructions import (
add_rag_instructions_to_agent
)
from agents import Agent as NewAgent, Runner, FunctionTool, RunContextWrapper
# Add import for OpenAI functionality
from src.utils.common import common_logger as logger, generate_openai_output
from typing import Any
from dataclasses import asdict
import asyncio
from mcp import ClientSession
from mcp.client.sse import sse_client
from pydantic import BaseModel
from typing import List, Optional, Dict
from .tool_calling import call_rag_tool
class NewResponse(BaseModel):
messages: List[Dict]
agent: Optional[Any] = None
tokens_used: Optional[dict] = {}
error_msg: Optional[str] = ""
async def mock_tool(tool_name: str, args: str, tool_config: str) -> str:
"""
Handles tool execution by either using mock instructions or generating a response.
Args:
tool_name: The name of the tool
args: The arguments passed to the tool
tool_config: The configuration of the tool
Returns:
The response from the tool
"""
print(f"Mock tool called for: {tool_name}")
# For non-mocked tools, generate a realistic response
description = tool_config.get("description", "")
mock_instructions = tool_config.get("mockInstructions", "")
messages = [
{"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'.Here is the description of the tool: {description}. Here are the instructions for the mock tool: {mock_instructions}. Generate a realistic response as if the tool was actually executed with the given parameters."},
{"role": "user", "content": f"Generate a realistic response for the tool '{tool_name}' with these parameters: {args}. The response should be concise and focused on what the tool would actually return."}
]
print(f"Generating simulated response for tool: {tool_name}")
response_content = generate_openai_output(messages, output_type='text', model="gpt-4o")
return response_content
async def call_webhook(tool_name: str, args: str, webhook_url: str) -> str:
"""
Calls the webhook with the given tool name and arguments.
Args:
tool_name (str): The name of the tool to call.
args (str): The arguments for the tool as a JSON string.
Returns:
str: The response from the webhook, or an error message if the call fails.
"""
content_dict = {
"toolCall": {
"function": {
"name": tool_name,
"arguments": args
}
}
}
request_body = {
"content": json.dumps(content_dict)
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(webhook_url, json=request_body) as response:
if response.status == 200:
response_json = await response.json()
return response_json.get("result", "")
else:
error_msg = await response.text()
print(f"Webhook error: {error_msg}")
return f"Error: {error_msg}"
except Exception as e:
print(f"Exception in call_webhook: {str(e)}")
return f"Error: Failed to call webhook - {str(e)}"
async def call_mcp(tool_name: str, args: str, mcp_server_url: str) -> str:
"""
Calls the MCP with the given tool name and arguments.
"""
async with sse_client(url=mcp_server_url) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
jargs = json.loads(args)
response = await session.call_tool(tool_name, arguments=jargs)
json_output = json.dumps([item.__dict__ for item in response.content], indent=2)
return json_output
async def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool_config: dict, complete_request: dict) -> str:
"""
Handles all tool calls by dispatching to appropriate functions.
"""
print(f"Catch all called for tool: {tool_name}")
print(f"Args: {args}")
print(f"Tool config: {tool_config}")
# Create event loop for async operations
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
response_content = None
if tool_config.get("mockTool", False):
# Call mock_tool to handle the response (it will decide whether to use mock instructions or generate a response)
response_content = await mock_tool(tool_name, args, tool_config)
print(response_content)
elif tool_config.get("isMcp", False):
mcp_server_name = tool_config.get("mcpServerName", "")
mcp_servers = complete_request.get("mcpServers", {})
mcp_server_url = next((server.get("url", "") for server in mcp_servers if server.get("name") == mcp_server_name), "")
response_content = await call_mcp(tool_name, args, mcp_server_url)
else:
webhook_url = complete_request.get("toolWebhookUrl", "")
response_content = await call_webhook(tool_name, args, webhook_url)
return response_content
def get_rag_tool(config: dict, complete_request: dict) -> FunctionTool:
"""
Creates a RAG tool based on the provided configuration.
"""
project_id = complete_request.get("projectId", "")
if config.get("ragDataSources", None):
print("getArticleInfo")
params = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for"
}
},
"additionalProperties": False,
"required": [
"query"
]
}
tool = FunctionTool(
name="getArticleInfo",
description="Get information about an article",
params_json_schema=params,
on_invoke_tool=lambda ctx, args: call_rag_tool(project_id, json.loads(args)['query'], config.get("ragDataSources", []), "chunks", 3)
)
return tool
else:
return None
def get_agents(agent_configs, tool_configs, complete_request):
"""
Creates and initializes Agent objects based on their configurations and connections.
"""
if not isinstance(agent_configs, list):
raise ValueError("Agents config is not a list in get_agents")
if not isinstance(tool_configs, list):
raise ValueError("Tools config is not a list in get_agents")
new_agents = []
new_agent_to_children = {}
new_agent_name_to_index = {}
# Create Agent objects from config
for agent_config in agent_configs:
logger.debug(f"Processing config for agent: {agent_config['name']}")
print("="*100)
print(f"Processing config for agent: {agent_config['name']}")
# If hasRagSources, append the RAG tool to the agent's tools
if agent_config.get("hasRagSources", False):
rag_tool_name = get_tool_config_by_type(tool_configs, "rag").get("name", "")
agent_config["tools"].append(rag_tool_name)
agent_config = add_rag_instructions_to_agent(agent_config, rag_tool_name)
# Prepare tool lists for this agent
external_tools = []
logger.debug(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools")
print(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools")
new_tools = []
rag_tool = get_rag_tool(agent_config, complete_request)
if rag_tool:
new_tools.append(rag_tool)
logger.debug(f"Added rag tool to agent {agent_config['name']}")
print(f"Added rag tool to agent {agent_config['name']}")
for tool_name in agent_config["tools"]:
tool_config = get_tool_config_by_name(tool_configs, tool_name)
if tool_config:
external_tools.append({
"type": "function",
"function": tool_config
})
#TODO: Remove this once we have a way to handle the additionalProperties
tool_config['parameters']['additionalProperties'] = False
tool = FunctionTool(
name=tool_name,
description=tool_config["description"],
params_json_schema=tool_config["parameters"],
on_invoke_tool=lambda ctx, args, _tool_name=tool_name, _tool_config=tool_config, _complete_request=complete_request:
catch_all(ctx, args, _tool_name, _tool_config, _complete_request)
)
new_tools.append(tool)
logger.debug(f"Added tool {tool_name} to agent {agent_config['name']}")
print(f"Added tool {tool_name} to agent {agent_config['name']}")
else:
logger.warning(f"Tool {tool_name} not found in tool_configs")
print(f"WARNING: Tool {tool_name} not found in tool_configs")
# Create the agent object
logger.debug(f"Creating Agent object for {agent_config['name']}")
print(f"Creating Agent object for {agent_config['name']}")
try:
new_agent = NewAgent(
name=agent_config["name"],
instructions=agent_config["instructions"],
handoff_description=agent_config["description"],
tools=new_tools,
model=agent_config["model"]
)
new_agent_to_children[agent_config["name"]] = agent_config.get("connectedAgents", [])
new_agent_name_to_index[agent_config["name"]] = len(new_agents)
new_agents.append(new_agent)
logger.debug(f"Successfully created agent: {agent_config['name']}")
print(f"Successfully created agent: {agent_config['name']}")
except Exception as e:
logger.error(f"Failed to create agent {agent_config['name']}: {str(e)}")
print(f"ERROR: Failed to create agent {agent_config['name']}: {str(e)}")
raise
for new_agent in new_agents:
# Initialize the handoffs attribute if it doesn't exist
if not hasattr(new_agent, 'handoffs'):
new_agent.handoffs = []
# Look up the agent's children from the old agent and create a list called handoffs in new_agent with pointers to the children in new_agents
new_agent.handoffs = [new_agents[new_agent_name_to_index[child]] for child in new_agent_to_children[new_agent.name]]
print("Returning created agents")
print("="*100)
return new_agents
def create_response(messages=None, tokens_used=None, agent=None, error_msg=''):
"""
Create a Response object with the given parameters.
Args:
messages: List of messages
tokens_used: Dictionary tracking token usage
agent: The agent that generated the response
error_msg: Error message if any
Returns:
Response object
"""
if messages is None:
messages = []
if tokens_used is None:
tokens_used = {}
return NewResponse(
messages=messages,
agent=agent,
tokens_used=tokens_used,
error_msg=error_msg
)
def run(
agent,
messages,
external_tools=None,
tokens_used=None
):
"""
Wrapper function for initializing and running the Swarm client.
"""
logger.info(f"Initializing Swarm client for agent: {agent.name}")
print(f"Initializing Swarm client for agent: {agent.name}")
# Initialize default parameters
if external_tools is None:
external_tools = []
if tokens_used is None:
tokens_used = {}
# Format messages to ensure they're compatible with the OpenAI API
formatted_messages = []
for msg in messages:
if isinstance(msg, dict) and "content" in msg:
formatted_msg = {
"role": msg.get("role", "user"),
"content": msg["content"]
}
formatted_messages.append(formatted_msg)
else:
formatted_messages.append({
"role": "user",
"content": str(msg)
})
# Create a new event loop for this thread
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run the agent with the formatted messages
logger.info("Beginning Swarm run with run_sync")
print("Beginning Swarm run with run_sync")
try:
response = loop.run_until_complete(Runner.run(agent, formatted_messages))
except Exception as e:
logger.error(f"Error during run: {str(e)}")
print(f"Error during run: {str(e)}")
raise
logger.info(f"Completed Swarm run for agent: {agent.name}")
print(f"Completed Swarm run for agent: {agent.name}")
return response
async def run_streamed(
agent,
messages,
external_tools=None,
tokens_used=None
):
"""
Wrapper function for initializing and running the Swarm client in streaming mode.
"""
logger.info(f"Initializing Swarm streaming client for agent: {agent.name}")
print(f"Initializing Swarm streaming client for agent: {agent.name}")
# Initialize default parameters
if external_tools is None:
external_tools = []
if tokens_used is None:
tokens_used = {}
# Format messages to ensure they're compatible with the OpenAI API
formatted_messages = []
for msg in messages:
if isinstance(msg, dict) and "content" in msg:
formatted_msg = {
"role": msg.get("role", "user"),
"content": msg["content"]
}
formatted_messages.append(formatted_msg)
else:
formatted_messages.append({
"role": "user",
"content": str(msg)
})
logger.info("Beginning Swarm streaming run")
print("Beginning Swarm streaming run")
try:
# Use the Runner.run_streamed method
stream_result = Runner.run_streamed(agent, formatted_messages)
return stream_result
except Exception as e:
logger.error(f"Error during streaming run: {str(e)}")
print(f"Error during streaming run: {str(e)}")
raise

View file

@ -0,0 +1,143 @@
from bson.objectid import ObjectId
from openai import OpenAI
import os
from motor.motor_asyncio import AsyncIOMotorClient
import asyncio
from dataclasses import dataclass
from typing import Dict, List, Any
from qdrant_client import QdrantClient
import json
# Initialize MongoDB client
mongo_uri = os.environ.get("MONGODB_URI", "mongodb://localhost:27017")
mongo_client = AsyncIOMotorClient(mongo_uri)
db = mongo_client.rowboat
data_sources_collection = db['sources']
data_source_docs_collection = db['source_docs']
qdrant_client = QdrantClient(url=os.environ.get("QDRANT_URL"))
# Initialize OpenAI client
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# Define embedding model
embedding_model = "text-embedding-3-small"
async def embed(model: str, value: str) -> dict:
"""
Generate embeddings using OpenAI's embedding models.
Args:
model (str): The embedding model to use (e.g., "text-embedding-3-small").
value (str): The text to embed.
Returns:
dict: A dictionary containing the embedding.
"""
response = client.embeddings.create(
model=model,
input=value
)
return {"embedding": response.data[0].embedding}
async def call_rag_tool(
project_id: str,
query: str,
source_ids: list[str],
return_type: str,
k: int,
) -> dict:
"""
Runs the RAG tool call to retrieve information based on the query and source IDs.
Args:
project_id (str): The ID of the project.
query (str): The query string to search for.
source_ids (list[str]): List of source IDs to filter the search.
return_type (str): The type of return, e.g., 'chunks' or other.
k (int): The number of results to return.
Returns:
dict: A dictionary containing the results of the search.
"""
print("\n\n calling rag tool \n\n")
print(query)
# Create embedding for the query
embed_result = await embed(model=embedding_model, value=query)
print(embed_result)
# Fetch all active data sources for this project
sources = await data_sources_collection.find({
"projectId": project_id,
"active": True
}).to_list(length=None)
print(sources)
# Filter sources to those in source_ids
valid_source_ids = [
str(s["_id"]) for s in sources if str(s["_id"]) in source_ids
]
print(valid_source_ids)
# If no valid sources are found, return empty results
if not valid_source_ids:
return ''
# Perform Qdrant vector search
qdrant_results = qdrant_client.search(
collection_name="embeddings",
query_vector=embed_result["embedding"],
query_filter={
"must": [
{"key": "projectId", "match": {"value": project_id}},
{"key": "sourceId", "match": {"any": valid_source_ids}},
]
},
limit=k,
with_payload=True
)
# Map the Qdrant results to the desired format
results = [
{
"title": point.payload["title"],
"name": point.payload["name"],
"content": point.payload["content"],
"docId": point.payload["docId"],
"sourceId": point.payload["sourceId"],
}
for point in qdrant_results
]
print(return_type)
print(results)
# If return_type is 'chunks', return the results directly
if return_type == "chunks":
return json.dumps({"Information": results}, indent=2)
# Otherwise, fetch the full document contents from MongoDB
doc_ids = [ObjectId(r["docId"]) for r in results]
docs = await data_source_docs_collection.find({"_id": {"$in": doc_ids}}).to_list(length=None)
# Create a dictionary for quick lookup of documents by their string ID
doc_dict = {str(doc["_id"]): doc for doc in docs}
# Update the results with the full document content
results = [
{**r, "content": doc_dict.get(r["docId"], {}).get("content", "")}
for r in results
]
# Convert results to a JSON string
formatted_string = json.dumps({"Information": results}, indent=2)
print(formatted_string)
return formatted_string
if __name__ == "__main__":
asyncio.run(call_rag_tool(
project_id="faf2bfb3-41d4-4299-b0d2-048581ea9bd8",
query="What is the range on your scooter",
source_ids=["67e102c9fab4514d7aaeb5a4"],
return_type="docs",
k=3))

View file

@ -0,0 +1,81 @@
import json
import random
from src.utils.common import common_logger
logger = common_logger
RAG_TOOL = {
"name": "getArticleInfo",
"type": "rag",
"description": "Fetch articles with knowledge relevant to the query",
"parameters": {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The query to retrieve articles for"
}
},
"required": [
"query"
]
}
}
CLOSE_CHAT_TOOL = {
"name": "close_chat",
"type": "close_chat",
"description": "Close the chat",
"parameters": {
"type": "object",
"properties": {
"error_message": {
"type": "string", "description": "The error message to close the chat with"
}
}
}
}
def tool_raise_error(error_message):
logger.error(f"Raising error: {error_message}")
raise ValueError(f"Raising error: {error_message}")
def respond_to_tool_raise_error(tool_calls, mock=False):
error_message = json.loads(tool_calls[0]["function"]["arguments"]).get("error_message", "")
return _create_tool_response(tool_calls, tool_raise_error(error_message))
def tool_close_chat(error_message):
logger.error(f"Closing chat: {error_message}")
raise ValueError(f"Closing chat: {error_message}")
def respond_to_tool_close_chat(tool_calls, mock=False):
error_message = json.loads(tool_calls[0]["function"]["arguments"]).get("error_message", "")
return _create_tool_response(tool_calls, tool_close_chat(error_message))
def _create_tool_response(tool_calls, content, mock=False):
"""
Creates a standardized tool response format.
"""
return {
"role": "tool",
"content": content,
"tool_call_id": tool_calls[0]["id"],
"name": tool_calls[0]["function"]["name"]
}
def create_error_tool_call(error_message):
error_message_tool_call = {
"role": "assistant",
"sender": "system",
"tool_calls": [
{
"function": {
"name": "raise_error",
"arguments": "{\"error_message\":\"" + error_message + "\"}"
},
"id": "call_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', k=24)),
"type": "function"
}
]
}
return error_message_tool_call

View file

@ -0,0 +1,18 @@
from enum import Enum
class AgentRole(Enum):
ESCALATION = "escalation"
POST_PROCESSING = "post_process"
GUARDRAILS = "guardrails"
class ControlType(Enum):
RETAIN = "retain"
PARENT_AGENT = "relinquish_to_parent"
START_AGENT = "relinquish_to_start"
class PromptType(Enum):
STYLE = "style_prompt"
GREETING = "greeting"
class ErrorType(Enum):
FATAL = "fatal"
ESCALATE = "escalate"

View file

@ -0,0 +1,200 @@
import json
import logging
import os
import subprocess
import sys
import time
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv()
def setup_logger(name, log_file='./run.log', level=logging.INFO, log_to_file=True):
"""Function to set up a logger with a specific name and log file."""
formatter = logging.Formatter('%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s')
if log_to_file:
handler = logging.FileHandler(log_file)
else:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
# Create a logger and set its level
logger = logging.getLogger(name)
logger.setLevel(level)
# Clear any existing handlers to avoid duplicates
if logger.hasHandlers():
logger.handlers.clear()
# Prevent propagation to parent loggers
logger.propagate = False
logger.addHandler(handler)
return logger
common_logger = setup_logger('logger')
logger = common_logger
def read_json_from_file(file_name):
logger.info(f"Reading json from {file_name}")
try:
with open(file_name, 'r') as file:
out = file.read()
out = json.loads(out)
return out
except Exception as e:
logger.error(e)
return None
def get_api_key(key_name):
api_key = os.getenv(key_name)
# Check if the API key was loaded successfully
if not api_key:
raise ValueError(f"{key_name} not found. Did you set it in the .env file?")
return api_key
openai_client = OpenAI(
api_key=get_api_key("OPENAI_API_KEY")
)
def generate_gpt4o_output_from_multi_turn_conv(messages, output_type='json', model="gpt-4o"):
return generate_openai_output(messages, output_type, model)
def generate_openai_output(messages, output_type='not_json', model="gpt-4o", return_completion=False):
try:
if output_type == 'json':
chat_completion = openai_client.chat.completions.create(
messages=messages,
model=model,
response_format={"type": "json_object"}
)
else:
chat_completion = openai_client.chat.completions.create(
messages=messages,
model=model,
)
if return_completion:
return chat_completion
return chat_completion.choices[0].message.content
except Exception as e:
logger.error(e)
return None
def generate_llm_output(messages, model):
model_provider = None
if "gpt" in model:
model_provider = "openai"
else:
raise ValueError(f"Model {model} not supported")
if model_provider == "openai":
response = generate_openai_output(messages, output_type='text', model=model)
return response
def generate_gpt4o_output_from_multi_turn_conv_multithreaded(messages, retries=5, delay=1, output_type='json'):
while retries > 0:
try:
# Call GPT-4o API
output = generate_gpt4o_output_from_multi_turn_conv(messages, output_type='json')
return output # If the request is successful, break out of the loop
except openai.RateLimitError:
print(f'Rate limit exceeded. Retrying in {delay} seconds...')
time.sleep(delay)
delay *= 2 # Exponential backoff
retries -= 1
if retries == 0:
print(f'Failed to process due to rate limit.')
return []
def convert_message_content_json_to_strings(messages):
for msg in messages:
if 'content' in msg.keys() and isinstance(msg['content'], dict):
msg['content'] = json.dumps(msg['content'])
return messages
def merge_defaultdicts(dict_parent, dict_child):
for key, value in dict_child.items():
if key in dict_parent:
# If the key exists in both, handle merging based on type
if isinstance(dict_parent[key], list):
dict_parent[key].extend(value)
elif isinstance(dict_parent[key], dict):
dict_parent[key].update(value)
elif isinstance(dict_parent[key], set):
dict_parent[key].update(value)
else:
dict_parent[key] += value # For other types like int, float, etc.
else:
dict_parent[key] = value
return dict_parent
def read_jsonl_from_file(file_name):
# logger.info(f"Reading jsonl from {file_name}")
try:
with open(file_name, 'r') as file:
lines = file.readlines()
dataset = [json.loads(line.strip()) for line in lines]
return dataset
except Exception as e:
logger.error(e)
return None
def write_jsonl_to_file(list_dicts, file_name):
try:
with open(file_name, 'w') as file:
for d in list_dicts:
file.write(json.dumps(d)+'\n')
return True
except Exception as e:
logger.error(e)
return False
def read_text_from_file(file_name):
try:
with open(file_name, 'r') as file:
out = file.read()
return out
except Exception as e:
logger.error(e)
return None
def write_json_to_file(data, file_name):
try:
with open(file_name, 'w') as file:
json.dump(data, file, indent=4)
return True
except Exception as e:
logger.error(e)
return False
def get_git_path(path):
# Run `git rev-parse --show-toplevel` to get the root of the Git repository
try:
git_root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], text=True).strip()
return f"{git_root}/{path}"
except subprocess.CalledProcessError:
raise RuntimeError("Not inside a Git repository")
def update_tokens_used(provider, model, tokens_used, completion):
provider_model = f"{provider}/{model}"
input_tokens = completion.usage.prompt_tokens
output_tokens = completion.usage.completion_tokens
if provider_model not in tokens_used:
tokens_used[provider_model] = {
'input_tokens': 0,
'output_tokens': 0,
}
tokens_used[provider_model]['input_tokens'] += input_tokens
tokens_used[provider_model]['output_tokens'] += output_tokens
return tokens_used