mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add more changes
This commit is contained in:
parent
cd90faf50c
commit
66a971b086
8 changed files with 35 additions and 9 deletions
|
|
@ -79,6 +79,8 @@ properties:
|
|||
properties:
|
||||
prompt_target_intent_matching_threshold:
|
||||
type: number
|
||||
optimize_context_window:
|
||||
type: boolean
|
||||
system_prompt:
|
||||
type: string
|
||||
prompt_targets:
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ pub struct Configuration {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Overrides {
|
||||
pub prompt_target_intent_matching_threshold: Option<f64>,
|
||||
pub optimize_context_window: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
|
|
|
|||
|
|
@ -137,9 +137,20 @@ impl HttpContext for StreamContext {
|
|||
.map(|(_, pt)| pt.into())
|
||||
.collect();
|
||||
|
||||
let mut metadata = deserialized_body.metadata.clone();
|
||||
|
||||
if let Some(overrides) = self.overrides.as_ref() {
|
||||
if overrides.optimize_context_window.unwrap_or_default() {
|
||||
if metadata.is_none() {
|
||||
metadata = Some(HashMap::new());
|
||||
}
|
||||
metadata.as_mut().unwrap().insert("optimize_context_window".to_string(), "true".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let arch_fc_chat_completion_request = ChatCompletionsRequest {
|
||||
messages: deserialized_body.messages.clone(),
|
||||
metadata: deserialized_body.metadata.clone(),
|
||||
metadata,
|
||||
stream: deserialized_body.stream,
|
||||
model: "--".to_string(),
|
||||
stream_options: deserialized_body.stream_options.clone(),
|
||||
|
|
@ -157,7 +168,7 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
debug!("sending request to model server");
|
||||
trace!("request body: {}", json_data);
|
||||
debug!("request body: {}", json_data);
|
||||
|
||||
let mut headers = vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME),
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ pub struct StreamCallContext {
|
|||
pub struct StreamContext {
|
||||
system_prompt: Rc<Option<String>>,
|
||||
pub prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
_overrides: Rc<Option<Overrides>>,
|
||||
pub overrides: Rc<Option<Overrides>>,
|
||||
pub metrics: Rc<Metrics>,
|
||||
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||
pub context_id: u32,
|
||||
|
|
@ -89,7 +89,7 @@ impl StreamContext {
|
|||
streaming_response: false,
|
||||
user_prompt: None,
|
||||
is_chat_completions_request: false,
|
||||
_overrides: overrides,
|
||||
overrides: overrides,
|
||||
request_id: None,
|
||||
traceparent: None,
|
||||
_tracing: tracing,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,9 @@ listener:
|
|||
port: 8080 #If you configure port 443, you'll need to update the listener with tls_certificates
|
||||
message_format: huggingface
|
||||
|
||||
overrides:
|
||||
optimize_context_window: true
|
||||
|
||||
endpoints:
|
||||
spotify:
|
||||
endpoint: api.spotify.com
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
req.messages, req.tools, self.extra_instruction
|
||||
)
|
||||
|
||||
logger.info(f"[request]: {json.dumps(messages)}")
|
||||
logger.info(f"[request to arch-fc (intent)]: {json.dumps(messages)}")
|
||||
|
||||
model_response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
|
|
@ -519,9 +519,11 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
"""
|
||||
logger.info("[Arch-Function] - ChatCompletion")
|
||||
|
||||
messages = self._process_messages(req.messages, req.tools)
|
||||
messages = self._process_messages(
|
||||
req.messages, req.tools, metadata=req.metadata
|
||||
)
|
||||
|
||||
logger.info(f"[request]: {json.dumps(messages)}")
|
||||
logger.info(f"[request to arch-fc]: {json.dumps(messages)}")
|
||||
|
||||
# always enable `stream=True` to collect model responses
|
||||
response = self.client.chat.completions.create(
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ class ArchGuardHanlder:
|
|||
raise NotImplementedError(f"{req.task} is not supported!")
|
||||
|
||||
logger.info("[Arch-Guard] - Prediction")
|
||||
logger.info(f"[request]: {req.input}")
|
||||
logger.info(f"[request arch-guard]: {req.input}")
|
||||
|
||||
if len(req.input.split()) < max_num_words:
|
||||
result = self._predict_text(req.task, req.input)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class Message(BaseModel):
|
|||
class ChatMessage(BaseModel):
|
||||
messages: List[Message] = []
|
||||
tools: List[Dict[str, Any]] = []
|
||||
metadata: Optional[Dict[str, str]] = {}
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
|
|
@ -123,6 +124,7 @@ class ArchBaseHandler:
|
|||
tools: List[Dict[str, Any]] = None,
|
||||
extra_instruction: str = None,
|
||||
max_tokens=4096,
|
||||
metadata: Dict[str, str] = {},
|
||||
):
|
||||
"""
|
||||
Processes a list of messages and formats them appropriately.
|
||||
|
|
@ -157,7 +159,12 @@ class ArchBaseHandler:
|
|||
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
|
||||
elif role == "tool":
|
||||
role = "user"
|
||||
content = f"<tool_response>\n\n</tool_response>"
|
||||
if metadata.get("optimize_context_window", "false").lower() == "true":
|
||||
content = f"<tool_response>\n\n</tool_response>"
|
||||
else:
|
||||
content = (
|
||||
f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
|
||||
)
|
||||
|
||||
processed_messages.append({"role": role, "content": content})
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue