From d35eeb1b7b3ddc449adbee0243c2745103d86f2c Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Date: Sat, 17 Jan 2026 18:12:56 +0530 Subject: [PATCH] fix: fix OPENAI_API_KEY bug in retrieval --- api/routes/telephony.py | 4 +- api/services/gen_ai/__init__.py | 2 + api/services/gen_ai/json_parser.py | 154 ++++++++++++ api/services/pipecat/run_pipeline.py | 43 ++-- api/services/pipecat/service_factory.py | 22 -- api/services/telephony/base.py | 4 +- .../telephony/providers/cloudonix_provider.py | 79 +++--- .../telephony/providers/twilio_provider.py | 22 +- api/services/workflow/pipecat_engine.py | 14 ++ .../pipecat_engine_variable_extractor.py | 48 ++-- api/tests/test_json_parser.py | 231 ++++++++++++++++++ 11 files changed, 508 insertions(+), 115 deletions(-) create mode 100644 api/services/gen_ai/json_parser.py create mode 100644 api/tests/test_json_parser.py diff --git a/api/routes/telephony.py b/api/routes/telephony.py index 15f676b..440eb36 100644 --- a/api/routes/telephony.py +++ b/api/routes/telephony.py @@ -1263,7 +1263,9 @@ async def handle_inbound_telephony( try: webhook_data, data_source = await parse_webhook_request(request) - logger.info(f"Inbound call data with data source: {data_source} and data :{dict(webhook_data)}") + logger.info( + f"Inbound call data with data source: {data_source} and data :{dict(webhook_data)}" + ) headers = dict(request.headers) # Detect provider and normalize data diff --git a/api/services/gen_ai/__init__.py b/api/services/gen_ai/__init__.py index 684a7f7..9a25a6f 100644 --- a/api/services/gen_ai/__init__.py +++ b/api/services/gen_ai/__init__.py @@ -6,10 +6,12 @@ from .embedding import ( OpenAIEmbeddingService, SentenceTransformerEmbeddingService, ) +from .json_parser import parse_llm_json __all__ = [ "BaseEmbeddingService", "EmbeddingAPIKeyNotConfiguredError", "SentenceTransformerEmbeddingService", "OpenAIEmbeddingService", + "parse_llm_json", ] diff --git a/api/services/gen_ai/json_parser.py b/api/services/gen_ai/json_parser.py new file mode 100644 index 0000000..e98828d --- /dev/null +++ b/api/services/gen_ai/json_parser.py @@ -0,0 +1,154 @@ +"""Robust JSON parser for handling common LLM output mistakes.""" + +from __future__ import annotations + +import json +import re +from typing import Any + + +def parse_llm_json(raw_content: str) -> dict[str, Any]: + """Parse JSON from LLM output, handling common formatting issues. + + Handles the following common LLM mistakes: + 1. JSON wrapped in markdown code blocks (```json ... ``` or ``` ... ```) + 2. Extra whitespace or newlines around JSON + 3. Text before/after the JSON object + + Args: + raw_content: The raw string output from the LLM. + + Returns: + Parsed JSON as a dictionary. If parsing fails, returns {"raw": raw_content}. + """ + if not raw_content or not raw_content.strip(): + return {} + + content = raw_content.strip() + + # Attempt 1: Direct parse (ideal case) + parsed = _try_parse_json(content) + if parsed is not None: + return parsed + + # Attempt 2: Remove markdown code block wrappers + # Matches ```json ... ``` or ``` ... ``` + code_block_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```" + code_block_match = re.search(code_block_pattern, content) + if code_block_match: + extracted = code_block_match.group(1).strip() + parsed = _try_parse_json(extracted) + if parsed is not None: + return parsed + + # Attempt 3: Find JSON object by matching braces + parsed = _extract_json_object(content) + if parsed is not None: + return parsed + + # Attempt 4: Find JSON array by matching brackets + parsed = _extract_json_array(content) + if parsed is not None: + return parsed + + # All attempts failed - return raw content + return {"raw": raw_content} + + +def _try_parse_json(content: str) -> dict[str, Any] | list | None: + """Attempt to parse JSON, returning None on failure.""" + try: + result = json.loads(content) + if isinstance(result, (dict, list)): + return result + return None + except json.JSONDecodeError: + return None + + +def _extract_json_object(content: str) -> dict[str, Any] | None: + """Extract a JSON object from text by finding matching braces.""" + # Find the first opening brace + start = content.find("{") + if start == -1: + return None + + # Find matching closing brace by counting braces + depth = 0 + in_string = False + escape_next = False + end = -1 + + for i, char in enumerate(content[start:], start=start): + if escape_next: + escape_next = False + continue + + if char == "\\": + escape_next = True + continue + + if char == '"' and not escape_next: + in_string = not in_string + continue + + if in_string: + continue + + if char == "{": + depth += 1 + elif char == "}": + depth -= 1 + if depth == 0: + end = i + break + + if end == -1: + return None + + json_str = content[start : end + 1] + return _try_parse_json(json_str) + + +def _extract_json_array(content: str) -> list | None: + """Extract a JSON array from text by finding matching brackets.""" + # Find the first opening bracket + start = content.find("[") + if start == -1: + return None + + # Find matching closing bracket by counting brackets + depth = 0 + in_string = False + escape_next = False + end = -1 + + for i, char in enumerate(content[start:], start=start): + if escape_next: + escape_next = False + continue + + if char == "\\": + escape_next = True + continue + + if char == '"' and not escape_next: + in_string = not in_string + continue + + if in_string: + continue + + if char == "[": + depth += 1 + elif char == "]": + depth -= 1 + if depth == 0: + end = i + break + + if end == -1: + return None + + json_str = content[start : end + 1] + return _try_parse_json(json_str) diff --git a/api/services/pipecat/run_pipeline.py b/api/services/pipecat/run_pipeline.py index f056554..11ae7cc 100644 --- a/api/services/pipecat/run_pipeline.py +++ b/api/services/pipecat/run_pipeline.py @@ -29,7 +29,6 @@ from api.services.pipecat.service_factory import ( create_llm_service, create_stt_service, create_tts_service, - create_voicemail_classification_llm, ) from api.services.pipecat.tracing_config import setup_pipeline_tracing from api.services.pipecat.transport_setup import ( @@ -501,12 +500,21 @@ async def _run_pipeline( node_transition_callback = send_node_transition + # Extract embeddings configuration from user config + embeddings_api_key = None + embeddings_model = None + if user_config and user_config.embeddings: + embeddings_api_key = user_config.embeddings.api_key + embeddings_model = user_config.embeddings.model + engine = PipecatEngine( llm=llm, workflow=workflow_graph, call_context_vars=merged_call_context_vars, workflow_run_id=workflow_run_id, node_transition_callback=node_transition_callback, + embeddings_api_key=embeddings_api_key, + embeddings_model=embeddings_model, ) # Create pipeline components with audio configuration and engine @@ -562,24 +570,23 @@ async def _run_pipeline( voicemail_detector = None start_node = workflow_graph.nodes.get(workflow_graph.start_node_id) if start_node and start_node.detect_voicemail: - classification_llm = create_voicemail_classification_llm() - if classification_llm: - logger.info( - f"Voicemail detection enabled for workflow run {workflow_run_id}" - ) - voicemail_detector = VoicemailDetector( - llm=classification_llm, - voicemail_response_delay=2.0, - ) + logger.info(f"Voicemail detection enabled for workflow run {workflow_run_id}") + # Create a separate LLM instance for the voicemail sub-pipeline + # (can't share with main pipeline as it would mess up frame linking) + voicemail_llm = create_llm_service(user_config) + voicemail_detector = VoicemailDetector( + llm=voicemail_llm, + voicemail_response_delay=2.0, + ) - # Register event handler to end task when voicemail is detected - @voicemail_detector.event_handler("on_voicemail_detected") - async def _on_voicemail_detected(_processor): - logger.info(f"Voicemail detected for workflow run {workflow_run_id}") - await engine.send_end_task_frame( - reason=EndTaskReason.VOICEMAIL_DETECTED.value, - abort_immediately=True, - ) + # Register event handler to end task when voicemail is detected + @voicemail_detector.event_handler("on_voicemail_detected") + async def _on_voicemail_detected(_processor): + logger.info(f"Voicemail detected for workflow run {workflow_run_id}") + await engine.send_end_task_frame( + reason=EndTaskReason.VOICEMAIL_DETECTED.value, + abort_immediately=True, + ) # Build the pipeline with the STT mute filter and context controller pipeline = build_pipeline( diff --git a/api/services/pipecat/service_factory.py b/api/services/pipecat/service_factory.py index be55b5f..3dd3d71 100644 --- a/api/services/pipecat/service_factory.py +++ b/api/services/pipecat/service_factory.py @@ -1,4 +1,3 @@ -import os from typing import TYPE_CHECKING from fastapi import HTTPException @@ -242,24 +241,3 @@ def create_llm_service(user_config): ) else: raise HTTPException(status_code=400, detail="Invalid LLM provider") - - -def create_voicemail_classification_llm(): - """Create a fast, lightweight LLM service for voicemail classification. - - Uses gpt-4o-mini which is fast and cost-effective for simple classification tasks. - The model only needs to output "CONVERSATION" or "VOICEMAIL" based on transcriptions. - - Returns: - OpenAILLMService instance, or None if OPENAI_API_KEY is not set. - """ - api_key = os.environ.get("OPENAI_API_KEY") - if not api_key: - logger.warning("OPENAI_API_KEY not set - voicemail detection will be disabled") - return None - - return OpenAILLMService( - api_key=api_key, - model="gpt-4o", - params=OpenAILLMService.InputParams(temperature=0.0), - ) diff --git a/api/services/telephony/base.py b/api/services/telephony/base.py index be2a78d..72d33d1 100644 --- a/api/services/telephony/base.py +++ b/api/services/telephony/base.py @@ -278,7 +278,9 @@ class TelephonyProvider(ABC): @staticmethod @abstractmethod - async def generate_inbound_response(websocket_url: str, workflow_run_id: int = None) -> tuple: + async def generate_inbound_response( + websocket_url: str, workflow_run_id: int = None + ) -> tuple: """ Generate the appropriate response for an inbound webhook. diff --git a/api/services/telephony/providers/cloudonix_provider.py b/api/services/telephony/providers/cloudonix_provider.py index 2951b70..048f49c 100644 --- a/api/services/telephony/providers/cloudonix_provider.py +++ b/api/services/telephony/providers/cloudonix_provider.py @@ -434,29 +434,37 @@ class CloudonixProvider(TelephonyProvider): user_agent = headers.get("user-agent", "").lower() if "cloudonix" in user_agent: return True - + # 2: Check for Cloudonix-specific headers - cloudonix_headers = ["x-cx-apikey", "x-cx-domain", "x-cx-session", "x-cx-source"] + cloudonix_headers = [ + "x-cx-apikey", + "x-cx-domain", + "x-cx-session", + "x-cx-source", + ] if any(header in headers for header in cloudonix_headers): return True - + # 3: Check data structure for Cloudonix-specific fields - if ("SessionData" in webhook_data and "Domain" in webhook_data and - webhook_data.get("Domain", "").endswith(".cloudonix.net")): + if ( + "SessionData" in webhook_data + and "Domain" in webhook_data + and webhook_data.get("Domain", "").endswith(".cloudonix.net") + ): return True - + # Check if AccountSid is a Cloudonix domain account_sid = webhook_data.get("AccountSid", "") if account_sid.endswith(".cloudonix.net"): return True - + return False @staticmethod def parse_inbound_webhook(webhook_data: Dict[str, Any]) -> NormalizedInboundData: """ Parse Cloudonix-specific inbound webhook data into normalized format. - + Cloudonix webhook structure includes: - CallSid: Call id - From: Caller number @@ -467,13 +475,11 @@ class CloudonixProvider(TelephonyProvider): session_data = webhook_data.get("SessionData", {}) token = session_data.get("token", "") if isinstance(session_data, dict) else "" - - call_id = (webhook_data.get("Session") or - webhook_data.get("CallSid") or - token) - - account_id = (webhook_data.get("Domain") or webhook_data.get("AccountSid", "")) - + + call_id = webhook_data.get("Session") or webhook_data.get("CallSid") or token + + account_id = webhook_data.get("Domain") or webhook_data.get("AccountSid", "") + # Extract underlying provider information from SessionData if available session_data = webhook_data.get("SessionData", {}) underlying_provider = None @@ -482,7 +488,7 @@ class CloudonixProvider(TelephonyProvider): trunk_headers = profile.get("trunk-sip-headers", {}) if "Twilio-AccountSid" in trunk_headers: underlying_provider = "twilio" - + return NormalizedInboundData( provider=CloudonixProvider.PROVIDER_NAME, call_id=call_id, @@ -492,7 +498,7 @@ class CloudonixProvider(TelephonyProvider): call_status=webhook_data.get("CallStatus", "in-progress"), account_id=account_id, from_country=webhook_data.get("FromCountry"), - to_country=webhook_data.get("ToCountry"), + to_country=webhook_data.get("ToCountry"), raw_data={ **webhook_data, "underlying_provider": underlying_provider, @@ -503,9 +509,9 @@ class CloudonixProvider(TelephonyProvider): def validate_account_id(config_data: dict, webhook_account_id: str) -> bool: """ Validate that the account_id from webhook matches the Cloudonix configuration. - + For Cloudonix: - - webhook_account_id is the Domain field (e.g., "test1.cloudonix.net") + - webhook_account_id is the Domain field (e.g., "test1.cloudonix.net") - config domain_id stores the same domain string """ if not webhook_account_id: @@ -551,30 +557,30 @@ class CloudonixProvider(TelephonyProvider): ) -> bool: """ Verify the API key of an inbound Cloudonix webhook for security. - + Cloudonix uses x-cx-apikey header validation instead of signature verification. The API key from the webhook should match the bearer_token in our configuration. """ if not api_key: logger.warning("No x-cx-apikey provided in Cloudonix webhook") return False - + # The bearer_token in config is the same as x-cx-apikey header value if not self.bearer_token: logger.warning("No bearer_token configured for Cloudonix provider") return False - + # Compare the API keys is_valid = api_key == self.bearer_token - + if is_valid: logger.info("Cloudonix x-cx-apikey validation successful") else: - logger.warning(f"Cloudonix x-cx-apikey validation failed. Expected key ending with ...{self.bearer_token[-8:] if len(self.bearer_token) > 8 else 'SHORT_KEY'}") - - return True #TODO: update this post clarification from cloudonix + logger.warning( + f"Cloudonix x-cx-apikey validation failed. Expected key ending with ...{self.bearer_token[-8:] if len(self.bearer_token) > 8 else 'SHORT_KEY'}" + ) - + return True # TODO: update this post clarification from cloudonix @staticmethod async def generate_inbound_response( @@ -582,11 +588,11 @@ class CloudonixProvider(TelephonyProvider): ) -> tuple: """ Generate the appropriate CXML response for an inbound Cloudonix webhook. - + Returns CXML to connect to WebSocket, same format as outbound calls. """ from fastapi import Response - + # Generate CXML response (same format as outbound calls) cxml_content = f""" @@ -595,26 +601,23 @@ class CloudonixProvider(TelephonyProvider): """ - + logger.info(f"Cloudonix inbound CXML response content:") logger.info(cxml_content) - - response = Response( - content=cxml_content, - media_type="application/xml" - ) - + + response = Response(content=cxml_content, media_type="application/xml") + logger.info(f"Cloudonix inbound response object: {response}") logger.info(f"Response headers: {response.headers}") logger.info(f"Response media type: {response.media_type}") - + return response @staticmethod def generate_validation_error_response(error_type) -> tuple: """ Generate Cloudonix-specific error response for validation failures. - + Since Cloudonix is TwiML-compatible, we use the same XML format. """ from fastapi import Response diff --git a/api/services/telephony/providers/twilio_provider.py b/api/services/telephony/providers/twilio_provider.py index 68dbde4..7ac451f 100644 --- a/api/services/telephony/providers/twilio_provider.py +++ b/api/services/telephony/providers/twilio_provider.py @@ -297,7 +297,7 @@ class TwilioProvider(TelephonyProvider): ) -> bool: """ Determine if this provider can handle the incoming webhook. - + Twilio webhooks have specific characteristics: - User-Agent: "TwilioProxy/1.1" - Headers: "x-twilio-signature", "i-twilio-idempotency-token" @@ -308,21 +308,27 @@ class TwilioProvider(TelephonyProvider): user_agent = headers.get("user-agent", "") if "twilioproxy" in user_agent.lower() or user_agent.startswith("TwilioProxy"): return True - + # 2: Check for Twilio-specific headers - twilio_headers = ["x-twilio-signature", "i-twilio-idempotency-token", "x-home-region"] + twilio_headers = [ + "x-twilio-signature", + "i-twilio-idempotency-token", + "x-home-region", + ] if any(header in headers for header in twilio_headers): return True - + # 3: Check data structure - CallSid + AccountSid with AC prefix + ApiVersion - if ("CallSid" in webhook_data and - "AccountSid" in webhook_data and - "ApiVersion" in webhook_data): + if ( + "CallSid" in webhook_data + and "AccountSid" in webhook_data + and "ApiVersion" in webhook_data + ): # Ensure AccountSid looks like Twilio (starts with AC, not a domain) account_sid = webhook_data.get("AccountSid", "") if account_sid.startswith("AC") and not "." in account_sid: return True - + return False @staticmethod diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index 0ab36de..99d4d11 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -69,6 +69,8 @@ class PipecatEngine: node_transition_callback: Optional[ Callable[[str, Optional[str]], Awaitable[None]] ] = None, + embeddings_api_key: Optional[str] = None, + embeddings_model: Optional[str] = None, ): self.task = task self.llm = llm @@ -103,6 +105,10 @@ class PipecatEngine: # Custom tool manager (initialized in initialize()) self._custom_tool_manager: Optional[CustomToolManager] = None + # Embeddings configuration (passed from run_pipeline.py) + self._embeddings_api_key: Optional[str] = embeddings_api_key + self._embeddings_model: Optional[str] = embeddings_model + async def _get_organization_id(self) -> Optional[int]: """Get and cache the organization ID from workflow run.""" if self._custom_tool_manager: @@ -318,11 +324,19 @@ class PipecatEngine: "Organization ID not available for knowledge base retrieval" ) + if not self._embeddings_api_key: + raise ValueError( + "Embeddings API key not configured. Please set your API key in " + "Model Configurations > Embedding." + ) + result = await retrieve_from_knowledge_base( query=query, organization_id=organization_id, document_uuids=document_uuids, limit=3, # Return top 3 most relevant chunks + embeddings_api_key=self._embeddings_api_key, + embeddings_model=self._embeddings_model, ) await function_call_params.result_callback(result) diff --git a/api/services/workflow/pipecat_engine_variable_extractor.py b/api/services/workflow/pipecat_engine_variable_extractor.py index 7b1eed6..30acc08 100644 --- a/api/services/workflow/pipecat_engine_variable_extractor.py +++ b/api/services/workflow/pipecat_engine_variable_extractor.py @@ -1,13 +1,11 @@ from __future__ import annotations -import json -import os from typing import TYPE_CHECKING, Any, List from loguru import logger -from openai import AsyncOpenAI from opentelemetry import trace +from api.services.gen_ai.json_parser import parse_llm_json from api.services.pipecat.tracing_config import is_tracing_enabled from api.services.workflow.dto import ExtractionVariableDTO from pipecat.processors.aggregators.llm_context import LLMContext @@ -32,7 +30,6 @@ class VariableExtractionManager: # and update internal counters / extracted variable state. self._engine = engine self._context = engine.context - self._model = "gpt-4o" # ------------------------------------------------------------------ # Internal helpers @@ -147,46 +144,43 @@ class VariableExtractionManager: extraction_context.set_messages(extraction_messages) # ------------------------------------------------------------------ - # Use independent OpenAI client for LLM call + # Use engine's LLM for out-of-band inference (no pipeline frames) # ------------------------------------------------------------------ - client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + llm_response = await self._engine.llm.run_inference(extraction_context) - # Direct API call - no pipeline involvement - response = await client.chat.completions.create( - model=self._model, - messages=extraction_messages, - temperature=0.0, - response_format={"type": "json_object"}, - ) - - llm_response = response.choices[0].message.content + # Get model name for tracing + model_name = getattr(self._engine.llm, "model_name", "unknown") if is_tracing_enabled(): tracer = trace.get_tracer("pipecat") with tracer.start_as_current_span( - "variable_extraction", context=parent_ctx + "llm-variable-extraction", context=parent_ctx ) as span: add_llm_span_attributes( span, - service_name="OpenAILLMService", - model=self._model, - operation_name="variable_extraction", + service_name=self._engine.llm.__class__.__name__, + model=model_name, + operation_name="llm-variable-extraction", messages=extraction_messages, output=llm_response, stream=False, - parameters={"temperature": 0.0, "response_format": "json_object"}, + parameters={}, ) # ------------------------------------------------------------------ # Parse the assistant output – fall back to raw text if it is not valid JSON. + # Uses parse_llm_json which handles common LLM mistakes like markdown + # code blocks (```json ... ```) and extra text around the JSON. # ------------------------------------------------------------------ - try: - extracted = json.loads(llm_response) - except json.JSONDecodeError: - logger.warning( - "Extractor returned invalid JSON; storing raw content instead." - ) - extracted = {"raw": llm_response} + if llm_response is None: + logger.warning("Extractor returned no response; returning empty result.") + extracted = {} + else: + extracted = parse_llm_json(llm_response) + if "raw" in extracted and len(extracted) == 1: + logger.warning( + "Extractor returned invalid JSON; storing raw content instead." + ) logger.debug(f"Extracted variables: {extracted}") return extracted diff --git a/api/tests/test_json_parser.py b/api/tests/test_json_parser.py new file mode 100644 index 0000000..92569d0 --- /dev/null +++ b/api/tests/test_json_parser.py @@ -0,0 +1,231 @@ +from api.services.gen_ai.json_parser import ( + _extract_json_array, + _extract_json_object, + _try_parse_json, + parse_llm_json, +) + + +class TestParseLlmJson: + """Tests for the main parse_llm_json function.""" + + def test_empty_string(self): + """Empty string returns empty dict.""" + assert parse_llm_json("") == {} + + def test_whitespace_only(self): + """Whitespace-only string returns empty dict.""" + assert parse_llm_json(" \n\t ") == {} + + def test_none_handling(self): + """None input returns empty dict.""" + assert parse_llm_json(None) == {} + + def test_valid_json_direct(self): + """Valid JSON is parsed directly.""" + result = parse_llm_json('{"name": "John", "age": 30}') + assert result == {"name": "John", "age": 30} + + def test_valid_json_with_whitespace(self): + """Valid JSON with surrounding whitespace is parsed.""" + result = parse_llm_json(' \n{"key": "value"}\n ') + assert result == {"key": "value"} + + def test_markdown_json_code_block(self): + """JSON wrapped in ```json ... ``` is extracted and parsed.""" + input_str = """```json +{ + "occupation_of_the_user": "software engineer" +} +```""" + result = parse_llm_json(input_str) + assert result == {"occupation_of_the_user": "software engineer"} + + def test_markdown_generic_code_block(self): + """JSON wrapped in ``` ... ``` (no language) is extracted and parsed.""" + input_str = """``` +{"status": "success", "count": 42} +```""" + result = parse_llm_json(input_str) + assert result == {"status": "success", "count": 42} + + def test_markdown_with_surrounding_text(self): + """Markdown code block with text before/after is handled.""" + input_str = """Here is the extracted data: +```json +{"name": "Alice"} +``` +I hope this helps!""" + result = parse_llm_json(input_str) + assert result == {"name": "Alice"} + + def test_json_with_text_before(self): + """JSON with explanatory text before is extracted.""" + input_str = 'The result is: {"answer": 42}' + result = parse_llm_json(input_str) + assert result == {"answer": 42} + + def test_json_with_text_after(self): + """JSON with text after is extracted.""" + input_str = '{"found": true} - extraction complete' + result = parse_llm_json(input_str) + assert result == {"found": True} + + def test_json_with_text_before_and_after(self): + """JSON with text on both sides is extracted.""" + input_str = 'Based on the conversation: {"mood": "happy"} is my assessment.' + result = parse_llm_json(input_str) + assert result == {"mood": "happy"} + + def test_nested_json_object(self): + """Nested JSON objects are parsed correctly.""" + input_str = '{"user": {"name": "Bob", "address": {"city": "NYC"}}}' + result = parse_llm_json(input_str) + assert result == {"user": {"name": "Bob", "address": {"city": "NYC"}}} + + def test_json_with_string_containing_braces(self): + """JSON with braces inside strings is parsed correctly.""" + input_str = '{"code": "function() { return {}; }"}' + result = parse_llm_json(input_str) + assert result == {"code": "function() { return {}; }"} + + def test_json_with_escaped_quotes(self): + """JSON with escaped quotes is parsed correctly.""" + input_str = '{"message": "He said \\"hello\\""}' + result = parse_llm_json(input_str) + assert result == {"message": 'He said "hello"'} + + def test_json_array_direct(self): + """JSON array is parsed directly.""" + result = parse_llm_json("[1, 2, 3]") + assert result == [1, 2, 3] + + def test_json_array_with_objects(self): + """JSON array of objects is parsed correctly.""" + input_str = '[{"id": 1}, {"id": 2}]' + result = parse_llm_json(input_str) + assert result == [{"id": 1}, {"id": 2}] + + def test_json_array_in_markdown(self): + """JSON array in markdown code block is extracted.""" + input_str = """```json +["apple", "banana", "cherry"] +```""" + result = parse_llm_json(input_str) + assert result == ["apple", "banana", "cherry"] + + def test_invalid_json_returns_raw(self): + """Invalid JSON returns raw content in 'raw' key.""" + input_str = "This is not JSON at all" + result = parse_llm_json(input_str) + assert result == {"raw": "This is not JSON at all"} + + def test_malformed_json_returns_raw(self): + """Malformed JSON returns raw content.""" + input_str = '{"key": "value"' # Missing closing brace + result = parse_llm_json(input_str) + assert result == {"raw": '{"key": "value"'} + + def test_complex_real_world_example(self): + """Test with a realistic LLM output example.""" + input_str = """Based on our conversation, I've extracted the following information: + +```json +{ + "user_name": "John Smith", + "email": "john@example.com", + "preferences": { + "notifications": true, + "theme": "dark" + } +} +``` + +Let me know if you need anything else!""" + result = parse_llm_json(input_str) + assert result == { + "user_name": "John Smith", + "email": "john@example.com", + "preferences": {"notifications": True, "theme": "dark"}, + } + + def test_json_with_newlines_inside(self): + """JSON with newlines inside values is handled.""" + input_str = '{"text": "line1\\nline2"}' + result = parse_llm_json(input_str) + assert result == {"text": "line1\nline2"} + + def test_json_with_unicode(self): + """JSON with unicode characters is parsed correctly.""" + input_str = '{"greeting": "こんにちは", "emoji": "🎉"}' + result = parse_llm_json(input_str) + assert result == {"greeting": "こんにちは", "emoji": "🎉"} + + def test_multiple_code_blocks_uses_first(self): + """When multiple code blocks exist, the first is used.""" + input_str = """```json +{"first": true} +``` +Some text +```json +{"second": true} +```""" + result = parse_llm_json(input_str) + assert result == {"first": True} + + +class TestTryParseJson: + """Tests for the _try_parse_json helper.""" + + def test_valid_dict(self): + assert _try_parse_json('{"a": 1}') == {"a": 1} + + def test_valid_list(self): + assert _try_parse_json("[1, 2]") == [1, 2] + + def test_invalid_returns_none(self): + assert _try_parse_json("not json") is None + + def test_primitive_returns_none(self): + """Primitive values (not dict/list) return None.""" + assert _try_parse_json('"just a string"') is None + assert _try_parse_json("42") is None + assert _try_parse_json("true") is None + + +class TestExtractJsonObject: + """Tests for the _extract_json_object helper.""" + + def test_extracts_from_text(self): + result = _extract_json_object('prefix {"key": "value"} suffix') + assert result == {"key": "value"} + + def test_no_object_returns_none(self): + assert _extract_json_object("no json here") is None + + def test_nested_braces(self): + result = _extract_json_object('{"outer": {"inner": 1}}') + assert result == {"outer": {"inner": 1}} + + def test_braces_in_strings(self): + result = _extract_json_object('{"code": "{ }"}') + assert result == {"code": "{ }"} + + +class TestExtractJsonArray: + """Tests for the _extract_json_array helper.""" + + def test_extracts_from_text(self): + result = _extract_json_array("here is the list: [1, 2, 3] done") + assert result == [1, 2, 3] + + def test_no_array_returns_none(self): + assert _extract_json_array("no array here") is None + + def test_nested_arrays(self): + result = _extract_json_array("[[1, 2], [3, 4]]") + assert result == [[1, 2], [3, 4]] + + def test_brackets_in_strings(self): + result = _extract_json_array('["a[b]c"]') + assert result == ["a[b]c"]