Fix/multiple generation (#104)

* fixes #100

* Fix test

* fix: fix bad configuration issue
This commit is contained in:
Abhishek 2026-01-03 12:59:18 +05:30 committed by GitHub
parent 90b690efff
commit 56953bbd09
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 758 additions and 460 deletions

View file

@ -13,9 +13,8 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
FunctionCallResultProperties,
FunctionCallsFromLLMInfoFrame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
@ -104,7 +103,7 @@ class PipecatEngine:
self._builtin_function_schemas: Optional[list[dict]] = None
# Track current LLM reference text for TTS aggregation correction
self._current_llm_reference_text: str = ""
self._current_llm_generation_reference_text: str = ""
# Custom tool manager (initialized in initialize())
self._custom_tool_manager: Optional[CustomToolManager] = None
@ -173,6 +172,9 @@ class PipecatEngine:
await self._register_builtin_functions()
await self.set_node(self.workflow.start_node_id)
# Trigger initial LLM generation
await self.task.queue_frame(LLMContextFrame(self.context))
logger.debug(f"{self.__class__.__name__} initialized")
except Exception as e:
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
@ -218,7 +220,6 @@ class PipecatEngine:
result = {"status": "done"}
properties = FunctionCallResultProperties(
run_llm=False,
on_context_updated=on_context_updated,
)
@ -256,8 +257,6 @@ class PipecatEngine:
"""Register built-in functions (calculator and timezone) with the LLM."""
logger.debug("Registering built-in functions with LLM")
properties = FunctionCallResultProperties(run_llm=True)
# Register calculator function
async def calculate_func(function_call_params: FunctionCallParams) -> None:
logger.info(f"LLM Function Call EXECUTED: safe_calculator")
@ -266,12 +265,10 @@ class PipecatEngine:
expr = function_call_params.arguments.get("expression", "")
result = safe_calculator(expr)
await function_call_params.result_callback(
{"expression": expr, "result": result}, properties=properties
{"expression": expr, "result": result}
)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
await function_call_params.result_callback({"error": str(e)})
# Register timezone functions
async def get_current_time_func(
@ -282,13 +279,9 @@ class PipecatEngine:
try:
timezone = function_call_params.arguments.get("timezone", "UTC")
result = get_current_time(timezone)
await function_call_params.result_callback(
result, properties=properties
)
await function_call_params.result_callback(result)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
await function_call_params.result_callback({"error": str(e)})
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
logger.info(f"LLM Function Call EXECUTED: convert_time")
@ -299,29 +292,15 @@ class PipecatEngine:
function_call_params.arguments.get("time"),
function_call_params.arguments.get("target_timezone"),
)
await function_call_params.result_callback(
result, properties=properties
)
await function_call_params.result_callback(result)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
await function_call_params.result_callback({"error": str(e)})
# Register all built-in functions
self.llm.register_function("safe_calculator", calculate_func)
self.llm.register_function("get_current_time", get_current_time_func)
self.llm.register_function("convert_time", convert_time_func)
async def _queue_tts_response(self, text: str) -> None:
"""Queue TTS frames for static text response."""
await self.task.queue_frames(
[
LLMFullResponseStartFrame(),
TTSSpeakFrame(text=text),
LLMFullResponseEndFrame(),
]
)
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
) -> None:
@ -384,7 +363,6 @@ class PipecatEngine:
functions,
) = await self._compose_system_message_functions_for_node(node)
await self._update_llm_context(system_message, functions)
await self.task.queue_frame(LLMContextFrame(self.context))
async def set_node(self, node_id: str):
"""
@ -733,7 +711,7 @@ class PipecatEngine:
async def handle_llm_text_frame(self, text: str):
"""Accumulate LLM text frames to build reference text."""
self._current_llm_reference_text += text
self._current_llm_generation_reference_text += text
def handle_client_disconnected(self):
"""Handle client disconnected event."""