mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-05-22 18:45:19 +02:00
add web search tool in run_streamed
This commit is contained in:
parent
6312fe8e83
commit
8e16ac6204
1 changed files with 228 additions and 25 deletions
|
|
@ -282,6 +282,54 @@ async def run_turn_streamed(
|
||||||
print('-'*50)
|
print('-'*50)
|
||||||
print(f"Found usage information. Updated cumulative tokens: {tokens_used}")
|
print(f"Found usage information. Updated cumulative tokens: {tokens_used}")
|
||||||
print('-'*50)
|
print('-'*50)
|
||||||
|
|
||||||
|
# Handle ResponseFunctionWebSearch specifically
|
||||||
|
if hasattr(event, 'data') and hasattr(event.data, 'raw_item'):
|
||||||
|
raw_item = event.data.raw_item
|
||||||
|
|
||||||
|
# Check if it's a web search call
|
||||||
|
if (hasattr(raw_item, 'type') and raw_item.type == 'web_search_call') or (
|
||||||
|
isinstance(raw_item, dict) and raw_item.get('type') == 'web_search_call'
|
||||||
|
):
|
||||||
|
# Get call_id safely, regardless of structure
|
||||||
|
call_id = None
|
||||||
|
if hasattr(raw_item, 'id'):
|
||||||
|
call_id = raw_item.id
|
||||||
|
elif isinstance(raw_item, dict) and 'id' in raw_item:
|
||||||
|
call_id = raw_item['id']
|
||||||
|
else:
|
||||||
|
call_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Get status safely
|
||||||
|
status = 'unknown'
|
||||||
|
if hasattr(raw_item, 'status'):
|
||||||
|
status = raw_item.status
|
||||||
|
elif isinstance(raw_item, dict) and 'status' in raw_item:
|
||||||
|
status = raw_item['status']
|
||||||
|
|
||||||
|
# Emit a tool call for web search
|
||||||
|
message = {
|
||||||
|
'content': None,
|
||||||
|
'role': 'assistant',
|
||||||
|
'sender': current_agent.name if current_agent else None,
|
||||||
|
'tool_calls': [{
|
||||||
|
'function': {
|
||||||
|
'name': 'web_search',
|
||||||
|
'arguments': json.dumps({
|
||||||
|
'search_id': call_id,
|
||||||
|
'status': status
|
||||||
|
})
|
||||||
|
},
|
||||||
|
'id': call_id,
|
||||||
|
'type': 'function'
|
||||||
|
}],
|
||||||
|
'tool_call_id': None,
|
||||||
|
'tool_name': None,
|
||||||
|
'response_type': 'internal'
|
||||||
|
}
|
||||||
|
print("Yielding web search raw response message: ", message)
|
||||||
|
yield ('message', message)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Update current agent when it changes
|
# Update current agent when it changes
|
||||||
|
|
@ -334,6 +382,29 @@ async def run_turn_streamed(
|
||||||
elif event.type == "run_item_stream_event":
|
elif event.type == "run_item_stream_event":
|
||||||
current_agent = event.item.agent
|
current_agent = event.item.agent
|
||||||
if event.item.type == "tool_call_item":
|
if event.item.type == "tool_call_item":
|
||||||
|
# Check if it's a ResponseFunctionWebSearch object
|
||||||
|
if hasattr(event.item.raw_item, 'type') and event.item.raw_item.type == 'web_search_call':
|
||||||
|
call_id = event.item.raw_item.id if hasattr(event.item.raw_item, 'id') else str(uuid.uuid4())
|
||||||
|
message = {
|
||||||
|
'content': None,
|
||||||
|
'role': 'assistant',
|
||||||
|
'sender': current_agent.name if current_agent else None,
|
||||||
|
'tool_calls': [{
|
||||||
|
'function': {
|
||||||
|
'name': 'web_search',
|
||||||
|
'arguments': json.dumps({
|
||||||
|
'search_id': call_id
|
||||||
|
})
|
||||||
|
},
|
||||||
|
'id': call_id,
|
||||||
|
'type': 'function'
|
||||||
|
}],
|
||||||
|
'tool_call_id': None,
|
||||||
|
'tool_name': None,
|
||||||
|
'response_type': 'internal'
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Handle normal tool calls
|
||||||
message = {
|
message = {
|
||||||
'content': None,
|
'content': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
|
@ -354,25 +425,69 @@ async def run_turn_streamed(
|
||||||
yield ('message', message)
|
yield ('message', message)
|
||||||
|
|
||||||
elif event.item.type == "tool_call_output_item":
|
elif event.item.type == "tool_call_output_item":
|
||||||
|
# Check if it's a web search result
|
||||||
|
if isinstance(event.item.raw_item, dict) and event.item.raw_item.get('type') == 'web_search_results':
|
||||||
|
call_id = event.item.raw_item.get('search_id', event.item.raw_item.get('id', str(uuid.uuid4())))
|
||||||
message = {
|
message = {
|
||||||
'content': str(event.item.output),
|
'content': str(event.item.output),
|
||||||
'role': 'tool',
|
'role': 'tool',
|
||||||
'sender': None,
|
'sender': None,
|
||||||
'tool_calls': None,
|
'tool_calls': None,
|
||||||
'tool_call_id': event.item.raw_item['call_id'],
|
'tool_call_id': call_id,
|
||||||
'tool_name': event.item.raw_item.get('name', None),
|
'tool_name': 'web_search',
|
||||||
'response_type': 'internal'
|
'response_type': 'internal'
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
# Safe extraction of call_id and name
|
||||||
|
call_id = None
|
||||||
|
tool_name = None
|
||||||
|
|
||||||
|
# Handle different types of raw_item
|
||||||
|
if isinstance(event.item.raw_item, dict):
|
||||||
|
call_id = event.item.raw_item.get('call_id')
|
||||||
|
tool_name = event.item.raw_item.get('name')
|
||||||
|
elif hasattr(event.item.raw_item, 'call_id'):
|
||||||
|
call_id = event.item.raw_item.call_id
|
||||||
|
if hasattr(event.item.raw_item, 'name'):
|
||||||
|
tool_name = event.item.raw_item.name
|
||||||
|
|
||||||
|
message = {
|
||||||
|
'content': str(event.item.output),
|
||||||
|
'role': 'tool',
|
||||||
|
'sender': None,
|
||||||
|
'tool_calls': None,
|
||||||
|
'tool_call_id': call_id,
|
||||||
|
'tool_name': tool_name,
|
||||||
|
'response_type': 'internal'
|
||||||
|
}
|
||||||
|
|
||||||
print("Yielding message: ", message)
|
print("Yielding message: ", message)
|
||||||
yield ('message', message)
|
yield ('message', message)
|
||||||
|
|
||||||
elif event.item.type == "message_output_item":
|
elif event.item.type == "message_output_item":
|
||||||
content = ""
|
content = ""
|
||||||
|
url_citations = []
|
||||||
|
|
||||||
|
# Extract text content and any URL citations
|
||||||
if hasattr(event.item.raw_item, 'content'):
|
if hasattr(event.item.raw_item, 'content'):
|
||||||
for content_item in event.item.raw_item.content:
|
for content_item in event.item.raw_item.content:
|
||||||
|
# Handle text content
|
||||||
if hasattr(content_item, 'text'):
|
if hasattr(content_item, 'text'):
|
||||||
content += content_item.text
|
content += content_item.text
|
||||||
|
|
||||||
|
# Extract URL citations if present
|
||||||
|
if hasattr(content_item, 'annotations'):
|
||||||
|
for annotation in content_item.annotations:
|
||||||
|
if hasattr(annotation, 'type') and annotation.type == 'url_citation':
|
||||||
|
citation = {
|
||||||
|
'url': annotation.url if hasattr(annotation, 'url') else '',
|
||||||
|
'title': annotation.title if hasattr(annotation, 'title') else '',
|
||||||
|
'start_index': annotation.start_index if hasattr(annotation, 'start_index') else 0,
|
||||||
|
'end_index': annotation.end_index if hasattr(annotation, 'end_index') else 0
|
||||||
|
}
|
||||||
|
url_citations.append(citation)
|
||||||
|
|
||||||
|
# Create message with URL citations if they exist
|
||||||
message = {
|
message = {
|
||||||
'content': content,
|
'content': content,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
|
@ -382,9 +497,97 @@ async def run_turn_streamed(
|
||||||
'tool_name': None,
|
'tool_name': None,
|
||||||
'response_type': 'external'
|
'response_type': 'external'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Add citations if any were found
|
||||||
|
if url_citations:
|
||||||
|
message['citations'] = url_citations
|
||||||
|
|
||||||
print("Yielding message: ", message)
|
print("Yielding message: ", message)
|
||||||
yield ('message', message)
|
yield ('message', message)
|
||||||
|
|
||||||
|
# Handle web search function call events
|
||||||
|
elif event.item.type == "web_search_call_item" or (hasattr(event.item, 'raw_item') and hasattr(event.item.raw_item, 'type') and event.item.raw_item.type == 'web_search_call'):
|
||||||
|
# Extract web search call ID if available
|
||||||
|
call_id = None
|
||||||
|
if hasattr(event.item.raw_item, 'id'):
|
||||||
|
call_id = event.item.raw_item.id
|
||||||
|
|
||||||
|
message = {
|
||||||
|
'content': None,
|
||||||
|
'role': 'assistant',
|
||||||
|
'sender': current_agent.name if current_agent else None,
|
||||||
|
'tool_calls': [{
|
||||||
|
'function': {
|
||||||
|
'name': 'web_search',
|
||||||
|
'arguments': json.dumps({
|
||||||
|
'search_id': call_id
|
||||||
|
})
|
||||||
|
},
|
||||||
|
'id': call_id or str(uuid.uuid4()),
|
||||||
|
'type': 'function'
|
||||||
|
}],
|
||||||
|
'tool_call_id': None,
|
||||||
|
'tool_name': None,
|
||||||
|
'response_type': 'internal'
|
||||||
|
}
|
||||||
|
print("Yielding web search message: ", message)
|
||||||
|
yield ('message', message)
|
||||||
|
|
||||||
|
# Handle web search results
|
||||||
|
elif event.item.type == "web_search_results_item" or (
|
||||||
|
hasattr(event.item, 'raw_item') and (
|
||||||
|
(hasattr(event.item.raw_item, 'type') and event.item.raw_item.type == 'web_search_results') or
|
||||||
|
(isinstance(event.item.raw_item, dict) and event.item.raw_item.get('type') == 'web_search_results')
|
||||||
|
)
|
||||||
|
):
|
||||||
|
# Extract call_id safely
|
||||||
|
call_id = None
|
||||||
|
raw_item = event.item.raw_item
|
||||||
|
|
||||||
|
# Try several ways to get the search_id or id
|
||||||
|
if hasattr(raw_item, 'search_id'):
|
||||||
|
call_id = raw_item.search_id
|
||||||
|
elif isinstance(raw_item, dict) and 'search_id' in raw_item:
|
||||||
|
call_id = raw_item['search_id']
|
||||||
|
elif hasattr(raw_item, 'id'):
|
||||||
|
call_id = raw_item.id
|
||||||
|
elif isinstance(raw_item, dict) and 'id' in raw_item:
|
||||||
|
call_id = raw_item['id']
|
||||||
|
else:
|
||||||
|
call_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Extract results content safely
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Try event.item.output first
|
||||||
|
if hasattr(event.item, 'output'):
|
||||||
|
results = event.item.output
|
||||||
|
# Then try raw_item.results
|
||||||
|
elif hasattr(raw_item, 'results'):
|
||||||
|
results = raw_item.results
|
||||||
|
elif isinstance(raw_item, dict) and 'results' in raw_item:
|
||||||
|
results = raw_item['results']
|
||||||
|
|
||||||
|
# Format the results for output
|
||||||
|
results_str = ""
|
||||||
|
try:
|
||||||
|
results_str = json.dumps(results) if results else ""
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error serializing results: {str(e)}")
|
||||||
|
results_str = str(results)
|
||||||
|
|
||||||
|
message = {
|
||||||
|
'content': results_str,
|
||||||
|
'role': 'tool',
|
||||||
|
'sender': None,
|
||||||
|
'tool_calls': None,
|
||||||
|
'tool_call_id': call_id,
|
||||||
|
'tool_name': 'web_search',
|
||||||
|
'response_type': 'internal'
|
||||||
|
}
|
||||||
|
print("Yielding web search results: ", message)
|
||||||
|
yield ('message', message)
|
||||||
|
|
||||||
print(f"\n{'='*50}\n")
|
print(f"\n{'='*50}\n")
|
||||||
|
|
||||||
# After all events are processed, set final state
|
# After all events are processed, set final state
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue