diff --git a/examples/flask_web_api.py b/examples/flask_web_api.py new file mode 100644 index 000000000..e87455ed1 --- /dev/null +++ b/examples/flask_web_api.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/3/27 9:44 +@Author : leiwu30 +@File : flask_web_api.py +@Description : Stream log information and communicate over the network via web api. +""" +import os +import json +import socket +import asyncio +import threading + +from metagpt.utils.stream_pipe import StreamPipe +from metagpt.roles.tutorial_assistant import TutorialAssistant +from metagpt.const import METAGPT_ROOT + +from flask import Flask, Response +from flask import request, jsonify, send_from_directory + +app = Flask(__name__) + + +def write_tutorial(message): + async def main(idea, stream_pipe): + role = TutorialAssistant(stream_pipe=stream_pipe) + await role.run(idea) + + def thread_run(idea: str, stream_pipe: StreamPipe = None): + """ + Convert asynchronous function to thread function + """ + asyncio.run(main(idea, stream_pipe)) + + stream_pipe = StreamPipe() + thread = threading.Thread(target=thread_run, args=(message["content"], stream_pipe,)) + thread.start() + + while not stream_pipe.finish: + stream_pipe.wait() + msg = stream_pipe.get_message() + yield stream_pipe.msg2stream(msg) + + # 文件位置 + md_file = stream_pipe.get_k_message("file_name") + + yield stream_pipe.msg2stream( + f"\n\n[{os.path.basename(md_file)}](http://{server_address}:{server_port}/download/{md_file})") + + +@app.route('/v1/chat/completions', methods=['POST']) +def completions(): + """ + data: { + "model": "write_tutorial", + "stream": true, + "messages": [ + { + "role": "user", + "content": "Write a tutorial about MySQL" + } + ] + } + """ + + data = json.loads(request.data) + print(json.dumps(data, indent=4, ensure_ascii=False)) + + # Non-streaming interfaces are not supported yet + stream_type = True if "stream" in data.keys() and data["stream"] else False + if not stream_type: + return jsonify({"status": 200}) + + # Only accept the last user information + last_message = data["messages"][-1] + model = data["model"] + + # write_tutorial + if model == "write_tutorial": + return Response(write_tutorial(last_message), mimetype="text/plain") + else: + return jsonify({"status": 200}) + # return Response(event_stream(), mimetype="text/plain") + + +@app.route('/download/') +def download_file(filename): + return send_from_directory(METAGPT_ROOT, filename, as_attachment=True) + + +if __name__ == "__main__": + """ + curl https://$server_address:$server_port/v1/chat/completions -X POST -d '{ + "model": "gpt-3.5-turbo", + "stream": true, + "messages": [ + { + "role": "user", + "content": "Write a tutorial about MySQL" + } + ] + }' + """ + server_port = 7860 + server_address = socket.gethostbyname(socket.gethostname()) + + app.run(port=server_port, host=server_address) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 1b93213f7..53fdd59f7 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -23,6 +23,7 @@ from metagpt.schema import ( TestingContext, ) from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.stream_pipe import StreamPipe class Action(SerializationMixin, ContextMixin, BaseModel): @@ -35,6 +36,7 @@ class Action(SerializationMixin, ContextMixin, BaseModel): prefix: str = "" # aask*时会加上prefix,作为system_message desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) + stream_pipe: Optional[StreamPipe] = None @property def repo(self) -> ProjectRepo: @@ -90,6 +92,8 @@ class Action(SerializationMixin, ContextMixin, BaseModel): async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: """Append default prefix""" + if self.stream_pipe and not self.llm.stream_pipe: + self.llm.stream_pipe = self.stream_pipe return await self.llm.aask(prompt, system_msgs) async def _run_action_node(self, *args, **kwargs): diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index 1aeacbe83..00f66b72f 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -62,6 +62,9 @@ class AnthropicLLM(BaseLLM): elif event_type == "content_block_delta": content = event.delta.text log_llm_stream(content) + if self.stream_pipe: + self.stream_pipe.set_message(content) + collected_content.append(content) elif event_type == "message_delta": usage.output_tokens = event.usage.output_tokens # update final output_tokens diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index db2757ec3..1b4be24e9 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -28,6 +28,7 @@ from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.common import log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.stream_pipe import StreamPipe class BaseLLM(ABC): @@ -42,6 +43,7 @@ class BaseLLM(ABC): cost_manager: Optional[CostManager] = None model: Optional[str] = None # deprecated pricing_plan: Optional[str] = None + stream_pipe: Optional[StreamPipe] = None @abstractmethod def __init__(self, config: LLMConfig): diff --git a/metagpt/provider/dashscope_api.py b/metagpt/provider/dashscope_api.py index 82224e893..10845ed87 100644 --- a/metagpt/provider/dashscope_api.py +++ b/metagpt/provider/dashscope_api.py @@ -221,6 +221,8 @@ class DashScopeLLM(BaseLLM): content = chunk.output.choices[0]["message"]["content"] usage = dict(chunk.usage) # each chunk has usage log_llm_stream(content) + if self.stream_pipe: + self.stream_pipe.set_message(content) collected_content.append(content) log_llm_stream("\n") self._update_costs(usage) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index e4b3a3f17..49a533792 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -149,6 +149,8 @@ class GeminiLLM(BaseLLM): logger.warning(f"messages: {messages}\nerrors: {e}\n{BlockedPromptException(str(chunk))}") raise BlockedPromptException(str(chunk)) log_llm_stream(content) + if self.stream_pipe: + self.stream_pipe.set_message(content) collected_content.append(content) log_llm_stream("\n") diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 2913eb1dd..450346ab7 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -83,6 +83,8 @@ class OllamaLLM(BaseLLM): content = self.get_choice_text(chunk) collected_content.append(content) log_llm_stream(content) + if self.stream_pipe: + self.stream_pipe.set_message(content) else: # stream finished usage = self.get_usage(chunk) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index dbfed72df..2cfd86cfb 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -87,6 +87,9 @@ class OpenAILLM(BaseLLM): chunk.choices[0].finish_reason if chunk.choices and hasattr(chunk.choices[0], "finish_reason") else None ) log_llm_stream(chunk_message) + if self.stream_pipe: + self.stream_pipe.set_message(chunk_message) + collected_messages.append(chunk_message) if finish_reason: if hasattr(chunk, "usage"): diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py index 3d78c8bfc..c45a03d9b 100644 --- a/metagpt/provider/qianfan_api.py +++ b/metagpt/provider/qianfan_api.py @@ -124,6 +124,8 @@ class QianFanLLM(BaseLLM): content = chunk.body.get("result", "") usage = chunk.body.get("usage", {}) log_llm_stream(content) + if self.stream_pipe: + self.stream_pipe.set_message(content) collected_content.append(content) log_llm_stream("\n") diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 2db441991..d42f34fdd 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -73,6 +73,8 @@ class ZhiPuAILLM(BaseLLM): content = self.get_choice_delta_text(chunk) collected_content.append(content) log_llm_stream(content) + if self.stream_pipe: + self.stream_pipe.set_message(content) log_llm_stream("\n") diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index e0f8a7ea6..599662145 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -39,6 +39,8 @@ from metagpt.strategy.planner import Planner from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output +from metagpt.utils.stream_pipe import StreamPipe + if TYPE_CHECKING: from metagpt.environment import Environment # noqa: F401 @@ -139,6 +141,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel): role_id: str = "" states: list[str] = [] + stream_pipe: Optional[StreamPipe] = None + # scenarios to set action system_prompt: # 1. `__init__` while using Role(actions=[...]) # 2. add action to role while using `role.set_action(action)` diff --git a/metagpt/roles/tutorial_assistant.py b/metagpt/roles/tutorial_assistant.py index 6cf3a6469..4da419cc0 100644 --- a/metagpt/roles/tutorial_assistant.py +++ b/metagpt/roles/tutorial_assistant.py @@ -10,7 +10,7 @@ from datetime import datetime from typing import Dict from metagpt.actions.write_tutorial import WriteContent, WriteDirectory -from metagpt.const import TUTORIAL_PATH +from metagpt.const import TUTORIAL_PATH, METAGPT_ROOT from metagpt.logs import logger from metagpt.roles.role import Role, RoleReactMode from metagpt.schema import Message @@ -40,7 +40,7 @@ class TutorialAssistant(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - self.set_actions([WriteDirectory(language=self.language)]) + self.set_actions([WriteDirectory(language=self.language, stream_pipe=self.stream_pipe)]) self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value) async def _handle_directory(self, titles: Dict) -> Message: @@ -58,7 +58,7 @@ class TutorialAssistant(Role): self.total_content += f"# {self.main_title}" actions = list() for first_dir in titles.get("directory"): - actions.append(WriteContent(language=self.language, directory=first_dir)) + actions.append(WriteContent(language=self.language, directory=first_dir, stream_pipe=self.stream_pipe)) key = list(first_dir.keys())[0] directory += f"- {key}\n" for second_dir in first_dir[key]: @@ -91,4 +91,8 @@ class TutorialAssistant(Role): root_path = TUTORIAL_PATH / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") await File.write(root_path, f"{self.main_title}.md", self.total_content.encode("utf-8")) msg.content = str(root_path / f"{self.main_title}.md") + + if self.stream_pipe: + self.stream_pipe.set_k_message("file_name", msg.content.replace(str(METAGPT_ROOT), "")) + self.stream_pipe.with_finish() return msg diff --git a/metagpt/utils/stream_pipe.py b/metagpt/utils/stream_pipe.py new file mode 100644 index 000000000..5fa4556ea --- /dev/null +++ b/metagpt/utils/stream_pipe.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/3/27 10:00 +# @Author : leiwu30 +# @File : stream_pipe.py +# @Version : None +# @Description : None + +import time +import json +from multiprocessing import Pipe + + +class StreamPipe: + parent_conn, child_conn = Pipe() + + variable: list = {} + finish: bool = False + + format_data = { + "id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur", + "object": "chat.completion.chunk", + "created": 1711361191, + "model": "gpt-3.5-turbo-0125", + "system_fingerprint": "fp_3bc1b5746c", + "choices": [ + { + "index": 0, + "delta": + { + "role": "assistant", + "content": "content" + }, + "logprobs": None, + "finish_reason": None + } + ] + } + + def set_message(self, msg): + self.parent_conn.send(msg) + + def wait(self): + pass + + def get_message(self): + if self.child_conn.poll(timeout=3): + return self.child_conn.recv() + else: + return None + + def set_k_message(self, k, msg): + self.variable[k] = msg + + def get_k_message(self, k): + return self.variable[k] + + def msg2stream(self, msg): + self.format_data['created'] = int(time.time()) + self.format_data['choices'][0]['delta']['content'] = msg + return f"data: {json.dumps(self.format_data, ensure_ascii=False)}\n".encode("utf-8") + + def with_finish(self, timeout: int = 3): + """ + Args: + timeout: /s + """ + # Pipe is not empty waiting for pipe condition + # while self.child_conn.poll(timeout=timeout): + # time.sleep(0.5) + self.finish = True diff --git a/requirements.txt b/requirements.txt index da8aa26b2..f5e9b15c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,4 +69,5 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation -jieba==0.42.1 # for tool recommendation \ No newline at end of file +jieba==0.42.1 # for tool recommendation +flask # for web api \ No newline at end of file