mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
Fix/multiple generation (#104)
* fixes #100 * Fix test * fix: fix bad configuration issue
This commit is contained in:
parent
90b690efff
commit
56953bbd09
18 changed files with 758 additions and 460 deletions
|
|
@ -13,9 +13,8 @@ from pipecat.frames.frames import (
|
|||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallResultProperties,
|
||||
FunctionCallsFromLLMInfoFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
|
|
@ -104,7 +103,7 @@ class PipecatEngine:
|
|||
self._builtin_function_schemas: Optional[list[dict]] = None
|
||||
|
||||
# Track current LLM reference text for TTS aggregation correction
|
||||
self._current_llm_reference_text: str = ""
|
||||
self._current_llm_generation_reference_text: str = ""
|
||||
|
||||
# Custom tool manager (initialized in initialize())
|
||||
self._custom_tool_manager: Optional[CustomToolManager] = None
|
||||
|
|
@ -173,6 +172,9 @@ class PipecatEngine:
|
|||
await self._register_builtin_functions()
|
||||
|
||||
await self.set_node(self.workflow.start_node_id)
|
||||
|
||||
# Trigger initial LLM generation
|
||||
await self.task.queue_frame(LLMContextFrame(self.context))
|
||||
logger.debug(f"{self.__class__.__name__} initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
|
||||
|
|
@ -218,7 +220,6 @@ class PipecatEngine:
|
|||
result = {"status": "done"}
|
||||
|
||||
properties = FunctionCallResultProperties(
|
||||
run_llm=False,
|
||||
on_context_updated=on_context_updated,
|
||||
)
|
||||
|
||||
|
|
@ -256,8 +257,6 @@ class PipecatEngine:
|
|||
"""Register built-in functions (calculator and timezone) with the LLM."""
|
||||
logger.debug("Registering built-in functions with LLM")
|
||||
|
||||
properties = FunctionCallResultProperties(run_llm=True)
|
||||
|
||||
# Register calculator function
|
||||
async def calculate_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: safe_calculator")
|
||||
|
|
@ -266,12 +265,10 @@ class PipecatEngine:
|
|||
expr = function_call_params.arguments.get("expression", "")
|
||||
result = safe_calculator(expr)
|
||||
await function_call_params.result_callback(
|
||||
{"expression": expr, "result": result}, properties=properties
|
||||
{"expression": expr, "result": result}
|
||||
)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback({"error": str(e)})
|
||||
|
||||
# Register timezone functions
|
||||
async def get_current_time_func(
|
||||
|
|
@ -282,13 +279,9 @@ class PipecatEngine:
|
|||
try:
|
||||
timezone = function_call_params.arguments.get("timezone", "UTC")
|
||||
result = get_current_time(timezone)
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback(result)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback({"error": str(e)})
|
||||
|
||||
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: convert_time")
|
||||
|
|
@ -299,29 +292,15 @@ class PipecatEngine:
|
|||
function_call_params.arguments.get("time"),
|
||||
function_call_params.arguments.get("target_timezone"),
|
||||
)
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback(result)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback({"error": str(e)})
|
||||
|
||||
# Register all built-in functions
|
||||
self.llm.register_function("safe_calculator", calculate_func)
|
||||
self.llm.register_function("get_current_time", get_current_time_func)
|
||||
self.llm.register_function("convert_time", convert_time_func)
|
||||
|
||||
async def _queue_tts_response(self, text: str) -> None:
|
||||
"""Queue TTS frames for static text response."""
|
||||
await self.task.queue_frames(
|
||||
[
|
||||
LLMFullResponseStartFrame(),
|
||||
TTSSpeakFrame(text=text),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
async def _perform_variable_extraction_if_needed(
|
||||
self, previous_node: Optional[Node]
|
||||
) -> None:
|
||||
|
|
@ -384,7 +363,6 @@ class PipecatEngine:
|
|||
functions,
|
||||
) = await self._compose_system_message_functions_for_node(node)
|
||||
await self._update_llm_context(system_message, functions)
|
||||
await self.task.queue_frame(LLMContextFrame(self.context))
|
||||
|
||||
async def set_node(self, node_id: str):
|
||||
"""
|
||||
|
|
@ -733,7 +711,7 @@ class PipecatEngine:
|
|||
|
||||
async def handle_llm_text_frame(self, text: str):
|
||||
"""Accumulate LLM text frames to build reference text."""
|
||||
self._current_llm_reference_text += text
|
||||
self._current_llm_generation_reference_text += text
|
||||
|
||||
def handle_client_disconnected(self):
|
||||
"""Handle client disconnected event."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue