This commit is contained in:
Adil Hafeez 2026-04-09 03:28:47 +00:00 committed by GitHub
commit e5b40b7a1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 44 additions and 9 deletions

View file

@ -288,6 +288,14 @@ properties:
agent_orchestration_model:
type: string
description: "Model name for the agent orchestrator (e.g., 'Plano-Orchestrator'). Must match a model in model_providers."
token_counting_strategy:
type: string
enum: [estimate, auto]
description: >
Strategy for counting input tokens used in metrics and rate limiting.
"estimate" (default): fast character-based approximation (~1 token per 4 chars).
"auto": uses the best available tokenizer for each provider (e.g., tiktoken for
OpenAI models), falling back to estimate for unsupported providers.
system_prompt:
type: string
prompt_targets:

View file

@ -206,6 +206,15 @@ pub struct Configuration {
pub model_metrics_sources: Option<Vec<MetricsSource>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub enum TokenCountingStrategy {
#[default]
#[serde(rename = "estimate")]
Estimate,
#[serde(rename = "auto")]
Auto,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Overrides {
pub prompt_target_intent_matching_threshold: Option<f64>,
@ -213,6 +222,7 @@ pub struct Overrides {
pub use_agent_orchestrator: Option<bool>,
pub llm_routing_model: Option<String>,
pub agent_orchestration_model: Option<String>,
pub token_counting_strategy: Option<TokenCountingStrategy>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]

View file

@ -10,7 +10,7 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::metrics::Metrics;
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
use common::configuration::{LlmProvider, LlmProviderType, Overrides, TokenCountingStrategy};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
@ -48,7 +48,7 @@ pub struct StreamContext {
ttft_time: Option<u128>,
traceparent: Option<String>,
request_body_sent_time: Option<u128>,
_overrides: Rc<Option<Overrides>>,
overrides: Rc<Option<Overrides>>,
user_message: Option<String>,
upstream_status_code: Option<StatusCode>,
binary_frame_decoder: Option<BedrockBinaryFrameDecoder<bytes::BytesMut>>,
@ -66,7 +66,7 @@ impl StreamContext {
) -> Self {
StreamContext {
metrics,
_overrides: overrides,
overrides,
ratelimit_selector: None,
streaming_response: false,
response_tokens: 0,
@ -270,22 +270,39 @@ impl StreamContext {
model: &str,
json_string: &str,
) -> Result<(), ratelimit::Error> {
// Tokenize and record token count.
let token_count = tokenizer::token_count(model, json_string).unwrap_or(0);
let strategy = (*self.overrides)
.as_ref()
.and_then(|o| o.token_counting_strategy.clone())
.unwrap_or_default();
let (token_count, method) = match strategy {
TokenCountingStrategy::Auto => {
let provider_id = self.get_provider_id();
match provider_id {
ProviderId::OpenAI => (
tokenizer::token_count(model, json_string).unwrap_or(json_string.len() / 4),
"tiktoken",
),
// Future: add provider-specific tokenizers here
// ProviderId::Mistral => (mistral_tokenizer::count(...), "mistral"),
_ => (json_string.len() / 4, "estimate"),
}
}
TokenCountingStrategy::Estimate => (json_string.len() / 4, "estimate"),
};
debug!(
"request_id={}: token count, model='{}' input_tokens={}",
"request_id={}: token count, model='{}' input_tokens={} method={}",
self.request_identifier(),
model,
token_count
token_count,
method
);
// Record the token count to metrics.
self.metrics
.input_sequence_length
.record(token_count as u64);
// Check if rate limiting needs to be applied.
if let Some(selector) = self.ratelimit_selector.take() {
info!(
"request_id={}: ratelimit check, model='{}' selector='{}:{}'",