more changes

This commit is contained in:
Adil Hafeez 2025-06-03 00:17:22 -07:00
parent 59dbbd6743
commit 9befd6364c
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
4 changed files with 68 additions and 79 deletions

View file

@ -9,7 +9,8 @@ mod tests {
#[test]
fn openai_builder() {
let request = OpenAIRequest::builder("gpt-3.5-turbo")
let request = OpenAIRequest::builder()
.model("gpt-3.5-turbo")
.temperature(0.7)
.top_p(0.9)
.n(1)
@ -18,16 +19,17 @@ mod tests {
.stop(vec!["\n".to_string()])
.presence_penalty(0.0)
.frequency_penalty(0.0)
.build();
.build()
.expect("Failed to build OpenAIRequest");
assert_eq!(request.model, "gpt-3.5-turbo");
assert_eq!(request.temperature, Some(0.7));
assert_eq!(request.top_p, Some(0.9));
assert_eq!(request.n, Some(1));
assert_eq!(request.max_tokens, Some(100));
assert_eq!(request.stream, Some(false));
assert_eq!(request.stop, Some(vec!["\n".to_string()]));
assert_eq!(request.presence_penalty, Some(0.0));
assert_eq!(request.frequency_penalty, Some(0.0));
assert_eq!(request.base.model, "gpt-3.5-turbo");
assert_eq!(request.base.temperature, Some(0.7));
assert_eq!(request.base.top_p, Some(0.9));
assert_eq!(request.base.n, Some(1));
assert_eq!(request.base.max_tokens, Some(100));
assert_eq!(request.base.stream, Some(false));
assert_eq!(request.base.stop, Some(vec!["\n".to_string()]));
assert_eq!(request.base.presence_penalty, Some(0.0));
assert_eq!(request.base.frequency_penalty, Some(0.0));
}
}

View file

@ -0,0 +1,27 @@
pub mod types;
use thiserror::Error;
use crate::providers::groq::types::{GroqRequest, GroqResponse};
#[derive(Debug, Error)]
pub enum GroqError {
#[error("json error: {0}")]
JsonParseError(#[from] serde_json::Error),
}
type Result<T> = std::result::Result<T, GroqError>;
impl TryFrom<&[u8]> for GroqRequest {
type Error = GroqError;
fn try_from(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(GroqError::from)
}
}
impl TryFrom<&[u8]> for GroqResponse {
type Error = GroqError;
fn try_from(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(GroqError::from)
}
}

View file

@ -1,9 +1,12 @@
pub mod openai;
pub mod groq;
pub mod deepseek;
pub mod common_types;
/// Supported LLM providers.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Provider {
Grok,
Groq,
OpenAI,
DeepSeek,
}

View file

@ -1,38 +1,15 @@
use serde::{Deserialize, Serialize};
use crate::providers::common_types::{ChatRequestBase, ChatResponseBase};
/// Represents a request to the OpenAI API (compatible with both chat and completion endpoints).
///
/// Fields are based on the OpenAI API schema:
/// https://platform.openai.com/docs/api-reference/chat/create
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIRequest {
/// The model to use (e.g., "gpt-3.5-turbo", "gpt-4").
pub model: String,
/// The list of messages for chat endpoints (use None for completion).
pub messages: Option<Vec<Message>>,
/// Sampling temperature to use (higher values = more random).
pub temperature: Option<f32>,
/// Nucleus sampling parameter.
pub top_p: Option<f32>,
/// How many completions to generate for each prompt/message.
pub n: Option<u32>,
/// Maximum number of tokens to generate.
pub max_tokens: Option<u32>,
/// Whether to stream back partial progress.
pub stream: Option<bool>,
/// Up to 4 sequences where the API will stop generating further tokens.
pub stop: Option<Vec<String>>,
/// Penalizes new tokens based on whether they appear in the text so far.
pub presence_penalty: Option<f32>,
/// Penalizes new tokens based on their frequency in the text so far.
pub frequency_penalty: Option<f32>,
#[serde(flatten)]
pub base: ChatRequestBase,
}
/// Builder for `OpenAIRequest`.
#[derive(Debug, Default, Clone)]
pub struct OpenAIRequestBuilder {
model: String,
messages: Option<Vec<Message>>,
model: Option<String>,
messages: Option<Vec<crate::providers::common_types::Message>>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<u32>,
@ -44,14 +21,16 @@ pub struct OpenAIRequestBuilder {
}
impl OpenAIRequestBuilder {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
..Default::default()
}
pub fn new() -> Self {
Self::default()
}
pub fn messages(mut self, messages: Vec<Message>) -> Self {
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn messages(mut self, messages: Vec<crate::providers::common_types::Message>) -> Self {
self.messages = Some(messages);
self
}
@ -96,9 +75,10 @@ impl OpenAIRequestBuilder {
self
}
pub fn build(self) -> OpenAIRequest {
OpenAIRequest {
model: self.model,
pub fn build(self) -> Result<OpenAIRequest, &'static str> {
let model = self.model.ok_or("model is required")?;
let base = crate::providers::common_types::ChatRequestBase {
model,
messages: self.messages,
temperature: self.temperature,
top_p: self.top_p,
@ -108,44 +88,21 @@ impl OpenAIRequestBuilder {
stop: self.stop,
presence_penalty: self.presence_penalty,
frequency_penalty: self.frequency_penalty,
}
};
Ok(OpenAIRequest { base })
}
}
impl OpenAIRequest {
pub fn builder(model: impl Into<String>) -> OpenAIRequestBuilder {
OpenAIRequestBuilder::new(model)
pub fn builder() -> OpenAIRequestBuilder {
OpenAIRequestBuilder::new()
}
}
/// Represents a message in the OpenAI chat API.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
/// The role of the message sender ("system", "user", or "assistant").
pub role: String,
/// The content of the message.
pub content: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OpenAIResponse {
pub id: String,
pub object: String,
pub created: u64,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
#[serde(flatten)]
pub base: ChatResponseBase,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Choice {
pub index: u32,
pub message: Message,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
pub use crate::providers::common_types::{Message, Choice, Usage};