fix: fix OPENAI_API_KEY bug in retrieval

This commit is contained in:
Abhishek Kumar 2026-01-17 18:12:56 +05:30
parent 692ef27751
commit d35eeb1b7b
11 changed files with 508 additions and 115 deletions

View file

@ -69,6 +69,8 @@ class PipecatEngine:
node_transition_callback: Optional[
Callable[[str, Optional[str]], Awaitable[None]]
] = None,
embeddings_api_key: Optional[str] = None,
embeddings_model: Optional[str] = None,
):
self.task = task
self.llm = llm
@ -103,6 +105,10 @@ class PipecatEngine:
# Custom tool manager (initialized in initialize())
self._custom_tool_manager: Optional[CustomToolManager] = None
# Embeddings configuration (passed from run_pipeline.py)
self._embeddings_api_key: Optional[str] = embeddings_api_key
self._embeddings_model: Optional[str] = embeddings_model
async def _get_organization_id(self) -> Optional[int]:
"""Get and cache the organization ID from workflow run."""
if self._custom_tool_manager:
@ -318,11 +324,19 @@ class PipecatEngine:
"Organization ID not available for knowledge base retrieval"
)
if not self._embeddings_api_key:
raise ValueError(
"Embeddings API key not configured. Please set your API key in "
"Model Configurations > Embedding."
)
result = await retrieve_from_knowledge_base(
query=query,
organization_id=organization_id,
document_uuids=document_uuids,
limit=3, # Return top 3 most relevant chunks
embeddings_api_key=self._embeddings_api_key,
embeddings_model=self._embeddings_model,
)
await function_call_params.result_callback(result)

View file

@ -1,13 +1,11 @@
from __future__ import annotations
import json
import os
from typing import TYPE_CHECKING, Any, List
from loguru import logger
from openai import AsyncOpenAI
from opentelemetry import trace
from api.services.gen_ai.json_parser import parse_llm_json
from api.services.pipecat.tracing_config import is_tracing_enabled
from api.services.workflow.dto import ExtractionVariableDTO
from pipecat.processors.aggregators.llm_context import LLMContext
@ -32,7 +30,6 @@ class VariableExtractionManager:
# and update internal counters / extracted variable state.
self._engine = engine
self._context = engine.context
self._model = "gpt-4o"
# ------------------------------------------------------------------
# Internal helpers
@ -147,46 +144,43 @@ class VariableExtractionManager:
extraction_context.set_messages(extraction_messages)
# ------------------------------------------------------------------
# Use independent OpenAI client for LLM call
# Use engine's LLM for out-of-band inference (no pipeline frames)
# ------------------------------------------------------------------
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
llm_response = await self._engine.llm.run_inference(extraction_context)
# Direct API call - no pipeline involvement
response = await client.chat.completions.create(
model=self._model,
messages=extraction_messages,
temperature=0.0,
response_format={"type": "json_object"},
)
llm_response = response.choices[0].message.content
# Get model name for tracing
model_name = getattr(self._engine.llm, "model_name", "unknown")
if is_tracing_enabled():
tracer = trace.get_tracer("pipecat")
with tracer.start_as_current_span(
"variable_extraction", context=parent_ctx
"llm-variable-extraction", context=parent_ctx
) as span:
add_llm_span_attributes(
span,
service_name="OpenAILLMService",
model=self._model,
operation_name="variable_extraction",
service_name=self._engine.llm.__class__.__name__,
model=model_name,
operation_name="llm-variable-extraction",
messages=extraction_messages,
output=llm_response,
stream=False,
parameters={"temperature": 0.0, "response_format": "json_object"},
parameters={},
)
# ------------------------------------------------------------------
# Parse the assistant output fall back to raw text if it is not valid JSON.
# Uses parse_llm_json which handles common LLM mistakes like markdown
# code blocks (```json ... ```) and extra text around the JSON.
# ------------------------------------------------------------------
try:
extracted = json.loads(llm_response)
except json.JSONDecodeError:
logger.warning(
"Extractor returned invalid JSON; storing raw content instead."
)
extracted = {"raw": llm_response}
if llm_response is None:
logger.warning("Extractor returned no response; returning empty result.")
extracted = {}
else:
extracted = parse_llm_json(llm_response)
if "raw" in extracted and len(extracted) == 1:
logger.warning(
"Extractor returned invalid JSON; storing raw content instead."
)
logger.debug(f"Extracted variables: {extracted}")
return extracted