Merge pull request #1118 from 2495165664/main

feat(core): Add stream data return and reception
This commit is contained in:
Alexander Wu 2024-04-08 20:40:37 +08:00 committed by GitHub
commit db65554c49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 162 additions and 0 deletions

View file

@ -0,0 +1,113 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/3/27 9:44
@Author : leiwu30
@File : stream_output_via_api.py
@Description : Stream log information and communicate over the network via web api.
"""
import json
import socket
import asyncio
import threading
from contextvars import ContextVar
from flask import Flask, Response
from flask import request, jsonify, send_from_directory
from metagpt.logs import logger
from metagpt.const import TUTORIAL_PATH
from metagpt.logs import set_llm_stream_logfunc
from metagpt.utils.stream_pipe import StreamPipe
from metagpt.roles.tutorial_assistant import TutorialAssistant
app = Flask(__name__)
def stream_pipe_log(content):
print(content, end="")
stream_pipe = stream_pipe_var.get(None)
if stream_pipe:
stream_pipe.set_message(content)
def write_tutorial(message):
async def main(idea, stream_pipe):
stream_pipe_var.set(stream_pipe)
role = TutorialAssistant()
await role.run(idea)
def thread_run(idea: str, stream_pipe: StreamPipe = None):
"""
Convert asynchronous function to thread function
"""
asyncio.run(main(idea, stream_pipe))
stream_pipe = StreamPipe()
thread = threading.Thread(target=thread_run, args=(message["content"], stream_pipe,))
thread.start()
while thread.is_alive():
msg = stream_pipe.get_message()
yield stream_pipe.msg2stream(msg)
@app.route('/v1/chat/completions', methods=['POST'])
def completions():
"""
data: {
"model": "write_tutorial",
"stream": true,
"messages": [
{
"role": "user",
"content": "Write a tutorial about MySQL"
}
]
}
"""
data = json.loads(request.data)
logger.info(json.dumps(data, indent=4, ensure_ascii=False))
# Non-streaming interfaces are not supported yet
stream_type = True if data.get("stream") else False
if not stream_type:
return jsonify({"status": 400, "msg": "Non-streaming requests are not supported, please use `stream=True`."})
# Only accept the last user information
# openai['model'] ~ MetaGPT['agent']
last_message = data["messages"][-1]
model = data["model"]
# write_tutorial
if model == "write_tutorial":
return Response(write_tutorial(last_message), mimetype="text/plain")
else:
return jsonify({"status": 400, "msg": "No suitable agent found."})
@app.route('/download/<path:filename>')
def download_file(filename):
return send_from_directory(TUTORIAL_PATH, filename, as_attachment=True)
if __name__ == "__main__":
"""
curl https://$server_address:$server_port/v1/chat/completions -X POST -d '{
"model": "write_tutorial",
"stream": true,
"messages": [
{
"role": "user",
"content": "Write a tutorial about MySQL"
}
]
}'
"""
server_port = 7860
server_address = socket.gethostbyname(socket.gethostname())
set_llm_stream_logfunc(stream_pipe_log)
stream_pipe_var: ContextVar[StreamPipe] = ContextVar("stream_pipe")
app.run(port=server_port, host=server_address)

View file

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# @Time : 2024/3/27 10:00
# @Author : leiwu30
# @File : stream_pipe.py
# @Version : None
# @Description : None
import time
import json
from multiprocessing import Pipe
class StreamPipe:
parent_conn, child_conn = Pipe()
finish: bool = False
format_data = {
"id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur",
"object": "chat.completion.chunk",
"created": 1711361191,
"model": "gpt-3.5-turbo-0125",
"system_fingerprint": "fp_3bc1b5746c",
"choices": [
{
"index": 0,
"delta":
{
"role": "assistant",
"content": "content"
},
"logprobs": None,
"finish_reason": None
}
]
}
def set_message(self, msg):
self.parent_conn.send(msg)
def get_message(self, timeout: int = 3):
if self.child_conn.poll(timeout):
return self.child_conn.recv()
else:
return None
def msg2stream(self, msg):
self.format_data['created'] = int(time.time())
self.format_data['choices'][0]['delta']['content'] = msg
return f"data: {json.dumps(self.format_data, ensure_ascii=False)}\n".encode("utf-8")