fix: fix OPENAI_API_KEY bug in retrieval

This commit is contained in:
Abhishek Kumar 2026-01-17 18:12:56 +05:30
parent 692ef27751
commit d35eeb1b7b
11 changed files with 508 additions and 115 deletions

View file

@ -6,10 +6,12 @@ from .embedding import (
OpenAIEmbeddingService,
SentenceTransformerEmbeddingService,
)
from .json_parser import parse_llm_json
__all__ = [
"BaseEmbeddingService",
"EmbeddingAPIKeyNotConfiguredError",
"SentenceTransformerEmbeddingService",
"OpenAIEmbeddingService",
"parse_llm_json",
]

View file

@ -0,0 +1,154 @@
"""Robust JSON parser for handling common LLM output mistakes."""
from __future__ import annotations
import json
import re
from typing import Any
def parse_llm_json(raw_content: str) -> dict[str, Any]:
"""Parse JSON from LLM output, handling common formatting issues.
Handles the following common LLM mistakes:
1. JSON wrapped in markdown code blocks (```json ... ``` or ``` ... ```)
2. Extra whitespace or newlines around JSON
3. Text before/after the JSON object
Args:
raw_content: The raw string output from the LLM.
Returns:
Parsed JSON as a dictionary. If parsing fails, returns {"raw": raw_content}.
"""
if not raw_content or not raw_content.strip():
return {}
content = raw_content.strip()
# Attempt 1: Direct parse (ideal case)
parsed = _try_parse_json(content)
if parsed is not None:
return parsed
# Attempt 2: Remove markdown code block wrappers
# Matches ```json ... ``` or ``` ... ```
code_block_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
code_block_match = re.search(code_block_pattern, content)
if code_block_match:
extracted = code_block_match.group(1).strip()
parsed = _try_parse_json(extracted)
if parsed is not None:
return parsed
# Attempt 3: Find JSON object by matching braces
parsed = _extract_json_object(content)
if parsed is not None:
return parsed
# Attempt 4: Find JSON array by matching brackets
parsed = _extract_json_array(content)
if parsed is not None:
return parsed
# All attempts failed - return raw content
return {"raw": raw_content}
def _try_parse_json(content: str) -> dict[str, Any] | list | None:
"""Attempt to parse JSON, returning None on failure."""
try:
result = json.loads(content)
if isinstance(result, (dict, list)):
return result
return None
except json.JSONDecodeError:
return None
def _extract_json_object(content: str) -> dict[str, Any] | None:
"""Extract a JSON object from text by finding matching braces."""
# Find the first opening brace
start = content.find("{")
if start == -1:
return None
# Find matching closing brace by counting braces
depth = 0
in_string = False
escape_next = False
end = -1
for i, char in enumerate(content[start:], start=start):
if escape_next:
escape_next = False
continue
if char == "\\":
escape_next = True
continue
if char == '"' and not escape_next:
in_string = not in_string
continue
if in_string:
continue
if char == "{":
depth += 1
elif char == "}":
depth -= 1
if depth == 0:
end = i
break
if end == -1:
return None
json_str = content[start : end + 1]
return _try_parse_json(json_str)
def _extract_json_array(content: str) -> list | None:
"""Extract a JSON array from text by finding matching brackets."""
# Find the first opening bracket
start = content.find("[")
if start == -1:
return None
# Find matching closing bracket by counting brackets
depth = 0
in_string = False
escape_next = False
end = -1
for i, char in enumerate(content[start:], start=start):
if escape_next:
escape_next = False
continue
if char == "\\":
escape_next = True
continue
if char == '"' and not escape_next:
in_string = not in_string
continue
if in_string:
continue
if char == "[":
depth += 1
elif char == "]":
depth -= 1
if depth == 0:
end = i
break
if end == -1:
return None
json_str = content[start : end + 1]
return _try_parse_json(json_str)

View file

@ -29,7 +29,6 @@ from api.services.pipecat.service_factory import (
create_llm_service,
create_stt_service,
create_tts_service,
create_voicemail_classification_llm,
)
from api.services.pipecat.tracing_config import setup_pipeline_tracing
from api.services.pipecat.transport_setup import (
@ -501,12 +500,21 @@ async def _run_pipeline(
node_transition_callback = send_node_transition
# Extract embeddings configuration from user config
embeddings_api_key = None
embeddings_model = None
if user_config and user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
engine = PipecatEngine(
llm=llm,
workflow=workflow_graph,
call_context_vars=merged_call_context_vars,
workflow_run_id=workflow_run_id,
node_transition_callback=node_transition_callback,
embeddings_api_key=embeddings_api_key,
embeddings_model=embeddings_model,
)
# Create pipeline components with audio configuration and engine
@ -562,24 +570,23 @@ async def _run_pipeline(
voicemail_detector = None
start_node = workflow_graph.nodes.get(workflow_graph.start_node_id)
if start_node and start_node.detect_voicemail:
classification_llm = create_voicemail_classification_llm()
if classification_llm:
logger.info(
f"Voicemail detection enabled for workflow run {workflow_run_id}"
)
voicemail_detector = VoicemailDetector(
llm=classification_llm,
voicemail_response_delay=2.0,
)
logger.info(f"Voicemail detection enabled for workflow run {workflow_run_id}")
# Create a separate LLM instance for the voicemail sub-pipeline
# (can't share with main pipeline as it would mess up frame linking)
voicemail_llm = create_llm_service(user_config)
voicemail_detector = VoicemailDetector(
llm=voicemail_llm,
voicemail_response_delay=2.0,
)
# Register event handler to end task when voicemail is detected
@voicemail_detector.event_handler("on_voicemail_detected")
async def _on_voicemail_detected(_processor):
logger.info(f"Voicemail detected for workflow run {workflow_run_id}")
await engine.send_end_task_frame(
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
abort_immediately=True,
)
# Register event handler to end task when voicemail is detected
@voicemail_detector.event_handler("on_voicemail_detected")
async def _on_voicemail_detected(_processor):
logger.info(f"Voicemail detected for workflow run {workflow_run_id}")
await engine.send_end_task_frame(
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
abort_immediately=True,
)
# Build the pipeline with the STT mute filter and context controller
pipeline = build_pipeline(

View file

@ -1,4 +1,3 @@
import os
from typing import TYPE_CHECKING
from fastapi import HTTPException
@ -242,24 +241,3 @@ def create_llm_service(user_config):
)
else:
raise HTTPException(status_code=400, detail="Invalid LLM provider")
def create_voicemail_classification_llm():
"""Create a fast, lightweight LLM service for voicemail classification.
Uses gpt-4o-mini which is fast and cost-effective for simple classification tasks.
The model only needs to output "CONVERSATION" or "VOICEMAIL" based on transcriptions.
Returns:
OpenAILLMService instance, or None if OPENAI_API_KEY is not set.
"""
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
logger.warning("OPENAI_API_KEY not set - voicemail detection will be disabled")
return None
return OpenAILLMService(
api_key=api_key,
model="gpt-4o",
params=OpenAILLMService.InputParams(temperature=0.0),
)

View file

@ -278,7 +278,9 @@ class TelephonyProvider(ABC):
@staticmethod
@abstractmethod
async def generate_inbound_response(websocket_url: str, workflow_run_id: int = None) -> tuple:
async def generate_inbound_response(
websocket_url: str, workflow_run_id: int = None
) -> tuple:
"""
Generate the appropriate response for an inbound webhook.

View file

@ -434,29 +434,37 @@ class CloudonixProvider(TelephonyProvider):
user_agent = headers.get("user-agent", "").lower()
if "cloudonix" in user_agent:
return True
# 2: Check for Cloudonix-specific headers
cloudonix_headers = ["x-cx-apikey", "x-cx-domain", "x-cx-session", "x-cx-source"]
cloudonix_headers = [
"x-cx-apikey",
"x-cx-domain",
"x-cx-session",
"x-cx-source",
]
if any(header in headers for header in cloudonix_headers):
return True
# 3: Check data structure for Cloudonix-specific fields
if ("SessionData" in webhook_data and "Domain" in webhook_data and
webhook_data.get("Domain", "").endswith(".cloudonix.net")):
if (
"SessionData" in webhook_data
and "Domain" in webhook_data
and webhook_data.get("Domain", "").endswith(".cloudonix.net")
):
return True
# Check if AccountSid is a Cloudonix domain
account_sid = webhook_data.get("AccountSid", "")
if account_sid.endswith(".cloudonix.net"):
return True
return False
@staticmethod
def parse_inbound_webhook(webhook_data: Dict[str, Any]) -> NormalizedInboundData:
"""
Parse Cloudonix-specific inbound webhook data into normalized format.
Cloudonix webhook structure includes:
- CallSid: Call id
- From: Caller number
@ -467,13 +475,11 @@ class CloudonixProvider(TelephonyProvider):
session_data = webhook_data.get("SessionData", {})
token = session_data.get("token", "") if isinstance(session_data, dict) else ""
call_id = (webhook_data.get("Session") or
webhook_data.get("CallSid") or
token)
account_id = (webhook_data.get("Domain") or webhook_data.get("AccountSid", ""))
call_id = webhook_data.get("Session") or webhook_data.get("CallSid") or token
account_id = webhook_data.get("Domain") or webhook_data.get("AccountSid", "")
# Extract underlying provider information from SessionData if available
session_data = webhook_data.get("SessionData", {})
underlying_provider = None
@ -482,7 +488,7 @@ class CloudonixProvider(TelephonyProvider):
trunk_headers = profile.get("trunk-sip-headers", {})
if "Twilio-AccountSid" in trunk_headers:
underlying_provider = "twilio"
return NormalizedInboundData(
provider=CloudonixProvider.PROVIDER_NAME,
call_id=call_id,
@ -492,7 +498,7 @@ class CloudonixProvider(TelephonyProvider):
call_status=webhook_data.get("CallStatus", "in-progress"),
account_id=account_id,
from_country=webhook_data.get("FromCountry"),
to_country=webhook_data.get("ToCountry"),
to_country=webhook_data.get("ToCountry"),
raw_data={
**webhook_data,
"underlying_provider": underlying_provider,
@ -503,9 +509,9 @@ class CloudonixProvider(TelephonyProvider):
def validate_account_id(config_data: dict, webhook_account_id: str) -> bool:
"""
Validate that the account_id from webhook matches the Cloudonix configuration.
For Cloudonix:
- webhook_account_id is the Domain field (e.g., "test1.cloudonix.net")
- webhook_account_id is the Domain field (e.g., "test1.cloudonix.net")
- config domain_id stores the same domain string
"""
if not webhook_account_id:
@ -551,30 +557,30 @@ class CloudonixProvider(TelephonyProvider):
) -> bool:
"""
Verify the API key of an inbound Cloudonix webhook for security.
Cloudonix uses x-cx-apikey header validation instead of signature verification.
The API key from the webhook should match the bearer_token in our configuration.
"""
if not api_key:
logger.warning("No x-cx-apikey provided in Cloudonix webhook")
return False
# The bearer_token in config is the same as x-cx-apikey header value
if not self.bearer_token:
logger.warning("No bearer_token configured for Cloudonix provider")
return False
# Compare the API keys
is_valid = api_key == self.bearer_token
if is_valid:
logger.info("Cloudonix x-cx-apikey validation successful")
else:
logger.warning(f"Cloudonix x-cx-apikey validation failed. Expected key ending with ...{self.bearer_token[-8:] if len(self.bearer_token) > 8 else 'SHORT_KEY'}")
return True #TODO: update this post clarification from cloudonix
logger.warning(
f"Cloudonix x-cx-apikey validation failed. Expected key ending with ...{self.bearer_token[-8:] if len(self.bearer_token) > 8 else 'SHORT_KEY'}"
)
return True # TODO: update this post clarification from cloudonix
@staticmethod
async def generate_inbound_response(
@ -582,11 +588,11 @@ class CloudonixProvider(TelephonyProvider):
) -> tuple:
"""
Generate the appropriate CXML response for an inbound Cloudonix webhook.
Returns CXML to connect to WebSocket, same format as outbound calls.
"""
from fastapi import Response
# Generate CXML response (same format as outbound calls)
cxml_content = f"""<?xml version="1.0" encoding="UTF-8"?>
<Response>
@ -595,26 +601,23 @@ class CloudonixProvider(TelephonyProvider):
</Connect>
<Pause length="40"/>
</Response>"""
logger.info(f"Cloudonix inbound CXML response content:")
logger.info(cxml_content)
response = Response(
content=cxml_content,
media_type="application/xml"
)
response = Response(content=cxml_content, media_type="application/xml")
logger.info(f"Cloudonix inbound response object: {response}")
logger.info(f"Response headers: {response.headers}")
logger.info(f"Response media type: {response.media_type}")
return response
@staticmethod
def generate_validation_error_response(error_type) -> tuple:
"""
Generate Cloudonix-specific error response for validation failures.
Since Cloudonix is TwiML-compatible, we use the same XML format.
"""
from fastapi import Response

View file

@ -297,7 +297,7 @@ class TwilioProvider(TelephonyProvider):
) -> bool:
"""
Determine if this provider can handle the incoming webhook.
Twilio webhooks have specific characteristics:
- User-Agent: "TwilioProxy/1.1"
- Headers: "x-twilio-signature", "i-twilio-idempotency-token"
@ -308,21 +308,27 @@ class TwilioProvider(TelephonyProvider):
user_agent = headers.get("user-agent", "")
if "twilioproxy" in user_agent.lower() or user_agent.startswith("TwilioProxy"):
return True
# 2: Check for Twilio-specific headers
twilio_headers = ["x-twilio-signature", "i-twilio-idempotency-token", "x-home-region"]
twilio_headers = [
"x-twilio-signature",
"i-twilio-idempotency-token",
"x-home-region",
]
if any(header in headers for header in twilio_headers):
return True
# 3: Check data structure - CallSid + AccountSid with AC prefix + ApiVersion
if ("CallSid" in webhook_data and
"AccountSid" in webhook_data and
"ApiVersion" in webhook_data):
if (
"CallSid" in webhook_data
and "AccountSid" in webhook_data
and "ApiVersion" in webhook_data
):
# Ensure AccountSid looks like Twilio (starts with AC, not a domain)
account_sid = webhook_data.get("AccountSid", "")
if account_sid.startswith("AC") and not "." in account_sid:
return True
return False
@staticmethod

View file

@ -69,6 +69,8 @@ class PipecatEngine:
node_transition_callback: Optional[
Callable[[str, Optional[str]], Awaitable[None]]
] = None,
embeddings_api_key: Optional[str] = None,
embeddings_model: Optional[str] = None,
):
self.task = task
self.llm = llm
@ -103,6 +105,10 @@ class PipecatEngine:
# Custom tool manager (initialized in initialize())
self._custom_tool_manager: Optional[CustomToolManager] = None
# Embeddings configuration (passed from run_pipeline.py)
self._embeddings_api_key: Optional[str] = embeddings_api_key
self._embeddings_model: Optional[str] = embeddings_model
async def _get_organization_id(self) -> Optional[int]:
"""Get and cache the organization ID from workflow run."""
if self._custom_tool_manager:
@ -318,11 +324,19 @@ class PipecatEngine:
"Organization ID not available for knowledge base retrieval"
)
if not self._embeddings_api_key:
raise ValueError(
"Embeddings API key not configured. Please set your API key in "
"Model Configurations > Embedding."
)
result = await retrieve_from_knowledge_base(
query=query,
organization_id=organization_id,
document_uuids=document_uuids,
limit=3, # Return top 3 most relevant chunks
embeddings_api_key=self._embeddings_api_key,
embeddings_model=self._embeddings_model,
)
await function_call_params.result_callback(result)

View file

@ -1,13 +1,11 @@
from __future__ import annotations
import json
import os
from typing import TYPE_CHECKING, Any, List
from loguru import logger
from openai import AsyncOpenAI
from opentelemetry import trace
from api.services.gen_ai.json_parser import parse_llm_json
from api.services.pipecat.tracing_config import is_tracing_enabled
from api.services.workflow.dto import ExtractionVariableDTO
from pipecat.processors.aggregators.llm_context import LLMContext
@ -32,7 +30,6 @@ class VariableExtractionManager:
# and update internal counters / extracted variable state.
self._engine = engine
self._context = engine.context
self._model = "gpt-4o"
# ------------------------------------------------------------------
# Internal helpers
@ -147,46 +144,43 @@ class VariableExtractionManager:
extraction_context.set_messages(extraction_messages)
# ------------------------------------------------------------------
# Use independent OpenAI client for LLM call
# Use engine's LLM for out-of-band inference (no pipeline frames)
# ------------------------------------------------------------------
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
llm_response = await self._engine.llm.run_inference(extraction_context)
# Direct API call - no pipeline involvement
response = await client.chat.completions.create(
model=self._model,
messages=extraction_messages,
temperature=0.0,
response_format={"type": "json_object"},
)
llm_response = response.choices[0].message.content
# Get model name for tracing
model_name = getattr(self._engine.llm, "model_name", "unknown")
if is_tracing_enabled():
tracer = trace.get_tracer("pipecat")
with tracer.start_as_current_span(
"variable_extraction", context=parent_ctx
"llm-variable-extraction", context=parent_ctx
) as span:
add_llm_span_attributes(
span,
service_name="OpenAILLMService",
model=self._model,
operation_name="variable_extraction",
service_name=self._engine.llm.__class__.__name__,
model=model_name,
operation_name="llm-variable-extraction",
messages=extraction_messages,
output=llm_response,
stream=False,
parameters={"temperature": 0.0, "response_format": "json_object"},
parameters={},
)
# ------------------------------------------------------------------
# Parse the assistant output fall back to raw text if it is not valid JSON.
# Uses parse_llm_json which handles common LLM mistakes like markdown
# code blocks (```json ... ```) and extra text around the JSON.
# ------------------------------------------------------------------
try:
extracted = json.loads(llm_response)
except json.JSONDecodeError:
logger.warning(
"Extractor returned invalid JSON; storing raw content instead."
)
extracted = {"raw": llm_response}
if llm_response is None:
logger.warning("Extractor returned no response; returning empty result.")
extracted = {}
else:
extracted = parse_llm_json(llm_response)
if "raw" in extracted and len(extracted) == 1:
logger.warning(
"Extractor returned invalid JSON; storing raw content instead."
)
logger.debug(f"Extracted variables: {extracted}")
return extracted