mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
fix: fix OPENAI_API_KEY bug in retrieval
This commit is contained in:
parent
692ef27751
commit
d35eeb1b7b
11 changed files with 508 additions and 115 deletions
|
|
@ -1263,7 +1263,9 @@ async def handle_inbound_telephony(
|
|||
|
||||
try:
|
||||
webhook_data, data_source = await parse_webhook_request(request)
|
||||
logger.info(f"Inbound call data with data source: {data_source} and data :{dict(webhook_data)}")
|
||||
logger.info(
|
||||
f"Inbound call data with data source: {data_source} and data :{dict(webhook_data)}"
|
||||
)
|
||||
headers = dict(request.headers)
|
||||
|
||||
# Detect provider and normalize data
|
||||
|
|
|
|||
|
|
@ -6,10 +6,12 @@ from .embedding import (
|
|||
OpenAIEmbeddingService,
|
||||
SentenceTransformerEmbeddingService,
|
||||
)
|
||||
from .json_parser import parse_llm_json
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddingService",
|
||||
"EmbeddingAPIKeyNotConfiguredError",
|
||||
"SentenceTransformerEmbeddingService",
|
||||
"OpenAIEmbeddingService",
|
||||
"parse_llm_json",
|
||||
]
|
||||
|
|
|
|||
154
api/services/gen_ai/json_parser.py
Normal file
154
api/services/gen_ai/json_parser.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
"""Robust JSON parser for handling common LLM output mistakes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
def parse_llm_json(raw_content: str) -> dict[str, Any]:
|
||||
"""Parse JSON from LLM output, handling common formatting issues.
|
||||
|
||||
Handles the following common LLM mistakes:
|
||||
1. JSON wrapped in markdown code blocks (```json ... ``` or ``` ... ```)
|
||||
2. Extra whitespace or newlines around JSON
|
||||
3. Text before/after the JSON object
|
||||
|
||||
Args:
|
||||
raw_content: The raw string output from the LLM.
|
||||
|
||||
Returns:
|
||||
Parsed JSON as a dictionary. If parsing fails, returns {"raw": raw_content}.
|
||||
"""
|
||||
if not raw_content or not raw_content.strip():
|
||||
return {}
|
||||
|
||||
content = raw_content.strip()
|
||||
|
||||
# Attempt 1: Direct parse (ideal case)
|
||||
parsed = _try_parse_json(content)
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
|
||||
# Attempt 2: Remove markdown code block wrappers
|
||||
# Matches ```json ... ``` or ``` ... ```
|
||||
code_block_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
|
||||
code_block_match = re.search(code_block_pattern, content)
|
||||
if code_block_match:
|
||||
extracted = code_block_match.group(1).strip()
|
||||
parsed = _try_parse_json(extracted)
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
|
||||
# Attempt 3: Find JSON object by matching braces
|
||||
parsed = _extract_json_object(content)
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
|
||||
# Attempt 4: Find JSON array by matching brackets
|
||||
parsed = _extract_json_array(content)
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
|
||||
# All attempts failed - return raw content
|
||||
return {"raw": raw_content}
|
||||
|
||||
|
||||
def _try_parse_json(content: str) -> dict[str, Any] | list | None:
|
||||
"""Attempt to parse JSON, returning None on failure."""
|
||||
try:
|
||||
result = json.loads(content)
|
||||
if isinstance(result, (dict, list)):
|
||||
return result
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_json_object(content: str) -> dict[str, Any] | None:
|
||||
"""Extract a JSON object from text by finding matching braces."""
|
||||
# Find the first opening brace
|
||||
start = content.find("{")
|
||||
if start == -1:
|
||||
return None
|
||||
|
||||
# Find matching closing brace by counting braces
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
end = -1
|
||||
|
||||
for i, char in enumerate(content[start:], start=start):
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
|
||||
if char == "\\":
|
||||
escape_next = True
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
continue
|
||||
|
||||
if char == "{":
|
||||
depth += 1
|
||||
elif char == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
end = i
|
||||
break
|
||||
|
||||
if end == -1:
|
||||
return None
|
||||
|
||||
json_str = content[start : end + 1]
|
||||
return _try_parse_json(json_str)
|
||||
|
||||
|
||||
def _extract_json_array(content: str) -> list | None:
|
||||
"""Extract a JSON array from text by finding matching brackets."""
|
||||
# Find the first opening bracket
|
||||
start = content.find("[")
|
||||
if start == -1:
|
||||
return None
|
||||
|
||||
# Find matching closing bracket by counting brackets
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
end = -1
|
||||
|
||||
for i, char in enumerate(content[start:], start=start):
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
|
||||
if char == "\\":
|
||||
escape_next = True
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
continue
|
||||
|
||||
if char == "[":
|
||||
depth += 1
|
||||
elif char == "]":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
end = i
|
||||
break
|
||||
|
||||
if end == -1:
|
||||
return None
|
||||
|
||||
json_str = content[start : end + 1]
|
||||
return _try_parse_json(json_str)
|
||||
|
|
@ -29,7 +29,6 @@ from api.services.pipecat.service_factory import (
|
|||
create_llm_service,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
create_voicemail_classification_llm,
|
||||
)
|
||||
from api.services.pipecat.tracing_config import setup_pipeline_tracing
|
||||
from api.services.pipecat.transport_setup import (
|
||||
|
|
@ -501,12 +500,21 @@ async def _run_pipeline(
|
|||
|
||||
node_transition_callback = send_node_transition
|
||||
|
||||
# Extract embeddings configuration from user config
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
if user_config and user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
workflow=workflow_graph,
|
||||
call_context_vars=merged_call_context_vars,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_transition_callback=node_transition_callback,
|
||||
embeddings_api_key=embeddings_api_key,
|
||||
embeddings_model=embeddings_model,
|
||||
)
|
||||
|
||||
# Create pipeline components with audio configuration and engine
|
||||
|
|
@ -562,24 +570,23 @@ async def _run_pipeline(
|
|||
voicemail_detector = None
|
||||
start_node = workflow_graph.nodes.get(workflow_graph.start_node_id)
|
||||
if start_node and start_node.detect_voicemail:
|
||||
classification_llm = create_voicemail_classification_llm()
|
||||
if classification_llm:
|
||||
logger.info(
|
||||
f"Voicemail detection enabled for workflow run {workflow_run_id}"
|
||||
)
|
||||
voicemail_detector = VoicemailDetector(
|
||||
llm=classification_llm,
|
||||
voicemail_response_delay=2.0,
|
||||
)
|
||||
logger.info(f"Voicemail detection enabled for workflow run {workflow_run_id}")
|
||||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
voicemail_detector = VoicemailDetector(
|
||||
llm=voicemail_llm,
|
||||
voicemail_response_delay=2.0,
|
||||
)
|
||||
|
||||
# Register event handler to end task when voicemail is detected
|
||||
@voicemail_detector.event_handler("on_voicemail_detected")
|
||||
async def _on_voicemail_detected(_processor):
|
||||
logger.info(f"Voicemail detected for workflow run {workflow_run_id}")
|
||||
await engine.send_end_task_frame(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
# Register event handler to end task when voicemail is detected
|
||||
@voicemail_detector.event_handler("on_voicemail_detected")
|
||||
async def _on_voicemail_detected(_processor):
|
||||
logger.info(f"Voicemail detected for workflow run {workflow_run_id}")
|
||||
await engine.send_end_task_frame(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
# Build the pipeline with the STT mute filter and context controller
|
||||
pipeline = build_pipeline(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
|
@ -242,24 +241,3 @@ def create_llm_service(user_config):
|
|||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid LLM provider")
|
||||
|
||||
|
||||
def create_voicemail_classification_llm():
|
||||
"""Create a fast, lightweight LLM service for voicemail classification.
|
||||
|
||||
Uses gpt-4o-mini which is fast and cost-effective for simple classification tasks.
|
||||
The model only needs to output "CONVERSATION" or "VOICEMAIL" based on transcriptions.
|
||||
|
||||
Returns:
|
||||
OpenAILLMService instance, or None if OPENAI_API_KEY is not set.
|
||||
"""
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set - voicemail detection will be disabled")
|
||||
return None
|
||||
|
||||
return OpenAILLMService(
|
||||
api_key=api_key,
|
||||
model="gpt-4o",
|
||||
params=OpenAILLMService.InputParams(temperature=0.0),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -278,7 +278,9 @@ class TelephonyProvider(ABC):
|
|||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
async def generate_inbound_response(websocket_url: str, workflow_run_id: int = None) -> tuple:
|
||||
async def generate_inbound_response(
|
||||
websocket_url: str, workflow_run_id: int = None
|
||||
) -> tuple:
|
||||
"""
|
||||
Generate the appropriate response for an inbound webhook.
|
||||
|
||||
|
|
|
|||
|
|
@ -434,29 +434,37 @@ class CloudonixProvider(TelephonyProvider):
|
|||
user_agent = headers.get("user-agent", "").lower()
|
||||
if "cloudonix" in user_agent:
|
||||
return True
|
||||
|
||||
|
||||
# 2: Check for Cloudonix-specific headers
|
||||
cloudonix_headers = ["x-cx-apikey", "x-cx-domain", "x-cx-session", "x-cx-source"]
|
||||
cloudonix_headers = [
|
||||
"x-cx-apikey",
|
||||
"x-cx-domain",
|
||||
"x-cx-session",
|
||||
"x-cx-source",
|
||||
]
|
||||
if any(header in headers for header in cloudonix_headers):
|
||||
return True
|
||||
|
||||
|
||||
# 3: Check data structure for Cloudonix-specific fields
|
||||
if ("SessionData" in webhook_data and "Domain" in webhook_data and
|
||||
webhook_data.get("Domain", "").endswith(".cloudonix.net")):
|
||||
if (
|
||||
"SessionData" in webhook_data
|
||||
and "Domain" in webhook_data
|
||||
and webhook_data.get("Domain", "").endswith(".cloudonix.net")
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
# Check if AccountSid is a Cloudonix domain
|
||||
account_sid = webhook_data.get("AccountSid", "")
|
||||
if account_sid.endswith(".cloudonix.net"):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def parse_inbound_webhook(webhook_data: Dict[str, Any]) -> NormalizedInboundData:
|
||||
"""
|
||||
Parse Cloudonix-specific inbound webhook data into normalized format.
|
||||
|
||||
|
||||
Cloudonix webhook structure includes:
|
||||
- CallSid: Call id
|
||||
- From: Caller number
|
||||
|
|
@ -467,13 +475,11 @@ class CloudonixProvider(TelephonyProvider):
|
|||
|
||||
session_data = webhook_data.get("SessionData", {})
|
||||
token = session_data.get("token", "") if isinstance(session_data, dict) else ""
|
||||
|
||||
call_id = (webhook_data.get("Session") or
|
||||
webhook_data.get("CallSid") or
|
||||
token)
|
||||
|
||||
account_id = (webhook_data.get("Domain") or webhook_data.get("AccountSid", ""))
|
||||
|
||||
|
||||
call_id = webhook_data.get("Session") or webhook_data.get("CallSid") or token
|
||||
|
||||
account_id = webhook_data.get("Domain") or webhook_data.get("AccountSid", "")
|
||||
|
||||
# Extract underlying provider information from SessionData if available
|
||||
session_data = webhook_data.get("SessionData", {})
|
||||
underlying_provider = None
|
||||
|
|
@ -482,7 +488,7 @@ class CloudonixProvider(TelephonyProvider):
|
|||
trunk_headers = profile.get("trunk-sip-headers", {})
|
||||
if "Twilio-AccountSid" in trunk_headers:
|
||||
underlying_provider = "twilio"
|
||||
|
||||
|
||||
return NormalizedInboundData(
|
||||
provider=CloudonixProvider.PROVIDER_NAME,
|
||||
call_id=call_id,
|
||||
|
|
@ -492,7 +498,7 @@ class CloudonixProvider(TelephonyProvider):
|
|||
call_status=webhook_data.get("CallStatus", "in-progress"),
|
||||
account_id=account_id,
|
||||
from_country=webhook_data.get("FromCountry"),
|
||||
to_country=webhook_data.get("ToCountry"),
|
||||
to_country=webhook_data.get("ToCountry"),
|
||||
raw_data={
|
||||
**webhook_data,
|
||||
"underlying_provider": underlying_provider,
|
||||
|
|
@ -503,9 +509,9 @@ class CloudonixProvider(TelephonyProvider):
|
|||
def validate_account_id(config_data: dict, webhook_account_id: str) -> bool:
|
||||
"""
|
||||
Validate that the account_id from webhook matches the Cloudonix configuration.
|
||||
|
||||
|
||||
For Cloudonix:
|
||||
- webhook_account_id is the Domain field (e.g., "test1.cloudonix.net")
|
||||
- webhook_account_id is the Domain field (e.g., "test1.cloudonix.net")
|
||||
- config domain_id stores the same domain string
|
||||
"""
|
||||
if not webhook_account_id:
|
||||
|
|
@ -551,30 +557,30 @@ class CloudonixProvider(TelephonyProvider):
|
|||
) -> bool:
|
||||
"""
|
||||
Verify the API key of an inbound Cloudonix webhook for security.
|
||||
|
||||
|
||||
Cloudonix uses x-cx-apikey header validation instead of signature verification.
|
||||
The API key from the webhook should match the bearer_token in our configuration.
|
||||
"""
|
||||
if not api_key:
|
||||
logger.warning("No x-cx-apikey provided in Cloudonix webhook")
|
||||
return False
|
||||
|
||||
|
||||
# The bearer_token in config is the same as x-cx-apikey header value
|
||||
if not self.bearer_token:
|
||||
logger.warning("No bearer_token configured for Cloudonix provider")
|
||||
return False
|
||||
|
||||
|
||||
# Compare the API keys
|
||||
is_valid = api_key == self.bearer_token
|
||||
|
||||
|
||||
if is_valid:
|
||||
logger.info("Cloudonix x-cx-apikey validation successful")
|
||||
else:
|
||||
logger.warning(f"Cloudonix x-cx-apikey validation failed. Expected key ending with ...{self.bearer_token[-8:] if len(self.bearer_token) > 8 else 'SHORT_KEY'}")
|
||||
|
||||
return True #TODO: update this post clarification from cloudonix
|
||||
logger.warning(
|
||||
f"Cloudonix x-cx-apikey validation failed. Expected key ending with ...{self.bearer_token[-8:] if len(self.bearer_token) > 8 else 'SHORT_KEY'}"
|
||||
)
|
||||
|
||||
|
||||
return True # TODO: update this post clarification from cloudonix
|
||||
|
||||
@staticmethod
|
||||
async def generate_inbound_response(
|
||||
|
|
@ -582,11 +588,11 @@ class CloudonixProvider(TelephonyProvider):
|
|||
) -> tuple:
|
||||
"""
|
||||
Generate the appropriate CXML response for an inbound Cloudonix webhook.
|
||||
|
||||
|
||||
Returns CXML to connect to WebSocket, same format as outbound calls.
|
||||
"""
|
||||
from fastapi import Response
|
||||
|
||||
|
||||
# Generate CXML response (same format as outbound calls)
|
||||
cxml_content = f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Response>
|
||||
|
|
@ -595,26 +601,23 @@ class CloudonixProvider(TelephonyProvider):
|
|||
</Connect>
|
||||
<Pause length="40"/>
|
||||
</Response>"""
|
||||
|
||||
|
||||
logger.info(f"Cloudonix inbound CXML response content:")
|
||||
logger.info(cxml_content)
|
||||
|
||||
response = Response(
|
||||
content=cxml_content,
|
||||
media_type="application/xml"
|
||||
)
|
||||
|
||||
|
||||
response = Response(content=cxml_content, media_type="application/xml")
|
||||
|
||||
logger.info(f"Cloudonix inbound response object: {response}")
|
||||
logger.info(f"Response headers: {response.headers}")
|
||||
logger.info(f"Response media type: {response.media_type}")
|
||||
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def generate_validation_error_response(error_type) -> tuple:
|
||||
"""
|
||||
Generate Cloudonix-specific error response for validation failures.
|
||||
|
||||
|
||||
Since Cloudonix is TwiML-compatible, we use the same XML format.
|
||||
"""
|
||||
from fastapi import Response
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ class TwilioProvider(TelephonyProvider):
|
|||
) -> bool:
|
||||
"""
|
||||
Determine if this provider can handle the incoming webhook.
|
||||
|
||||
|
||||
Twilio webhooks have specific characteristics:
|
||||
- User-Agent: "TwilioProxy/1.1"
|
||||
- Headers: "x-twilio-signature", "i-twilio-idempotency-token"
|
||||
|
|
@ -308,21 +308,27 @@ class TwilioProvider(TelephonyProvider):
|
|||
user_agent = headers.get("user-agent", "")
|
||||
if "twilioproxy" in user_agent.lower() or user_agent.startswith("TwilioProxy"):
|
||||
return True
|
||||
|
||||
|
||||
# 2: Check for Twilio-specific headers
|
||||
twilio_headers = ["x-twilio-signature", "i-twilio-idempotency-token", "x-home-region"]
|
||||
twilio_headers = [
|
||||
"x-twilio-signature",
|
||||
"i-twilio-idempotency-token",
|
||||
"x-home-region",
|
||||
]
|
||||
if any(header in headers for header in twilio_headers):
|
||||
return True
|
||||
|
||||
|
||||
# 3: Check data structure - CallSid + AccountSid with AC prefix + ApiVersion
|
||||
if ("CallSid" in webhook_data and
|
||||
"AccountSid" in webhook_data and
|
||||
"ApiVersion" in webhook_data):
|
||||
if (
|
||||
"CallSid" in webhook_data
|
||||
and "AccountSid" in webhook_data
|
||||
and "ApiVersion" in webhook_data
|
||||
):
|
||||
# Ensure AccountSid looks like Twilio (starts with AC, not a domain)
|
||||
account_sid = webhook_data.get("AccountSid", "")
|
||||
if account_sid.startswith("AC") and not "." in account_sid:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -69,6 +69,8 @@ class PipecatEngine:
|
|||
node_transition_callback: Optional[
|
||||
Callable[[str, Optional[str]], Awaitable[None]]
|
||||
] = None,
|
||||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
):
|
||||
self.task = task
|
||||
self.llm = llm
|
||||
|
|
@ -103,6 +105,10 @@ class PipecatEngine:
|
|||
# Custom tool manager (initialized in initialize())
|
||||
self._custom_tool_manager: Optional[CustomToolManager] = None
|
||||
|
||||
# Embeddings configuration (passed from run_pipeline.py)
|
||||
self._embeddings_api_key: Optional[str] = embeddings_api_key
|
||||
self._embeddings_model: Optional[str] = embeddings_model
|
||||
|
||||
async def _get_organization_id(self) -> Optional[int]:
|
||||
"""Get and cache the organization ID from workflow run."""
|
||||
if self._custom_tool_manager:
|
||||
|
|
@ -318,11 +324,19 @@ class PipecatEngine:
|
|||
"Organization ID not available for knowledge base retrieval"
|
||||
)
|
||||
|
||||
if not self._embeddings_api_key:
|
||||
raise ValueError(
|
||||
"Embeddings API key not configured. Please set your API key in "
|
||||
"Model Configurations > Embedding."
|
||||
)
|
||||
|
||||
result = await retrieve_from_knowledge_base(
|
||||
query=query,
|
||||
organization_id=organization_id,
|
||||
document_uuids=document_uuids,
|
||||
limit=3, # Return top 3 most relevant chunks
|
||||
embeddings_api_key=self._embeddings_api_key,
|
||||
embeddings_model=self._embeddings_model,
|
||||
)
|
||||
|
||||
await function_call_params.result_callback(result)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from opentelemetry import trace
|
||||
|
||||
from api.services.gen_ai.json_parser import parse_llm_json
|
||||
from api.services.pipecat.tracing_config import is_tracing_enabled
|
||||
from api.services.workflow.dto import ExtractionVariableDTO
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
|
@ -32,7 +30,6 @@ class VariableExtractionManager:
|
|||
# and update internal counters / extracted variable state.
|
||||
self._engine = engine
|
||||
self._context = engine.context
|
||||
self._model = "gpt-4o"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
|
|
@ -147,46 +144,43 @@ class VariableExtractionManager:
|
|||
extraction_context.set_messages(extraction_messages)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Use independent OpenAI client for LLM call
|
||||
# Use engine's LLM for out-of-band inference (no pipeline frames)
|
||||
# ------------------------------------------------------------------
|
||||
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
llm_response = await self._engine.llm.run_inference(extraction_context)
|
||||
|
||||
# Direct API call - no pipeline involvement
|
||||
response = await client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=extraction_messages,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
llm_response = response.choices[0].message.content
|
||||
# Get model name for tracing
|
||||
model_name = getattr(self._engine.llm, "model_name", "unknown")
|
||||
|
||||
if is_tracing_enabled():
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
with tracer.start_as_current_span(
|
||||
"variable_extraction", context=parent_ctx
|
||||
"llm-variable-extraction", context=parent_ctx
|
||||
) as span:
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name="OpenAILLMService",
|
||||
model=self._model,
|
||||
operation_name="variable_extraction",
|
||||
service_name=self._engine.llm.__class__.__name__,
|
||||
model=model_name,
|
||||
operation_name="llm-variable-extraction",
|
||||
messages=extraction_messages,
|
||||
output=llm_response,
|
||||
stream=False,
|
||||
parameters={"temperature": 0.0, "response_format": "json_object"},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Parse the assistant output – fall back to raw text if it is not valid JSON.
|
||||
# Uses parse_llm_json which handles common LLM mistakes like markdown
|
||||
# code blocks (```json ... ```) and extra text around the JSON.
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
extracted = json.loads(llm_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Extractor returned invalid JSON; storing raw content instead."
|
||||
)
|
||||
extracted = {"raw": llm_response}
|
||||
if llm_response is None:
|
||||
logger.warning("Extractor returned no response; returning empty result.")
|
||||
extracted = {}
|
||||
else:
|
||||
extracted = parse_llm_json(llm_response)
|
||||
if "raw" in extracted and len(extracted) == 1:
|
||||
logger.warning(
|
||||
"Extractor returned invalid JSON; storing raw content instead."
|
||||
)
|
||||
|
||||
logger.debug(f"Extracted variables: {extracted}")
|
||||
return extracted
|
||||
|
|
|
|||
231
api/tests/test_json_parser.py
Normal file
231
api/tests/test_json_parser.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
from api.services.gen_ai.json_parser import (
|
||||
_extract_json_array,
|
||||
_extract_json_object,
|
||||
_try_parse_json,
|
||||
parse_llm_json,
|
||||
)
|
||||
|
||||
|
||||
class TestParseLlmJson:
|
||||
"""Tests for the main parse_llm_json function."""
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Empty string returns empty dict."""
|
||||
assert parse_llm_json("") == {}
|
||||
|
||||
def test_whitespace_only(self):
|
||||
"""Whitespace-only string returns empty dict."""
|
||||
assert parse_llm_json(" \n\t ") == {}
|
||||
|
||||
def test_none_handling(self):
|
||||
"""None input returns empty dict."""
|
||||
assert parse_llm_json(None) == {}
|
||||
|
||||
def test_valid_json_direct(self):
|
||||
"""Valid JSON is parsed directly."""
|
||||
result = parse_llm_json('{"name": "John", "age": 30}')
|
||||
assert result == {"name": "John", "age": 30}
|
||||
|
||||
def test_valid_json_with_whitespace(self):
|
||||
"""Valid JSON with surrounding whitespace is parsed."""
|
||||
result = parse_llm_json(' \n{"key": "value"}\n ')
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_markdown_json_code_block(self):
|
||||
"""JSON wrapped in ```json ... ``` is extracted and parsed."""
|
||||
input_str = """```json
|
||||
{
|
||||
"occupation_of_the_user": "software engineer"
|
||||
}
|
||||
```"""
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"occupation_of_the_user": "software engineer"}
|
||||
|
||||
def test_markdown_generic_code_block(self):
|
||||
"""JSON wrapped in ``` ... ``` (no language) is extracted and parsed."""
|
||||
input_str = """```
|
||||
{"status": "success", "count": 42}
|
||||
```"""
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"status": "success", "count": 42}
|
||||
|
||||
def test_markdown_with_surrounding_text(self):
|
||||
"""Markdown code block with text before/after is handled."""
|
||||
input_str = """Here is the extracted data:
|
||||
```json
|
||||
{"name": "Alice"}
|
||||
```
|
||||
I hope this helps!"""
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"name": "Alice"}
|
||||
|
||||
def test_json_with_text_before(self):
|
||||
"""JSON with explanatory text before is extracted."""
|
||||
input_str = 'The result is: {"answer": 42}'
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"answer": 42}
|
||||
|
||||
def test_json_with_text_after(self):
|
||||
"""JSON with text after is extracted."""
|
||||
input_str = '{"found": true} - extraction complete'
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"found": True}
|
||||
|
||||
def test_json_with_text_before_and_after(self):
|
||||
"""JSON with text on both sides is extracted."""
|
||||
input_str = 'Based on the conversation: {"mood": "happy"} is my assessment.'
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"mood": "happy"}
|
||||
|
||||
def test_nested_json_object(self):
|
||||
"""Nested JSON objects are parsed correctly."""
|
||||
input_str = '{"user": {"name": "Bob", "address": {"city": "NYC"}}}'
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"user": {"name": "Bob", "address": {"city": "NYC"}}}
|
||||
|
||||
def test_json_with_string_containing_braces(self):
|
||||
"""JSON with braces inside strings is parsed correctly."""
|
||||
input_str = '{"code": "function() { return {}; }"}'
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"code": "function() { return {}; }"}
|
||||
|
||||
def test_json_with_escaped_quotes(self):
|
||||
"""JSON with escaped quotes is parsed correctly."""
|
||||
input_str = '{"message": "He said \\"hello\\""}'
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"message": 'He said "hello"'}
|
||||
|
||||
def test_json_array_direct(self):
|
||||
"""JSON array is parsed directly."""
|
||||
result = parse_llm_json("[1, 2, 3]")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
def test_json_array_with_objects(self):
|
||||
"""JSON array of objects is parsed correctly."""
|
||||
input_str = '[{"id": 1}, {"id": 2}]'
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == [{"id": 1}, {"id": 2}]
|
||||
|
||||
def test_json_array_in_markdown(self):
|
||||
"""JSON array in markdown code block is extracted."""
|
||||
input_str = """```json
|
||||
["apple", "banana", "cherry"]
|
||||
```"""
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == ["apple", "banana", "cherry"]
|
||||
|
||||
def test_invalid_json_returns_raw(self):
|
||||
"""Invalid JSON returns raw content in 'raw' key."""
|
||||
input_str = "This is not JSON at all"
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"raw": "This is not JSON at all"}
|
||||
|
||||
def test_malformed_json_returns_raw(self):
|
||||
"""Malformed JSON returns raw content."""
|
||||
input_str = '{"key": "value"' # Missing closing brace
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"raw": '{"key": "value"'}
|
||||
|
||||
def test_complex_real_world_example(self):
|
||||
"""Test with a realistic LLM output example."""
|
||||
input_str = """Based on our conversation, I've extracted the following information:
|
||||
|
||||
```json
|
||||
{
|
||||
"user_name": "John Smith",
|
||||
"email": "john@example.com",
|
||||
"preferences": {
|
||||
"notifications": true,
|
||||
"theme": "dark"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Let me know if you need anything else!"""
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {
|
||||
"user_name": "John Smith",
|
||||
"email": "john@example.com",
|
||||
"preferences": {"notifications": True, "theme": "dark"},
|
||||
}
|
||||
|
||||
def test_json_with_newlines_inside(self):
|
||||
"""JSON with newlines inside values is handled."""
|
||||
input_str = '{"text": "line1\\nline2"}'
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"text": "line1\nline2"}
|
||||
|
||||
def test_json_with_unicode(self):
|
||||
"""JSON with unicode characters is parsed correctly."""
|
||||
input_str = '{"greeting": "こんにちは", "emoji": "🎉"}'
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"greeting": "こんにちは", "emoji": "🎉"}
|
||||
|
||||
def test_multiple_code_blocks_uses_first(self):
|
||||
"""When multiple code blocks exist, the first is used."""
|
||||
input_str = """```json
|
||||
{"first": true}
|
||||
```
|
||||
Some text
|
||||
```json
|
||||
{"second": true}
|
||||
```"""
|
||||
result = parse_llm_json(input_str)
|
||||
assert result == {"first": True}
|
||||
|
||||
|
||||
class TestTryParseJson:
|
||||
"""Tests for the _try_parse_json helper."""
|
||||
|
||||
def test_valid_dict(self):
|
||||
assert _try_parse_json('{"a": 1}') == {"a": 1}
|
||||
|
||||
def test_valid_list(self):
|
||||
assert _try_parse_json("[1, 2]") == [1, 2]
|
||||
|
||||
def test_invalid_returns_none(self):
|
||||
assert _try_parse_json("not json") is None
|
||||
|
||||
def test_primitive_returns_none(self):
|
||||
"""Primitive values (not dict/list) return None."""
|
||||
assert _try_parse_json('"just a string"') is None
|
||||
assert _try_parse_json("42") is None
|
||||
assert _try_parse_json("true") is None
|
||||
|
||||
|
||||
class TestExtractJsonObject:
|
||||
"""Tests for the _extract_json_object helper."""
|
||||
|
||||
def test_extracts_from_text(self):
|
||||
result = _extract_json_object('prefix {"key": "value"} suffix')
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_no_object_returns_none(self):
|
||||
assert _extract_json_object("no json here") is None
|
||||
|
||||
def test_nested_braces(self):
|
||||
result = _extract_json_object('{"outer": {"inner": 1}}')
|
||||
assert result == {"outer": {"inner": 1}}
|
||||
|
||||
def test_braces_in_strings(self):
|
||||
result = _extract_json_object('{"code": "{ }"}')
|
||||
assert result == {"code": "{ }"}
|
||||
|
||||
|
||||
class TestExtractJsonArray:
|
||||
"""Tests for the _extract_json_array helper."""
|
||||
|
||||
def test_extracts_from_text(self):
|
||||
result = _extract_json_array("here is the list: [1, 2, 3] done")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
def test_no_array_returns_none(self):
|
||||
assert _extract_json_array("no array here") is None
|
||||
|
||||
def test_nested_arrays(self):
|
||||
result = _extract_json_array("[[1, 2], [3, 4]]")
|
||||
assert result == [[1, 2], [3, 4]]
|
||||
|
||||
def test_brackets_in_strings(self):
|
||||
result = _extract_json_array('["a[b]c"]')
|
||||
assert result == ["a[b]c"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue