From 4d53297c1735261440c370d7086000dbf38fb479 Mon Sep 17 00:00:00 2001 From: Tang Quoc Thai Date: Thu, 15 Jan 2026 00:06:28 +0100 Subject: [PATCH 1/3] feat: add passthrough_auth option for forwarding client Authorization header (#687) * feat: add passthrough_auth option for forwarding client Authorization header * fix tests * Update comment to reflect upstream forwarding * Apply suggestions from code review --------- Co-authored-by: Adil Hafeez Co-authored-by: Adil Hafeez --- cli/planoai/config_generator.py | 11 ++- config/arch_config_schema.yaml | 14 ++-- config/test_passthrough.yaml | 37 ++++++++++ crates/common/src/configuration.rs | 2 + crates/llm_gateway/src/stream_context.rs | 32 ++++++--- .../llm_providers/supported_providers.rst | 69 +++++++++++++++++++ .../includes/arch_config_full_reference.yaml | 20 +++--- .../arch_config_full_reference_rendered.yaml | 18 +++++ 8 files changed, 177 insertions(+), 26 deletions(-) create mode 100644 config/test_passthrough.yaml diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index 3b920181..636e2d22 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -236,10 +236,19 @@ def validate_and_render_schema(): for routing_preference in model_provider.get("routing_preferences", []): if routing_preference.get("name") in model_usage_name_keys: raise Exception( - f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference" + f'Duplicate routing preference name "{routing_preference.get("name")}", please provide unique name for each routing preference' ) model_usage_name_keys.add(routing_preference.get("name")) + # Warn if both passthrough_auth and access_key are configured + if model_provider.get("passthrough_auth") and model_provider.get( + "access_key" + ): + print( + f"WARNING: Model provider '{model_provider.get('name')}' has both 'passthrough_auth: true' and 'access_key' configured. " + f"The access_key will be ignored and the client's Authorization header will be forwarded instead." + ) + model_provider["model"] = model_id model_provider["provider_interface"] = provider model_provider_name_set.add(model_provider.get("name")) diff --git a/config/arch_config_schema.yaml b/config/arch_config_schema.yaml index 78856adf..71a4e3e9 100644 --- a/config/arch_config_schema.yaml +++ b/config/arch_config_schema.yaml @@ -1,4 +1,4 @@ -$schema: "http://json-schema.org/draft-07/schema#" +$schema: 'http://json-schema.org/draft-07/schema#' type: object properties: version: @@ -109,12 +109,12 @@ properties: endpoints: type: object patternProperties: - "^[a-zA-Z][a-zA-Z0-9_]*$": + '^[a-zA-Z][a-zA-Z0-9_]*$': type: object properties: endpoint: type: string - pattern: "^.*$" + pattern: '^.*$' connect_timeout: type: string protocol: @@ -143,6 +143,9 @@ properties: type: boolean base_url: type: string + passthrough_auth: + type: boolean + description: "When true, forwards the client's Authorization header to upstream instead of using the configured access_key. Useful for routing to services like LiteLLM that validate their own virtual keys." http_host: type: string provider_interface: @@ -187,6 +190,9 @@ properties: type: boolean base_url: type: string + passthrough_auth: + type: boolean + description: "When true, forwards the client's Authorization header to upstream instead of using the configured access_key. Useful for routing to services like LiteLLM that validate their own virtual keys." http_host: type: string provider_interface: @@ -219,7 +225,7 @@ properties: model_aliases: type: object patternProperties: - "^.*$": + '^.*$': type: object properties: target: diff --git a/config/test_passthrough.yaml b/config/test_passthrough.yaml new file mode 100644 index 00000000..7e59370e --- /dev/null +++ b/config/test_passthrough.yaml @@ -0,0 +1,37 @@ +# Test configuration for passthrough_auth feature +# This config demonstrates forwarding client's Authorization header to upstream +# instead of using a configured access_key. +# +# Use case: Deploying Plano in front of LiteLLM, OpenRouter, or other LLM proxies +# that manage their own API key validation. +# +# To test: +# docker build -t plano-passthrough-test . +# docker run -d -p 10000:10000 -v $(pwd)/config/test_passthrough.yaml:/app/arch_config.yaml plano-passthrough-test +# +# curl http://localhost:10000/v1/chat/completions \ +# -H "Authorization: Bearer sk-your-virtual-key" \ +# -H "Content-Type: application/json" \ +# -d '{"model": "gpt-4o", "messages": [{"role": "user", "content": "Hello"}]}' + +version: v0.3.0 + +listeners: + - name: llm + type: model + port: 10000 + +model_providers: + # Passthrough auth example - forwards client's Authorization header + # Replace base_url with your LiteLLM or proxy endpoint + - model: openai/gpt-4o + base_url: 'https://litellm.example.com' + passthrough_auth: true + default: true + + # Example with both passthrough_auth and access_key (access_key will be ignored) + # This configuration will trigger a warning during startup + - model: openai/gpt-4o-mini + base_url: 'https://litellm.example.com' + passthrough_auth: true + access_key: 'this-will-be-ignored' diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 58ea1e3e..60fd20d0 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -324,6 +324,7 @@ pub struct LlmProvider { pub cluster_name: Option, pub base_url_path_prefix: Option, pub internal: Option, + pub passthrough_auth: Option, } pub trait IntoModels { @@ -367,6 +368,7 @@ impl Default for LlmProvider { cluster_name: None, base_url_path_prefix: None, internal: None, + passthrough_auth: None, } } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 420a1035..8da0f92a 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -149,6 +149,23 @@ impl StreamContext { } fn modify_auth_headers(&mut self) -> Result<(), ServerError> { + if self.llm_provider().passthrough_auth == Some(true) { + // Check if client provided an Authorization header + if self.get_http_request_header("Authorization").is_none() { + warn!( + "[PLANO_REQ_ID:{}] AUTH_PASSTHROUGH: passthrough_auth enabled but no Authorization header present in client request", + self.request_identifier() + ); + } else { + debug!( + "[PLANO_REQ_ID:{}] AUTH_PASSTHROUGH: preserving client Authorization header for provider '{}'", + self.request_identifier(), + self.llm_provider().name + ); + } + return Ok(()); + } + let llm_provider_api_key_value = self.llm_provider() .access_key @@ -778,16 +795,11 @@ impl HttpContext for StreamContext { //We need to update the upstream path if there is a variation for a provider like Gemini/Groq, etc. self.update_upstream_path(&request_path); - if self.llm_provider().endpoint.is_some() { - self.add_http_request_header( - ARCH_ROUTING_HEADER, - &self - .llm_provider() - .cluster_name - .as_ref() - .unwrap() - .to_string(), - ); + // Clone cluster_name to avoid borrowing self while calling add_http_request_header (which requires mut self) + let cluster_name_opt = self.llm_provider().cluster_name.clone(); + + if let Some(cluster_name) = cluster_name_opt { + self.add_http_request_header(ARCH_ROUTING_HEADER, &cluster_name); } else { self.add_http_request_header( ARCH_ROUTING_HEADER, diff --git a/docs/source/concepts/llm_providers/supported_providers.rst b/docs/source/concepts/llm_providers/supported_providers.rst index acdb8381..188f35a0 100644 --- a/docs/source/concepts/llm_providers/supported_providers.rst +++ b/docs/source/concepts/llm_providers/supported_providers.rst @@ -728,6 +728,75 @@ Configure routing preferences for dynamic model selection: - name: creative_writing description: creative content generation, storytelling, and writing assistance +.. _passthrough_auth: + +Passthrough Authentication +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When deploying Plano in front of LLM proxy services that manage their own API key validation (such as LiteLLM, OpenRouter, or custom gateways), you may want to forward the client's original ``Authorization`` header instead of replacing it with a configured ``access_key``. + +The ``passthrough_auth`` option enables this behavior: + +.. code-block:: yaml + + llm_providers: + # Forward client's Authorization header to LiteLLM + - model: openai/gpt-4o-litellm + base_url: https://litellm.example.com + passthrough_auth: true + default: true + + # Forward to OpenRouter + - model: openai/claude-3-opus + base_url: https://openrouter.ai/api/v1 + passthrough_auth: true + +**How it works:** + +1. Client sends a request with ``Authorization: Bearer `` +2. Plano preserves this header instead of replacing it with ``access_key`` +3. The upstream service (e.g., LiteLLM) validates the virtual key +4. Response flows back through Plano to the client + +**Use Cases:** + +- **LiteLLM Integration**: Route requests to LiteLLM which manages virtual keys and rate limits +- **OpenRouter**: Forward requests to OpenRouter with per-user API keys +- **Custom API Gateways**: Integrate with internal gateways that have their own authentication +- **Multi-tenant Deployments**: Allow different clients to use their own credentials + +**Important Notes:** + +- When ``passthrough_auth: true`` is set, the ``access_key`` field is ignored (a warning is logged if both are configured) +- If the client doesn't provide an ``Authorization`` header, the request is forwarded without authentication (upstream will likely return 401) +- The ``base_url`` is typically required when using ``passthrough_auth`` + +**Configuration with LiteLLM example:** + +.. code-block:: yaml + + # plano_config.yaml + version: v0.3.0 + + listeners: + - name: llm + type: model + port: 10000 + + model_providers: + - model: openai/gpt-4o + base_url: https://litellm.example.com + passthrough_auth: true + default: true + +.. code-block:: bash + + # Client request - virtual key is forwarded to upstream + curl http://localhost:10000/v1/chat/completions \ + -H "Authorization: Bearer sk-litellm-virtual-key-abc123" \ + -H "Content-Type: application/json" \ + -d '{"model": "gpt-4o", "messages": [{"role": "user", "content": "Hello"}]}' + Model Selection Guidelines -------------------------- diff --git a/docs/source/resources/includes/arch_config_full_reference.yaml b/docs/source/resources/includes/arch_config_full_reference.yaml index aa186c26..be3c18a2 100644 --- a/docs/source/resources/includes/arch_config_full_reference.yaml +++ b/docs/source/resources/includes/arch_config_full_reference.yaml @@ -1,26 +1,22 @@ - # Arch Gateway configuration version version: v0.3.0 - # External HTTP agents - API type is controlled by request path (/v1/responses, /v1/messages, /v1/chat/completions) agents: - - id: weather_agent # Example agent for weather + - id: weather_agent # Example agent for weather url: http://host.docker.internal:10510 - - id: flight_agent # Example agent for flights + - id: flight_agent # Example agent for flights url: http://host.docker.internal:10520 - # MCP filters applied to requests/responses (e.g., input validation, query rewriting) filters: - - id: input_guards # Example filter for input validation + - id: input_guards # Example filter for input validation url: http://host.docker.internal:10500 # type: mcp (default) # transport: streamable-http (default) # tool: input_guards (default - same as filter id) - # LLM provider configurations with API keys and model routing model_providers: - model: openai/gpt-4o @@ -36,6 +32,12 @@ model_providers: - model: mistral/ministral-3b-latest access_key: $MISTRAL_API_KEY + # Example: Passthrough authentication for LiteLLM or similar proxies + # When passthrough_auth is true, client's Authorization header is forwarded + # instead of using the configured access_key + - model: openai/gpt-4o-litellm + base_url: https://litellm.example.com + passthrough_auth: true # Model aliases - use friendly names instead of full provider model names model_aliases: @@ -45,7 +47,6 @@ model_aliases: smart-llm: target: gpt-4o - # HTTP listeners - entry points for agent routing, prompt targets, and direct LLM access listeners: # Agent listener for routing requests to multiple agents @@ -73,7 +74,6 @@ listeners: port: 10000 # This listener is used for prompt_targets and function calling - # Reusable service endpoints endpoints: app_server: @@ -83,7 +83,6 @@ endpoints: mistral_local: endpoint: 127.0.0.1:8001 - # Prompt targets for function calling and API orchestration prompt_targets: - name: get_current_weather @@ -103,7 +102,6 @@ prompt_targets: path: /weather http_method: POST - # OpenTelemetry tracing configuration tracing: # Random sampling percentage (1-100) diff --git a/docs/source/resources/includes/arch_config_full_reference_rendered.yaml b/docs/source/resources/includes/arch_config_full_reference_rendered.yaml index 4ba89a92..a33878b6 100644 --- a/docs/source/resources/includes/arch_config_full_reference_rendered.yaml +++ b/docs/source/resources/includes/arch_config_full_reference_rendered.yaml @@ -64,6 +64,15 @@ listeners: model: ministral-3b-latest name: mistral/ministral-3b-latest provider_interface: mistral + - base_url: https://litellm.example.com + cluster_name: openai_litellm.example.com + endpoint: litellm.example.com + model: gpt-4o-litellm + name: openai/gpt-4o-litellm + passthrough_auth: true + port: 443 + protocol: https + provider_interface: openai name: egress_traffic port: 12000 timeout: 30s @@ -91,6 +100,15 @@ model_providers: model: ministral-3b-latest name: mistral/ministral-3b-latest provider_interface: mistral +- base_url: https://litellm.example.com + cluster_name: openai_litellm.example.com + endpoint: litellm.example.com + model: gpt-4o-litellm + name: openai/gpt-4o-litellm + passthrough_auth: true + port: 443 + protocol: https + provider_interface: openai - internal: true model: Arch-Function name: arch-function From 626f556cc6571afc453cce7285268ef4709a09d4 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 16 Jan 2026 15:38:43 -0800 Subject: [PATCH 2/3] reduce number of info statements in pipeline processor (#698) Co-authored-by: Adil Hafeez --- crates/brightstaff/src/handlers/pipeline_processor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs index 8b9bf21a..09520617 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -811,7 +811,7 @@ impl PipelineProcessor { }); } - info!( + debug!( "Response from HTTP agent {}: {}", agent.id, String::from_utf8_lossy(&response_bytes) From cdc1d7cee22050c8159e6098e28894edbd0d1d7d Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Fri, 16 Jan 2026 16:24:03 -0800 Subject: [PATCH 3/3] making Messages.Content optional, and having the upstream LLM fail if the right fields aren't set (#699) Co-authored-by: Salman Paracha --- .../src/handlers/agent_chat_completions.rs | 2 +- .../src/handlers/function_calling.rs | 20 +-- .../src/handlers/integration_tests.rs | 4 +- .../src/handlers/pipeline_processor.rs | 2 +- .../brightstaff/src/handlers/router_chat.rs | 4 +- .../src/router/orchestrator_model_v1.rs | 30 +++-- .../brightstaff/src/router/router_model_v1.rs | 30 +++-- crates/brightstaff/src/signals/analyzer.rs | 17 +-- crates/hermesllm/src/apis/amazon_bedrock.rs | 19 +-- crates/hermesllm/src/apis/anthropic.rs | 2 +- crates/hermesllm/src/apis/openai.rs | 123 +++++++++++++----- crates/hermesllm/src/apis/openai_responses.rs | 8 +- crates/hermesllm/src/providers/request.rs | 12 +- crates/hermesllm/src/transforms/lib.rs | 5 +- .../src/transforms/request/from_anthropic.rs | 21 +-- .../src/transforms/request/from_openai.rs | 53 ++++---- tests/e2e/test_model_alias_routing.py | 75 +++++++++++ 17 files changed, 294 insertions(+), 133 deletions(-) diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index bac9607b..5ced34c0 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -402,7 +402,7 @@ async fn handle_agent_chat( // and add it to the conversation history current_messages.push(OpenAIMessage { role: hermesllm::apis::openai::Role::Assistant, - content: hermesllm::apis::openai::MessageContent::Text(response_text), + content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)), name: Some(agent_name.clone()), tool_calls: None, tool_call_id: None, diff --git a/crates/brightstaff/src/handlers/function_calling.rs b/crates/brightstaff/src/handlers/function_calling.rs index 8f641df6..7ba15e2d 100644 --- a/crates/brightstaff/src/handlers/function_calling.rs +++ b/crates/brightstaff/src/handlers/function_calling.rs @@ -638,7 +638,7 @@ impl ArchFunctionHandler { let system_prompt = self.format_system_prompt(tools)?; processed_messages.push(Message { role: Role::System, - content: MessageContent::Text(system_prompt), + content: Some(MessageContent::Text(system_prompt)), name: None, tool_calls: None, tool_call_id: None, @@ -649,8 +649,9 @@ impl ArchFunctionHandler { for (idx, message) in messages.iter().enumerate() { let mut role = message.role.clone(); let mut content = match &message.content { - MessageContent::Text(text) => text.clone(), - MessageContent::Parts(_) => String::new(), + Some(MessageContent::Text(text)) => text.clone(), + Some(MessageContent::Parts(_)) => String::new(), + None => String::new(), }; // Handle tool calls @@ -675,7 +676,8 @@ impl ArchFunctionHandler { } else { // Get the tool call from previous message if idx > 0 { - if let MessageContent::Text(prev_content) = &messages[idx - 1].content { + if let Some(MessageContent::Text(prev_content)) = &messages[idx - 1].content + { let mut tool_call_msg = prev_content.clone(); // Strip markdown code blocks @@ -721,7 +723,7 @@ impl ArchFunctionHandler { processed_messages.push(Message { role, - content: MessageContent::Text(content), + content: Some(MessageContent::Text(content)), name: message.name.clone(), tool_calls: None, tool_call_id: None, @@ -740,7 +742,7 @@ impl ArchFunctionHandler { // Add extra instruction if provided if let Some(instruction) = extra_instruction { if let Some(last) = processed_messages.last_mut() { - if let MessageContent::Text(content) = &mut last.content { + if let Some(MessageContent::Text(content)) = &mut last.content { content.push('\n'); content.push_str(instruction); } @@ -761,7 +763,7 @@ impl ArchFunctionHandler { // Keep system message if present if let Some(first) = messages.first() { if first.role == Role::System { - if let MessageContent::Text(content) = &first.content { + if let Some(MessageContent::Text(content)) = &first.content { num_tokens += content.len() / 4; // Approximate 4 chars per token } conversation_idx = 1; @@ -772,7 +774,7 @@ impl ArchFunctionHandler { // Start with message_idx pointing past the end (will be used if no truncation needed) let mut message_idx = messages.len(); for i in (conversation_idx..messages.len()).rev() { - if let MessageContent::Text(content) = &messages[i].content { + if let Some(MessageContent::Text(content)) = &messages[i].content { num_tokens += content.len() / 4; if num_tokens >= max_tokens && messages[i].role == Role::User { // Set message_idx to current position and break @@ -802,7 +804,7 @@ impl ArchFunctionHandler { pub fn prefill_message(&self, mut messages: Vec, prefill: &str) -> Vec { messages.push(Message { role: Role::Assistant, - content: MessageContent::Text(prefill.to_string()), + content: Some(MessageContent::Text(prefill.to_string())), name: None, tool_calls: None, tool_call_id: None, diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index 29552f83..9239f94a 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -28,7 +28,7 @@ mod tests { fn create_test_message(role: Role, content: &str) -> Message { Message { role, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: None, tool_call_id: None, @@ -129,7 +129,7 @@ mod tests { let processed_messages = result.unwrap(); // With empty filter chain, should return the original messages unchanged assert_eq!(processed_messages.len(), 1); - if let MessageContent::Text(content) = &processed_messages[0].content { + if let Some(MessageContent::Text(content)) = &processed_messages[0].content { assert_eq!(content, "Hello world!"); } else { panic!("Expected text content"); diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs index 09520617..bc36de01 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -887,7 +887,7 @@ mod tests { fn create_test_message(role: Role, content: &str) -> Message { Message { role, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: None, tool_call_id: None, diff --git a/crates/brightstaff/src/handlers/router_chat.rs b/crates/brightstaff/src/handlers/router_chat.rs index 67e25338..701e8e51 100644 --- a/crates/brightstaff/src/handlers/router_chat.rs +++ b/crates/brightstaff/src/handlers/router_chat.rs @@ -95,7 +95,9 @@ pub async fn router_chat_get_upstream_model( .messages .last() .map_or("None".to_string(), |msg| { - msg.content.to_string().replace('\n', "\\n") + msg.content + .as_ref() + .map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n")) }); const MAX_MESSAGE_LENGTH: usize = 50; diff --git a/crates/brightstaff/src/router/orchestrator_model_v1.rs b/crates/brightstaff/src/router/orchestrator_model_v1.rs index ef32db83..8d64f8e7 100644 --- a/crates/brightstaff/src/router/orchestrator_model_v1.rs +++ b/crates/brightstaff/src/router/orchestrator_model_v1.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use common::configuration::{AgentUsagePreference, OrchestrationPreference}; use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role}; +use hermesllm::transforms::lib::ExtractText; use serde::{ser::Serialize as SerializeTrait, Deserialize, Serialize}; use tracing::{debug, warn}; @@ -181,7 +182,9 @@ impl OrchestratorModel for OrchestratorModelV1 { let messages_vec = messages .iter() .filter(|m| { - m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty() + m.role != Role::System + && m.role != Role::Tool + && !m.content.extract_text().is_empty() }) .collect::>(); @@ -190,7 +193,7 @@ impl OrchestratorModel for OrchestratorModelV1 { let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR; let mut selected_messages_list_reversed: Vec<&Message> = vec![]; for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() { - let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR; + let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR; token_count += message_token_count; if token_count > self.max_token_length { debug!( @@ -240,7 +243,12 @@ impl OrchestratorModel for OrchestratorModelV1 { .rev() .map(|message| Message { role: message.role.clone(), - content: MessageContent::Text(message.content.to_string()), + content: Some(MessageContent::Text( + message + .content + .as_ref() + .map_or(String::new(), |c| c.to_string()), + )), name: None, tool_calls: None, tool_call_id: None, @@ -262,7 +270,7 @@ impl OrchestratorModel for OrchestratorModelV1 { ChatCompletionsRequest { model: self.orchestration_model.clone(), messages: vec![Message { - content: MessageContent::Text(orchestrator_message), + content: Some(MessageContent::Text(orchestrator_message)), role: Role::User, name: None, tool_calls: None, @@ -539,7 +547,7 @@ If no routes are needed, return an empty list for `route`. let req = orchestrator.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -618,7 +626,7 @@ If no routes are needed, return an empty list for `route`. }]); let req = orchestrator.generate_request(&conversation, &usage_preferences); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -689,7 +697,7 @@ If no routes are needed, return an empty list for `route`. let req = orchestrator.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -761,7 +769,7 @@ If no routes are needed, return an empty list for `route`. let req = orchestrator.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -848,7 +856,7 @@ If no routes are needed, return an empty list for `route`. let req = orchestrator.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -940,7 +948,7 @@ If no routes are needed, return an empty list for `route`. let req = orchestrator.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -1058,7 +1066,7 @@ If no routes are needed, return an empty list for `route`. let req: ChatCompletionsRequest = orchestrator.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index 84680928..796dfaac 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use common::configuration::{ModelUsagePreference, RoutingPreference}; use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role}; +use hermesllm::transforms::lib::ExtractText; use serde::{Deserialize, Serialize}; use tracing::{debug, warn}; @@ -78,7 +79,9 @@ impl RouterModel for RouterModelV1 { let messages_vec = messages .iter() .filter(|m| { - m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty() + m.role != Role::System + && m.role != Role::Tool + && !m.content.extract_text().is_empty() }) .collect::>(); @@ -87,7 +90,7 @@ impl RouterModel for RouterModelV1 { let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR; let mut selected_messages_list_reversed: Vec<&Message> = vec![]; for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() { - let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR; + let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR; token_count += message_token_count; if token_count > self.max_token_length { debug!( @@ -136,7 +139,12 @@ impl RouterModel for RouterModelV1 { Message { role: message.role.clone(), // we can unwrap here because we have already filtered out messages without content - content: MessageContent::Text(message.content.to_string()), + content: Some(MessageContent::Text( + message + .content + .as_ref() + .map_or(String::new(), |c| c.to_string()), + )), name: None, tool_calls: None, tool_call_id: None, @@ -154,7 +162,7 @@ impl RouterModel for RouterModelV1 { ChatCompletionsRequest { model: self.routing_model.clone(), messages: vec![Message { - content: MessageContent::Text(router_message), + content: Some(MessageContent::Text(router_message)), role: Role::User, name: None, tool_calls: None, @@ -344,7 +352,7 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -409,7 +417,7 @@ Based on your analysis, provide your response in the following JSON formats if y }]); let req = router.generate_request(&conversation, &usage_preferences); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -469,7 +477,7 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -530,7 +538,7 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -598,7 +606,7 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -667,7 +675,7 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -762,7 +770,7 @@ Based on your analysis, provide your response in the following JSON formats if y let req: ChatCompletionsRequest = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } diff --git a/crates/brightstaff/src/signals/analyzer.rs b/crates/brightstaff/src/signals/analyzer.rs index 9880bf2c..5ee3c7d9 100644 --- a/crates/brightstaff/src/signals/analyzer.rs +++ b/crates/brightstaff/src/signals/analyzer.rs @@ -1122,9 +1122,9 @@ pub struct TextBasedSignalAnalyzer { impl TextBasedSignalAnalyzer { /// Extract text content from MessageContent, skipping non-text content - fn extract_text(content: &hermesllm::apis::openai::MessageContent) -> Option { + fn extract_text(content: &Option) -> Option { match content { - hermesllm::apis::openai::MessageContent::Text(text) => Some(text.clone()), + Some(hermesllm::apis::openai::MessageContent::Text(text)) => Some(text.clone()), // Tool calls and other structured content are skipped _ => None, } @@ -1941,12 +1941,13 @@ impl Default for TextBasedSignalAnalyzer { mod tests { use super::*; use hermesllm::apis::openai::MessageContent; + use hermesllm::transforms::lib::ExtractText; use std::time::Instant; fn create_message(role: Role, content: &str) -> Message { Message { role, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: None, tool_call_id: None, @@ -2130,7 +2131,7 @@ mod tests { .iter() .enumerate() .map(|(i, msg)| { - let text = msg.content.to_string(); + let text = msg.content.extract_text(); (i, msg.role.clone(), NormalizedMessage::from_text(&text)) }) .collect() @@ -2532,7 +2533,7 @@ mod tests { |content: &str, tool_id: &str, tool_name: &str, args: &str| -> Message { Message { role: Role::Assistant, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: Some(vec![ToolCall { id: tool_id.to_string(), @@ -2550,7 +2551,7 @@ mod tests { let create_tool_message = |tool_call_id: &str, content: &str| -> Message { Message { role: Role::Tool, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: None, tool_call_id: Some(tool_call_id.to_string()), @@ -2665,7 +2666,7 @@ mod tests { |content: &str, tool_id: &str, tool_name: &str, args: &str| -> Message { Message { role: Role::Assistant, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: Some(vec![ToolCall { id: tool_id.to_string(), @@ -2683,7 +2684,7 @@ mod tests { let create_tool_message = |tool_call_id: &str, content: &str| -> Message { Message { role: Role::Tool, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: None, tool_call_id: Some(tool_call_id.to_string()), diff --git a/crates/hermesllm/src/apis/amazon_bedrock.rs b/crates/hermesllm/src/apis/amazon_bedrock.rs index dbada283..e2cbc201 100644 --- a/crates/hermesllm/src/apis/amazon_bedrock.rs +++ b/crates/hermesllm/src/apis/amazon_bedrock.rs @@ -225,7 +225,7 @@ impl ProviderRequest for ConverseRequest { if let SystemContentBlock::Text { text } = sys_block { openai_messages.push(Message { role: Role::System, - content: MessageContent::Text(text.clone()), + content: Some(MessageContent::Text(text.clone())), name: None, tool_calls: None, tool_call_id: None, @@ -258,7 +258,7 @@ impl ProviderRequest for ConverseRequest { openai_messages.push(Message { role, - content: MessageContent::Text(content), + content: Some(MessageContent::Text(content)), name: None, tool_calls: None, tool_call_id: None, @@ -279,7 +279,7 @@ impl ProviderRequest for ConverseRequest { for msg in messages { match msg.role { crate::apis::openai::Role::System => { - if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content { system_blocks.push(SystemContentBlock::Text { text: text.clone() }); } } @@ -290,12 +290,13 @@ impl ProviderRequest for ConverseRequest { _ => continue, }; - let content = - if let crate::apis::openai::MessageContent::Text(text) = &msg.content { - vec![ContentBlock::Text { text: text.clone() }] - } else { - vec![] - }; + let content = if let Some(crate::apis::openai::MessageContent::Text(text)) = + &msg.content + { + vec![ContentBlock::Text { text: text.clone() }] + } else { + vec![] + }; bedrock_messages.push(crate::apis::amazon_bedrock::Message { role, content }); } diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index ed3317ce..6e53e6db 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -584,7 +584,7 @@ impl ProviderRequest for MessagesRequest { let system_text = system_messages .iter() .filter_map(|msg| { - if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content { Some(text.as_str()) } else { None diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 834c33ec..cd4e7d0b 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -155,7 +155,8 @@ pub enum Role { #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Message { pub role: Role, - pub content: MessageContent, + /// The contents of the message. Required unless tool_calls is specified (for assistant role) + pub content: Option, pub name: Option, /// Tool calls made by the assistant (only present for assistant role) pub tool_calls: Option>, @@ -204,8 +205,7 @@ impl ResponseMessage { content: self .content .as_ref() - .map(|s| MessageContent::Text(s.clone())) - .unwrap_or(MessageContent::Text(String::new())), + .map(|s| MessageContent::Text(s.clone())), name: None, // Response messages don't have names in the same way request messages do tool_calls: self.tool_calls.clone(), tool_call_id: None, // Response messages don't have tool_call_id @@ -233,6 +233,12 @@ impl ExtractText for MessageContent { } } +impl ExtractText for Option { + fn extract_text(&self) -> String { + self.as_ref().map(|c| c.extract_text()).unwrap_or_default() + } +} + impl ExtractText for Vec { fn extract_text(&self) -> String { self.iter() @@ -247,23 +253,7 @@ impl ExtractText for Vec { impl Display for MessageContent { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MessageContent::Text(text) => write!(f, "{}", text), - MessageContent::Parts(parts) => { - let text_parts: Vec = parts - .iter() - .filter_map(|part| match part { - ContentPart::Text { text } => Some(text.clone()), - ContentPart::ImageUrl { .. } => { - // skip image URLs or their data in text representation - None - } - }) - .collect(); - let combined_text = text_parts.join("\n"); - write!(f, "{}", combined_text) - } - } + write!(f, "{}", self.extract_text()) } } @@ -622,8 +612,10 @@ impl ProviderRequest for ChatCompletionsRequest { fn extract_messages_text(&self) -> String { self.messages.iter().fold(String::new(), |acc, m| { - acc + " " - + &match &m.content { + let content_text = m + .content + .as_ref() + .map(|content| match content { MessageContent::Text(text) => text.clone(), MessageContent::Parts(parts) => parts .iter() @@ -633,16 +625,18 @@ impl ProviderRequest for ChatCompletionsRequest { }) .collect::>() .join(" "), - } + }) + .unwrap_or_default(); + acc + " " + &content_text }) } fn get_recent_user_message(&self) -> Option { self.messages.last().and_then(|msg| { - match &msg.content { + msg.content.as_ref().and_then(|content| match content { MessageContent::Text(text) => Some(text.clone()), MessageContent::Parts(_) => None, // No user message in parts - } + }) }) } @@ -778,7 +772,8 @@ mod tests { let message = &deserialized_request.messages[0]; assert_eq!(message.role, Role::User); - if let MessageContent::Text(content) = &message.content { + assert!(message.content.is_some()); + if let Some(MessageContent::Text(content)) = &message.content { assert_eq!(content, "Hello, world!"); } else { panic!("Expected text content"); @@ -822,7 +817,8 @@ mod tests { let message = &deserialized_request.messages[0]; assert_eq!(message.role, Role::User); - if let MessageContent::Text(content) = &message.content { + assert!(message.content.is_some()); + if let Some(MessageContent::Text(content)) = &message.content { assert_eq!(content, "Test message"); } else { panic!("Expected text content"); @@ -947,7 +943,8 @@ mod tests { // Validate first message (user with multimodal content) let user_message = &deserialized_request.messages[0]; assert_eq!(user_message.role, Role::User); - if let MessageContent::Parts(ref content_parts) = user_message.content { + assert!(user_message.content.is_some()); + if let Some(MessageContent::Parts(ref content_parts)) = user_message.content { assert_eq!(content_parts.len(), 2); // Validate text content part @@ -971,7 +968,8 @@ mod tests { // Validate second message (assistant with tool calls) let assistant_message = &deserialized_request.messages[1]; assert_eq!(assistant_message.role, Role::Assistant); - if let MessageContent::Text(text) = &assistant_message.content { + assert!(assistant_message.content.is_some()); + if let Some(MessageContent::Text(text)) = &assistant_message.content { assert_eq!( text, "I can see a beautiful cityscape. Let me check the weather for you." @@ -997,7 +995,8 @@ mod tests { // Validate third message (tool response) let tool_message = &deserialized_request.messages[2]; assert_eq!(tool_message.role, Role::Tool); - if let MessageContent::Text(text) = &tool_message.content { + assert!(tool_message.content.is_some()); + if let Some(MessageContent::Text(text)) = &tool_message.content { assert_eq!(text, "Current weather in New York: 72°F, sunny"); } else { panic!("Expected text content for tool message"); @@ -1061,6 +1060,62 @@ mod tests { assert!((original_temp - serialized_temp).abs() < 1e-6); } + #[test] + fn test_assistant_message_with_tool_calls_no_content() { + // Test that assistant messages can have tool_calls without content + let json_with_tool_calls_no_content = json!({ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "What's the weather in San Francisco?" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"San Francisco, CA\"}" + } + } + ] + } + ] + }); + + // Should deserialize successfully + let request: ChatCompletionsRequest = + serde_json::from_value(json_with_tool_calls_no_content.clone()).unwrap(); + + assert_eq!(request.messages.len(), 2); + + // Check user message + let user_msg = &request.messages[0]; + assert_eq!(user_msg.role, Role::User); + assert!(user_msg.content.is_some()); + + // Check assistant message - should have tool_calls but no content + let assistant_msg = &request.messages[1]; + assert_eq!(assistant_msg.role, Role::Assistant); + assert!(assistant_msg.content.is_none()); + assert!(assistant_msg.tool_calls.is_some()); + + let tool_calls = assistant_msg.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id, "call_123"); + assert_eq!(tool_calls[0].function.name, "get_weather"); + + // Should serialize back without content field + let serialized = serde_json::to_value(&request).unwrap(); + // Verify the assistant message doesn't have a content field in serialized JSON + let serialized_assistant_msg = &serialized["messages"][1]; + assert!(serialized_assistant_msg.get("content").is_none()); + assert!(serialized_assistant_msg.get("tool_calls").is_some()); + } + #[test] fn test_api_provider_trait() { // Test the ApiDefinition trait implementation @@ -1097,7 +1152,7 @@ mod tests { let deserialized_user: Message = serde_json::from_value(user_json.clone()).unwrap(); assert_eq!(deserialized_user.role, Role::User); - if let MessageContent::Text(content) = &deserialized_user.content { + if let Some(MessageContent::Text(content)) = &deserialized_user.content { assert_eq!(content, "Hello!"); } else { panic!("Expected text content"); @@ -1128,7 +1183,7 @@ mod tests { let deserialized_assistant: Message = serde_json::from_value(assistant_json.clone()).unwrap(); assert_eq!(deserialized_assistant.role, Role::Assistant); - if let MessageContent::Text(content) = &deserialized_assistant.content { + if let Some(MessageContent::Text(content)) = &deserialized_assistant.content { assert_eq!(content, "I'll help with that."); } else { panic!("Expected text content"); @@ -1154,7 +1209,7 @@ mod tests { let deserialized_tool: Message = serde_json::from_value(tool_json.clone()).unwrap(); assert_eq!(deserialized_tool.role, Role::Tool); - if let MessageContent::Text(content) = &deserialized_tool.content { + if let Some(MessageContent::Text(content)) = &deserialized_tool.content { assert_eq!(content, "Weather is sunny"); } else { panic!("Expected text content"); @@ -1193,7 +1248,7 @@ mod tests { // Test conversion from ResponseMessage to Message let converted = deserialized_response.to_message(); assert_eq!(converted.role, Role::Assistant); - if let MessageContent::Text(text) = converted.content { + if let Some(MessageContent::Text(text)) = converted.content { assert_eq!(text, "Response content"); } else { panic!("Expected text content"); diff --git a/crates/hermesllm/src/apis/openai_responses.rs b/crates/hermesllm/src/apis/openai_responses.rs index 720e24d3..dbc82f8b 100644 --- a/crates/hermesllm/src/apis/openai_responses.rs +++ b/crates/hermesllm/src/apis/openai_responses.rs @@ -1146,7 +1146,7 @@ impl ProviderRequest for ResponsesAPIRequest { .iter() .filter(|msg| msg.role == crate::apis::openai::Role::System) .filter_map(|msg| { - if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content { Some(text.as_str()) } else { None @@ -1170,7 +1170,8 @@ impl ProviderRequest for ResponsesAPIRequest { if !input_messages.is_empty() { // If there's only one message, use Text format if input_messages.len() == 1 { - if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content + if let Some(crate::apis::openai::MessageContent::Text(text)) = + &input_messages[0].content { self.input = crate::apis::openai_responses::InputParam::Text(text.clone()); } @@ -1180,7 +1181,8 @@ impl ProviderRequest for ResponsesAPIRequest { let combined_text = input_messages .iter() .filter_map(|msg| { - if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content + { Some(format!( "{}: {}", match msg.role { diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index d1d85888..e97e8a68 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -671,14 +671,16 @@ mod tests { messages: vec![ Message { role: Role::System, - content: MessageContent::Text("You are a helpful assistant".to_string()), + content: Some(MessageContent::Text( + "You are a helpful assistant".to_string(), + )), name: None, tool_calls: None, tool_call_id: None, }, Message { role: Role::User, - content: MessageContent::Text("Hello!".to_string()), + content: Some(MessageContent::Text("Hello!".to_string())), name: None, tool_calls: None, tool_call_id: None, @@ -900,7 +902,7 @@ mod tests { model: "gpt-4".to_string(), messages: vec![Message { role: Role::User, - content: MessageContent::Text("Hello!".to_string()), + content: Some(MessageContent::Text("Hello!".to_string())), name: None, tool_calls: None, tool_call_id: None, @@ -993,14 +995,14 @@ mod tests { messages: vec![ Message { role: Role::System, - content: MessageContent::Text("You are helpful".to_string()), + content: Some(MessageContent::Text("You are helpful".to_string())), name: None, tool_calls: None, tool_call_id: None, }, Message { role: Role::User, - content: MessageContent::Text("Hello!".to_string()), + content: Some(MessageContent::Text("Hello!".to_string())), name: None, tool_calls: None, tool_call_id: None, diff --git a/crates/hermesllm/src/transforms/lib.rs b/crates/hermesllm/src/transforms/lib.rs index a44f8d79..115f061c 100644 --- a/crates/hermesllm/src/transforms/lib.rs +++ b/crates/hermesllm/src/transforms/lib.rs @@ -188,7 +188,7 @@ pub fn convert_openai_message_to_anthropic_content( // Handle regular content match &message.content { - MessageContent::Text(text) => { + Some(MessageContent::Text(text)) => { if !text.is_empty() { blocks.push(MessagesContentBlock::Text { text: text.clone(), @@ -196,7 +196,7 @@ pub fn convert_openai_message_to_anthropic_content( }); } } - MessageContent::Parts(parts) => { + Some(MessageContent::Parts(parts)) => { for part in parts { match part { ContentPart::Text { text } => { @@ -212,6 +212,7 @@ pub fn convert_openai_message_to_anthropic_content( } } } + None => {} } // Handle tool calls diff --git a/crates/hermesllm/src/transforms/request/from_anthropic.rs b/crates/hermesllm/src/transforms/request/from_anthropic.rs index c07be4e5..82dbe547 100644 --- a/crates/hermesllm/src/transforms/request/from_anthropic.rs +++ b/crates/hermesllm/src/transforms/request/from_anthropic.rs @@ -174,7 +174,7 @@ impl TryFrom for Vec { MessagesMessageContent::Single(text) => { result.push(Message { role: message.role.into(), - content: MessageContent::Text(text), + content: Some(MessageContent::Text(text)), name: None, tool_calls: None, tool_call_id: None, @@ -186,7 +186,7 @@ impl TryFrom for Vec { for (tool_use_id, result_text, _is_error) in tool_results { result.push(Message { role: Role::Tool, - content: MessageContent::Text(result_text), + content: Some(MessageContent::Text(result_text)), name: None, tool_calls: None, tool_call_id: Some(tool_use_id), @@ -260,7 +260,7 @@ impl From for Message { Message { role: Role::System, - content: system_content, + content: Some(system_content), name: None, tool_calls: None, tool_call_id: None, @@ -317,16 +317,19 @@ fn convert_anthropic_tool_choice( fn build_openai_content( content_parts: Vec, tool_calls: &[ToolCall], -) -> MessageContent { - if content_parts.len() == 1 && tool_calls.is_empty() { +) -> Option { + if content_parts.is_empty() && !tool_calls.is_empty() { + // For assistant messages with only tool calls, content is optional + None + } else if content_parts.len() == 1 && tool_calls.is_empty() { match &content_parts[0] { - ContentPart::Text { text } => MessageContent::Text(text.clone()), - _ => MessageContent::Parts(content_parts), + ContentPart::Text { text } => Some(MessageContent::Text(text.clone())), + _ => Some(MessageContent::Parts(content_parts)), } } else if content_parts.is_empty() { - MessageContent::Text("".to_string()) + Some(MessageContent::Text("".to_string())) } else { - MessageContent::Parts(content_parts) + Some(MessageContent::Parts(content_parts)) } } diff --git a/crates/hermesllm/src/transforms/request/from_openai.rs b/crates/hermesllm/src/transforms/request/from_openai.rs index e39cfed3..ddc3b1ca 100644 --- a/crates/hermesllm/src/transforms/request/from_openai.rs +++ b/crates/hermesllm/src/transforms/request/from_openai.rs @@ -18,7 +18,6 @@ use crate::apis::openai_responses::{ ResponsesAPIRequest, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice, }; use crate::clients::TransformError; -use crate::transforms::lib::ExtractText; use crate::transforms::lib::*; use crate::transforms::*; @@ -48,7 +47,7 @@ impl TryFrom for Vec { if let Some(instructions) = converter.instructions { messages.push(Message { role: Role::System, - content: MessageContent::Text(instructions), + content: Some(MessageContent::Text(instructions)), name: None, tool_call_id: None, tool_calls: None, @@ -58,7 +57,7 @@ impl TryFrom for Vec { // Add the user message messages.push(Message { role: Role::User, - content: MessageContent::Text(text), + content: Some(MessageContent::Text(text)), name: None, tool_call_id: None, tool_calls: None, @@ -74,7 +73,7 @@ impl TryFrom for Vec { if let Some(instructions) = converter.instructions { converted_messages.push(Message { role: Role::System, - content: MessageContent::Text(instructions), + content: Some(MessageContent::Text(instructions)), name: None, tool_call_id: None, tool_calls: None, @@ -154,7 +153,7 @@ impl TryFrom for Vec { converted_messages.push(Message { role, - content, + content: Some(content), name: None, tool_call_id: None, tool_calls: None, @@ -174,11 +173,7 @@ impl TryFrom for Vec { impl From for MessagesSystemPrompt { fn from(val: Message) -> Self { - let system_text = match val.content { - MessageContent::Text(text) => text, - MessageContent::Parts(parts) => parts.extract_text(), - }; - MessagesSystemPrompt::Single(system_text) + MessagesSystemPrompt::Single(val.content.extract_text()) } } @@ -191,6 +186,8 @@ impl TryFrom for MessagesMessage { Role::Assistant => MessagesRole::Assistant, Role::Tool => { // Tool messages become user messages with tool results + // Extract content text first, before moving tool_call_id + let content_text = message.content.extract_text(); let tool_call_id = message.tool_call_id.ok_or_else(|| { TransformError::MissingField( "tool_call_id required for Tool messages".to_string(), @@ -204,7 +201,7 @@ impl TryFrom for MessagesMessage { tool_use_id: tool_call_id, is_error: None, content: ToolResultContent::Blocks(vec![MessagesContentBlock::Text { - text: message.content.extract_text(), + text: content_text, cache_control: None, }]), cache_control: None, @@ -248,12 +245,12 @@ impl TryFrom for BedrockMessage { Role::User => { // Convert user message content to content blocks match message.content { - MessageContent::Text(text) => { + Some(MessageContent::Text(text)) => { if !text.is_empty() { content_blocks.push(ContentBlock::Text { text }); } } - MessageContent::Parts(parts) => { + Some(MessageContent::Parts(parts)) => { // Convert OpenAI content parts to Bedrock ContentBlocks for part in parts { match part { @@ -293,6 +290,9 @@ impl TryFrom for BedrockMessage { } } } + None => { + // Empty content for user - shouldn't happen but handle gracefully + } } // Ensure we have at least one content block @@ -550,10 +550,7 @@ impl TryFrom for ConverseRequest { for message in req.messages { match message.role { Role::System => { - let system_text = match message.content { - MessageContent::Text(text) => text, - MessageContent::Parts(parts) => parts.extract_text(), - }; + let system_text = message.content.extract_text(); system_messages.push(SystemContentBlock::Text { text: system_text }); } _ => { @@ -778,14 +775,16 @@ mod tests { messages: vec![ Message { role: Role::System, - content: MessageContent::Text("You are a helpful assistant.".to_string()), + content: Some(MessageContent::Text( + "You are a helpful assistant.".to_string(), + )), name: None, tool_call_id: None, tool_calls: None, }, Message { role: Role::User, - content: MessageContent::Text("Hello, how are you?".to_string()), + content: Some(MessageContent::Text("Hello, how are you?".to_string())), name: None, tool_call_id: None, tool_calls: None, @@ -840,7 +839,7 @@ mod tests { model: "gpt-4".to_string(), messages: vec![Message { role: Role::User, - content: MessageContent::Text("What's the weather like?".to_string()), + content: Some(MessageContent::Text("What's the weather like?".to_string())), name: None, tool_call_id: None, tool_calls: None, @@ -907,7 +906,7 @@ mod tests { model: "gpt-4".to_string(), messages: vec![Message { role: Role::User, - content: MessageContent::Text("Help me with something".to_string()), + content: Some(MessageContent::Text("Help me with something".to_string())), name: None, tool_call_id: None, tool_calls: None, @@ -950,28 +949,30 @@ mod tests { messages: vec![ Message { role: Role::System, - content: MessageContent::Text("Be concise".to_string()), + content: Some(MessageContent::Text("Be concise".to_string())), name: None, tool_call_id: None, tool_calls: None, }, Message { role: Role::User, - content: MessageContent::Text("Hello".to_string()), + content: Some(MessageContent::Text("Hello".to_string())), name: None, tool_call_id: None, tool_calls: None, }, Message { role: Role::Assistant, - content: MessageContent::Text("Hi there! How can I help you?".to_string()), + content: Some(MessageContent::Text( + "Hi there! How can I help you?".to_string(), + )), name: None, tool_call_id: None, tool_calls: None, }, Message { role: Role::User, - content: MessageContent::Text("What's 2+2?".to_string()), + content: Some(MessageContent::Text("What's 2+2?".to_string())), name: None, tool_call_id: None, tool_calls: None, @@ -1009,7 +1010,7 @@ mod tests { fn test_openai_message_to_bedrock_conversion() { let openai_message = Message { role: Role::User, - content: MessageContent::Text("Test message".to_string()), + content: Some(MessageContent::Text("Test message".to_string())), name: None, tool_call_id: None, tool_calls: None, diff --git a/tests/e2e/test_model_alias_routing.py b/tests/e2e/test_model_alias_routing.py index 7af14df1..f20c24af 100644 --- a/tests/e2e/test_model_alias_routing.py +++ b/tests/e2e/test_model_alias_routing.py @@ -22,6 +22,81 @@ LLM_GATEWAY_ENDPOINT = os.getenv( # ============================================================================= +def test_assistant_message_with_null_content_and_tool_calls(): + """Test that assistant messages with null content and tool_calls are properly handled""" + logger.info( + "Testing assistant message with null content and tool_calls (multi-turn conversation)" + ) + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI( + api_key="test-key", + base_url=f"{base_url}/v1", + ) + + # Simulate a multi-turn conversation where: + # 1. User asks a question + # 2. Assistant makes a tool call (with null content) + # 3. Tool responds + # 4. Assistant should provide final answer + completion = client.chat.completions.create( + model="gpt-4o", + max_tokens=500, + messages=[ + { + "role": "system", + "content": "You are a weather assistant. Use the get_weather tool to fetch weather information.", + }, + {"role": "user", "content": "What's the weather in Seattle?"}, + { + "role": "assistant", + "content": None, # This is the key test - null content with tool_calls + "tool_calls": [ + { + "id": "call_test123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Seattle"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_test123", + "content": '{"location": "Seattle", "temperature": "10°C", "condition": "Partly cloudy"}', + }, + ], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"} + }, + "required": ["city"], + }, + }, + } + ], + ) + + response_content = completion.choices[0].message.content + logger.info(f"Response after tool call: {response_content}") + + # The assistant should provide a final response using the tool result + assert response_content is not None + assert len(response_content) > 0 + logger.info( + "✓ Assistant message with null content and tool_calls handled correctly" + ) + + def test_openai_client_with_alias_arch_summarize_v1(): """Test OpenAI client using model alias 'arch.summarize.v1' which should resolve to '4o-mini'""" logger.info("Testing OpenAI client with alias 'arch.summarize.v1' -> '4o-mini'")