mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
updated code based on code review comments
This commit is contained in:
parent
2dd2e05f7d
commit
cfa0585a53
5 changed files with 8 additions and 40 deletions
|
|
@ -43,7 +43,7 @@ let request = ProviderRequestType::try_from((request_bytes.as_bytes(), &Provider
|
|||
|
||||
// Access request properties
|
||||
println!("Model: {}", request.model());
|
||||
println!("User message: {:?}", request.extract_user_message());
|
||||
println!("User message: {:?}", request.get_recent_user_message());
|
||||
println!("Is streaming: {}", request.is_streaming());
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -549,13 +549,6 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
self.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&mut self) {
|
||||
self.stream = Some(true);
|
||||
if self.stream_options.is_none() {
|
||||
self.stream_options = Some(StreamOptions { include_usage: Some(true) });
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
self.messages.iter().fold(String::new(), |acc, m| {
|
||||
acc + " " + &match &m.content {
|
||||
|
|
@ -568,7 +561,7 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
})
|
||||
}
|
||||
|
||||
fn extract_user_message(&self) -> Option<String> {
|
||||
fn get_recent_user_message(&self) -> Option<String> {
|
||||
self.messages.last().and_then(|msg| {
|
||||
match &msg.content {
|
||||
MessageContent::Text(text) => Some(text.clone()),
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ mod tests {
|
|||
|
||||
let request = result.unwrap();
|
||||
assert_eq!(request.model(), "gpt-4");
|
||||
assert_eq!(request.extract_user_message(), Some("Hello!".to_string()));
|
||||
assert_eq!(request.get_recent_user_message(), Some("Hello!".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -46,14 +46,11 @@ pub trait ProviderRequest: Send + Sync {
|
|||
/// Check if this is a streaming request
|
||||
fn is_streaming(&self) -> bool;
|
||||
|
||||
/// Set streaming options (e.g., include_usage)
|
||||
fn set_streaming_options(&mut self);
|
||||
|
||||
/// Extract text content from messages for token counting
|
||||
fn extract_messages_text(&self) -> String;
|
||||
|
||||
/// Extract the user message for tracing/logging purposes
|
||||
fn extract_user_message(&self) -> Option<String>;
|
||||
fn get_recent_user_message(&self) -> Option<String>;
|
||||
|
||||
/// Convert the request to bytes for transmission
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
|
||||
|
|
@ -78,21 +75,15 @@ impl ProviderRequest for ProviderRequestType {
|
|||
}
|
||||
}
|
||||
|
||||
fn set_streaming_options(&mut self) {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.set_streaming_options(),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.extract_messages_text(),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_user_message(&self) -> Option<String> {
|
||||
fn get_recent_user_message(&self) -> Option<String> {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.extract_user_message(),
|
||||
Self::ChatCompletionsRequest(r) => r.get_recent_user_message(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -356,7 +356,7 @@ impl HttpContext for StreamContext {
|
|||
deserialized_body.set_model(resolved_model.clone());
|
||||
|
||||
// Extract user message for tracing
|
||||
self.user_message = deserialized_body.extract_user_message();
|
||||
self.user_message = deserialized_body.get_recent_user_message();
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||
|
|
@ -368,11 +368,6 @@ impl HttpContext for StreamContext {
|
|||
// Use provider interface for streaming detection and setup
|
||||
self.streaming_response = deserialized_body.is_streaming();
|
||||
|
||||
// Set streaming options if needed
|
||||
if self.streaming_response {
|
||||
deserialized_body.set_streaming_options();
|
||||
}
|
||||
|
||||
// Use provider interface for text extraction (after potential mutation)
|
||||
let input_tokens_str = deserialized_body.extract_messages_text();
|
||||
// enforce ratelimits on ingress
|
||||
|
|
@ -385,9 +380,6 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
let llm_provider_str = self.llm_provider().provider_interface.to_string();
|
||||
let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||
|
||||
// Convert chat completion request to llm provider specific request using provider interface
|
||||
let deserialized_body_bytes = match deserialized_body.to_bytes() {
|
||||
Ok(bytes) => bytes,
|
||||
|
|
@ -562,17 +554,9 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
}
|
||||
|
||||
let llm_provider_str = self.llm_provider().provider_interface.to_string();
|
||||
let _provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||
|
||||
if self.streaming_response {
|
||||
debug!("processing streaming response");
|
||||
|
||||
// Parse streaming response using OpenAI-compatible format
|
||||
// Since all providers use OpenAI-compatible streaming format
|
||||
let provider_id = self.get_provider_id();
|
||||
|
||||
match ProviderStreamResponseIter::try_from((&body[..], &provider_id)) {
|
||||
match ProviderStreamResponseIter::try_from((&body[..], &self.get_provider_id())) {
|
||||
Ok(mut streaming_response) => {
|
||||
// Process each streaming chunk
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue