add support for using custom upstream llm (#365)

This commit is contained in:
Adil Hafeez 2025-01-17 18:25:55 -08:00 committed by GitHub
parent 3fc21de60c
commit 07ef3149b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 263 additions and 52 deletions

View file

@ -80,7 +80,7 @@ impl StreamContext {
fn select_llm_provider(&mut self) {
let provider_hint = self
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.map(|provider_name| provider_name.into());
.map(|llm_name| llm_name.into());
debug!("llm provider hint: {:?}", provider_hint);
self.llm_provider = Some(routing::get_llm_provider(
@ -174,10 +174,22 @@ impl HttpContext for StreamContext {
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
self.select_llm_provider();
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
// if endpoint is not set then use provider name as routing header so envoy can resolve the cluster name
if self.llm_provider().endpoint.is_none() {
self.add_http_request_header(
ARCH_ROUTING_HEADER,
&self.llm_provider().provider_interface.to_string(),
);
} else {
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
}
if let Err(error) = self.modify_auth_headers() {
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
// ensure that the provider has an endpoint if the access key is missing else return a bad request
if self.llm_provider.as_ref().unwrap().endpoint.is_none() {
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
}
}
self.delete_content_length_header();
self.save_ratelimit_header();
@ -334,16 +346,18 @@ impl HttpContext for StreamContext {
// Record the latency to the latency histogram
self.metrics.request_latency.record(duration_ms as u64);
// Compute the time per output token
let tpot = duration_ms as u64 / self.response_tokens as u64;
if self.response_tokens > 0 {
// Compute the time per output token
let tpot = duration_ms as u64 / self.response_tokens as u64;
debug!("Time per output token: {} milliseconds", tpot);
// Record the time per output token
self.metrics.time_per_output_token.record(tpot);
debug!("Time per output token: {} milliseconds", tpot);
// Record the time per output token
self.metrics.time_per_output_token.record(tpot);
debug!("Tokens per second: {}", 1000 / tpot);
// Record the tokens per second
self.metrics.tokens_per_second.record(1000 / tpot);
debug!("Tokens per second: {}", 1000 / tpot);
// Record the tokens per second
self.metrics.tokens_per_second.record(1000 / tpot);
}
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
@ -381,11 +395,13 @@ impl HttpContext for StreamContext {
self.llm_provider().name.to_string(),
);
llm_span.add_event(Event::new(
"time_to_first_token".to_string(),
self.ttft_time.unwrap(),
));
trace_data.add_span(llm_span);
if self.ttft_time.is_some() {
llm_span.add_event(Event::new(
"time_to_first_token".to_string(),
self.ttft_time.unwrap(),
));
trace_data.add_span(llm_span);
}
self.traces_queue.lock().unwrap().push_back(trace_data);
}