mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
fix: fix OPENAI_API_KEY bug in retrieval
This commit is contained in:
parent
692ef27751
commit
d35eeb1b7b
11 changed files with 508 additions and 115 deletions
|
|
@ -69,6 +69,8 @@ class PipecatEngine:
|
|||
node_transition_callback: Optional[
|
||||
Callable[[str, Optional[str]], Awaitable[None]]
|
||||
] = None,
|
||||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
):
|
||||
self.task = task
|
||||
self.llm = llm
|
||||
|
|
@ -103,6 +105,10 @@ class PipecatEngine:
|
|||
# Custom tool manager (initialized in initialize())
|
||||
self._custom_tool_manager: Optional[CustomToolManager] = None
|
||||
|
||||
# Embeddings configuration (passed from run_pipeline.py)
|
||||
self._embeddings_api_key: Optional[str] = embeddings_api_key
|
||||
self._embeddings_model: Optional[str] = embeddings_model
|
||||
|
||||
async def _get_organization_id(self) -> Optional[int]:
|
||||
"""Get and cache the organization ID from workflow run."""
|
||||
if self._custom_tool_manager:
|
||||
|
|
@ -318,11 +324,19 @@ class PipecatEngine:
|
|||
"Organization ID not available for knowledge base retrieval"
|
||||
)
|
||||
|
||||
if not self._embeddings_api_key:
|
||||
raise ValueError(
|
||||
"Embeddings API key not configured. Please set your API key in "
|
||||
"Model Configurations > Embedding."
|
||||
)
|
||||
|
||||
result = await retrieve_from_knowledge_base(
|
||||
query=query,
|
||||
organization_id=organization_id,
|
||||
document_uuids=document_uuids,
|
||||
limit=3, # Return top 3 most relevant chunks
|
||||
embeddings_api_key=self._embeddings_api_key,
|
||||
embeddings_model=self._embeddings_model,
|
||||
)
|
||||
|
||||
await function_call_params.result_callback(result)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from opentelemetry import trace
|
||||
|
||||
from api.services.gen_ai.json_parser import parse_llm_json
|
||||
from api.services.pipecat.tracing_config import is_tracing_enabled
|
||||
from api.services.workflow.dto import ExtractionVariableDTO
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
|
@ -32,7 +30,6 @@ class VariableExtractionManager:
|
|||
# and update internal counters / extracted variable state.
|
||||
self._engine = engine
|
||||
self._context = engine.context
|
||||
self._model = "gpt-4o"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
|
|
@ -147,46 +144,43 @@ class VariableExtractionManager:
|
|||
extraction_context.set_messages(extraction_messages)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Use independent OpenAI client for LLM call
|
||||
# Use engine's LLM for out-of-band inference (no pipeline frames)
|
||||
# ------------------------------------------------------------------
|
||||
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
llm_response = await self._engine.llm.run_inference(extraction_context)
|
||||
|
||||
# Direct API call - no pipeline involvement
|
||||
response = await client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=extraction_messages,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
llm_response = response.choices[0].message.content
|
||||
# Get model name for tracing
|
||||
model_name = getattr(self._engine.llm, "model_name", "unknown")
|
||||
|
||||
if is_tracing_enabled():
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
with tracer.start_as_current_span(
|
||||
"variable_extraction", context=parent_ctx
|
||||
"llm-variable-extraction", context=parent_ctx
|
||||
) as span:
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name="OpenAILLMService",
|
||||
model=self._model,
|
||||
operation_name="variable_extraction",
|
||||
service_name=self._engine.llm.__class__.__name__,
|
||||
model=model_name,
|
||||
operation_name="llm-variable-extraction",
|
||||
messages=extraction_messages,
|
||||
output=llm_response,
|
||||
stream=False,
|
||||
parameters={"temperature": 0.0, "response_format": "json_object"},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Parse the assistant output – fall back to raw text if it is not valid JSON.
|
||||
# Uses parse_llm_json which handles common LLM mistakes like markdown
|
||||
# code blocks (```json ... ```) and extra text around the JSON.
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
extracted = json.loads(llm_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Extractor returned invalid JSON; storing raw content instead."
|
||||
)
|
||||
extracted = {"raw": llm_response}
|
||||
if llm_response is None:
|
||||
logger.warning("Extractor returned no response; returning empty result.")
|
||||
extracted = {}
|
||||
else:
|
||||
extracted = parse_llm_json(llm_response)
|
||||
if "raw" in extracted and len(extracted) == 1:
|
||||
logger.warning(
|
||||
"Extractor returned invalid JSON; storing raw content instead."
|
||||
)
|
||||
|
||||
logger.debug(f"Extracted variables: {extracted}")
|
||||
return extracted
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue