mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Merge 20e8e0c51e into 8dedf0bec1
This commit is contained in:
commit
e5b40b7a1f
3 changed files with 44 additions and 9 deletions
|
|
@ -288,6 +288,14 @@ properties:
|
|||
agent_orchestration_model:
|
||||
type: string
|
||||
description: "Model name for the agent orchestrator (e.g., 'Plano-Orchestrator'). Must match a model in model_providers."
|
||||
token_counting_strategy:
|
||||
type: string
|
||||
enum: [estimate, auto]
|
||||
description: >
|
||||
Strategy for counting input tokens used in metrics and rate limiting.
|
||||
"estimate" (default): fast character-based approximation (~1 token per 4 chars).
|
||||
"auto": uses the best available tokenizer for each provider (e.g., tiktoken for
|
||||
OpenAI models), falling back to estimate for unsupported providers.
|
||||
system_prompt:
|
||||
type: string
|
||||
prompt_targets:
|
||||
|
|
|
|||
|
|
@ -206,6 +206,15 @@ pub struct Configuration {
|
|||
pub model_metrics_sources: Option<Vec<MetricsSource>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
|
||||
pub enum TokenCountingStrategy {
|
||||
#[default]
|
||||
#[serde(rename = "estimate")]
|
||||
Estimate,
|
||||
#[serde(rename = "auto")]
|
||||
Auto,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Overrides {
|
||||
pub prompt_target_intent_matching_threshold: Option<f64>,
|
||||
|
|
@ -213,6 +222,7 @@ pub struct Overrides {
|
|||
pub use_agent_orchestrator: Option<bool>,
|
||||
pub llm_routing_model: Option<String>,
|
||||
pub agent_orchestration_model: Option<String>,
|
||||
pub token_counting_strategy: Option<TokenCountingStrategy>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use std::sync::Arc;
|
|||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
|
||||
use common::configuration::{LlmProvider, LlmProviderType, Overrides, TokenCountingStrategy};
|
||||
use common::consts::{
|
||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
|
|
@ -48,7 +48,7 @@ pub struct StreamContext {
|
|||
ttft_time: Option<u128>,
|
||||
traceparent: Option<String>,
|
||||
request_body_sent_time: Option<u128>,
|
||||
_overrides: Rc<Option<Overrides>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
user_message: Option<String>,
|
||||
upstream_status_code: Option<StatusCode>,
|
||||
binary_frame_decoder: Option<BedrockBinaryFrameDecoder<bytes::BytesMut>>,
|
||||
|
|
@ -66,7 +66,7 @@ impl StreamContext {
|
|||
) -> Self {
|
||||
StreamContext {
|
||||
metrics,
|
||||
_overrides: overrides,
|
||||
overrides,
|
||||
ratelimit_selector: None,
|
||||
streaming_response: false,
|
||||
response_tokens: 0,
|
||||
|
|
@ -270,22 +270,39 @@ impl StreamContext {
|
|||
model: &str,
|
||||
json_string: &str,
|
||||
) -> Result<(), ratelimit::Error> {
|
||||
// Tokenize and record token count.
|
||||
let token_count = tokenizer::token_count(model, json_string).unwrap_or(0);
|
||||
let strategy = (*self.overrides)
|
||||
.as_ref()
|
||||
.and_then(|o| o.token_counting_strategy.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let (token_count, method) = match strategy {
|
||||
TokenCountingStrategy::Auto => {
|
||||
let provider_id = self.get_provider_id();
|
||||
match provider_id {
|
||||
ProviderId::OpenAI => (
|
||||
tokenizer::token_count(model, json_string).unwrap_or(json_string.len() / 4),
|
||||
"tiktoken",
|
||||
),
|
||||
// Future: add provider-specific tokenizers here
|
||||
// ProviderId::Mistral => (mistral_tokenizer::count(...), "mistral"),
|
||||
_ => (json_string.len() / 4, "estimate"),
|
||||
}
|
||||
}
|
||||
TokenCountingStrategy::Estimate => (json_string.len() / 4, "estimate"),
|
||||
};
|
||||
|
||||
debug!(
|
||||
"request_id={}: token count, model='{}' input_tokens={}",
|
||||
"request_id={}: token count, model='{}' input_tokens={} method={}",
|
||||
self.request_identifier(),
|
||||
model,
|
||||
token_count
|
||||
token_count,
|
||||
method
|
||||
);
|
||||
|
||||
// Record the token count to metrics.
|
||||
self.metrics
|
||||
.input_sequence_length
|
||||
.record(token_count as u64);
|
||||
|
||||
// Check if rate limiting needs to be applied.
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
info!(
|
||||
"request_id={}: ratelimit check, model='{}' selector='{}:{}'",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue