From 7706b88f03a8edc378aee8c279537f339eff0042 Mon Sep 17 00:00:00 2001 From: leiwu30 <2495165664@qq.com> Date: Wed, 27 Mar 2024 14:57:23 +0800 Subject: [PATCH 01/12] feat(core): Add stream data return and reception 1. add file: utils/steam_pipe.py 2. add demo: samples/flask_web_api.py 3. Other core code modifications, Add and use the StreamPipe class at night 4. Add flask library to requirements --- examples/flask_web_api.py | 108 ++++++++++++++++++++++++++ metagpt/actions/action.py | 4 + metagpt/provider/anthropic_api.py | 3 + metagpt/provider/base_llm.py | 2 + metagpt/provider/dashscope_api.py | 2 + metagpt/provider/google_gemini_api.py | 2 + metagpt/provider/ollama_api.py | 2 + metagpt/provider/openai_api.py | 3 + metagpt/provider/qianfan_api.py | 2 + metagpt/provider/zhipuai_api.py | 2 + metagpt/roles/role.py | 4 + metagpt/roles/tutorial_assistant.py | 10 ++- metagpt/utils/stream_pipe.py | 70 +++++++++++++++++ requirements.txt | 3 +- 14 files changed, 213 insertions(+), 4 deletions(-) create mode 100644 examples/flask_web_api.py create mode 100644 metagpt/utils/stream_pipe.py diff --git a/examples/flask_web_api.py b/examples/flask_web_api.py new file mode 100644 index 000000000..e87455ed1 --- /dev/null +++ b/examples/flask_web_api.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/3/27 9:44 +@Author : leiwu30 +@File : flask_web_api.py +@Description : Stream log information and communicate over the network via web api. +""" +import os +import json +import socket +import asyncio +import threading + +from metagpt.utils.stream_pipe import StreamPipe +from metagpt.roles.tutorial_assistant import TutorialAssistant +from metagpt.const import METAGPT_ROOT + +from flask import Flask, Response +from flask import request, jsonify, send_from_directory + +app = Flask(__name__) + + +def write_tutorial(message): + async def main(idea, stream_pipe): + role = TutorialAssistant(stream_pipe=stream_pipe) + await role.run(idea) + + def thread_run(idea: str, stream_pipe: StreamPipe = None): + """ + Convert asynchronous function to thread function + """ + asyncio.run(main(idea, stream_pipe)) + + stream_pipe = StreamPipe() + thread = threading.Thread(target=thread_run, args=(message["content"], stream_pipe,)) + thread.start() + + while not stream_pipe.finish: + stream_pipe.wait() + msg = stream_pipe.get_message() + yield stream_pipe.msg2stream(msg) + + # 文件位置 + md_file = stream_pipe.get_k_message("file_name") + + yield stream_pipe.msg2stream( + f"\n\n[{os.path.basename(md_file)}](http://{server_address}:{server_port}/download/{md_file})") + + +@app.route('/v1/chat/completions', methods=['POST']) +def completions(): + """ + data: { + "model": "write_tutorial", + "stream": true, + "messages": [ + { + "role": "user", + "content": "Write a tutorial about MySQL" + } + ] + } + """ + + data = json.loads(request.data) + print(json.dumps(data, indent=4, ensure_ascii=False)) + + # Non-streaming interfaces are not supported yet + stream_type = True if "stream" in data.keys() and data["stream"] else False + if not stream_type: + return jsonify({"status": 200}) + + # Only accept the last user information + last_message = data["messages"][-1] + model = data["model"] + + # write_tutorial + if model == "write_tutorial": + return Response(write_tutorial(last_message), mimetype="text/plain") + else: + return jsonify({"status": 200}) + # return Response(event_stream(), mimetype="text/plain") + + +@app.route('/download/') +def download_file(filename): + return send_from_directory(METAGPT_ROOT, filename, as_attachment=True) + + +if __name__ == "__main__": + """ + curl https://$server_address:$server_port/v1/chat/completions -X POST -d '{ + "model": "gpt-3.5-turbo", + "stream": true, + "messages": [ + { + "role": "user", + "content": "Write a tutorial about MySQL" + } + ] + }' + """ + server_port = 7860 + server_address = socket.gethostbyname(socket.gethostname()) + + app.run(port=server_port, host=server_address) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 1b93213f7..53fdd59f7 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -23,6 +23,7 @@ from metagpt.schema import ( TestingContext, ) from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.stream_pipe import StreamPipe class Action(SerializationMixin, ContextMixin, BaseModel): @@ -35,6 +36,7 @@ class Action(SerializationMixin, ContextMixin, BaseModel): prefix: str = "" # aask*时会加上prefix,作为system_message desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) + stream_pipe: Optional[StreamPipe] = None @property def repo(self) -> ProjectRepo: @@ -90,6 +92,8 @@ class Action(SerializationMixin, ContextMixin, BaseModel): async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: """Append default prefix""" + if self.stream_pipe and not self.llm.stream_pipe: + self.llm.stream_pipe = self.stream_pipe return await self.llm.aask(prompt, system_msgs) async def _run_action_node(self, *args, **kwargs): diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index 1aeacbe83..00f66b72f 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -62,6 +62,9 @@ class AnthropicLLM(BaseLLM): elif event_type == "content_block_delta": content = event.delta.text log_llm_stream(content) + if self.stream_pipe: + self.stream_pipe.set_message(content) + collected_content.append(content) elif event_type == "message_delta": usage.output_tokens = event.usage.output_tokens # update final output_tokens diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index db2757ec3..1b4be24e9 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -28,6 +28,7 @@ from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.common import log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.stream_pipe import StreamPipe class BaseLLM(ABC): @@ -42,6 +43,7 @@ class BaseLLM(ABC): cost_manager: Optional[CostManager] = None model: Optional[str] = None # deprecated pricing_plan: Optional[str] = None + stream_pipe: Optional[StreamPipe] = None @abstractmethod def __init__(self, config: LLMConfig): diff --git a/metagpt/provider/dashscope_api.py b/metagpt/provider/dashscope_api.py index 82224e893..10845ed87 100644 --- a/metagpt/provider/dashscope_api.py +++ b/metagpt/provider/dashscope_api.py @@ -221,6 +221,8 @@ class DashScopeLLM(BaseLLM): content = chunk.output.choices[0]["message"]["content"] usage = dict(chunk.usage) # each chunk has usage log_llm_stream(content) + if self.stream_pipe: + self.stream_pipe.set_message(content) collected_content.append(content) log_llm_stream("\n") self._update_costs(usage) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index e4b3a3f17..49a533792 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -149,6 +149,8 @@ class GeminiLLM(BaseLLM): logger.warning(f"messages: {messages}\nerrors: {e}\n{BlockedPromptException(str(chunk))}") raise BlockedPromptException(str(chunk)) log_llm_stream(content) + if self.stream_pipe: + self.stream_pipe.set_message(content) collected_content.append(content) log_llm_stream("\n") diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 2913eb1dd..450346ab7 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -83,6 +83,8 @@ class OllamaLLM(BaseLLM): content = self.get_choice_text(chunk) collected_content.append(content) log_llm_stream(content) + if self.stream_pipe: + self.stream_pipe.set_message(content) else: # stream finished usage = self.get_usage(chunk) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index dbfed72df..2cfd86cfb 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -87,6 +87,9 @@ class OpenAILLM(BaseLLM): chunk.choices[0].finish_reason if chunk.choices and hasattr(chunk.choices[0], "finish_reason") else None ) log_llm_stream(chunk_message) + if self.stream_pipe: + self.stream_pipe.set_message(chunk_message) + collected_messages.append(chunk_message) if finish_reason: if hasattr(chunk, "usage"): diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py index 3d78c8bfc..c45a03d9b 100644 --- a/metagpt/provider/qianfan_api.py +++ b/metagpt/provider/qianfan_api.py @@ -124,6 +124,8 @@ class QianFanLLM(BaseLLM): content = chunk.body.get("result", "") usage = chunk.body.get("usage", {}) log_llm_stream(content) + if self.stream_pipe: + self.stream_pipe.set_message(content) collected_content.append(content) log_llm_stream("\n") diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 2db441991..d42f34fdd 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -73,6 +73,8 @@ class ZhiPuAILLM(BaseLLM): content = self.get_choice_delta_text(chunk) collected_content.append(content) log_llm_stream(content) + if self.stream_pipe: + self.stream_pipe.set_message(content) log_llm_stream("\n") diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index e0f8a7ea6..599662145 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -39,6 +39,8 @@ from metagpt.strategy.planner import Planner from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output +from metagpt.utils.stream_pipe import StreamPipe + if TYPE_CHECKING: from metagpt.environment import Environment # noqa: F401 @@ -139,6 +141,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel): role_id: str = "" states: list[str] = [] + stream_pipe: Optional[StreamPipe] = None + # scenarios to set action system_prompt: # 1. `__init__` while using Role(actions=[...]) # 2. add action to role while using `role.set_action(action)` diff --git a/metagpt/roles/tutorial_assistant.py b/metagpt/roles/tutorial_assistant.py index 6cf3a6469..4da419cc0 100644 --- a/metagpt/roles/tutorial_assistant.py +++ b/metagpt/roles/tutorial_assistant.py @@ -10,7 +10,7 @@ from datetime import datetime from typing import Dict from metagpt.actions.write_tutorial import WriteContent, WriteDirectory -from metagpt.const import TUTORIAL_PATH +from metagpt.const import TUTORIAL_PATH, METAGPT_ROOT from metagpt.logs import logger from metagpt.roles.role import Role, RoleReactMode from metagpt.schema import Message @@ -40,7 +40,7 @@ class TutorialAssistant(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - self.set_actions([WriteDirectory(language=self.language)]) + self.set_actions([WriteDirectory(language=self.language, stream_pipe=self.stream_pipe)]) self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value) async def _handle_directory(self, titles: Dict) -> Message: @@ -58,7 +58,7 @@ class TutorialAssistant(Role): self.total_content += f"# {self.main_title}" actions = list() for first_dir in titles.get("directory"): - actions.append(WriteContent(language=self.language, directory=first_dir)) + actions.append(WriteContent(language=self.language, directory=first_dir, stream_pipe=self.stream_pipe)) key = list(first_dir.keys())[0] directory += f"- {key}\n" for second_dir in first_dir[key]: @@ -91,4 +91,8 @@ class TutorialAssistant(Role): root_path = TUTORIAL_PATH / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") await File.write(root_path, f"{self.main_title}.md", self.total_content.encode("utf-8")) msg.content = str(root_path / f"{self.main_title}.md") + + if self.stream_pipe: + self.stream_pipe.set_k_message("file_name", msg.content.replace(str(METAGPT_ROOT), "")) + self.stream_pipe.with_finish() return msg diff --git a/metagpt/utils/stream_pipe.py b/metagpt/utils/stream_pipe.py new file mode 100644 index 000000000..5fa4556ea --- /dev/null +++ b/metagpt/utils/stream_pipe.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/3/27 10:00 +# @Author : leiwu30 +# @File : stream_pipe.py +# @Version : None +# @Description : None + +import time +import json +from multiprocessing import Pipe + + +class StreamPipe: + parent_conn, child_conn = Pipe() + + variable: list = {} + finish: bool = False + + format_data = { + "id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur", + "object": "chat.completion.chunk", + "created": 1711361191, + "model": "gpt-3.5-turbo-0125", + "system_fingerprint": "fp_3bc1b5746c", + "choices": [ + { + "index": 0, + "delta": + { + "role": "assistant", + "content": "content" + }, + "logprobs": None, + "finish_reason": None + } + ] + } + + def set_message(self, msg): + self.parent_conn.send(msg) + + def wait(self): + pass + + def get_message(self): + if self.child_conn.poll(timeout=3): + return self.child_conn.recv() + else: + return None + + def set_k_message(self, k, msg): + self.variable[k] = msg + + def get_k_message(self, k): + return self.variable[k] + + def msg2stream(self, msg): + self.format_data['created'] = int(time.time()) + self.format_data['choices'][0]['delta']['content'] = msg + return f"data: {json.dumps(self.format_data, ensure_ascii=False)}\n".encode("utf-8") + + def with_finish(self, timeout: int = 3): + """ + Args: + timeout: /s + """ + # Pipe is not empty waiting for pipe condition + # while self.child_conn.poll(timeout=timeout): + # time.sleep(0.5) + self.finish = True diff --git a/requirements.txt b/requirements.txt index da8aa26b2..f5e9b15c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,4 +69,5 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation -jieba==0.42.1 # for tool recommendation \ No newline at end of file +jieba==0.42.1 # for tool recommendation +flask # for web api \ No newline at end of file From 29fecffa3f4b96acc325e4c7df76fc671f94ff2a Mon Sep 17 00:00:00 2001 From: leiwu30 <2495165664@qq.com> Date: Wed, 27 Mar 2024 16:55:38 +0800 Subject: [PATCH 02/12] fix: Make flask an extra dependency https://github.com/geekan/MetaGPT/pull/1118#discussion_r1540656272 --- requirements.txt | 3 +-- setup.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index f5e9b15c1..da8aa26b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,5 +69,4 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation -jieba==0.42.1 # for tool recommendation -flask # for web api \ No newline at end of file +jieba==0.42.1 # for tool recommendation \ No newline at end of file diff --git a/setup.py b/setup.py index f834b4c44..257263bba 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ extras_require = { "llama-index-vector-stores-faiss==0.1.1", "chromadb==0.4.23", ], + "web-api": ["flask==3.0.2"], } extras_require["test"] = [ From 923150b2f32d734bd8c5bce95c49566b429d43cc Mon Sep 17 00:00:00 2001 From: leiwu30 <2495165664@qq.com> Date: Thu, 28 Mar 2024 09:38:44 +0800 Subject: [PATCH 03/12] Revert "feat(core): Add stream data return and reception" This reverts commit 7706b88f03a8edc378aee8c279537f339eff0042. --- examples/flask_web_api.py | 108 -------------------------- metagpt/actions/action.py | 4 - metagpt/provider/anthropic_api.py | 3 - metagpt/provider/base_llm.py | 2 - metagpt/provider/dashscope_api.py | 2 - metagpt/provider/google_gemini_api.py | 2 - metagpt/provider/ollama_api.py | 2 - metagpt/provider/openai_api.py | 3 - metagpt/provider/qianfan_api.py | 2 - metagpt/provider/zhipuai_api.py | 2 - metagpt/roles/role.py | 4 - metagpt/roles/tutorial_assistant.py | 10 +-- metagpt/utils/stream_pipe.py | 70 ----------------- 13 files changed, 3 insertions(+), 211 deletions(-) delete mode 100644 examples/flask_web_api.py delete mode 100644 metagpt/utils/stream_pipe.py diff --git a/examples/flask_web_api.py b/examples/flask_web_api.py deleted file mode 100644 index e87455ed1..000000000 --- a/examples/flask_web_api.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2024/3/27 9:44 -@Author : leiwu30 -@File : flask_web_api.py -@Description : Stream log information and communicate over the network via web api. -""" -import os -import json -import socket -import asyncio -import threading - -from metagpt.utils.stream_pipe import StreamPipe -from metagpt.roles.tutorial_assistant import TutorialAssistant -from metagpt.const import METAGPT_ROOT - -from flask import Flask, Response -from flask import request, jsonify, send_from_directory - -app = Flask(__name__) - - -def write_tutorial(message): - async def main(idea, stream_pipe): - role = TutorialAssistant(stream_pipe=stream_pipe) - await role.run(idea) - - def thread_run(idea: str, stream_pipe: StreamPipe = None): - """ - Convert asynchronous function to thread function - """ - asyncio.run(main(idea, stream_pipe)) - - stream_pipe = StreamPipe() - thread = threading.Thread(target=thread_run, args=(message["content"], stream_pipe,)) - thread.start() - - while not stream_pipe.finish: - stream_pipe.wait() - msg = stream_pipe.get_message() - yield stream_pipe.msg2stream(msg) - - # 文件位置 - md_file = stream_pipe.get_k_message("file_name") - - yield stream_pipe.msg2stream( - f"\n\n[{os.path.basename(md_file)}](http://{server_address}:{server_port}/download/{md_file})") - - -@app.route('/v1/chat/completions', methods=['POST']) -def completions(): - """ - data: { - "model": "write_tutorial", - "stream": true, - "messages": [ - { - "role": "user", - "content": "Write a tutorial about MySQL" - } - ] - } - """ - - data = json.loads(request.data) - print(json.dumps(data, indent=4, ensure_ascii=False)) - - # Non-streaming interfaces are not supported yet - stream_type = True if "stream" in data.keys() and data["stream"] else False - if not stream_type: - return jsonify({"status": 200}) - - # Only accept the last user information - last_message = data["messages"][-1] - model = data["model"] - - # write_tutorial - if model == "write_tutorial": - return Response(write_tutorial(last_message), mimetype="text/plain") - else: - return jsonify({"status": 200}) - # return Response(event_stream(), mimetype="text/plain") - - -@app.route('/download/') -def download_file(filename): - return send_from_directory(METAGPT_ROOT, filename, as_attachment=True) - - -if __name__ == "__main__": - """ - curl https://$server_address:$server_port/v1/chat/completions -X POST -d '{ - "model": "gpt-3.5-turbo", - "stream": true, - "messages": [ - { - "role": "user", - "content": "Write a tutorial about MySQL" - } - ] - }' - """ - server_port = 7860 - server_address = socket.gethostbyname(socket.gethostname()) - - app.run(port=server_port, host=server_address) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 53fdd59f7..1b93213f7 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -23,7 +23,6 @@ from metagpt.schema import ( TestingContext, ) from metagpt.utils.project_repo import ProjectRepo -from metagpt.utils.stream_pipe import StreamPipe class Action(SerializationMixin, ContextMixin, BaseModel): @@ -36,7 +35,6 @@ class Action(SerializationMixin, ContextMixin, BaseModel): prefix: str = "" # aask*时会加上prefix,作为system_message desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) - stream_pipe: Optional[StreamPipe] = None @property def repo(self) -> ProjectRepo: @@ -92,8 +90,6 @@ class Action(SerializationMixin, ContextMixin, BaseModel): async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: """Append default prefix""" - if self.stream_pipe and not self.llm.stream_pipe: - self.llm.stream_pipe = self.stream_pipe return await self.llm.aask(prompt, system_msgs) async def _run_action_node(self, *args, **kwargs): diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index 00f66b72f..1aeacbe83 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -62,9 +62,6 @@ class AnthropicLLM(BaseLLM): elif event_type == "content_block_delta": content = event.delta.text log_llm_stream(content) - if self.stream_pipe: - self.stream_pipe.set_message(content) - collected_content.append(content) elif event_type == "message_delta": usage.output_tokens = event.usage.output_tokens # update final output_tokens diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 1b4be24e9..db2757ec3 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -28,7 +28,6 @@ from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.common import log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs -from metagpt.utils.stream_pipe import StreamPipe class BaseLLM(ABC): @@ -43,7 +42,6 @@ class BaseLLM(ABC): cost_manager: Optional[CostManager] = None model: Optional[str] = None # deprecated pricing_plan: Optional[str] = None - stream_pipe: Optional[StreamPipe] = None @abstractmethod def __init__(self, config: LLMConfig): diff --git a/metagpt/provider/dashscope_api.py b/metagpt/provider/dashscope_api.py index 10845ed87..82224e893 100644 --- a/metagpt/provider/dashscope_api.py +++ b/metagpt/provider/dashscope_api.py @@ -221,8 +221,6 @@ class DashScopeLLM(BaseLLM): content = chunk.output.choices[0]["message"]["content"] usage = dict(chunk.usage) # each chunk has usage log_llm_stream(content) - if self.stream_pipe: - self.stream_pipe.set_message(content) collected_content.append(content) log_llm_stream("\n") self._update_costs(usage) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 49a533792..e4b3a3f17 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -149,8 +149,6 @@ class GeminiLLM(BaseLLM): logger.warning(f"messages: {messages}\nerrors: {e}\n{BlockedPromptException(str(chunk))}") raise BlockedPromptException(str(chunk)) log_llm_stream(content) - if self.stream_pipe: - self.stream_pipe.set_message(content) collected_content.append(content) log_llm_stream("\n") diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 450346ab7..2913eb1dd 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -83,8 +83,6 @@ class OllamaLLM(BaseLLM): content = self.get_choice_text(chunk) collected_content.append(content) log_llm_stream(content) - if self.stream_pipe: - self.stream_pipe.set_message(content) else: # stream finished usage = self.get_usage(chunk) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 2cfd86cfb..dbfed72df 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -87,9 +87,6 @@ class OpenAILLM(BaseLLM): chunk.choices[0].finish_reason if chunk.choices and hasattr(chunk.choices[0], "finish_reason") else None ) log_llm_stream(chunk_message) - if self.stream_pipe: - self.stream_pipe.set_message(chunk_message) - collected_messages.append(chunk_message) if finish_reason: if hasattr(chunk, "usage"): diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py index c45a03d9b..3d78c8bfc 100644 --- a/metagpt/provider/qianfan_api.py +++ b/metagpt/provider/qianfan_api.py @@ -124,8 +124,6 @@ class QianFanLLM(BaseLLM): content = chunk.body.get("result", "") usage = chunk.body.get("usage", {}) log_llm_stream(content) - if self.stream_pipe: - self.stream_pipe.set_message(content) collected_content.append(content) log_llm_stream("\n") diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index d42f34fdd..2db441991 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -73,8 +73,6 @@ class ZhiPuAILLM(BaseLLM): content = self.get_choice_delta_text(chunk) collected_content.append(content) log_llm_stream(content) - if self.stream_pipe: - self.stream_pipe.set_message(content) log_llm_stream("\n") diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 599662145..e0f8a7ea6 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -39,8 +39,6 @@ from metagpt.strategy.planner import Planner from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output -from metagpt.utils.stream_pipe import StreamPipe - if TYPE_CHECKING: from metagpt.environment import Environment # noqa: F401 @@ -141,8 +139,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel): role_id: str = "" states: list[str] = [] - stream_pipe: Optional[StreamPipe] = None - # scenarios to set action system_prompt: # 1. `__init__` while using Role(actions=[...]) # 2. add action to role while using `role.set_action(action)` diff --git a/metagpt/roles/tutorial_assistant.py b/metagpt/roles/tutorial_assistant.py index 4da419cc0..6cf3a6469 100644 --- a/metagpt/roles/tutorial_assistant.py +++ b/metagpt/roles/tutorial_assistant.py @@ -10,7 +10,7 @@ from datetime import datetime from typing import Dict from metagpt.actions.write_tutorial import WriteContent, WriteDirectory -from metagpt.const import TUTORIAL_PATH, METAGPT_ROOT +from metagpt.const import TUTORIAL_PATH from metagpt.logs import logger from metagpt.roles.role import Role, RoleReactMode from metagpt.schema import Message @@ -40,7 +40,7 @@ class TutorialAssistant(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - self.set_actions([WriteDirectory(language=self.language, stream_pipe=self.stream_pipe)]) + self.set_actions([WriteDirectory(language=self.language)]) self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value) async def _handle_directory(self, titles: Dict) -> Message: @@ -58,7 +58,7 @@ class TutorialAssistant(Role): self.total_content += f"# {self.main_title}" actions = list() for first_dir in titles.get("directory"): - actions.append(WriteContent(language=self.language, directory=first_dir, stream_pipe=self.stream_pipe)) + actions.append(WriteContent(language=self.language, directory=first_dir)) key = list(first_dir.keys())[0] directory += f"- {key}\n" for second_dir in first_dir[key]: @@ -91,8 +91,4 @@ class TutorialAssistant(Role): root_path = TUTORIAL_PATH / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") await File.write(root_path, f"{self.main_title}.md", self.total_content.encode("utf-8")) msg.content = str(root_path / f"{self.main_title}.md") - - if self.stream_pipe: - self.stream_pipe.set_k_message("file_name", msg.content.replace(str(METAGPT_ROOT), "")) - self.stream_pipe.with_finish() return msg diff --git a/metagpt/utils/stream_pipe.py b/metagpt/utils/stream_pipe.py deleted file mode 100644 index 5fa4556ea..000000000 --- a/metagpt/utils/stream_pipe.py +++ /dev/null @@ -1,70 +0,0 @@ -# -*- coding: utf-8 -*- -# @Time : 2024/3/27 10:00 -# @Author : leiwu30 -# @File : stream_pipe.py -# @Version : None -# @Description : None - -import time -import json -from multiprocessing import Pipe - - -class StreamPipe: - parent_conn, child_conn = Pipe() - - variable: list = {} - finish: bool = False - - format_data = { - "id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur", - "object": "chat.completion.chunk", - "created": 1711361191, - "model": "gpt-3.5-turbo-0125", - "system_fingerprint": "fp_3bc1b5746c", - "choices": [ - { - "index": 0, - "delta": - { - "role": "assistant", - "content": "content" - }, - "logprobs": None, - "finish_reason": None - } - ] - } - - def set_message(self, msg): - self.parent_conn.send(msg) - - def wait(self): - pass - - def get_message(self): - if self.child_conn.poll(timeout=3): - return self.child_conn.recv() - else: - return None - - def set_k_message(self, k, msg): - self.variable[k] = msg - - def get_k_message(self, k): - return self.variable[k] - - def msg2stream(self, msg): - self.format_data['created'] = int(time.time()) - self.format_data['choices'][0]['delta']['content'] = msg - return f"data: {json.dumps(self.format_data, ensure_ascii=False)}\n".encode("utf-8") - - def with_finish(self, timeout: int = 3): - """ - Args: - timeout: /s - """ - # Pipe is not empty waiting for pipe condition - # while self.child_conn.poll(timeout=timeout): - # time.sleep(0.5) - self.finish = True From db2ea905b4d838dda2f456dc06cb69d70059c220 Mon Sep 17 00:00:00 2001 From: leiwu30 <2495165664@qq.com> Date: Thu, 28 Mar 2024 10:27:15 +0800 Subject: [PATCH 04/12] fix: Only add demo and a tool class to complete the function --- examples/flask_web_api.py | 119 +++++++++++++++++++++++++++++++++++ metagpt/utils/stream_pipe.py | 57 +++++++++++++++++ 2 files changed, 176 insertions(+) create mode 100644 examples/flask_web_api.py create mode 100644 metagpt/utils/stream_pipe.py diff --git a/examples/flask_web_api.py b/examples/flask_web_api.py new file mode 100644 index 000000000..b8baa4359 --- /dev/null +++ b/examples/flask_web_api.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/3/27 9:44 +@Author : leiwu30 +@File : flask_web_api.py +@Description : Stream log information and communicate over the network via web api. +""" +import os +import json +import socket +import asyncio +import threading + +from metagpt.team import Team +from metagpt.const import METAGPT_ROOT +from metagpt.logs import set_llm_stream_logfunc +from metagpt.utils.stream_pipe import StreamPipe +from metagpt.roles.tutorial_assistant import TutorialAssistant + +from contextvars import ContextVar +from flask import Flask, Response +from flask import request, jsonify, send_from_directory + +app = Flask(__name__) + + +def stream_pipe_log(content): + print(content, end="") + stream_pipe = stream_pipe_var.get(None) + if stream_pipe: + stream_pipe.set_message(content) + + +def write_tutorial(message): + async def main(idea, stream_pipe): + stream_pipe_var.set(stream_pipe) + role = TutorialAssistant(stream_pipe=stream_pipe) + await role.run(idea) + + def thread_run(idea: str, stream_pipe: StreamPipe = None): + """ + Convert asynchronous function to thread function + """ + asyncio.run(main(idea, stream_pipe)) + + stream_pipe = StreamPipe() + thread = threading.Thread(target=thread_run, args=(message["content"], stream_pipe,)) + thread.start() + + while thread.is_alive(): + msg = stream_pipe.get_message() + yield stream_pipe.msg2stream(msg) + + # 文件位置 + # md_file = stream_pipe.get_k_message("file_name") + # yield stream_pipe.msg2stream( + # f"\n\n[{os.path.basename(md_file)}](http://{server_address}:{server_port}/download/{md_file})") + + +@app.route('/v1/chat/completions', methods=['POST']) +def completions(): + """ + data: { + "model": "write_tutorial", + "stream": true, + "messages": [ + { + "role": "user", + "content": "Write a tutorial about MySQL" + } + ] + } + """ + + data = json.loads(request.data) + print(json.dumps(data, indent=4, ensure_ascii=False)) + + # Non-streaming interfaces are not supported yet + stream_type = True if "stream" in data.keys() and data["stream"] else False + if not stream_type: + return jsonify({"status": 200}) + + # Only accept the last user information + last_message = data["messages"][-1] + model = data["model"] + + # write_tutorial + if model == "write_tutorial": + return Response(write_tutorial(last_message), mimetype="text/plain") + else: + return jsonify({"status": 200}) + # return Response(event_stream(), mimetype="text/plain") + + +@app.route('/download/') +def download_file(filename): + return send_from_directory(METAGPT_ROOT, filename, as_attachment=True) + + +if __name__ == "__main__": + """ + curl https://$server_address:$server_port/v1/chat/completions -X POST -d '{ + "model": "write_tutorial", + "stream": true, + "messages": [ + { + "role": "user", + "content": "Write a tutorial about MySQL" + } + ] + }' + """ + server_port = 7860 + server_address = socket.gethostbyname(socket.gethostname()) + + set_llm_stream_logfunc(stream_pipe_log) + stream_pipe_var: ContextVar[StreamPipe] = ContextVar("stream_pipe") + app.run(port=server_port, host=server_address) diff --git a/metagpt/utils/stream_pipe.py b/metagpt/utils/stream_pipe.py new file mode 100644 index 000000000..6e319411c --- /dev/null +++ b/metagpt/utils/stream_pipe.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/3/27 10:00 +# @Author : leiwu30 +# @File : stream_pipe.py +# @Version : None +# @Description : None + +import time +import json +from multiprocessing import Pipe + + +class StreamPipe: + parent_conn, child_conn = Pipe() + + variable: list = {} + finish: bool = False + + format_data = { + "id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur", + "object": "chat.completion.chunk", + "created": 1711361191, + "model": "gpt-3.5-turbo-0125", + "system_fingerprint": "fp_3bc1b5746c", + "choices": [ + { + "index": 0, + "delta": + { + "role": "assistant", + "content": "content" + }, + "logprobs": None, + "finish_reason": None + } + ] + } + + def set_message(self, msg): + self.parent_conn.send(msg) + + def get_message(self, timeout: int = 3): + if self.child_conn.poll(timeout): + return self.child_conn.recv() + else: + return None + + def set_k_message(self, k, msg): + self.variable[k] = msg + + def get_k_message(self, k): + return self.variable[k] + + def msg2stream(self, msg): + self.format_data['created'] = int(time.time()) + self.format_data['choices'][0]['delta']['content'] = msg + return f"data: {json.dumps(self.format_data, ensure_ascii=False)}\n".encode("utf-8") From a2484f64207c67bac82b0de26afc595fceb0e891 Mon Sep 17 00:00:00 2001 From: leiwu30 <2495165664@qq.com> Date: Sun, 7 Apr 2024 10:12:26 +0800 Subject: [PATCH 05/12] fix: Fix bug for merging --- ...sk_web_api.py => stream_output_via_api.py} | 34 ++++++++----------- metagpt/utils/stream_pipe.py | 8 ----- 2 files changed, 14 insertions(+), 28 deletions(-) rename examples/{flask_web_api.py => stream_output_via_api.py} (78%) diff --git a/examples/flask_web_api.py b/examples/stream_output_via_api.py similarity index 78% rename from examples/flask_web_api.py rename to examples/stream_output_via_api.py index b8baa4359..94709b5bf 100644 --- a/examples/flask_web_api.py +++ b/examples/stream_output_via_api.py @@ -3,25 +3,24 @@ """ @Time : 2024/3/27 9:44 @Author : leiwu30 -@File : flask_web_api.py +@File : stream_output_via_api.py @Description : Stream log information and communicate over the network via web api. """ -import os import json import socket import asyncio import threading -from metagpt.team import Team -from metagpt.const import METAGPT_ROOT -from metagpt.logs import set_llm_stream_logfunc -from metagpt.utils.stream_pipe import StreamPipe -from metagpt.roles.tutorial_assistant import TutorialAssistant - from contextvars import ContextVar from flask import Flask, Response from flask import request, jsonify, send_from_directory +from metagpt.logs import logger +from metagpt.const import TUTORIAL_PATH +from metagpt.logs import set_llm_stream_logfunc +from metagpt.utils.stream_pipe import StreamPipe +from metagpt.roles.tutorial_assistant import TutorialAssistant + app = Flask(__name__) @@ -35,7 +34,7 @@ def stream_pipe_log(content): def write_tutorial(message): async def main(idea, stream_pipe): stream_pipe_var.set(stream_pipe) - role = TutorialAssistant(stream_pipe=stream_pipe) + role = TutorialAssistant() await role.run(idea) def thread_run(idea: str, stream_pipe: StreamPipe = None): @@ -52,11 +51,6 @@ def write_tutorial(message): msg = stream_pipe.get_message() yield stream_pipe.msg2stream(msg) - # 文件位置 - # md_file = stream_pipe.get_k_message("file_name") - # yield stream_pipe.msg2stream( - # f"\n\n[{os.path.basename(md_file)}](http://{server_address}:{server_port}/download/{md_file})") - @app.route('/v1/chat/completions', methods=['POST']) def completions(): @@ -74,14 +68,15 @@ def completions(): """ data = json.loads(request.data) - print(json.dumps(data, indent=4, ensure_ascii=False)) + logger.info(json.dumps(data, indent=4, ensure_ascii=False)) # Non-streaming interfaces are not supported yet - stream_type = True if "stream" in data.keys() and data["stream"] else False + stream_type = True if data.get("stream") else False if not stream_type: - return jsonify({"status": 200}) + return jsonify({"status": 400, "msg": "Non-streaming requests are not supported, please use `stream=True`."}) # Only accept the last user information + # openai['model'] ~ MetaGPT['agent'] last_message = data["messages"][-1] model = data["model"] @@ -89,13 +84,12 @@ def completions(): if model == "write_tutorial": return Response(write_tutorial(last_message), mimetype="text/plain") else: - return jsonify({"status": 200}) - # return Response(event_stream(), mimetype="text/plain") + return jsonify({"status": 400, "msg": "No suitable agent found."}) @app.route('/download/') def download_file(filename): - return send_from_directory(METAGPT_ROOT, filename, as_attachment=True) + return send_from_directory(TUTORIAL_PATH, filename, as_attachment=True) if __name__ == "__main__": diff --git a/metagpt/utils/stream_pipe.py b/metagpt/utils/stream_pipe.py index 6e319411c..d3d3cff32 100644 --- a/metagpt/utils/stream_pipe.py +++ b/metagpt/utils/stream_pipe.py @@ -12,8 +12,6 @@ from multiprocessing import Pipe class StreamPipe: parent_conn, child_conn = Pipe() - - variable: list = {} finish: bool = False format_data = { @@ -45,12 +43,6 @@ class StreamPipe: else: return None - def set_k_message(self, k, msg): - self.variable[k] = msg - - def get_k_message(self, k): - return self.variable[k] - def msg2stream(self, msg): self.format_data['created'] = int(time.time()) self.format_data['choices'][0]['delta']['content'] = msg From f646715fde4563924585f7ba917520ac02bbecd2 Mon Sep 17 00:00:00 2001 From: leiwu30 <2495165664@qq.com> Date: Mon, 8 Apr 2024 10:23:53 +0800 Subject: [PATCH 06/12] fix: Delete flask dependency --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 257263bba..f834b4c44 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,6 @@ extras_require = { "llama-index-vector-stores-faiss==0.1.1", "chromadb==0.4.23", ], - "web-api": ["flask==3.0.2"], } extras_require["test"] = [ From 1ff50e85c2909b38946a9731d52ca86d8f4646ad Mon Sep 17 00:00:00 2001 From: XueFeng <1158231926@qq.com> Date: Mon, 8 Apr 2024 16:54:00 +0800 Subject: [PATCH 07/12] Update `zhipuai_api.py` for custom `max_tokens` and `temperature` config. --- metagpt/provider/zhipuai_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 2db441991..a45081fcf 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -43,7 +43,9 @@ class ZhiPuAILLM(BaseLLM): self.llm = ZhiPuModelAPI(api_key=self.api_key) def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: - kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3} + max_tokens = self.config.max_token if self.config.max_token > 0 else 1024 + temperature = self.config.temperature if self.config.temperature > 0.0 else 0.3 + kwargs = {"model": self.model, "max_tokens": max_tokens, "messages": messages, "stream": stream, "temperature": temperature} return kwargs def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: From 8a27b6dd4a15d96cb572d4bec71e5f62933e392d Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 10 Apr 2024 11:42:54 +0800 Subject: [PATCH 08/12] use gpt-4-turbo as default --- config/config2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config2.yaml b/config/config2.yaml index 8e5825b57..ba071e804 100644 --- a/config/config2.yaml +++ b/config/config2.yaml @@ -2,6 +2,6 @@ # Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py llm: api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options - model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview + model: "gpt-4-turbo" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview base_url: "https://api.openai.com/v1" # or forward url / other llm url api_key: "YOUR_API_KEY" \ No newline at end of file From c494844c02cd58c9492c00095821a4a818d99eb2 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 10 Apr 2024 14:12:45 +0800 Subject: [PATCH 09/12] make embedding configurable and add gpt-4-turbo. --- config/config2.example.yaml | 10 ++ examples/rag_pipeline.py | 5 +- metagpt/config2.py | 4 + metagpt/configs/embedding_config.py | 32 ++++++ metagpt/rag/factories/base.py | 3 + metagpt/rag/factories/embedding.py | 88 ++++++++++++++--- metagpt/rag/factories/llm.py | 7 +- metagpt/rag/schema.py | 20 +++- metagpt/utils/async_helper.py | 15 +++ metagpt/utils/token_counter.py | 3 + setup.py | 3 + tests/metagpt/rag/factories/test_embedding.py | 97 +++++++++++++++---- 12 files changed, 250 insertions(+), 37 deletions(-) create mode 100644 metagpt/configs/embedding_config.py diff --git a/config/config2.example.yaml b/config/config2.example.yaml index c5454ec32..7f4758acb 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -13,6 +13,16 @@ llm: # - gpt-4 8k: "gpt-4" # See for more: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ +# RAG Embedding. +# For backward compatibility, if the embedding is not set and the llm's api_type is either openai or azure, the llm's config will be used. +embedding: + api_type: "" # openai / azure / gemini / ollama etc. Check EmbeddingType for more options. + base_url: "YOU_BASE_URL" + api_key: "YOU_API_KEY" + model: "YOU_MODEL" + api_version: "YOU_API_VERSION" + embed_batch_size: 100 + repair_llm_output: true # when the output is not a valid json, try to repair it proxy: "YOUR_PROXY" # for tools like requests, playwright, selenium, etc. diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index b5111b75c..1687d556b 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -8,7 +8,6 @@ from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine from metagpt.rag.schema import ( - BM25RetrieverConfig, ChromaIndexConfig, ChromaRetrieverConfig, ElasticsearchIndexConfig, @@ -51,7 +50,7 @@ class RAGExample: if not self._engine: self._engine = SimpleEngine.from_docs( input_files=[DOC_PATH], - retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], + retriever_configs=[FAISSRetrieverConfig()], ranker_configs=[LLMRankerConfig()], ) return self._engine @@ -61,7 +60,7 @@ class RAGExample: self._engine = value async def run_pipeline(self, question=QUESTION, print_title=True): - """This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like: + """This example run rag pipeline, use faiss retriever and llm ranker, will print something like: Retrieve Result: 0. Productivi..., 10.0 diff --git a/metagpt/config2.py b/metagpt/config2.py index ed68b4db2..58a99c920 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Literal, Optional from pydantic import BaseModel, model_validator from metagpt.configs.browser_config import BrowserConfig +from metagpt.configs.embedding_config import EmbeddingConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig from metagpt.configs.redis_config import RedisConfig @@ -47,6 +48,9 @@ class Config(CLIParams, YamlModel): # Key Parameters llm: LLMConfig + # RAG Embedding + embedding: EmbeddingConfig = EmbeddingConfig() + # Global Proxy. Will be used if llm.proxy is not set proxy: str = "" diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py new file mode 100644 index 000000000..545c2a9cc --- /dev/null +++ b/metagpt/configs/embedding_config.py @@ -0,0 +1,32 @@ +from enum import Enum +from typing import Optional + +from pydantic import field_validator + +from metagpt.utils.yaml_model import YamlModel + + +class EmbeddingType(Enum): + OPENAI = "openai" + AZURE = "azure" + GEMINI = "gemini" + OLLAMA = "ollama" + + +class EmbeddingConfig(YamlModel): + """Config for Embedding.""" + + api_type: Optional[EmbeddingType] = None + api_key: Optional[str] = None + base_url: Optional[str] = None + api_version: Optional[str] = None + + model: Optional[str] = None + embed_batch_size: Optional[int] = None + + @field_validator("api_type", mode="before") + @classmethod + def check_api_type(cls, v): + if v == "": + return None + return v diff --git a/metagpt/rag/factories/base.py b/metagpt/rag/factories/base.py index fbdfbf1a8..fcfec03ec 100644 --- a/metagpt/rag/factories/base.py +++ b/metagpt/rag/factories/base.py @@ -26,6 +26,9 @@ class GenericFactory: if creator: return creator(**kwargs) + self._raise_for_key(key) + + def _raise_for_key(self, key: Any): raise ValueError(f"Creator not registered for key: {key}") diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index 4247db256..3613fd228 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -1,37 +1,103 @@ """RAG Embedding Factory.""" +from __future__ import annotations + +from typing import Any from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding +from llama_index.embeddings.gemini import GeminiEmbedding +from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.embeddings.openai import OpenAIEmbedding from metagpt.config2 import config +from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.base import GenericFactory class RAGEmbeddingFactory(GenericFactory): - """Create LlamaIndex Embedding with MetaGPT's config.""" + """Create LlamaIndex Embedding with MetaGPT's embedding config.""" def __init__(self): creators = { + EmbeddingType.OPENAI: self._create_openai, + EmbeddingType.AZURE: self._create_azure, + EmbeddingType.GEMINI: self._create_gemini, + EmbeddingType.OLLAMA: self._create_ollama, + # For backward compatibility LLMType.OPENAI: self._create_openai, LLMType.AZURE: self._create_azure, } super().__init__(creators) - def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding: - """Key is LLMType, default use config.llm.api_type.""" - return super().get_instance(key or config.llm.api_type) + def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding: + """Key is EmbeddingType.""" + return super().get_instance(key or self._resolve_embedding_type()) - def _create_openai(self): - return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url) + def _resolve_embedding_type(self) -> EmbeddingType | LLMType: + """Resolves the embedding type. - def _create_azure(self): - return AzureOpenAIEmbedding( - azure_endpoint=config.llm.base_url, - api_key=config.llm.api_key, - api_version=config.llm.api_version, + If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE. + Raise TypeError if embedding type not found. + """ + if config.embedding.api_type: + return config.embedding.api_type + + if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]: + return config.llm.api_type + + raise TypeError("To use RAG, please set your embedding in config2.yaml.") + + def _create_openai(self) -> OpenAIEmbedding: + params = dict( + api_key=config.embedding.api_key or config.llm.api_key, + api_base=config.embedding.base_url or config.llm.base_url, ) + self._try_set_model_and_batch_size(params) + + return OpenAIEmbedding(**params) + + def _create_azure(self) -> AzureOpenAIEmbedding: + params = dict( + api_key=config.embedding.api_key or config.llm.api_key, + azure_endpoint=config.embedding.base_url or config.llm.base_url, + api_version=config.embedding.api_version or config.llm.api_version, + ) + + self._try_set_model_and_batch_size(params) + + return AzureOpenAIEmbedding(**params) + + def _create_gemini(self) -> GeminiEmbedding: + params = dict( + api_key=config.embedding.api_key, + api_base=config.embedding.base_url, + ) + + self._try_set_model_and_batch_size(params) + + return GeminiEmbedding(**params) + + def _create_ollama(self) -> OllamaEmbedding: + params = dict( + base_url=config.embedding.base_url, + ) + + self._try_set_model_and_batch_size(params) + + return OllamaEmbedding(**params) + + def _try_set_model_and_batch_size(self, params: dict): + """Set the model_name and embed_batch_size only when they are specified.""" + if config.embedding.model: + params["model_name"] = config.embedding.model + + if config.embedding.embed_batch_size: + params["embed_batch_size"] = config.embedding.embed_batch_size + + def _raise_for_key(self, key: Any): + raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}") + get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 17c499b76..9fd19cab5 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -1,5 +1,5 @@ """RAG LLM.""" - +import asyncio from typing import Any from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW @@ -15,7 +15,7 @@ from pydantic import Field from metagpt.config2 import config from metagpt.llm import LLM from metagpt.provider.base_llm import BaseLLM -from metagpt.utils.async_helper import run_coroutine_in_new_loop +from metagpt.utils.async_helper import NestAsyncio from metagpt.utils.token_counter import TOKEN_MAX @@ -39,7 +39,8 @@ class RAGLLM(CustomLLM): @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs)) + NestAsyncio.apply_once() + return asyncio.get_event_loop().run_until_complete(self.acomplete(prompt, **kwargs)) @llm_completion_callback() async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse: diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 183f6e0c7..582297b93 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,14 +1,16 @@ """RAG schemas.""" from pathlib import Path -from typing import Any, Literal, Union +from typing import Any, ClassVar, Literal, Union from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import VectorStoreQueryMode -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from metagpt.config2 import config +from metagpt.configs.embedding_config import EmbeddingType from metagpt.rag.interface import RAGObject @@ -31,7 +33,19 @@ class IndexRetrieverConfig(BaseRetrieverConfig): class FAISSRetrieverConfig(IndexRetrieverConfig): """Config for FAISS-based retrievers.""" - dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.") + dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.") + + _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { + EmbeddingType.GEMINI: 768, + EmbeddingType.OLLAMA: 4096, + } + + @model_validator(mode="after") + def check_dimensions(self): + if self.dimensions == 0: + self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536) + + return self class BM25RetrieverConfig(IndexRetrieverConfig): diff --git a/metagpt/utils/async_helper.py b/metagpt/utils/async_helper.py index ee440ef44..cecb20c5d 100644 --- a/metagpt/utils/async_helper.py +++ b/metagpt/utils/async_helper.py @@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any: new_loop.call_soon_threadsafe(new_loop.stop) t.join() new_loop.close() + + +class NestAsyncio: + """Make asyncio event loop reentrant.""" + + is_applied = False + + @classmethod + def apply_once(cls): + """Ensures `nest_asyncio.apply()` is called only once.""" + if not cls.is_applied: + import nest_asyncio + + nest_asyncio.apply() + cls.is_applied = True diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 0ba2daa89..0ca22cf35 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -28,6 +28,7 @@ TOKEN_COSTS = { "gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12}, "gpt-4-0613": {"prompt": 0.06, "completion": 0.12}, "gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03}, + "gpt-4-turbo": {"prompt": 0.01, "completion": 0.03}, "gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03}, "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator @@ -147,6 +148,7 @@ FIREWORKS_GRADE_TOKEN_COSTS = { TOKEN_MAX = { "gpt-4-0125-preview": 128000, "gpt-4-turbo-preview": 128000, + "gpt-4-turbo": 128000, "gpt-4-1106-preview": 128000, "gpt-4-vision-preview": 128000, "gpt-4-1106-vision-preview": 128000, @@ -202,6 +204,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"): "gpt-4-32k-0314", "gpt-4-0613", "gpt-4-32k-0613", + "gpt-4-turbo", "gpt-4-turbo-preview", "gpt-4-0125-preview", "gpt-4-1106-preview", diff --git a/setup.py b/setup.py index c54ace90a..e43bf3ed0 100644 --- a/setup.py +++ b/setup.py @@ -32,12 +32,15 @@ extras_require = { "llama-index-core==0.10.15", "llama-index-embeddings-azure-openai==0.1.6", "llama-index-embeddings-openai==0.1.5", + "llama-index-embeddings-gemini==0.1.6", + "llama-index-embeddings-ollama==0.1.2", "llama-index-llms-azure-openai==0.1.4", "llama-index-readers-file==0.1.4", "llama-index-retrievers-bm25==0.1.3", "llama-index-vector-stores-faiss==0.1.1", "llama-index-vector-stores-elasticsearch==0.1.6", "llama-index-vector-stores-chroma==0.1.6", + "docx2txt==0.8", ], "android_assistant": ["pyshine==0.0.9", "opencv-python==4.6.0.66"], } diff --git a/tests/metagpt/rag/factories/test_embedding.py b/tests/metagpt/rag/factories/test_embedding.py index 1ded6b4a8..1a9e9b2c9 100644 --- a/tests/metagpt/rag/factories/test_embedding.py +++ b/tests/metagpt/rag/factories/test_embedding.py @@ -1,5 +1,6 @@ import pytest +from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.embedding import RAGEmbeddingFactory @@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory: self.embedding_factory = RAGEmbeddingFactory() @pytest.fixture - def mock_openai_embedding(self, mocker): + def mock_config(self, mocker): + return mocker.patch("metagpt.rag.factories.embedding.config") + + @staticmethod + def mock_openai_embedding(mocker): return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding") - @pytest.fixture - def mock_azure_embedding(self, mocker): + @staticmethod + def mock_azure_embedding(mocker): return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding") - def test_get_rag_embedding_openai(self, mock_openai_embedding): - # Exec - self.embedding_factory.get_rag_embedding(LLMType.OPENAI) + @staticmethod + def mock_gemini_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding") - # Assert - mock_openai_embedding.assert_called_once() + @staticmethod + def mock_ollama_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding") - def test_get_rag_embedding_azure(self, mock_azure_embedding): - # Exec - self.embedding_factory.get_rag_embedding(LLMType.AZURE) - - # Assert - mock_azure_embedding.assert_called_once() - - def test_get_rag_embedding_default(self, mocker, mock_openai_embedding): + @pytest.mark.parametrize( + ("mock_func", "embedding_type"), + [ + (mock_openai_embedding, LLMType.OPENAI), + (mock_azure_embedding, LLMType.AZURE), + (mock_openai_embedding, EmbeddingType.OPENAI), + (mock_azure_embedding, EmbeddingType.AZURE), + (mock_gemini_embedding, EmbeddingType.GEMINI), + (mock_ollama_embedding, EmbeddingType.OLLAMA), + ], + ) + def test_get_rag_embedding(self, mock_func, embedding_type, mocker): # Mock - mock_config = mocker.patch("metagpt.rag.factories.embedding.config") + mock = mock_func(mocker) + + # Exec + self.embedding_factory.get_rag_embedding(embedding_type) + + # Assert + mock.assert_called_once() + + def test_get_rag_embedding_default(self, mocker, mock_config): + # Mock + mock_openai_embedding = self.mock_openai_embedding(mocker) + + mock_config.embedding.api_type = None mock_config.llm.api_type = LLMType.OPENAI # Exec @@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory: # Assert mock_openai_embedding.assert_called_once() + + @pytest.mark.parametrize( + "model, embed_batch_size, expected_params", + [("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})], + ) + def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params): + # Mock + mock_config.embedding.model = model + mock_config.embedding.embed_batch_size = embed_batch_size + + # Setup + test_params = {} + + # Exec + self.embedding_factory._try_set_model_and_batch_size(test_params) + + # Assert + assert test_params == expected_params + + def test_resolve_embedding_type(self, mock_config): + # Mock + mock_config.embedding.api_type = EmbeddingType.OPENAI + + # Exec + embedding_type = self.embedding_factory._resolve_embedding_type() + + # Assert + assert embedding_type == EmbeddingType.OPENAI + + def test_resolve_embedding_type_exception(self, mock_config): + # Mock + mock_config.embedding.api_type = None + mock_config.llm.api_type = LLMType.GEMINI + + # Assert + with pytest.raises(TypeError): + self.embedding_factory._resolve_embedding_type() + + def test_raise_for_key(self): + with pytest.raises(ValueError): + self.embedding_factory._raise_for_key("key") From 60d34f4a50f5e39afb898f960fcaceb1eb756021 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 10 Apr 2024 14:22:52 +0800 Subject: [PATCH 10/12] use gpt-4-turbo as default --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 44fcfab18..8f5cc5393 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ # Check https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html ```yaml llm: api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options - model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview + model: "gpt-4-turbo" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview base_url: "https://api.openai.com/v1" # or forward url / other llm url api_key: "YOUR_API_KEY" ``` From caa13001634abaff2ca26c6bbf4b573a73035123 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 10 Apr 2024 14:52:48 +0800 Subject: [PATCH 11/12] make embedding configurable and add gpt-4-turbo. --- config/config2.example.yaml | 8 ++++---- metagpt/configs/embedding_config.py | 20 +++++++++++++++++++- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 7f4758acb..7cfd70347 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -17,10 +17,10 @@ llm: # For backward compatibility, if the embedding is not set and the llm's api_type is either openai or azure, the llm's config will be used. embedding: api_type: "" # openai / azure / gemini / ollama etc. Check EmbeddingType for more options. - base_url: "YOU_BASE_URL" - api_key: "YOU_API_KEY" - model: "YOU_MODEL" - api_version: "YOU_API_VERSION" + base_url: "" + api_key: "" + model: "" + api_version: "" embed_batch_size: 100 repair_llm_output: true # when the output is not a valid json, try to repair it diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py index 545c2a9cc..20de47999 100644 --- a/metagpt/configs/embedding_config.py +++ b/metagpt/configs/embedding_config.py @@ -14,7 +14,25 @@ class EmbeddingType(Enum): class EmbeddingConfig(YamlModel): - """Config for Embedding.""" + """Config for Embedding. + + Examples: + --------- + api_type: "openai" + api_key: "YOU_API_KEY" + + api_type: "azure" + api_key: "YOU_API_KEY" + base_url: "YOU_BASE_URL" + api_version: "YOU_API_VERSION" + + api_type: "gemini" + api_key: "YOU_API_KEY" + + api_type: "ollama" + base_url: "YOU_BASE_URL" + model: "YOU_MODEL" + """ api_type: Optional[EmbeddingType] = None api_key: Optional[str] = None From 549cb2d90b31f63690e8967f12fb7ed845058618 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 10 Apr 2024 15:54:08 +0800 Subject: [PATCH 12/12] format code --- examples/stream_output_via_api.py | 26 +++++++++++++++----------- metagpt/provider/zhipuai_api.py | 8 +++++++- metagpt/utils/stream_pipe.py | 19 +++++-------------- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/examples/stream_output_via_api.py b/examples/stream_output_via_api.py index 94709b5bf..5961f3a08 100644 --- a/examples/stream_output_via_api.py +++ b/examples/stream_output_via_api.py @@ -6,20 +6,18 @@ @File : stream_output_via_api.py @Description : Stream log information and communicate over the network via web api. """ +import asyncio import json import socket -import asyncio import threading - from contextvars import ContextVar -from flask import Flask, Response -from flask import request, jsonify, send_from_directory -from metagpt.logs import logger +from flask import Flask, Response, jsonify, request, send_from_directory + from metagpt.const import TUTORIAL_PATH -from metagpt.logs import set_llm_stream_logfunc -from metagpt.utils.stream_pipe import StreamPipe +from metagpt.logs import logger, set_llm_stream_logfunc from metagpt.roles.tutorial_assistant import TutorialAssistant +from metagpt.utils.stream_pipe import StreamPipe app = Flask(__name__) @@ -39,12 +37,18 @@ def write_tutorial(message): def thread_run(idea: str, stream_pipe: StreamPipe = None): """ - Convert asynchronous function to thread function + Convert asynchronous function to thread function """ asyncio.run(main(idea, stream_pipe)) stream_pipe = StreamPipe() - thread = threading.Thread(target=thread_run, args=(message["content"], stream_pipe,)) + thread = threading.Thread( + target=thread_run, + args=( + message["content"], + stream_pipe, + ), + ) thread.start() while thread.is_alive(): @@ -52,7 +56,7 @@ def write_tutorial(message): yield stream_pipe.msg2stream(msg) -@app.route('/v1/chat/completions', methods=['POST']) +@app.route("/v1/chat/completions", methods=["POST"]) def completions(): """ data: { @@ -87,7 +91,7 @@ def completions(): return jsonify({"status": 400, "msg": "No suitable agent found."}) -@app.route('/download/') +@app.route("/download/") def download_file(filename): return send_from_directory(TUTORIAL_PATH, filename, as_attachment=True) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index a45081fcf..acac44aaf 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -45,7 +45,13 @@ class ZhiPuAILLM(BaseLLM): def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: max_tokens = self.config.max_token if self.config.max_token > 0 else 1024 temperature = self.config.temperature if self.config.temperature > 0.0 else 0.3 - kwargs = {"model": self.model, "max_tokens": max_tokens, "messages": messages, "stream": stream, "temperature": temperature} + kwargs = { + "model": self.model, + "max_tokens": max_tokens, + "messages": messages, + "stream": stream, + "temperature": temperature, + } return kwargs def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: diff --git a/metagpt/utils/stream_pipe.py b/metagpt/utils/stream_pipe.py index d3d3cff32..4c4485158 100644 --- a/metagpt/utils/stream_pipe.py +++ b/metagpt/utils/stream_pipe.py @@ -5,8 +5,8 @@ # @Version : None # @Description : None -import time import json +import time from multiprocessing import Pipe @@ -21,17 +21,8 @@ class StreamPipe: "model": "gpt-3.5-turbo-0125", "system_fingerprint": "fp_3bc1b5746c", "choices": [ - { - "index": 0, - "delta": - { - "role": "assistant", - "content": "content" - }, - "logprobs": None, - "finish_reason": None - } - ] + {"index": 0, "delta": {"role": "assistant", "content": "content"}, "logprobs": None, "finish_reason": None} + ], } def set_message(self, msg): @@ -44,6 +35,6 @@ class StreamPipe: return None def msg2stream(self, msg): - self.format_data['created'] = int(time.time()) - self.format_data['choices'][0]['delta']['content'] = msg + self.format_data["created"] = int(time.time()) + self.format_data["choices"][0]["delta"]["content"] = msg return f"data: {json.dumps(self.format_data, ensure_ascii=False)}\n".encode("utf-8")