From db2ea905b4d838dda2f456dc06cb69d70059c220 Mon Sep 17 00:00:00 2001 From: leiwu30 <2495165664@qq.com> Date: Thu, 28 Mar 2024 10:27:15 +0800 Subject: [PATCH] fix: Only add demo and a tool class to complete the function --- examples/flask_web_api.py | 119 +++++++++++++++++++++++++++++++++++ metagpt/utils/stream_pipe.py | 57 +++++++++++++++++ 2 files changed, 176 insertions(+) create mode 100644 examples/flask_web_api.py create mode 100644 metagpt/utils/stream_pipe.py diff --git a/examples/flask_web_api.py b/examples/flask_web_api.py new file mode 100644 index 000000000..b8baa4359 --- /dev/null +++ b/examples/flask_web_api.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/3/27 9:44 +@Author : leiwu30 +@File : flask_web_api.py +@Description : Stream log information and communicate over the network via web api. +""" +import os +import json +import socket +import asyncio +import threading + +from metagpt.team import Team +from metagpt.const import METAGPT_ROOT +from metagpt.logs import set_llm_stream_logfunc +from metagpt.utils.stream_pipe import StreamPipe +from metagpt.roles.tutorial_assistant import TutorialAssistant + +from contextvars import ContextVar +from flask import Flask, Response +from flask import request, jsonify, send_from_directory + +app = Flask(__name__) + + +def stream_pipe_log(content): + print(content, end="") + stream_pipe = stream_pipe_var.get(None) + if stream_pipe: + stream_pipe.set_message(content) + + +def write_tutorial(message): + async def main(idea, stream_pipe): + stream_pipe_var.set(stream_pipe) + role = TutorialAssistant(stream_pipe=stream_pipe) + await role.run(idea) + + def thread_run(idea: str, stream_pipe: StreamPipe = None): + """ + Convert asynchronous function to thread function + """ + asyncio.run(main(idea, stream_pipe)) + + stream_pipe = StreamPipe() + thread = threading.Thread(target=thread_run, args=(message["content"], stream_pipe,)) + thread.start() + + while thread.is_alive(): + msg = stream_pipe.get_message() + yield stream_pipe.msg2stream(msg) + + # 文件位置 + # md_file = stream_pipe.get_k_message("file_name") + # yield stream_pipe.msg2stream( + # f"\n\n[{os.path.basename(md_file)}](http://{server_address}:{server_port}/download/{md_file})") + + +@app.route('/v1/chat/completions', methods=['POST']) +def completions(): + """ + data: { + "model": "write_tutorial", + "stream": true, + "messages": [ + { + "role": "user", + "content": "Write a tutorial about MySQL" + } + ] + } + """ + + data = json.loads(request.data) + print(json.dumps(data, indent=4, ensure_ascii=False)) + + # Non-streaming interfaces are not supported yet + stream_type = True if "stream" in data.keys() and data["stream"] else False + if not stream_type: + return jsonify({"status": 200}) + + # Only accept the last user information + last_message = data["messages"][-1] + model = data["model"] + + # write_tutorial + if model == "write_tutorial": + return Response(write_tutorial(last_message), mimetype="text/plain") + else: + return jsonify({"status": 200}) + # return Response(event_stream(), mimetype="text/plain") + + +@app.route('/download/') +def download_file(filename): + return send_from_directory(METAGPT_ROOT, filename, as_attachment=True) + + +if __name__ == "__main__": + """ + curl https://$server_address:$server_port/v1/chat/completions -X POST -d '{ + "model": "write_tutorial", + "stream": true, + "messages": [ + { + "role": "user", + "content": "Write a tutorial about MySQL" + } + ] + }' + """ + server_port = 7860 + server_address = socket.gethostbyname(socket.gethostname()) + + set_llm_stream_logfunc(stream_pipe_log) + stream_pipe_var: ContextVar[StreamPipe] = ContextVar("stream_pipe") + app.run(port=server_port, host=server_address) diff --git a/metagpt/utils/stream_pipe.py b/metagpt/utils/stream_pipe.py new file mode 100644 index 000000000..6e319411c --- /dev/null +++ b/metagpt/utils/stream_pipe.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/3/27 10:00 +# @Author : leiwu30 +# @File : stream_pipe.py +# @Version : None +# @Description : None + +import time +import json +from multiprocessing import Pipe + + +class StreamPipe: + parent_conn, child_conn = Pipe() + + variable: list = {} + finish: bool = False + + format_data = { + "id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur", + "object": "chat.completion.chunk", + "created": 1711361191, + "model": "gpt-3.5-turbo-0125", + "system_fingerprint": "fp_3bc1b5746c", + "choices": [ + { + "index": 0, + "delta": + { + "role": "assistant", + "content": "content" + }, + "logprobs": None, + "finish_reason": None + } + ] + } + + def set_message(self, msg): + self.parent_conn.send(msg) + + def get_message(self, timeout: int = 3): + if self.child_conn.poll(timeout): + return self.child_conn.recv() + else: + return None + + def set_k_message(self, k, msg): + self.variable[k] = msg + + def get_k_message(self, k): + return self.variable[k] + + def msg2stream(self, msg): + self.format_data['created'] = int(time.time()) + self.format_data['choices'][0]['delta']['content'] = msg + return f"data: {json.dumps(self.format_data, ensure_ascii=False)}\n".encode("utf-8")