mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: set calculator as custom tool on demand
This commit is contained in:
parent
89fce77438
commit
f368fe5134
13 changed files with 265 additions and 157 deletions
|
|
@ -40,20 +40,14 @@ from api.services.workflow.pipecat_engine_context_composer import (
|
|||
)
|
||||
from api.services.workflow.pipecat_engine_custom_tools import (
|
||||
CustomToolManager,
|
||||
get_function_schema,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
|
||||
from api.services.workflow.tools.knowledge_base import (
|
||||
retrieve_from_knowledge_base,
|
||||
)
|
||||
from api.services.workflow.tools.timezone import (
|
||||
convert_time,
|
||||
get_current_time,
|
||||
get_time_tools,
|
||||
)
|
||||
from api.services.workflow.tools.timezone import get_current_time
|
||||
from api.utils.template_renderer import render_template
|
||||
|
||||
|
||||
|
|
@ -93,9 +87,6 @@ class PipecatEngine:
|
|||
# access to _context
|
||||
self._variable_extraction_manager = None
|
||||
|
||||
# Lazy loaded built-in function schemas
|
||||
self._builtin_function_schemas: Optional[list[dict]] = None
|
||||
|
||||
# Track current LLM reference text for TTS aggregation correction
|
||||
self._current_llm_generation_reference_text: str = ""
|
||||
|
||||
|
|
@ -144,36 +135,6 @@ class PipecatEngine:
|
|||
return None
|
||||
return tracing_ctx.get_turn_context() or tracing_ctx.get_conversation_context()
|
||||
|
||||
@property
|
||||
def builtin_function_schemas(self) -> list[dict]:
|
||||
"""Get built-in function schemas (calculator and timezone tools)."""
|
||||
if self._builtin_function_schemas is None:
|
||||
self._builtin_function_schemas = []
|
||||
|
||||
# Transform calculator tools to get_function_schema format
|
||||
for tool in get_calculator_tools():
|
||||
func = tool["function"]
|
||||
schema = get_function_schema(
|
||||
func["name"],
|
||||
func["description"],
|
||||
properties=func["parameters"]["properties"],
|
||||
required=func["parameters"]["required"],
|
||||
)
|
||||
self._builtin_function_schemas.append(schema)
|
||||
|
||||
# Transform timezone tools to get_function_schema format
|
||||
for tool in get_time_tools():
|
||||
func = tool["function"]
|
||||
schema = get_function_schema(
|
||||
func["name"],
|
||||
func["description"],
|
||||
properties=func["parameters"]["properties"],
|
||||
required=func["parameters"]["required"],
|
||||
)
|
||||
self._builtin_function_schemas.append(schema)
|
||||
|
||||
return self._builtin_function_schemas
|
||||
|
||||
async def initialize(self):
|
||||
# TODO: May be set_node in a separate task so that we return from initialize immediately
|
||||
if self._initialized:
|
||||
|
|
@ -197,9 +158,6 @@ class PipecatEngine:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to fetch current EST time: {e}")
|
||||
|
||||
# Register built-in functions with the LLM
|
||||
await self._register_builtin_functions()
|
||||
|
||||
await self.set_node(self.workflow.start_node_id)
|
||||
|
||||
logger.debug(f"{self.__class__.__name__} initialized")
|
||||
|
|
@ -316,57 +274,6 @@ class PipecatEngine:
|
|||
cancel_on_interruption=False,
|
||||
)
|
||||
|
||||
async def _register_builtin_functions(self):
|
||||
"""Register built-in functions (calculator and timezone) with the LLM."""
|
||||
logger.debug("Registering built-in functions with LLM")
|
||||
|
||||
# Register calculator function
|
||||
async def calculate_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: safe_calculator")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
|
||||
try:
|
||||
expr = function_call_params.arguments.get("expression", "")
|
||||
result = safe_calculator(expr)
|
||||
await function_call_params.result_callback(
|
||||
{"expression": expr, "result": result}
|
||||
)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback({"error": str(e)})
|
||||
|
||||
# Register timezone functions
|
||||
async def get_current_time_func(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: get_current_time")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
|
||||
try:
|
||||
timezone = function_call_params.arguments.get("timezone", "UTC")
|
||||
result = get_current_time(timezone)
|
||||
await function_call_params.result_callback(result)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback({"error": str(e)})
|
||||
|
||||
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: convert_time")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
|
||||
try:
|
||||
result = convert_time(
|
||||
function_call_params.arguments.get("source_timezone"),
|
||||
function_call_params.arguments.get("time"),
|
||||
function_call_params.arguments.get("target_timezone"),
|
||||
)
|
||||
await function_call_params.result_callback(result)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback({"error": str(e)})
|
||||
|
||||
# Register all built-in functions
|
||||
self.llm.register_function("safe_calculator", calculate_func)
|
||||
self.llm.register_function("get_current_time", get_current_time_func)
|
||||
self.llm.register_function("convert_time", convert_time_func)
|
||||
|
||||
async def _register_knowledge_base_function(
|
||||
self, document_uuids: list[str]
|
||||
) -> None:
|
||||
|
|
@ -553,7 +460,6 @@ class PipecatEngine:
|
|||
)
|
||||
functions = await compose_functions_for_node(
|
||||
node=node,
|
||||
builtin_function_schemas=self.builtin_function_schemas,
|
||||
custom_tool_manager=self._custom_tool_manager,
|
||||
)
|
||||
await self._update_llm_context(system_prompt, functions)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue