replace enable_token_counting bool with token_counting_strategy enum (estimate|auto)

This commit is contained in:
Adil Hafeez 2026-03-25 05:35:27 +00:00
parent e5f3039924
commit 20e8e0c51e
3 changed files with 37 additions and 13 deletions

View file

@ -285,9 +285,14 @@ properties:
agent_orchestration_model:
type: string
description: "Model name for the agent orchestrator (e.g., 'Plano-Orchestrator'). Must match a model in model_providers."
enable_token_counting:
type: boolean
description: "Enable tiktoken-based input token counting for metrics and rate limiting. Default is false."
token_counting_strategy:
type: string
enum: [estimate, auto]
description: >
Strategy for counting input tokens used in metrics and rate limiting.
"estimate" (default): fast character-based approximation (~1 token per 4 chars).
"auto": uses the best available tokenizer for each provider (e.g., tiktoken for
OpenAI models), falling back to estimate for unsupported providers.
system_prompt:
type: string
prompt_targets:

View file

@ -124,6 +124,15 @@ pub struct Configuration {
pub state_storage: Option<StateStorageConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub enum TokenCountingStrategy {
#[default]
#[serde(rename = "estimate")]
Estimate,
#[serde(rename = "auto")]
Auto,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Overrides {
pub prompt_target_intent_matching_threshold: Option<f64>,
@ -131,7 +140,7 @@ pub struct Overrides {
pub use_agent_orchestrator: Option<bool>,
pub llm_routing_model: Option<String>,
pub agent_orchestration_model: Option<String>,
pub enable_token_counting: Option<bool>,
pub token_counting_strategy: Option<TokenCountingStrategy>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]

View file

@ -10,7 +10,7 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::metrics::Metrics;
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
use common::configuration::{LlmProvider, LlmProviderType, Overrides, TokenCountingStrategy};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
@ -269,15 +269,25 @@ impl StreamContext {
model: &str,
json_string: &str,
) -> Result<(), ratelimit::Error> {
let use_tiktoken = (*self.overrides)
let strategy = (*self.overrides)
.as_ref()
.and_then(|o| o.enable_token_counting)
.unwrap_or(false);
.and_then(|o| o.token_counting_strategy.clone())
.unwrap_or_default();
let token_count = if use_tiktoken {
tokenizer::token_count(model, json_string).unwrap_or(0)
} else {
json_string.len() / 4
let (token_count, method) = match strategy {
TokenCountingStrategy::Auto => {
let provider_id = self.get_provider_id();
match provider_id {
ProviderId::OpenAI => (
tokenizer::token_count(model, json_string).unwrap_or(json_string.len() / 4),
"tiktoken",
),
// Future: add provider-specific tokenizers here
// ProviderId::Mistral => (mistral_tokenizer::count(...), "mistral"),
_ => (json_string.len() / 4, "estimate"),
}
}
TokenCountingStrategy::Estimate => (json_string.len() / 4, "estimate"),
};
debug!(
@ -285,7 +295,7 @@ impl StreamContext {
self.request_identifier(),
model,
token_count,
if use_tiktoken { "tiktoken" } else { "estimate" }
method
);
self.metrics