mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-05-13 17:22:37 +02:00
Merge pull request #175 from rowboatlabs/tool-override
Allow mocking tools over API
This commit is contained in:
commit
6895e54425
8 changed files with 54 additions and 18 deletions
|
|
@ -68,6 +68,21 @@ chat = StatefulChat(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Tool overrides
|
||||||
|
|
||||||
|
You can provide tool override instructions to test a specific configuration:
|
||||||
|
|
||||||
|
```python
|
||||||
|
chat = StatefulChat(
|
||||||
|
client,
|
||||||
|
mock_tools={
|
||||||
|
"weather_lookup": "The weather in any city is sunny and 25°C.",
|
||||||
|
"calculator": "The result of any calculation is 42.",
|
||||||
|
"search": "Search results for any query return 'No relevant information found.'"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
### Low-Level Usage
|
### Low-Level Usage
|
||||||
|
|
||||||
For more control over the conversation, you can use the `Client` class directly:
|
For more control over the conversation, you can use the `Client` class directly:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "rowboat"
|
name = "rowboat"
|
||||||
version = "3.1.0"
|
version = "4.0.0"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Ramnique Singh", email = "ramnique@rowboatlabs.com" },
|
{ name = "Ramnique Singh", email = "ramnique@rowboatlabs.com" },
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -22,13 +22,15 @@ class Client:
|
||||||
messages: List[ApiMessage],
|
messages: List[ApiMessage],
|
||||||
state: Optional[Dict[str, Any]] = None,
|
state: Optional[Dict[str, Any]] = None,
|
||||||
workflow_id: Optional[str] = None,
|
workflow_id: Optional[str] = None,
|
||||||
test_profile_id: Optional[str] = None
|
test_profile_id: Optional[str] = None,
|
||||||
|
mock_tools: Optional[Dict[str, str]] = None
|
||||||
) -> ApiResponse:
|
) -> ApiResponse:
|
||||||
request = ApiRequest(
|
request = ApiRequest(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
state=state,
|
state=state,
|
||||||
workflowId=workflow_id,
|
workflowId=workflow_id,
|
||||||
testProfileId=test_profile_id
|
testProfileId=test_profile_id,
|
||||||
|
mockTools=mock_tools
|
||||||
)
|
)
|
||||||
json_data = request.model_dump()
|
json_data = request.model_dump()
|
||||||
response = requests.post(self.base_url, headers=self.headers, json=json_data)
|
response = requests.post(self.base_url, headers=self.headers, json=json_data)
|
||||||
|
|
@ -52,7 +54,8 @@ class Client:
|
||||||
messages: List[ApiMessage],
|
messages: List[ApiMessage],
|
||||||
state: Optional[Dict[str, Any]] = None,
|
state: Optional[Dict[str, Any]] = None,
|
||||||
workflow_id: Optional[str] = None,
|
workflow_id: Optional[str] = None,
|
||||||
test_profile_id: Optional[str] = None
|
test_profile_id: Optional[str] = None,
|
||||||
|
mock_tools: Optional[Dict[str, str]] = None,
|
||||||
) -> ApiResponse:
|
) -> ApiResponse:
|
||||||
"""Stateless chat method that handles a single conversation turn"""
|
"""Stateless chat method that handles a single conversation turn"""
|
||||||
|
|
||||||
|
|
@ -61,10 +64,11 @@ class Client:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
state=state,
|
state=state,
|
||||||
workflow_id=workflow_id,
|
workflow_id=workflow_id,
|
||||||
test_profile_id=test_profile_id
|
test_profile_id=test_profile_id,
|
||||||
|
mock_tools=mock_tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not response_data.messages[-1].agenticResponseType == 'external':
|
if not response_data.messages[-1].responseType == 'external':
|
||||||
raise ValueError("Last message was not an external message")
|
raise ValueError("Last message was not an external message")
|
||||||
|
|
||||||
return response_data
|
return response_data
|
||||||
|
|
@ -76,13 +80,15 @@ class StatefulChat:
|
||||||
self,
|
self,
|
||||||
client: Client,
|
client: Client,
|
||||||
workflow_id: Optional[str] = None,
|
workflow_id: Optional[str] = None,
|
||||||
test_profile_id: Optional[str] = None
|
test_profile_id: Optional[str] = None,
|
||||||
|
mock_tools: Optional[Dict[str, str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.client = client
|
self.client = client
|
||||||
self.messages: List[ApiMessage] = []
|
self.messages: List[ApiMessage] = []
|
||||||
self.state: Optional[Dict[str, Any]] = None
|
self.state: Optional[Dict[str, Any]] = None
|
||||||
self.workflow_id = workflow_id
|
self.workflow_id = workflow_id
|
||||||
self.test_profile_id = test_profile_id
|
self.test_profile_id = test_profile_id
|
||||||
|
self.mock_tools = mock_tools
|
||||||
|
|
||||||
def run(self, message: Union[str]) -> str:
|
def run(self, message: Union[str]) -> str:
|
||||||
"""Handle a single user turn in the conversation"""
|
"""Handle a single user turn in the conversation"""
|
||||||
|
|
@ -96,7 +102,8 @@ class StatefulChat:
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
state=self.state,
|
state=self.state,
|
||||||
workflow_id=self.workflow_id,
|
workflow_id=self.workflow_id,
|
||||||
test_profile_id=self.test_profile_id
|
test_profile_id=self.test_profile_id,
|
||||||
|
mock_tools=self.mock_tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update internal state
|
# Update internal state
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Optional, Union, Any, Literal
|
from typing import List, Optional, Union, Any, Literal, Dict
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
class SystemMessage(BaseModel):
|
class SystemMessage(BaseModel):
|
||||||
|
|
@ -12,8 +12,8 @@ class UserMessage(BaseModel):
|
||||||
class AssistantMessage(BaseModel):
|
class AssistantMessage(BaseModel):
|
||||||
role: Literal['assistant']
|
role: Literal['assistant']
|
||||||
content: str
|
content: str
|
||||||
agenticSender: Optional[str] = None
|
agenticName: Optional[str] = None
|
||||||
agenticResponseType: Literal['internal', 'external']
|
responseType: Literal['internal', 'external']
|
||||||
|
|
||||||
class FunctionCall(BaseModel):
|
class FunctionCall(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
@ -27,15 +27,14 @@ class ToolCall(BaseModel):
|
||||||
class AssistantMessageWithToolCalls(BaseModel):
|
class AssistantMessageWithToolCalls(BaseModel):
|
||||||
role: Literal['assistant']
|
role: Literal['assistant']
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
tool_calls: List[ToolCall]
|
toolCalls: List[ToolCall]
|
||||||
agenticSender: Optional[str] = None
|
agenticName: Optional[str] = None
|
||||||
agenticResponseType: Literal['internal', 'external']
|
|
||||||
|
|
||||||
class ToolMessage(BaseModel):
|
class ToolMessage(BaseModel):
|
||||||
role: Literal['tool']
|
role: Literal['tool']
|
||||||
content: str
|
content: str
|
||||||
tool_call_id: str
|
toolCallId: str
|
||||||
tool_name: str
|
toolName: str
|
||||||
|
|
||||||
ApiMessage = Union[
|
ApiMessage = Union[
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
|
@ -50,7 +49,8 @@ class ApiRequest(BaseModel):
|
||||||
state: Any
|
state: Any
|
||||||
workflowId: Optional[str] = None
|
workflowId: Optional[str] = None
|
||||||
testProfileId: Optional[str] = None
|
testProfileId: Optional[str] = None
|
||||||
|
mockTools: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
class ApiResponse(BaseModel):
|
class ApiResponse(BaseModel):
|
||||||
messages: List[ApiMessage]
|
messages: List[ApiMessage]
|
||||||
state: Any
|
state: Optional[Any] = None
|
||||||
|
|
@ -51,6 +51,7 @@ export async function POST(
|
||||||
return Response.json({ error: `Invalid request body: ${result.error.message}` }, { status: 400 });
|
return Response.json({ error: `Invalid request body: ${result.error.message}` }, { status: 400 });
|
||||||
}
|
}
|
||||||
const reqMessages = result.data.messages;
|
const reqMessages = result.data.messages;
|
||||||
|
const mockToolOverrides = result.data.mockTools;
|
||||||
|
|
||||||
// fetch published workflow id
|
// fetch published workflow id
|
||||||
const project = await projectsCollection.findOne({
|
const project = await projectsCollection.findOne({
|
||||||
|
|
@ -80,6 +81,11 @@ export async function POST(
|
||||||
return Response.json({ error: "Workflow not found" }, { status: 404 });
|
return Response.json({ error: "Workflow not found" }, { status: 404 });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// override mock instructions
|
||||||
|
if (mockToolOverrides) {
|
||||||
|
workflow.mockTools = mockToolOverrides;
|
||||||
|
}
|
||||||
|
|
||||||
// check billing authorization
|
// check billing authorization
|
||||||
if (USE_BILLING && billingCustomerId) {
|
if (USE_BILLING && billingCustomerId) {
|
||||||
const agentModels = workflow.agents.reduce((acc, agent) => {
|
const agentModels = workflow.agents.reduce((acc, agent) => {
|
||||||
|
|
|
||||||
|
|
@ -797,7 +797,13 @@ async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof
|
||||||
function createTools(logger: PrefixLogger, workflow: z.infer<typeof Workflow>, toolConfig: Record<string, z.infer<typeof WorkflowTool>>): Record<string, Tool> {
|
function createTools(logger: PrefixLogger, workflow: z.infer<typeof Workflow>, toolConfig: Record<string, z.infer<typeof WorkflowTool>>): Record<string, Tool> {
|
||||||
const tools: Record<string, Tool> = {};
|
const tools: Record<string, Tool> = {};
|
||||||
for (const [toolName, config] of Object.entries(toolConfig)) {
|
for (const [toolName, config] of Object.entries(toolConfig)) {
|
||||||
if (config.isMcp) {
|
if (workflow.mockTools?.[toolName]) {
|
||||||
|
tools[toolName] = createMockTool(logger, {
|
||||||
|
...config,
|
||||||
|
mockInstructions: workflow.mockTools?.[toolName], // override mock instructions
|
||||||
|
});
|
||||||
|
logger.log(`created mock tool: ${toolName}`);
|
||||||
|
} else if (config.isMcp) {
|
||||||
tools[toolName] = createMcpTool(logger, config, workflow.projectId);
|
tools[toolName] = createMcpTool(logger, config, workflow.projectId);
|
||||||
logger.log(`created mcp tool: ${toolName}`);
|
logger.log(`created mcp tool: ${toolName}`);
|
||||||
} else if (config.isComposio) {
|
} else if (config.isComposio) {
|
||||||
|
|
|
||||||
|
|
@ -160,6 +160,7 @@ export const ApiRequest = z.object({
|
||||||
state: z.unknown(),
|
state: z.unknown(),
|
||||||
workflowId: z.string().nullable().optional(),
|
workflowId: z.string().nullable().optional(),
|
||||||
testProfileId: z.string().nullable().optional(),
|
testProfileId: z.string().nullable().optional(),
|
||||||
|
mockTools: z.record(z.string(), z.string()).nullable().optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export const ApiResponse = z.object({
|
export const ApiResponse = z.object({
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,7 @@ export const Workflow = z.object({
|
||||||
createdAt: z.string().datetime(),
|
createdAt: z.string().datetime(),
|
||||||
lastUpdatedAt: z.string().datetime(),
|
lastUpdatedAt: z.string().datetime(),
|
||||||
projectId: z.string(),
|
projectId: z.string(),
|
||||||
|
mockTools: z.record(z.string(), z.string()).optional(), // a dict of toolName => mockInstructions
|
||||||
});
|
});
|
||||||
export const WorkflowTemplate = Workflow
|
export const WorkflowTemplate = Workflow
|
||||||
.omit({
|
.omit({
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue