mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
fix llama3 chat template bug
This commit is contained in:
parent
784c6265ee
commit
0cdca1b642
3 changed files with 41 additions and 17 deletions
|
|
@ -96,4 +96,3 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return self._chat_completion_stream(messages)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
import json
|
||||
from typing import Literal
|
||||
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
|
||||
from metagpt.provider.bedrock.utils import messages_to_prompt_llama
|
||||
from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3
|
||||
|
||||
|
||||
class MistralProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
return messages_to_prompt_llama(messages)
|
||||
return messages_to_prompt_llama2(messages)
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["outputs"][0]["text"]
|
||||
|
|
@ -36,8 +37,14 @@ class MetaProvider(BaseBedrockProvider):
|
|||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
||||
max_tokens_field_name = "max_gen_len"
|
||||
|
||||
def __init__(self, llama_version: Literal["llama2", "llama3"]) -> None:
|
||||
self.llama_version = llama_version
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
return messages_to_prompt_llama(messages)
|
||||
if self.llama_version == "llama2":
|
||||
return messages_to_prompt_llama2(messages)
|
||||
else:
|
||||
return messages_to_prompt_llama3(messages)
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["generation"]
|
||||
|
|
@ -72,17 +79,20 @@ class AmazonProvider(BaseBedrockProvider):
|
|||
|
||||
|
||||
PROVIDERS = {
|
||||
"mistral": MistralProvider(),
|
||||
"meta": MetaProvider(),
|
||||
"ai21": Ai21Provider(),
|
||||
"cohere": CohereProvider(),
|
||||
"anthropic": AnthropicProvider(),
|
||||
"amazon": AmazonProvider()
|
||||
"mistral": MistralProvider,
|
||||
"meta": MetaProvider,
|
||||
"ai21": Ai21Provider,
|
||||
"cohere": CohereProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
"amazon": AmazonProvider
|
||||
}
|
||||
|
||||
|
||||
def get_provider(model_id: str):
|
||||
model_name = model_id.split(".")[0] # meta、mistral……
|
||||
if model_name not in PROVIDERS:
|
||||
raise KeyError(f"{model_name} is not supported!")
|
||||
return PROVIDERS[model_name]
|
||||
provider, model_name = model_id.split(".")[0:2] # meta、mistral……
|
||||
if provider not in PROVIDERS:
|
||||
raise KeyError(f"{provider} is not supported!")
|
||||
if provider == "meta":
|
||||
# distinguish llama2 and llama3
|
||||
return PROVIDERS[provider](model_name[:6])
|
||||
return PROVIDERS[provider]()
|
||||
|
|
|
|||
|
|
@ -29,8 +29,10 @@ SUPPORT_STREAM_MODELS = {
|
|||
"mistral.mistral-large-2402-v1:0": 32000,
|
||||
}
|
||||
|
||||
# TODO:use a general function for constructing chat templates.
|
||||
|
||||
def messages_to_prompt_llama(messages: list[dict]):
|
||||
|
||||
def messages_to_prompt_llama2(messages: list[dict]):
|
||||
BOS, EOS = "<s>", "</s>"
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
|
|
@ -53,7 +55,20 @@ def messages_to_prompt_llama(messages: list[dict]):
|
|||
return prompt
|
||||
|
||||
|
||||
def messages_to_prompt_llama3(messages: list[dict]):
|
||||
BOS, EOS = "<|begin_of_text|>", "<|eot_id|>"
|
||||
GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
|
||||
|
||||
prompt = f"{BOS}"
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
prompt += GENERAL_TEMPLATE.format(role=role, content=content)
|
||||
if role != "assistant":
|
||||
prompt += f"<|start_header_id|>assistant<|end_header_id|>"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_max_tokens(model_id) -> int:
|
||||
return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue