rowboat/apps/python-sdk/lib.py

171 lines
5.8 KiB
Python
Raw Normal View History

from typing import Dict, List, Optional, Any, Callable, Union, Tuple
import requests
import json
from schema import ApiRequest, ApiResponse, ApiMessage, ToolMessage, UserMessage, SystemMessage, AssistantMessage, AssistantMessageWithToolCalls
class Client:
def __init__(self, host: str, project_id: str, project_secret: str) -> None:
self.base_url: str = f'{host}/api/v1/{project_id}/chat'
self.headers: Dict[str, str] = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {project_secret}'
}
def _call_api(
self,
messages: List[ApiMessage],
state: Optional[Dict[str, Any]] = None,
) -> ApiResponse:
request = ApiRequest(
messages=messages,
state=state
)
# print request json
print(request.model_dump_json())
response = requests.post(self.base_url, headers=self.headers, data=request.model_dump_json())
if not response.status_code == 200:
raise ValueError(f"Error: {response.status_code} - {response.text}")
response_data = ApiResponse.model_validate(response.json())
if not response_data.messages:
raise ValueError("No response")
last_message = response_data.messages[-1]
if not isinstance(last_message, (AssistantMessage, AssistantMessageWithToolCalls)):
raise ValueError("Last message was not an assistant message")
return response_data
def _process_tool_calls(
self,
tool_calls: List[Any],
tools: Dict[str, Callable[..., str]]
) -> List[ToolMessage]:
"""Process tool calls and return a list of tool response messages"""
tool_messages = []
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_arguments = json.loads(tool_call.function.arguments)
if tool_name not in tools:
raise ValueError(f'Missing tool: {tool_name}')
tool_response = tools[tool_name](**tool_arguments)
tool_msg = ToolMessage(
role='tool',
content=tool_response,
tool_call_id=tool_call.id,
tool_name=tool_name
)
tool_messages.append(tool_msg)
return tool_messages
def chat(
self,
messages: List[ApiMessage],
tools: Optional[Dict[str, Callable[..., str]]] = None,
state: Optional[Dict[str, Any]] = None,
max_turns: int = 3
) -> Tuple[List[ApiMessage], Optional[Dict[str, Any]]]:
"""Stateless chat method that handles a single conversation turn with multiple tool call rounds"""
current_messages = messages
current_state = state
turns = 0
response_messages = []
response_state = None
has_tool_calls = False
while turns < max_turns:
# call api
response_data = self._call_api(
messages=current_messages,
state=current_state
)
current_messages.extend(response_data.messages)
current_state = response_data.state
response_messages = response_data.messages
response_state = response_data.state
# Process tool calls if present and tools are provided
last_message = response_data.messages[-1]
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
if has_tool_calls:
tool_messages = self._process_tool_calls(last_message.tool_calls, tools)
current_messages.extend(tool_messages)
# If no tool calls were made, we're done
if not has_tool_calls:
break
turns += 1
if turns == max_turns and has_tool_calls:
raise ValueError("Max turns reached")
if not last_message.agenticResponseType == 'external':
raise ValueError("Last message was not an external message")
return response_messages, response_state
class StatefulChat:
"""Maintains conversation state across multiple turns"""
def __init__(
self,
client: Client,
tools: Optional[Dict[str, Callable[..., str]]] = None,
system_prompt: Optional[str] = None,
) -> None:
self.client = client
self.tools = tools
self.messages: List[ApiMessage] = []
self.state: Optional[Dict[str, Any]] = None
if system_prompt:
self.messages.append(SystemMessage(role='system', content=system_prompt))
def run(self, message: Union[str]) -> str:
"""Handle a single user turn in the conversation"""
# Process the message
user_msg = UserMessage(role='user', content=message)
self.messages.append(user_msg)
# Get response using the client's chat method
new_messages, new_state = self.client.chat(
messages=self.messages,
tools=self.tools,
state=self.state
)
# Update internal state
self.messages = new_messages
self.state = new_state
# Return only the final message content
last_message = new_messages[-1]
return last_message.content
def weather_lookup_tool(city_name: str) -> str:
return f"The weather in {city_name} is 22°C."
if __name__ == "__main__":
host: str = "<HOST>"
project_id: str = "<PROJECT_ID>"
project_secret: str = "<PROJECT_SECRET>"
client = Client(host, project_id, project_secret)
tools: Dict[str, Callable[..., str]] = {
'weather_lookup': weather_lookup_tool
}
chat_session = StatefulChat(client, tools)
resp = chat_session.run("whats the weather in london?")
print(resp)