mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
more changes
This commit is contained in:
parent
59dbbd6743
commit
9befd6364c
4 changed files with 68 additions and 79 deletions
|
|
@ -9,7 +9,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn openai_builder() {
|
||||
let request = OpenAIRequest::builder("gpt-3.5-turbo")
|
||||
let request = OpenAIRequest::builder()
|
||||
.model("gpt-3.5-turbo")
|
||||
.temperature(0.7)
|
||||
.top_p(0.9)
|
||||
.n(1)
|
||||
|
|
@ -18,16 +19,17 @@ mod tests {
|
|||
.stop(vec!["\n".to_string()])
|
||||
.presence_penalty(0.0)
|
||||
.frequency_penalty(0.0)
|
||||
.build();
|
||||
.build()
|
||||
.expect("Failed to build OpenAIRequest");
|
||||
|
||||
assert_eq!(request.model, "gpt-3.5-turbo");
|
||||
assert_eq!(request.temperature, Some(0.7));
|
||||
assert_eq!(request.top_p, Some(0.9));
|
||||
assert_eq!(request.n, Some(1));
|
||||
assert_eq!(request.max_tokens, Some(100));
|
||||
assert_eq!(request.stream, Some(false));
|
||||
assert_eq!(request.stop, Some(vec!["\n".to_string()]));
|
||||
assert_eq!(request.presence_penalty, Some(0.0));
|
||||
assert_eq!(request.frequency_penalty, Some(0.0));
|
||||
assert_eq!(request.base.model, "gpt-3.5-turbo");
|
||||
assert_eq!(request.base.temperature, Some(0.7));
|
||||
assert_eq!(request.base.top_p, Some(0.9));
|
||||
assert_eq!(request.base.n, Some(1));
|
||||
assert_eq!(request.base.max_tokens, Some(100));
|
||||
assert_eq!(request.base.stream, Some(false));
|
||||
assert_eq!(request.base.stop, Some(vec!["\n".to_string()]));
|
||||
assert_eq!(request.base.presence_penalty, Some(0.0));
|
||||
assert_eq!(request.base.frequency_penalty, Some(0.0));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
pub mod types;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::providers::groq::types::{GroqRequest, GroqResponse};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GroqError {
|
||||
#[error("json error: {0}")]
|
||||
JsonParseError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, GroqError>;
|
||||
|
||||
impl TryFrom<&[u8]> for GroqRequest {
|
||||
type Error = GroqError;
|
||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(GroqError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for GroqResponse {
|
||||
type Error = GroqError;
|
||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(GroqError::from)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,9 +1,12 @@
|
|||
pub mod openai;
|
||||
pub mod groq;
|
||||
pub mod deepseek;
|
||||
pub mod common_types;
|
||||
|
||||
/// Supported LLM providers.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Provider {
|
||||
Grok,
|
||||
Groq,
|
||||
OpenAI,
|
||||
DeepSeek,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,38 +1,15 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use crate::providers::common_types::{ChatRequestBase, ChatResponseBase};
|
||||
|
||||
/// Represents a request to the OpenAI API (compatible with both chat and completion endpoints).
|
||||
///
|
||||
/// Fields are based on the OpenAI API schema:
|
||||
/// https://platform.openai.com/docs/api-reference/chat/create
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIRequest {
|
||||
/// The model to use (e.g., "gpt-3.5-turbo", "gpt-4").
|
||||
pub model: String,
|
||||
/// The list of messages for chat endpoints (use None for completion).
|
||||
pub messages: Option<Vec<Message>>,
|
||||
/// Sampling temperature to use (higher values = more random).
|
||||
pub temperature: Option<f32>,
|
||||
/// Nucleus sampling parameter.
|
||||
pub top_p: Option<f32>,
|
||||
/// How many completions to generate for each prompt/message.
|
||||
pub n: Option<u32>,
|
||||
/// Maximum number of tokens to generate.
|
||||
pub max_tokens: Option<u32>,
|
||||
/// Whether to stream back partial progress.
|
||||
pub stream: Option<bool>,
|
||||
/// Up to 4 sequences where the API will stop generating further tokens.
|
||||
pub stop: Option<Vec<String>>,
|
||||
/// Penalizes new tokens based on whether they appear in the text so far.
|
||||
pub presence_penalty: Option<f32>,
|
||||
/// Penalizes new tokens based on their frequency in the text so far.
|
||||
pub frequency_penalty: Option<f32>,
|
||||
#[serde(flatten)]
|
||||
pub base: ChatRequestBase,
|
||||
}
|
||||
|
||||
/// Builder for `OpenAIRequest`.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct OpenAIRequestBuilder {
|
||||
model: String,
|
||||
messages: Option<Vec<Message>>,
|
||||
model: Option<String>,
|
||||
messages: Option<Vec<crate::providers::common_types::Message>>,
|
||||
temperature: Option<f32>,
|
||||
top_p: Option<f32>,
|
||||
n: Option<u32>,
|
||||
|
|
@ -44,14 +21,16 @@ pub struct OpenAIRequestBuilder {
|
|||
}
|
||||
|
||||
impl OpenAIRequestBuilder {
|
||||
pub fn new(model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
model: model.into(),
|
||||
..Default::default()
|
||||
}
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn messages(mut self, messages: Vec<Message>) -> Self {
|
||||
pub fn model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model = Some(model.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn messages(mut self, messages: Vec<crate::providers::common_types::Message>) -> Self {
|
||||
self.messages = Some(messages);
|
||||
self
|
||||
}
|
||||
|
|
@ -96,9 +75,10 @@ impl OpenAIRequestBuilder {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> OpenAIRequest {
|
||||
OpenAIRequest {
|
||||
model: self.model,
|
||||
pub fn build(self) -> Result<OpenAIRequest, &'static str> {
|
||||
let model = self.model.ok_or("model is required")?;
|
||||
let base = crate::providers::common_types::ChatRequestBase {
|
||||
model,
|
||||
messages: self.messages,
|
||||
temperature: self.temperature,
|
||||
top_p: self.top_p,
|
||||
|
|
@ -108,44 +88,21 @@ impl OpenAIRequestBuilder {
|
|||
stop: self.stop,
|
||||
presence_penalty: self.presence_penalty,
|
||||
frequency_penalty: self.frequency_penalty,
|
||||
}
|
||||
};
|
||||
Ok(OpenAIRequest { base })
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIRequest {
|
||||
pub fn builder(model: impl Into<String>) -> OpenAIRequestBuilder {
|
||||
OpenAIRequestBuilder::new(model)
|
||||
pub fn builder() -> OpenAIRequestBuilder {
|
||||
OpenAIRequestBuilder::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a message in the OpenAI chat API.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
/// The role of the message sender ("system", "user", or "assistant").
|
||||
pub role: String,
|
||||
/// The content of the message.
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct OpenAIResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Option<Usage>,
|
||||
#[serde(flatten)]
|
||||
pub base: ChatResponseBase,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: Message,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
pub use crate::providers::common_types::{Message, Choice, Usage};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue