diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index f7817a09..a80d7d0e 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -285,9 +285,14 @@ properties: agent_orchestration_model: type: string description: "Model name for the agent orchestrator (e.g., 'Plano-Orchestrator'). Must match a model in model_providers." - enable_token_counting: - type: boolean - description: "Enable tiktoken-based input token counting for metrics and rate limiting. Default is false." + token_counting_strategy: + type: string + enum: [estimate, auto] + description: > + Strategy for counting input tokens used in metrics and rate limiting. + "estimate" (default): fast character-based approximation (~1 token per 4 chars). + "auto": uses the best available tokenizer for each provider (e.g., tiktoken for + OpenAI models), falling back to estimate for unsupported providers. system_prompt: type: string prompt_targets: diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 2d7a7f22..a08f69db 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -124,6 +124,15 @@ pub struct Configuration { pub state_storage: Option, } +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] +pub enum TokenCountingStrategy { + #[default] + #[serde(rename = "estimate")] + Estimate, + #[serde(rename = "auto")] + Auto, +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct Overrides { pub prompt_target_intent_matching_threshold: Option, @@ -131,7 +140,7 @@ pub struct Overrides { pub use_agent_orchestrator: Option, pub llm_routing_model: Option, pub agent_orchestration_model: Option, - pub enable_token_counting: Option, + pub token_counting_strategy: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 20b8d3e6..f8ad8251 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -10,7 +10,7 @@ use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use crate::metrics::Metrics; -use common::configuration::{LlmProvider, LlmProviderType, Overrides}; +use common::configuration::{LlmProvider, LlmProviderType, Overrides, TokenCountingStrategy}; use common::consts::{ ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH, RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, @@ -269,15 +269,25 @@ impl StreamContext { model: &str, json_string: &str, ) -> Result<(), ratelimit::Error> { - let use_tiktoken = (*self.overrides) + let strategy = (*self.overrides) .as_ref() - .and_then(|o| o.enable_token_counting) - .unwrap_or(false); + .and_then(|o| o.token_counting_strategy.clone()) + .unwrap_or_default(); - let token_count = if use_tiktoken { - tokenizer::token_count(model, json_string).unwrap_or(0) - } else { - json_string.len() / 4 + let (token_count, method) = match strategy { + TokenCountingStrategy::Auto => { + let provider_id = self.get_provider_id(); + match provider_id { + ProviderId::OpenAI => ( + tokenizer::token_count(model, json_string).unwrap_or(json_string.len() / 4), + "tiktoken", + ), + // Future: add provider-specific tokenizers here + // ProviderId::Mistral => (mistral_tokenizer::count(...), "mistral"), + _ => (json_string.len() / 4, "estimate"), + } + } + TokenCountingStrategy::Estimate => (json_string.len() / 4, "estimate"), }; debug!( @@ -285,7 +295,7 @@ impl StreamContext { self.request_identifier(), model, token_count, - if use_tiktoken { "tiktoken" } else { "estimate" } + method ); self.metrics