From a7414884100f67fbf6e3aed0abaceecb30472873 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 15:16:05 +0800 Subject: [PATCH] add stream --- .../provider/bedrock/amazon_bedrock_api.py | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 11a2bb3f5..d8aaed8e9 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -42,16 +42,8 @@ class AmazonBedrockLLM(BaseLLM): def _generate_kwargs(self): return { "temperature": self.config.get("temperature", 0.3), - "top_p": self.config.get("top_p", 0.95), - "top_k": self.config.get("top_k", 1), } - def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - pass - - def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - pass - def completion(self, messages: list[dict]): request_body = self.provider.get_request_body( messages, **self._generate_kwargs) @@ -61,13 +53,38 @@ class AmazonBedrockLLM(BaseLLM): completions = self.provider.get_choice_text(response) return completions - def acompletion(self, messages: list[dict]): - pass + def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + request_body = self.provider.get_request_body( + messages, **self._generate_kwargs) + response = self.__client.invoke_model_with_response_stream( + modelId=self.config.model, body=request_body + ) + collected_content = [] + + for event in response.get("body"): + chunk_text = json.loads(event["chunk"]["bytes"])[ + "outputs"][0]["text"] + collected_content.append(chunk_text) + log_llm_stream(chunk_text) + + log_llm_stream("\n") + full_text = ("".join(collected_content)).lstrip() + return full_text + + async def acompletion(self, messages: list[dict]): + return self._achat_completion(messages) + + async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + # TODO:make it async + return self.completion(messages) + + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + return self._chat_completion_stream(messages) if __name__ == '__main__': from .config import my_config - prompt = "who are you?" + prompt = "write an essay for living on mars in 1000 word" messages = [{"role": "user", "content": prompt}] llm = AmazonBedrockLLM(my_config) - print(llm.completion(messages)) + llm._chat_completion_stream(messages)