feat: set calculator as custom tool on demand

This commit is contained in:
Abhishek Kumar 2026-04-02 14:07:03 +05:30
parent 89fce77438
commit f368fe5134
13 changed files with 265 additions and 157 deletions

View file

@ -40,20 +40,14 @@ from api.services.workflow.pipecat_engine_context_composer import (
)
from api.services.workflow.pipecat_engine_custom_tools import (
CustomToolManager,
get_function_schema,
)
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
from api.services.workflow.tools.knowledge_base import (
retrieve_from_knowledge_base,
)
from api.services.workflow.tools.timezone import (
convert_time,
get_current_time,
get_time_tools,
)
from api.services.workflow.tools.timezone import get_current_time
from api.utils.template_renderer import render_template
@ -93,9 +87,6 @@ class PipecatEngine:
# access to _context
self._variable_extraction_manager = None
# Lazy loaded built-in function schemas
self._builtin_function_schemas: Optional[list[dict]] = None
# Track current LLM reference text for TTS aggregation correction
self._current_llm_generation_reference_text: str = ""
@ -144,36 +135,6 @@ class PipecatEngine:
return None
return tracing_ctx.get_turn_context() or tracing_ctx.get_conversation_context()
@property
def builtin_function_schemas(self) -> list[dict]:
"""Get built-in function schemas (calculator and timezone tools)."""
if self._builtin_function_schemas is None:
self._builtin_function_schemas = []
# Transform calculator tools to get_function_schema format
for tool in get_calculator_tools():
func = tool["function"]
schema = get_function_schema(
func["name"],
func["description"],
properties=func["parameters"]["properties"],
required=func["parameters"]["required"],
)
self._builtin_function_schemas.append(schema)
# Transform timezone tools to get_function_schema format
for tool in get_time_tools():
func = tool["function"]
schema = get_function_schema(
func["name"],
func["description"],
properties=func["parameters"]["properties"],
required=func["parameters"]["required"],
)
self._builtin_function_schemas.append(schema)
return self._builtin_function_schemas
async def initialize(self):
# TODO: May be set_node in a separate task so that we return from initialize immediately
if self._initialized:
@ -197,9 +158,6 @@ class PipecatEngine:
except Exception as e:
logger.error(f"Failed to fetch current EST time: {e}")
# Register built-in functions with the LLM
await self._register_builtin_functions()
await self.set_node(self.workflow.start_node_id)
logger.debug(f"{self.__class__.__name__} initialized")
@ -316,57 +274,6 @@ class PipecatEngine:
cancel_on_interruption=False,
)
async def _register_builtin_functions(self):
"""Register built-in functions (calculator and timezone) with the LLM."""
logger.debug("Registering built-in functions with LLM")
# Register calculator function
async def calculate_func(function_call_params: FunctionCallParams) -> None:
logger.info(f"LLM Function Call EXECUTED: safe_calculator")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
expr = function_call_params.arguments.get("expression", "")
result = safe_calculator(expr)
await function_call_params.result_callback(
{"expression": expr, "result": result}
)
except Exception as e:
await function_call_params.result_callback({"error": str(e)})
# Register timezone functions
async def get_current_time_func(
function_call_params: FunctionCallParams,
) -> None:
logger.info(f"LLM Function Call EXECUTED: get_current_time")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
timezone = function_call_params.arguments.get("timezone", "UTC")
result = get_current_time(timezone)
await function_call_params.result_callback(result)
except Exception as e:
await function_call_params.result_callback({"error": str(e)})
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
logger.info(f"LLM Function Call EXECUTED: convert_time")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
result = convert_time(
function_call_params.arguments.get("source_timezone"),
function_call_params.arguments.get("time"),
function_call_params.arguments.get("target_timezone"),
)
await function_call_params.result_callback(result)
except Exception as e:
await function_call_params.result_callback({"error": str(e)})
# Register all built-in functions
self.llm.register_function("safe_calculator", calculate_func)
self.llm.register_function("get_current_time", get_current_time_func)
self.llm.register_function("convert_time", convert_time_func)
async def _register_knowledge_base_function(
self, document_uuids: list[str]
) -> None:
@ -553,7 +460,6 @@ class PipecatEngine:
)
functions = await compose_functions_for_node(
node=node,
builtin_function_schemas=self.builtin_function_schemas,
custom_tool_manager=self._custom_tool_manager,
)
await self._update_llm_context(system_prompt, functions)