implement meta

This commit is contained in:
usamimeri_renko 2024-04-25 20:05:25 +08:00
parent a741488410
commit 9775a2b1eb
4 changed files with 60 additions and 23 deletions

View file

@ -51,6 +51,7 @@ class AmazonBedrockLLM(BaseLLM):
modelId=self.config.model, body=request_body
)
completions = self.provider.get_choice_text(response)
log_llm_stream(completions)
return completions
def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
@ -59,11 +60,10 @@ class AmazonBedrockLLM(BaseLLM):
response = self.__client.invoke_model_with_response_stream(
modelId=self.config.model, body=request_body
)
collected_content = []
for event in response.get("body"):
chunk_text = json.loads(event["chunk"]["bytes"])[
"outputs"][0]["text"]
collected_content = []
for event in response["body"]:
chunk_text = self.provider.get_choice_text_from_stream(event)
collected_content.append(chunk_text)
log_llm_stream(chunk_text)
@ -84,7 +84,11 @@ class AmazonBedrockLLM(BaseLLM):
if __name__ == '__main__':
from .config import my_config
prompt = "write an essay for living on mars in 1000 word"
messages = [{"role": "user", "content": prompt}]
messages = [
{"role": "system", "content": "your name is Bob"},
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hello,my friend"},
{"role": "user", "content": "What is your name?"}]
llm = AmazonBedrockLLM(my_config)
llm.completion(messages)
llm._chat_completion_stream(messages)

View file

@ -7,16 +7,16 @@ class BaseBedrockProvider(object):
return json.dumps({"prompt": self.messages_to_prompt(messages)} | generate_kwargs)
def get_choice_text(self, response) -> str:
response_body = json.loads(response["body"].read())
response_body = self._get_response_body_json(response)
completions = response_body["outputs"][0]['text']
return completions
def get_choice_text_from_stream(self, event):
return json.loads(event["chunk"]["bytes"])["outputs"][0]["text"]
def _get_response_body_json(self, response):
return json.loads(response["body"].read())
def messages_to_prompt(self, messages: list[dict]):
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
def format_prompt(self, prompt: str) -> str:
return prompt
def format_messages(self, messages: list[dict]) -> list[dict]:
return messages

View file

@ -1,18 +1,16 @@
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
import json
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
from metagpt.provider.bedrock.utils import messages_to_prompt_llama
class MistralProvider(BaseBedrockProvider):
def format_prompt(self, prompt: str) -> str:
# for mixtral and llama
return f"<s>[INST]{prompt}[/INST]"
def get_request_body(self, messages, **generate_kwargs):
return json.dumps({"prompt": self.format_prompt(self.messages_to_prompt(messages))} | generate_kwargs)
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
def messages_to_prompt(self, messages: list[dict]):
return messages_to_prompt_llama(messages)
class AnthropicProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
pass
@ -21,7 +19,16 @@ class CohereProvider(BaseBedrockProvider):
class MetaProvider(BaseBedrockProvider):
pass
def messages_to_prompt(self, messages: list[dict]):
return messages_to_prompt_llama(messages)
def get_choice_text(self, response) -> str:
response_body = self._get_response_body_json(response)
completions = response_body['generation']
return completions
def get_choice_text_from_stream(self, event):
return json.loads(event["chunk"]["bytes"])["generation"]
class Ai21Provider(BaseBedrockProvider):
@ -29,7 +36,8 @@ class Ai21Provider(BaseBedrockProvider):
PROVIDERS = {
"mistral": MistralProvider()
"mistral": MistralProvider(),
"meta": MetaProvider(),
}
NOT_SUUPORT_STREAM_MODELS = {

View file

@ -0,0 +1,25 @@
from metagpt.logs import logger
def messages_to_prompt_llama(messages: list[dict]):
BOS, EOS = "<s>", "</s>"
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
prompt = f"{BOS}"
for message in messages:
role = message["role"]
content = message["content"]
if role == "system":
prompt += f"{B_SYS} {content} {E_SYS}"
elif role == "user":
prompt += f"{B_INST} {content} {E_INST}"
elif role == "assistant":
prompt += f"{content}"
else:
logger.warning(
f"Unknown role name {role} when formatting messages")
prompt += f"{content}"
return prompt