diff --git a/.agent-store-config.yaml.example b/.agent-store-config.yaml.example new file mode 100644 index 000000000..037a44ed4 --- /dev/null +++ b/.agent-store-config.yaml.example @@ -0,0 +1,9 @@ +role: + name: Teacher # Referenced the `Teacher` in `metagpt/roles/teacher.py`. + module: metagpt.roles.teacher # Referenced `metagpt/roles/teacher.py`. + skills: # Refer to the skill `name` of the published skill in `.well-known/skills.yaml`. + - name: text_to_speech + description: Text-to-speech + - name: text_to_image + description: Create a drawing based on the text. + diff --git a/.gitignore b/.gitignore index c12506b0e..93e24ba48 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,9 @@ workspace/* tmp metagpt/roles/idea_agent.py .aider* +*.bak + +# output folder +output +tmp.png + diff --git a/.well-known/ai-plugin.json b/.well-known/ai-plugin.json new file mode 100644 index 000000000..ac0178fd0 --- /dev/null +++ b/.well-known/ai-plugin.json @@ -0,0 +1,18 @@ +{ + "schema_version": "v1", + "name_for_model": "text processing tools", + "name_for_human": "MetaGPT Text Plugin", + "description_for_model": "Plugins for text processing, including text-to-speech, text-to-image, text-to-embedding, text summarization, text-to-code, vector similarity calculation, web content crawling, and more.", + "description_for_human": "Plugins for text processing, including text-to-speech, text-to-image, text-to-embedding, text summarization, text-to-code, vector similarity calculation, web content crawling, and more.", + "auth": { + "type": "none" + }, + "api": { + "type": "openapi", + "url": "https://github.com/iorisa/MetaGPT/blob/feature/assistant_role/.well-known/metagpt_oas3_api.yaml", + "has_user_authentication": false + }, + "logo_url": "https://github.com/geekan/MetaGPT/blob/main/docs/resources/MetaGPT-logo.png", + "contact_email": "mashenquan@fuzhi.cn", + "legal_info_url": "https://github.com/geekan/MetaGPT/blob/main/docs/README_CN.md" +} \ No newline at end of file diff --git a/.well-known/metagpt_oas3_api.yaml b/.well-known/metagpt_oas3_api.yaml new file mode 100644 index 000000000..0a702e8b6 --- /dev/null +++ b/.well-known/metagpt_oas3_api.yaml @@ -0,0 +1,338 @@ +openapi: "3.0.0" + +info: + title: "MetaGPT Export OpenAPIs" + version: "1.0" +servers: + - url: "/oas3" + variables: + port: + default: '8080' + description: HTTP service port + +paths: + /tts/azsure: + x-prerequisite: + configurations: + AZURE_TTS_SUBSCRIPTION_KEY: + type: string + description: "For more details, check out: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)" + AZURE_TTS_REGION: + type: string + description: "For more details, check out: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)" + required: + allOf: + - AZURE_TTS_SUBSCRIPTION_KEY + - AZURE_TTS_REGION + post: + summary: "Convert Text to Base64-encoded .wav File Stream" + description: "For more details, check out: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)" + operationId: azure_tts.oas3_azsure_tts + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - text + properties: + text: + type: string + description: Text to convert + lang: + type: string + description: The language code or locale, e.g., en-US (English - United States) + default: "zh-CN" + voice: + type: string + description: "Voice style, see: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts), [Voice Gallery](https://speech.microsoft.com/portal/voicegallery)" + default: "zh-CN-XiaomoNeural" + style: + type: string + description: "Speaking style to express different emotions. For more details, checkout: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)" + default: "affectionate" + role: + type: string + description: "Role to specify age and gender. For more details, checkout: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)" + default: "Girl" + subscription_key: + type: string + description: "Key used to access Azure AI service API, see: [Azure Portal](https://portal.azure.com/) > `Resource Management` > `Keys and Endpoint`" + default: "" + region: + type: string + description: "Location (or region) of your resource, see: [Azure Portal](https://portal.azure.com/) > `Resource Management` > `Keys and Endpoint`" + default: "" + responses: + '200': + description: "Base64-encoded .wav file data if successful, otherwise an empty string." + content: + application/json: + schema: + type: object + properties: + wav_data: + type: string + format: base64 + '400': + description: "Bad Request" + '500': + description: "Internal Server Error" + + /tts/iflytek: + x-prerequisite: + configurations: + IFLYTEK_APP_ID: + type: string + description: "Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`" + IFLYTEK_API_KEY: + type: string + description: "WebAPI argument, see: `https://console.xfyun.cn/services/tts`" + IFLYTEK_API_SECRET: + type: string + description: "WebAPI argument, see: `https://console.xfyun.cn/services/tts`" + required: + allOf: + - IFLYTEK_APP_ID + - IFLYTEK_API_KEY + - IFLYTEK_API_SECRET + post: + summary: "Convert Text to Base64-encoded .mp3 File Stream" + description: "For more details, check out: [iFlyTek](https://console.xfyun.cn/services/tts)" + operationId: iflytek_tts.oas3_iflytek_tts + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - text + properties: + text: + type: string + description: Text to convert + voice: + type: string + description: "Voice style, see: [iFlyTek Text-to_Speech](https://www.xfyun.cn/doc/tts/online_tts/API.html#%E6%8E%A5%E5%8F%A3%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B)" + default: "xiaoyan" + app_id: + type: string + description: "Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`" + default: "" + api_key: + type: string + description: "WebAPI argument, see: `https://console.xfyun.cn/services/tts`" + default: "" + api_secret: + type: string + description: "WebAPI argument, see: `https://console.xfyun.cn/services/tts`" + default: "" + responses: + '200': + description: "Base64-encoded .mp3 file data if successful, otherwise an empty string." + content: + application/json: + schema: + type: object + properties: + wav_data: + type: string + format: base64 + '400': + description: "Bad Request" + '500': + description: "Internal Server Error" + + + /txt2img/openai: + x-prerequisite: + configurations: + OPENAI_API_KEY: + type: string + description: "OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`" + required: + allOf: + - OPENAI_API_KEY + post: + summary: "Convert Text to Base64-encoded Image Data Stream" + operationId: openai_text_to_image.oas3_openai_text_to_image + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + text: + type: string + description: "The text used for image conversion." + size_type: + type: string + enum: ["256x256", "512x512", "1024x1024"] + default: "1024x1024" + description: "Size of the generated image." + openai_api_key: + type: string + default: "" + description: "OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`" + responses: + '200': + description: "Base64-encoded image data." + content: + application/json: + schema: + type: object + properties: + image_data: + type: string + format: base64 + '400': + description: "Bad Request" + '500': + description: "Internal Server Error" + /txt2embedding/openai: + x-prerequisite: + configurations: + OPENAI_API_KEY: + type: string + description: "OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`" + required: + allOf: + - OPENAI_API_KEY + post: + summary: Text to embedding + operationId: openai_text_to_embedding.oas3_openai_text_to_embedding + description: Retrieve an embedding for the provided text using the OpenAI API. + requestBody: + content: + application/json: + schema: + type: object + properties: + input: + type: string + description: The text used for embedding. + model: + type: string + description: "ID of the model to use. For more details, checkout: [models](https://api.openai.com/v1/models)" + enum: + - text-embedding-ada-002 + responses: + "200": + description: Successful response + content: + application/json: + schema: + $ref: "#/components/schemas/ResultEmbedding" + "4XX": + description: Client error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "5XX": + description: Server error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /txt2image/metagpt: + x-prerequisite: + configurations: + METAGPT_TEXT_TO_IMAGE_MODEL_URL: + type: string + description: "Model url." + required: + allOf: + - METAGPT_TEXT_TO_IMAGE_MODEL_URL + post: + summary: "Text to Image" + description: "Generate an image from the provided text using the MetaGPT Text-to-Image API." + operationId: metagpt_text_to_image.oas3_metagpt_text_to_image + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - text + properties: + text: + type: string + description: "The text used for image conversion." + size_type: + type: string + enum: ["512x512", "512x768"] + default: "512x512" + description: "Size of the generated image." + model_url: + type: string + description: "Model reset API URL for text-to-image." + default: "" + responses: + '200': + description: "Base64-encoded image data." + content: + application/json: + schema: + type: object + properties: + image_data: + type: string + format: base64 + '400': + description: "Bad Request" + '500': + description: "Internal Server Error" + +components: + schemas: + Embedding: + type: object + description: Represents an embedding vector returned by the embedding endpoint. + properties: + object: + type: string + example: embedding + embedding: + type: array + items: + type: number + example: [0.0023064255, -0.009327292, ...] + index: + type: integer + example: 0 + Usage: + type: object + properties: + prompt_tokens: + type: integer + example: 8 + total_tokens: + type: integer + example: 8 + ResultEmbedding: + type: object + properties: + object: + type: string + example: result_embedding + data: + type: array + items: + $ref: "#/components/schemas/Embedding" + model: + type: string + example: text-embedding-ada-002 + usage: + $ref: "#/components/schemas/Usage" + Error: + type: object + properties: + error: + type: string + example: An error occurred \ No newline at end of file diff --git a/.well-known/openapi.yaml b/.well-known/openapi.yaml new file mode 100644 index 000000000..bc291b7db --- /dev/null +++ b/.well-known/openapi.yaml @@ -0,0 +1,35 @@ +openapi: "3.0.0" + +info: + title: Hello World + version: "1.0" +servers: + - url: /openapi + +paths: + /greeting/{name}: + post: + summary: Generate greeting + description: Generates a greeting message. + operationId: hello.post_greeting + responses: + 200: + description: greeting response + content: + text/plain: + schema: + type: string + example: "hello dave!" + parameters: + - name: name + in: path + description: Name of the person to greet. + required: true + schema: + type: string + example: "dave" + requestBody: + content: + application/json: + schema: + type: object \ No newline at end of file diff --git a/.well-known/skills.yaml b/.well-known/skills.yaml new file mode 100644 index 000000000..c19a9501e --- /dev/null +++ b/.well-known/skills.yaml @@ -0,0 +1,161 @@ +skillapi: "0.1.0" + +info: + title: "Agent Skill Specification" + version: "1.0" + +entities: + Assistant: + summary: assistant + description: assistant + skills: + - name: text_to_speech + description: Generate a voice file from the input text, text-to-speech + id: text_to_speech.text_to_speech + x-prerequisite: + configurations: + AZURE_TTS_SUBSCRIPTION_KEY: + type: string + description: "For more details, check out: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)" + AZURE_TTS_REGION: + type: string + description: "For more details, check out: [Azure Text-to_Speech](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts)" + IFLYTEK_APP_ID: + type: string + description: "Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`" + IFLYTEK_API_KEY: + type: string + description: "WebAPI argument, see: `https://console.xfyun.cn/services/tts`" + IFLYTEK_API_SECRET: + type: string + description: "WebAPI argument, see: `https://console.xfyun.cn/services/tts`" + required: + oneOf: + - allOf: + - AZURE_TTS_SUBSCRIPTION_KEY + - AZURE_TTS_REGION + - allOf: + - IFLYTEK_APP_ID + - IFLYTEK_API_KEY + - IFLYTEK_API_SECRET + parameters: + text: + description: 'The text used for voice conversion.' + required: true + type: string + lang: + description: 'The value can contain a language code such as en (English), or a locale such as en-US (English - United States).' + type: string + enum: + - English + - Chinese + default: Chinese + voice: + description: Name of voice styles + type: string + default: zh-CN-XiaomoNeural + style: + type: string + description: Speaking style to express different emotions like cheerfulness, empathy, and calm. + enum: + - affectionate + - angry + - calm + - cheerful + - depressed + - disgruntled + - embarrassed + - envious + - fearful + - gentle + - sad + - serious + default: affectionate + role: + type: string + description: With roles, the same voice can act as a different age and gender. + enum: + - Girl + - Boy + - OlderAdultFemale + - OlderAdultMale + - SeniorFemale + - SeniorMale + - YoungAdultFemale + - YoungAdultMale + default: Girl + examples: + - ask: 'A girl says "hello world"' + answer: 'text_to_speech(text="hello world", role="Girl")' + - ask: 'A boy affectionate says "hello world"' + answer: 'text_to_speech(text="hello world", role="Boy", style="affectionate")' + - ask: 'A boy says "你好"' + answer: 'text_to_speech(text="你好", role="Boy", lang="Chinese")' + returns: + type: string + format: base64 + + - name: text_to_image + description: Create a drawing based on the text. + id: text_to_image.text_to_image + x-prerequisite: + configurations: + OPENAI_API_KEY: + type: string + description: "OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`" + METAGPT_TEXT_TO_IMAGE_MODEL_URL: + type: string + description: "Model url." + required: + oneOf: + - OPENAI_API_KEY + - METAGPT_TEXT_TO_IMAGE_MODEL_URL + parameters: + text: + description: 'The text used for image conversion.' + type: string + required: true + size_type: + description: size type + type: string + default: "512x512" + examples: + - ask: 'Draw a girl' + answer: 'text_to_image(text="Draw a girl", size_type="512x512")' + - ask: 'Draw an apple' + answer: 'text_to_image(text="Draw an apple", size_type="512x512")' + returns: + type: string + format: base64 + + - name: web_search + description: Perform Google searches to provide real-time information. + id: web_search.web_search + x-prerequisite: + configurations: + SEARCH_ENGINE: + type: string + description: "Supported values: serpapi/google/serper/ddg" + SERPER_API_KEY: + type: string + description: "SERPER API KEY, For more details, checkout: `https://serper.dev/api-key`" + required: + allOf: + - SEARCH_ENGINE + - SERPER_API_KEY + parameters: + query: + type: string + description: 'The search query.' + required: true + max_results: + type: number + default: 6 + description: 'The number of search results to retrieve.' + examples: + - ask: 'Search for information about artificial intelligence' + answer: 'web_search(query="Search for information about artificial intelligence", max_results=6)' + - ask: 'Find news articles about climate change' + answer: 'web_search(query="Find news articles about climate change", max_results=6)' + returns: + type: string diff --git a/config/config.yaml b/config/config.yaml index e724897ee..09f2895d1 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -14,6 +14,7 @@ OPENAI_BASE_URL: "https://api.openai.com/v1" OPENAI_API_MODEL: "gpt-4-1106-preview" MAX_TOKENS: 4096 RPM: 10 +LLM_TYPE: OpenAI # Except for these three major models – OpenAI, MetaGPT LLM, and Azure – other large models can be distinguished based on the validity of the key. #### if Spark #SPARK_APPID : "YOUR_APPID" @@ -113,4 +114,25 @@ RPM: 10 ### repair operation on the content extracted from LLM's raw output. Warning, it improves the result but not fix all cases. # REPAIR_LLM_OUTPUT: false -# PROMPT_FORMAT: json #json or markdown \ No newline at end of file +# PROMPT_FORMAT: json #json or markdown + +### Agent configurations +# RAISE_NOT_CONFIG_ERROR: true # "true" if the LLM key is not configured, throw a NotConfiguredException, else "false". +# WORKSPACE_PATH_WITH_UID: false # "true" if using `{workspace}/{uid}` as the workspace path; "false" use `{workspace}`. + +### Meta Models +#METAGPT_TEXT_TO_IMAGE_MODEL: MODEL_URL + +### S3 config +#S3_ACCESS_KEY: "YOUR_S3_ACCESS_KEY" +#S3_SECRET_KEY: "YOUR_S3_SECRET_KEY" +#S3_ENDPOINT_URL: "YOUR_S3_ENDPOINT_URL" +#S3_SECURE: true # true/false +#S3_BUCKET: "YOUR_S3_BUCKET" + +### Redis config +#REDIS_HOST: "YOUR_REDIS_HOST" +#REDIS_PORT: "YOUR_REDIS_PORT" +#REDIS_PASSWORD: "YOUR_REDIS_PASSWORD" +#REDIS_DB: "YOUR_REDIS_DB_INDEX, str, 0-based" + diff --git a/examples/search_kb.py b/examples/search_kb.py index 5d61bbe02..0afd7ad15 100644 --- a/examples/search_kb.py +++ b/examples/search_kb.py @@ -2,10 +2,10 @@ # -*- coding: utf-8 -*- """ @File : search_kb.py +@Modified By: mashenquan, 2023-12-22. Delete useless codes. """ import asyncio -from metagpt.actions import Action from metagpt.const import DATA_PATH from metagpt.document_store import FaissStore from metagpt.logs import logger @@ -29,10 +29,9 @@ from metagpt.schema import Message async def search(): store = FaissStore(DATA_PATH / "example.json") role = Sales(profile="Sales", store=store) - role._watch({Action}) queries = [ - Message(content="Which facial cleanser is good for oily skin?", cause_by=Action), - Message(content="Is L'Oreal good to use?", cause_by=Action), + Message(content="Which facial cleanser is good for oily skin?"), + Message(content="Is L'Oreal good to use?"), ] for query in queries: logger.info(f"User: {query}") diff --git a/examples/search_with_specific_engine.py b/examples/search_with_specific_engine.py index 923f538ed..adb5665cb 100644 --- a/examples/search_with_specific_engine.py +++ b/examples/search_with_specific_engine.py @@ -1,3 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +""" import asyncio from metagpt.roles import Searcher diff --git a/examples/write_teaching_plan.py b/examples/write_teaching_plan.py new file mode 100644 index 000000000..01181dc2b --- /dev/null +++ b/examples/write_teaching_plan.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023-07-27 +@Author : mashenquan +@File : write_teaching_plan.py +@Desc: Write teaching plan demo + ``` + export PYTHONPATH=$PYTHONPATH:$PWD + python examples/write_teaching_plan.py --language=Chinese --teaching_language=English + + ``` +""" + +import asyncio +from pathlib import Path + +import aiofiles +import fire + +from metagpt.actions.write_teaching_plan import TeachingPlanRequirement +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.roles.teacher import Teacher +from metagpt.schema import Message +from metagpt.team import Team + + +async def startup(lesson_file: str, investment: float = 3.0, n_round: int = 1, *args, **kwargs): + """Run a startup. Be a teacher in education industry.""" + + demo_lesson = """ + UNIT 1 Making New Friends + TOPIC 1 Welcome to China! + Section A + + 1a Listen and number the following names. + Jane Mari Kangkang Michael + Look, listen and understand. Then practice the conversation. + Work in groups. Introduce yourself using + I ’m ... Then practice 1a + with your own hometown or the following places. + + 1b Listen and number the following names + Jane Michael Maria Kangkang + 1c Work in groups. Introduce yourself using I ’m ... Then practice 1a with your own hometown or the following places. + China the USA the UK Hong Kong Beijing + + 2a Look, listen and understand. Then practice the conversation + Hello! + Hello! + Hello! + Hello! Are you Maria? + No, I’m not. I’m Jane. + Oh, nice to meet you, Jane + Nice to meet you, too. + Hi, Maria! + Hi, Kangkang! + Welcome to China! + Thanks. + + 2b Work in groups. Make up a conversation with your own name and the + following structures. + A: Hello! / Good morning! / Hi! I’m ... Are you ... ? + B: ... + + 3a Listen, say and trace + Aa Bb Cc Dd Ee Ff Gg + + 3b Listen and number the following letters. Then circle the letters with the same sound as Bb. + Aa Bb Cc Dd Ee Ff Gg + + 3c Match the big letters with the small ones. Then write them on the lines. + """ + CONFIG.set_context(kwargs) + + lesson = "" + if lesson_file and Path(lesson_file).exists(): + async with aiofiles.open(lesson_file, mode="r", encoding="utf-8") as reader: + lesson = await reader.read() + logger.info(f"Course content: {lesson}") + if not lesson: + logger.info("No course content provided, using the demo course.") + lesson = demo_lesson + + company = Team() + company.hire([Teacher(*args, **kwargs)]) + company.invest(investment) + company.env.publish_message(Message(content=lesson, cause_by=TeachingPlanRequirement)) + await company.run(n_round=1) + + +def main(idea: str, investment: float = 3.0, n_round: int = 5, *args, **kwargs): + """ + We are a software startup comprised of AI. By investing in us, you are empowering a future filled with limitless possibilities. + :param idea: lesson filename. + :param investment: As an investor, you have the opportunity to contribute a certain dollar amount to this AI company. + :param n_round: Reserved. + :param args: Parameters passed in format: `python your_script.py arg1 arg2 arg3` + :param kwargs: Parameters passed in format: `python your_script.py --param1=value1 --param2=value2` + :return: + """ + asyncio.run(startup(idea, investment, n_round, *args, **kwargs)) + + +if __name__ == "__main__": + """ + Formats: + ``` + python write_teaching_plan.py lesson_filename --teaching_language= --language= + ``` + If `lesson_filename` is not available, a demo lesson content will be used. + """ + fire.Fire(main) diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py new file mode 100644 index 000000000..6da3e2989 --- /dev/null +++ b/metagpt/actions/rebuild_class_view.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : rebuild_class_view.py +@Desc : Rebuild class view info +""" +import re +from pathlib import Path + +from metagpt.actions import Action +from metagpt.config import CONFIG +from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO +from metagpt.repo_parser import RepoParser +from metagpt.utils.di_graph_repository import DiGraphRepository +from metagpt.utils.graph_repository import GraphKeyword, GraphRepository + + +class RebuildClassView(Action): + def __init__(self, name="", context=None, llm=None): + super().__init__(name=name, context=context, llm=llm) + + async def run(self, with_messages=None, format=CONFIG.prompt_format): + graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name + graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + repo_parser = RepoParser(base_directory=self.context) + class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint + await GraphRepository.update_graph_db_with_class_views(graph_db, class_views) + symbols = repo_parser.generate_symbols() # use ast + for file_info in symbols: + await GraphRepository.update_graph_db_with_file_info(graph_db, file_info) + await self._create_mermaid_class_view(graph_db=graph_db) + await self._save(graph_db=graph_db) + + async def _create_mermaid_class_view(self, graph_db): + pass + # dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO) + # if not dataset: + # logger.warning(f"No page info for {concat_namespace(filename, class_name)}") + # return + # code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_) + # src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno) + # code_type = "" + # dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS) + # for spo in dataset: + # if spo.object_ in ["javascript", "python"]: + # code_type = spo.object_ + # break + + # try: + # node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format) + # class_view = node.instruct_content.dict()["Class View"] + # except Exception as e: + # class_view = RepoParser.rebuild_class_view(src_code, code_type) + # await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view) + # logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}") + + async def _save(self, graph_db): + class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO) + dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW) + all_class_view = [] + for spo in dataset: + title = f"---\ntitle: {spo.subject}\n---\n" + filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd" + await class_view_file_repo.save(filename=filename, content=title + spo.object_) + all_class_view.append(spo.object_) + await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view)) diff --git a/metagpt/actions/rebuild_class_view_an.py b/metagpt/actions/rebuild_class_view_an.py new file mode 100644 index 000000000..da32a9b5e --- /dev/null +++ b/metagpt/actions/rebuild_class_view_an.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : rebuild_class_view_an.py +@Desc : Defines `ActionNode` objects used by rebuild_class_view.py +""" +from metagpt.actions.action_node import ActionNode + +CLASS_SOURCE_CODE_BLOCK = ActionNode( + key="Class View", + expected_type=str, + instruction='Generate the mermaid class diagram corresponding to source code in "context."', + example=""" + classDiagram + class A { + -int x + +int y + -int speed + -int direction + +__init__(x: int, y: int, speed: int, direction: int) + +change_direction(new_direction: int) None + +move() None + } + """, +) + +REBUILD_CLASS_VIEW_NODES = [ + CLASS_SOURCE_CODE_BLOCK, +] + +REBUILD_CLASS_VIEW_NODE = ActionNode.from_children("RebuildClassView", REBUILD_CLASS_VIEW_NODES) diff --git a/metagpt/actions/skill_action.py b/metagpt/actions/skill_action.py new file mode 100644 index 000000000..c95a83cbb --- /dev/null +++ b/metagpt/actions/skill_action.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/28 +@Author : mashenquan +@File : skill_action.py +@Desc : Call learned skill +""" +from __future__ import annotations + +import ast +import importlib +import traceback +from copy import deepcopy +from typing import Dict, Optional + +from metagpt.actions import Action, ActionOutput +from metagpt.learn.skill_loader import Skill +from metagpt.logs import logger + + +class ArgumentsParingAction(Action): + skill: Skill + ask: str + rsp: Optional[ActionOutput] + args: Optional[Dict] + + @property + def prompt(self): + prompt = f"{self.skill.name} function parameters description:\n" + for k, v in self.skill.arguments.items(): + prompt += f"parameter `{k}`: {v}\n" + prompt += "\n" + prompt += "Examples:\n" + for e in self.skill.examples: + prompt += f"If want you to do `{e.ask}`, return `{e.answer}` brief and clear.\n" + prompt += f"\nNow I want you to do `{self.ask}`, return in examples format above, brief and clear." + return prompt + + async def run(self, *args, **kwargs) -> ActionOutput: + prompt = self.prompt + rsp = await self.llm.aask(msg=prompt, system_msgs=[]) + logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}") + self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp) + self.rsp = ActionOutput(content=rsp) + return self.rsp + + @staticmethod + def parse_arguments(skill_name, txt) -> dict: + prefix = skill_name + "(" + if prefix not in txt: + logger.error(f"{skill_name} not in {txt}") + return None + if ")" not in txt: + logger.error(f"')' not in {txt}") + return None + begin_ix = txt.find(prefix) + end_ix = txt.rfind(")") + args_txt = txt[begin_ix + len(prefix) : end_ix] + logger.info(args_txt) + fake_expression = f"dict({args_txt})" + parsed_expression = ast.parse(fake_expression, mode="eval") + args = {} + for keyword in parsed_expression.body.keywords: + key = keyword.arg + value = ast.literal_eval(keyword.value) + args[key] = value + return args + + +class SkillAction(Action): + skill: Skill + args: Dict + rsp: str = "" + + async def run(self, *args, **kwargs) -> str | ActionOutput | None: + """Run action""" + options = deepcopy(kwargs) + if self.args: + for k in self.args.keys(): + if k in options: + options.pop(k) + try: + self.rsp = await self.find_and_call_function(self.skill.name, args=self.args, **options) + except Exception as e: + logger.exception(f"{e}, traceback:{traceback.format_exc()}") + self.rsp = f"Error: {e}" + return ActionOutput(content=self.rsp, instruct_content=self.skill.json()) + + @staticmethod + async def find_and_call_function(function_name, args, **kwargs): + try: + module = importlib.import_module("metagpt.learn") + function = getattr(module, function_name) + # 调用函数并返回结果 + result = await function(**args, **kwargs) + return result + except (ModuleNotFoundError, AttributeError): + logger.error(f"{function_name} not found") + return None + + +if __name__ == "__main__": + ArgumentsParingAction.parse_arguments( + skill_name="text_to_image", txt='`text_to_image(text="Draw an apple", size_type="512x512")`' + ) diff --git a/metagpt/actions/talk_action.py b/metagpt/actions/talk_action.py new file mode 100644 index 000000000..3695ec5bb --- /dev/null +++ b/metagpt/actions/talk_action.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/28 +@Author : mashenquan +@File : talk_action.py +@Desc : Act as it’s a talk +""" +import json + +from metagpt.actions import Action, ActionOutput +from metagpt.config import CONFIG +from metagpt.const import DEFAULT_LANGUAGE +from metagpt.llm import LLMType +from metagpt.logs import logger + + +class TalkAction(Action): + def __init__(self, name: str = "", talk="", history_summary="", knowledge="", context=None, llm=None, **kwargs): + context = context or {} + context["talk"] = talk + context["history_summery"] = history_summary + context["knowledge"] = knowledge + super(TalkAction, self).__init__(name=name, context=context, llm=llm) + self._talk = talk + self._history_summary = history_summary + self._knowledge = knowledge + self._rsp = None + + @property + def prompt(self): + prompt = "" + if CONFIG.agent_description: + prompt = ( + f"You are {CONFIG.agent_description}. Your responses should align with the role-play agreement, " + f"maintaining the character's persona and habits. When faced with unrelated questions, playfully " + f"decline to answer without revealing your AI nature to preserve the character's image.\n\n" + ) + prompt += f"Knowledge:\n{self._knowledge}\n\n" if self._knowledge else "" + prompt += f"{self._history_summary}\n\n" + prompt += ( + "If the information is insufficient, you can search in the historical conversation or knowledge above.\n" + ) + language = CONFIG.language or DEFAULT_LANGUAGE + prompt += ( + f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.\n " + f"{self._talk}" + ) + logger.debug(f"PROMPT: {prompt}") + return prompt + + @property + def prompt_gpt4(self): + kvs = { + "{role}": CONFIG.agent_description or "", + "{history}": self._history_summary or "", + "{knowledge}": self._knowledge or "", + "{language}": CONFIG.language or DEFAULT_LANGUAGE, + "{ask}": self._talk, + } + prompt = TalkAction.__FORMATION_LOOSE__ + for k, v in kvs.items(): + prompt = prompt.replace(k, v) + logger.info(f"PROMPT: {prompt}") + return prompt + + async def run_old(self, *args, **kwargs) -> ActionOutput: + prompt = self.prompt + rsp = await self.llm.aask(msg=prompt, system_msgs=[]) + logger.debug(f"PROMPT:{prompt}\nRESULT:{rsp}\n") + self._rsp = ActionOutput(content=rsp) + return self._rsp + + @property + def aask_args(self): + language = CONFIG.language or DEFAULT_LANGUAGE + system_msgs = [ + f"You are {CONFIG.agent_description}.", + "Your responses should align with the role-play agreement, " + "maintaining the character's persona and habits. When faced with unrelated questions, playfully " + "decline to answer without revealing your AI nature to preserve the character's image.", + "If the information is insufficient, you can search in the context or knowledge.", + f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.", + ] + format_msgs = [] + if self._knowledge: + format_msgs.append({"role": "assistant", "content": self._knowledge}) + if self._history_summary: + if CONFIG.LLM_TYPE == LLMType.METAGPT.value: + format_msgs.extend(json.loads(self._history_summary)) + else: + format_msgs.append({"role": "assistant", "content": self._history_summary}) + return self._talk, format_msgs, system_msgs + + async def run(self, *args, **kwargs) -> ActionOutput: + msg, format_msgs, system_msgs = self.aask_args + rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs) + self._rsp = ActionOutput(content=rsp) + return self._rsp + + __FORMATION__ = """Formation: "Capacity and role" defines the role you are currently playing; + "[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation; + "[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses; + "Statement" defines the work detail you need to complete at this stage; + "[ASK_BEGIN]" and [ASK_END] tags enclose the questions; + "Constraint" defines the conditions that your responses must comply with. + "Personality" defines your language style。 + "Insight" provides a deeper understanding of the characters' inner traits. + "Initial" defines the initial setup of a character. + +Capacity and role: {role} +Statement: Your responses should align with the role-play agreement, maintaining the + character's persona and habits. When faced with unrelated questions, playfully decline to answer without revealing + your AI nature to preserve the character's image. + +[HISTORY_BEGIN] + +{history} + +[HISTORY_END] + +[KNOWLEDGE_BEGIN] + +{knowledge} + +[KNOWLEDGE_END] + +Statement: If the information is insufficient, you can search in the historical conversation or knowledge. +Statement: Unless you are a language professional, answer the following questions strictly in {language} +, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]" +, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses. + + +{ask} +""" + + __FORMATION_LOOSE__ = """Formation: "Capacity and role" defines the role you are currently playing; + "[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation; + "[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses; + "Statement" defines the work detail you need to complete at this stage; + "Constraint" defines the conditions that your responses must comply with. + "Personality" defines your language style。 + "Insight" provides a deeper understanding of the characters' inner traits. + "Initial" defines the initial setup of a character. + +Capacity and role: {role} +Statement: Your responses should maintaining the character's persona and habits. When faced with unrelated questions +, playfully decline to answer without revealing your AI nature to preserve the character's image. + +[HISTORY_BEGIN] + +{history} + +[HISTORY_END] + +[KNOWLEDGE_BEGIN] + +{knowledge} + +[KNOWLEDGE_END] + +Statement: If the information is insufficient, you can search in the historical conversation or knowledge. +Statement: Unless you are a language professional, answer the following questions strictly in {language} +, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]" +, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses. + + +{ask} +""" diff --git a/metagpt/actions/write_teaching_plan.py b/metagpt/actions/write_teaching_plan.py new file mode 100644 index 000000000..534f5ded9 --- /dev/null +++ b/metagpt/actions/write_teaching_plan.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/7/27 +@Author : mashenquan +@File : write_teaching_plan.py +""" +from metagpt.actions import Action +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.schema import Message + + +class TeachingPlanRequirement(Action): + """Teaching Plan Requirement without any implementation details""" + + async def run(self, *args, **kwargs): + raise NotImplementedError + + +class WriteTeachingPlanPart(Action): + """Write Teaching Plan Part""" + + def __init__(self, name: str = "", context=None, llm=None, topic: str = "", language: str = "Chinese"): + """ + + :param name: action name + :param context: context + :param llm: object of :class:`LLM` + :param topic: topic part of teaching plan + :param language: A human language, such as Chinese, English, French, etc. + """ + super().__init__(name, context, llm) + self.topic = topic + self.language = language + self.rsp = None + + async def run(self, messages, *args, **kwargs): + if len(messages) < 1 or not isinstance(messages[0], Message): + raise ValueError("Invalid args, a tuple of List[Message] is expected") + + statement_patterns = self.TOPIC_STATEMENTS.get(self.topic, []) + statements = [] + for p in statement_patterns: + s = format_value(p) + statements.append(s) + formatter = self.PROMPT_TITLE_TEMPLATE if self.topic == self.COURSE_TITLE else self.PROMPT_TEMPLATE + prompt = formatter.format( + formation=self.FORMATION, + role=self.prefix, + statements="\n".join(statements), + lesson=messages[0].content, + topic=self.topic, + language=self.language, + ) + + logger.debug(prompt) + rsp = await self._aask(prompt=prompt) + logger.debug(rsp) + self._set_result(rsp) + return self.rsp + + def _set_result(self, rsp): + if self.DATA_BEGIN_TAG in rsp: + ix = rsp.index(self.DATA_BEGIN_TAG) + rsp = rsp[ix + len(self.DATA_BEGIN_TAG) :] + if self.DATA_END_TAG in rsp: + ix = rsp.index(self.DATA_END_TAG) + rsp = rsp[0:ix] + self.rsp = rsp.strip() + if self.topic != self.COURSE_TITLE: + return + if "#" not in self.rsp or self.rsp.index("#") != 0: + self.rsp = "# " + self.rsp + + def __str__(self): + """Return `topic` value when str()""" + return self.topic + + def __repr__(self): + """Show `topic` value when debug""" + return self.topic + + @staticmethod + def format_value(value): + """Fill parameters inside `value` with `options`.""" + if not isinstance(value, str): + return value + if "{" not in value: + return value + + merged_opts = CONFIG.options or {} + try: + return value.format(**merged_opts) + except KeyError as e: + logger.warning(f"Parameter is missing:{e}") + + for k, v in merged_opts.items(): + value = value.replace("{" + f"{k}" + "}", str(v)) + return value + + FORMATION = ( + '"Capacity and role" defines the role you are currently playing;\n' + '\t"[LESSON_BEGIN]" and "[LESSON_END]" tags enclose the content of textbook;\n' + '\t"Statement" defines the work detail you need to complete at this stage;\n' + '\t"Answer options" defines the format requirements for your responses;\n' + '\t"Constraint" defines the conditions that your responses must comply with.' + ) + + COURSE_TITLE = "Title" + TOPICS = [ + COURSE_TITLE, + "Teaching Hours", + "Teaching Objectives", + "Teaching Content", + "Teaching Methods and Strategies", + "Learning Activities", + "Teaching Time Allocation", + "Assessment and Feedback", + "Teaching Summary and Improvement", + "Vocabulary Cloze", + "Choice Questions", + "Grammar Questions", + "Translation Questions", + ] + + TOPIC_STATEMENTS = { + COURSE_TITLE: [ + "Statement: Find and return the title of the lesson only in markdown first-level header format, " + "without anything else." + ], + "Teaching Content": [ + 'Statement: "Teaching Content" must include vocabulary, analysis, and examples of various grammar ' + "structures that appear in the textbook, as well as the listening materials and key points.", + 'Statement: "Teaching Content" must include more examples.', + ], + "Teaching Time Allocation": [ + 'Statement: "Teaching Time Allocation" must include how much time is allocated to each ' + "part of the textbook content." + ], + "Teaching Methods and Strategies": [ + 'Statement: "Teaching Methods and Strategies" must include teaching focus, difficulties, materials, ' + "procedures, in detail." + ], + "Vocabulary Cloze": [ + 'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", ' + "create vocabulary cloze. The cloze should include 10 {language} questions with {teaching_language} " + "answers, and it should also include 10 {teaching_language} questions with {language} answers. " + "The key-related vocabulary and phrases in the textbook content must all be included in the exercises.", + ], + "Grammar Questions": [ + 'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", ' + "create grammar questions. 10 questions." + ], + "Choice Questions": [ + 'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", ' + "create choice questions. 10 questions." + ], + "Translation Questions": [ + 'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", ' + "create translation questions. The translation should include 10 {language} questions with " + "{teaching_language} answers, and it should also include 10 {teaching_language} questions with " + "{language} answers." + ], + } + + # Teaching plan title + PROMPT_TITLE_TEMPLATE = ( + "Do not refer to the context of the previous conversation records, " + "start the conversation anew.\n\n" + "Formation: {formation}\n\n" + "{statements}\n" + "Constraint: Writing in {language}.\n" + 'Answer options: Encloses the lesson title with "[TEACHING_PLAN_BEGIN]" ' + 'and "[TEACHING_PLAN_END]" tags.\n' + "[LESSON_BEGIN]\n" + "{lesson}\n" + "[LESSON_END]" + ) + + # Teaching plan parts: + PROMPT_TEMPLATE = ( + "Do not refer to the context of the previous conversation records, " + "start the conversation anew.\n\n" + "Formation: {formation}\n\n" + "Capacity and role: {role}\n" + 'Statement: Write the "{topic}" part of teaching plan, ' + 'WITHOUT ANY content unrelated to "{topic}"!!\n' + "{statements}\n" + 'Answer options: Enclose the teaching plan content with "[TEACHING_PLAN_BEGIN]" ' + 'and "[TEACHING_PLAN_END]" tags.\n' + "Answer options: Using proper markdown format from second-level header format.\n" + "Constraint: Writing in {language}.\n" + "[LESSON_BEGIN]\n" + "{lesson}\n" + "[LESSON_END]" + ) + + DATA_BEGIN_TAG = "[TEACHING_PLAN_BEGIN]" + DATA_END_TAG = "[TEACHING_PLAN_END]" diff --git a/metagpt/config.py b/metagpt/config.py index 5176a7677..3c773d780 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -6,12 +6,15 @@ Provide configuration, singleton 1. According to Section 2.2.3.11 of RFC 135, add git repository support. 2. Add the parameter `src_workspace` for the old version project path. """ +import datetime +import json import os import warnings from copy import deepcopy from enum import Enum from pathlib import Path from typing import Any +from uuid import uuid4 import yaml @@ -19,6 +22,7 @@ from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS from metagpt.logs import logger from metagpt.tools import SearchEngineType, WebBrowserEngineType from metagpt.utils.common import require_python_version +from metagpt.utils.cost_manager import CostManager from metagpt.utils.singleton import Singleton @@ -42,6 +46,8 @@ class LLMProviderEnum(Enum): FIREWORKS = "fireworks" OPEN_LLM = "open_llm" GEMINI = "gemini" + METAGPT = "metagpt" + AZURE_OPENAI = "azure_openai" class Config(metaclass=Singleton): @@ -57,7 +63,7 @@ class Config(metaclass=Singleton): key_yaml_file = METAGPT_ROOT / "config/key.yaml" default_yaml_file = METAGPT_ROOT / "config/config.yaml" - def __init__(self, yaml_file=default_yaml_file): + def __init__(self, yaml_file=default_yaml_file, cost_data=""): global_options = OPTIONS.get() # cli paras self.project_path = "" @@ -67,6 +73,8 @@ class Config(metaclass=Singleton): self.max_auto_summarize_code = 0 self._init_with_config_files_and_env(yaml_file) + # The agent needs to be billed per user, so billing information cannot be destroyed when the session ends. + self.cost_manager = CostManager(**json.loads(cost_data)) if cost_data else CostManager() self._update() global_options.update(OPTIONS.get()) logger.debug("Config loading done.") @@ -136,8 +144,7 @@ class Config(metaclass=Singleton): self.long_term_memory = self._get("LONG_TERM_MEMORY", False) if self.long_term_memory: logger.warning("LONG_TERM_MEMORY is True") - self.max_budget = self._get("MAX_BUDGET", 10.0) - self.total_cost = 0.0 + self.cost_manager.max_budget = self._get("MAX_BUDGET", 10.0) self.code_review_k_times = 2 self.puppeteer_config = self._get("PUPPETEER_CONFIG", "") @@ -148,10 +155,17 @@ class Config(metaclass=Singleton): self.mermaid_engine = self._get("MERMAID_ENGINE", "nodejs") self.pyppeteer_executable_path = self._get("PYPPETEER_EXECUTABLE_PATH", "") + workspace_uid = ( + self._get("WORKSPACE_UID") or f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}" + ) self.repair_llm_output = self._get("REPAIR_LLM_OUTPUT", False) - self.prompt_schema = self._get("PROMPT_FORMAT", "json") + self.prompt_format = self._get("PROMPT_FORMAT", "json") self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT)) + val = self._get("WORKSPACE_PATH_WITH_UID") + if val and val.lower() == "true": # for agent + self.workspace_path = self.workspace_path / workspace_uid self._ensure_workspace_exists() + self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1) def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): """update config via cli""" @@ -192,7 +206,8 @@ class Config(metaclass=Singleton): return i.get(*args, **kwargs) def get(self, key, *args, **kwargs): - """Search for a value in config/key.yaml, config/config.yaml, and env; raise an error if not found""" + """Retrieve values from config/key.yaml, config/config.yaml, and environment variables. + Throw an error if not found.""" value = self._get(key, *args, **kwargs) if value is None: raise ValueError(f"Key '{key}' not found in environment variables or in the YAML file") diff --git a/metagpt/const.py b/metagpt/const.py index 3b4f2ae4b..76ddc077c 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -48,7 +48,7 @@ def get_metagpt_root(): # METAGPT PROJECT ROOT AND VARS -METAGPT_ROOT = get_metagpt_root() +METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace" DATA_PATH = METAGPT_ROOT / "data" @@ -100,5 +100,23 @@ TEST_CODES_FILE_REPO = "tests" TEST_OUTPUTS_FILE_REPO = "test_outputs" CODE_SUMMARIES_FILE_REPO = "docs/code_summaries" CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries" +RESOURCES_FILE_REPO = "resources" +SD_OUTPUT_FILE_REPO = "resources/SD_Output" +GRAPH_REPO_FILE_REPO = "docs/graph_repo" +CLASS_VIEW_FILE_REPO = "docs/class_views" YAPI_URL = "http://yapi.deepwisdomai.com/" + +DEFAULT_LANGUAGE = "English" +DEFAULT_MAX_TOKENS = 1500 +COMMAND_TOKENS = 500 +BRAIN_MEMORY = "BRAIN_MEMORY" +SKILL_PATH = "SKILL_PATH" +SERPER_API_KEY = "SERPER_API_KEY" +DEFAULT_TOKEN_SIZE = 500 + +# format +BASE64_FORMAT = "base64" + +# REDIS +REDIS_KEY = "REDIS_KEY" diff --git a/metagpt/document_store/faiss_store.py b/metagpt/document_store/faiss_store.py index b1faa3538..7acaa194d 100644 --- a/metagpt/document_store/faiss_store.py +++ b/metagpt/document_store/faiss_store.py @@ -21,9 +21,12 @@ from metagpt.logs import logger class FaissStore(LocalStore): - def __init__(self, raw_data_path: Path, cache_dir=None, meta_col="source", content_col="output"): + def __init__( + self, raw_data_path: Path, cache_dir=None, meta_col="source", content_col="output", embedding_conf=None + ): self.meta_col = meta_col self.content_col = content_col + self.embedding_conf = embedding_conf or {} super().__init__(raw_data_path, cache_dir) def _load(self) -> Optional["FaissStore"]: @@ -38,7 +41,9 @@ class FaissStore(LocalStore): return store def _write(self, docs, metadatas): - store = FAISS.from_texts(docs, OpenAIEmbeddings(openai_api_version="2020-11-07"), metadatas=metadatas) + store = FAISS.from_texts( + docs, OpenAIEmbeddings(openai_api_version="2020-11-07", **self.embedding_conf), metadatas=metadatas + ) return store def persist(self): diff --git a/metagpt/environment.py b/metagpt/environment.py index e0fb741c0..319abc870 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -17,6 +17,7 @@ from typing import Iterable, Set from pydantic import BaseModel, Field +from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.roles.role import Role, role_subclass_registry from metagpt.schema import Message @@ -105,7 +106,7 @@ class Environment(BaseModel): for role in roles: self.add_role(role) - def publish_message(self, message: Message) -> bool: + def publish_message(self, message: Message, peekable: bool = True) -> bool: """ Distribute the message to the recipients. In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned @@ -170,3 +171,8 @@ class Environment(BaseModel): def set_subscription(self, obj, tags): """Set the labels for message to be consumed by the object""" self.members[obj] = tags + + @staticmethod + def archive(auto_archive=True): + if auto_archive and CONFIG.git_repo: + CONFIG.git_repo.archive() diff --git a/metagpt/learn/__init__.py b/metagpt/learn/__init__.py index 28b8739c3..bab9f3e37 100644 --- a/metagpt/learn/__init__.py +++ b/metagpt/learn/__init__.py @@ -5,3 +5,9 @@ @Author : alexanderwu @File : __init__.py """ + +from metagpt.learn.text_to_image import text_to_image +from metagpt.learn.text_to_speech import text_to_speech +from metagpt.learn.google_search import google_search + +__all__ = ["text_to_image", "text_to_speech", "google_search"] diff --git a/metagpt/learn/google_search.py b/metagpt/learn/google_search.py new file mode 100644 index 000000000..ef099fe94 --- /dev/null +++ b/metagpt/learn/google_search.py @@ -0,0 +1,12 @@ +from metagpt.tools.search_engine import SearchEngine + + +async def google_search(query: str, max_results: int = 6, **kwargs): + """Perform a web search and retrieve search results. + + :param query: The search query. + :param max_results: The number of search results to retrieve + :return: The web search results in markdown format. + """ + resluts = await SearchEngine().run(query, max_results=max_results, as_string=False) + return "\n".join(f"{i}. [{j['title']}]({j['link']}): {j['snippet']}" for i, j in enumerate(resluts, 1)) diff --git a/metagpt/learn/skill_loader.py b/metagpt/learn/skill_loader.py new file mode 100644 index 000000000..dff5e26ae --- /dev/null +++ b/metagpt/learn/skill_loader.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/18 +@Author : mashenquan +@File : skill_loader.py +@Desc : Skill YAML Configuration Loader. +""" +from pathlib import Path +from typing import Dict, List, Optional + +import yaml +from pydantic import BaseModel, Field + +from metagpt.config import CONFIG + + +class Example(BaseModel): + ask: str + answer: str + + +class Returns(BaseModel): + type: str + format: Optional[str] = None + + +class Parameter(BaseModel): + type: str + description: str = None + + +class Skill(BaseModel): + name: str + description: str = None + id: str = None + x_prerequisite: Dict = Field(default=None, alias="x-prerequisite") + parameters: Dict[str, Parameter] = None + examples: List[Example] + returns: Returns + + @property + def arguments(self) -> Dict: + if not self.parameters: + return {} + ret = {} + for k, v in self.parameters.items(): + ret[k] = v.description if v.description else "" + return ret + + +class Entity(BaseModel): + name: str = None + skills: List[Skill] + + +class Components(BaseModel): + pass + + +class SkillsDeclaration(BaseModel): + skillapi: str + entities: Dict[str, Entity] + components: Components = None + + +class SkillLoader: + def __init__(self, skill_yaml_file_name: Path = None): + if not skill_yaml_file_name: + skill_yaml_file_name = Path(__file__).parent.parent.parent / ".well-known/skills.yaml" + with open(str(skill_yaml_file_name), "r") as file: + skills = yaml.safe_load(file) + self._skills = SkillsDeclaration(**skills) + + def get_skill_list(self, entity_name: str = "Assistant") -> Dict: + """Return the skill name based on the skill description.""" + entity = self.get_entity(entity_name) + if not entity: + return {} + + agent_skills = CONFIG.agent_skills + if not agent_skills: + return {} + + class AgentSkill(BaseModel): + name: str + + names = [AgentSkill(**i).name for i in agent_skills] + description_to_name_mappings = {} + for s in entity.skills: + if s.name not in names: + continue + description_to_name_mappings[s.description] = s.name + + return description_to_name_mappings + + def get_skill(self, name, entity_name: str = "Assistant") -> Skill: + """Return a skill by name.""" + entity = self.get_entity(entity_name) + if not entity: + return None + for sk in entity.skills: + if sk.name == name: + return sk + + def get_entity(self, name) -> Entity: + """Return a list of skills for the entity.""" + if not self._skills: + return None + return self._skills.entities.get(name) + + +if __name__ == "__main__": + CONFIG.agent_skills = [ + {"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True}, + {"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True}, + {"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True}, + {"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True}, + {"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True}, + {"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True}, + ] + loader = SkillLoader() + print(loader.get_skill_list()) diff --git a/metagpt/learn/text_to_embedding.py b/metagpt/learn/text_to_embedding.py new file mode 100644 index 000000000..26dab0419 --- /dev/null +++ b/metagpt/learn/text_to_embedding.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/18 +@Author : mashenquan +@File : text_to_embedding.py +@Desc : Text-to-Embedding skill, which provides text-to-embedding functionality. +""" + +from metagpt.config import CONFIG +from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding + + +async def text_to_embedding(text, model="text-embedding-ada-002", openai_api_key="", **kwargs): + """Text to embedding + + :param text: The text used for embedding. + :param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`. + :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` + :return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`. + """ + if CONFIG.OPENAI_API_KEY or openai_api_key: + return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key) + raise EnvironmentError diff --git a/metagpt/learn/text_to_image.py b/metagpt/learn/text_to_image.py new file mode 100644 index 000000000..24669312c --- /dev/null +++ b/metagpt/learn/text_to_image.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/18 +@Author : mashenquan +@File : text_to_image.py +@Desc : Text-to-Image skill, which provides text-to-image functionality. +""" + +from metagpt.config import CONFIG +from metagpt.const import BASE64_FORMAT +from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image +from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image +from metagpt.utils.s3 import S3 + + +async def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs): + """Text to image + + :param text: The text used for image conversion. + :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` + :param size_type: If using OPENAI, the available size options are ['256x256', '512x512', '1024x1024'], while for MetaGPT, the options are ['512x512', '512x768']. + :param model_url: MetaGPT model url + :return: The image data is returned in Base64 encoding. + """ + image_declaration = "data:image/png;base64," + if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url: + base64_data = await oas3_metagpt_text_to_image(text, size_type, model_url) + elif CONFIG.OPENAI_API_KEY or openai_api_key: + base64_data = await oas3_openai_text_to_image(text, size_type, openai_api_key) + else: + raise ValueError("Missing necessary parameters.") + + s3 = S3() + url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if s3.is_valid else "" + if url: + return f"![{text}]({url})" + return image_declaration + base64_data if base64_data else "" diff --git a/metagpt/learn/text_to_speech.py b/metagpt/learn/text_to_speech.py new file mode 100644 index 000000000..72958b8c7 --- /dev/null +++ b/metagpt/learn/text_to_speech.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/17 +@Author : mashenquan +@File : text_to_speech.py +@Desc : Text-to-Speech skill, which provides text-to-speech functionality +""" +import openai + +from metagpt.config import CONFIG +from metagpt.const import BASE64_FORMAT +from metagpt.tools.azure_tts import oas3_azsure_tts +from metagpt.tools.iflytek_tts import oas3_iflytek_tts +from metagpt.utils.s3 import S3 + + +async def text_to_speech( + text, + lang="zh-CN", + voice="zh-CN-XiaomoNeural", + style="affectionate", + role="Girl", + subscription_key="", + region="", + iflytek_app_id="", + iflytek_api_key="", + iflytek_api_secret="", + **kwargs, +): + """Text to speech + For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts` + + :param lang: The value can contain a language code such as en (English), or a locale such as en-US (English - United States). For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts` + :param voice: For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`, `https://speech.microsoft.com/portal/voicegallery` + :param style: Speaking style to express different emotions like cheerfulness, empathy, and calm. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts` + :param role: With roles, the same voice can act as a different age and gender. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts` + :param text: The text used for voice conversion. + :param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint` + :param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API. + :param iflytek_app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts` + :param iflytek_api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts` + :param iflytek_api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts` + :return: Returns the Base64-encoded .wav/.mp3 file data if successful, otherwise an empty string. + + """ + + if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or (subscription_key and region): + audio_declaration = "data:audio/wav;base64," + base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region) + s3 = S3() + url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT) if s3.is_valid else "" + if url: + return f"[{text}]({url})" + return audio_declaration + base64_data if base64_data else base64_data + if (CONFIG.IFLYTEK_APP_ID and CONFIG.IFLYTEK_API_KEY and CONFIG.IFLYTEK_API_SECRET) or ( + iflytek_app_id and iflytek_api_key and iflytek_api_secret + ): + audio_declaration = "data:audio/mp3;base64," + base64_data = await oas3_iflytek_tts( + text=text, app_id=iflytek_app_id, api_key=iflytek_api_key, api_secret=iflytek_api_secret + ) + s3 = S3() + url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT) if s3.is_valid else "" + if url: + return f"[{text}]({url})" + return audio_declaration + base64_data if base64_data else base64_data + + raise openai.InvalidRequestError( + message="AZURE_TTS_SUBSCRIPTION_KEY, AZURE_TTS_REGION, IFLYTEK_APP_ID, IFLYTEK_API_KEY, IFLYTEK_API_SECRET error", + param={}, + ) diff --git a/metagpt/management/skill_manager.py b/metagpt/management/skill_manager.py index b3181b64e..e4892e3d9 100644 --- a/metagpt/management/skill_manager.py +++ b/metagpt/management/skill_manager.py @@ -4,11 +4,11 @@ @Time : 2023/6/5 01:44 @Author : alexanderwu @File : skill_manager.py +@Modified By: mashenquan, 2023/8/20. Remove useless `_llm` """ from metagpt.actions import Action from metagpt.const import PROMPT_PATH from metagpt.document_store.chromadb_store import ChromaStore -from metagpt.llm import LLM from metagpt.logs import logger Skill = Action @@ -18,7 +18,6 @@ class SkillManager: """Used to manage all skills""" def __init__(self): - self._llm = LLM() self._store = ChromaStore("skill_manager") self._skills: dict[str:Skill] = {} diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py new file mode 100644 index 000000000..8aa3be2b6 --- /dev/null +++ b/metagpt/memory/brain_memory.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/18 +@Author : mashenquan +@File : brain_memory.py +@Desc : Support memory for multiple tasks and multiple mainlines. Obsoleted by `utils/*_repository.py`. +@Modified By: mashenquan, 2023/9/4. + redis memory cache. +""" +import json +import re +from enum import Enum +from typing import Dict, List, Optional + +import openai +import pydantic + +from metagpt.config import CONFIG +from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS +from metagpt.llm import LLMType +from metagpt.logs import logger +from metagpt.schema import Message, RawMessage +from metagpt.utils.redis import Redis + + +class MessageType(Enum): + Talk = "TALK" + Solution = "SOLUTION" + Problem = "PROBLEM" + Skill = "SKILL" + Answer = "ANSWER" + + +class BrainMemory(pydantic.BaseModel): + history: List[Dict] = [] + stack: List[Dict] = [] + solution: List[Dict] = [] + knowledge: List[Dict] = [] + historical_summary: str = "" + last_history_id: str = "" + is_dirty: bool = False + last_talk: str = None + llm_type: Optional[str] = None + cacheable: bool = True + + def add_talk(self, msg: Message): + msg.role = "user" + self.add_history(msg) + self.is_dirty = True + + def add_answer(self, msg: Message): + msg.role = "assistant" + self.add_history(msg) + self.is_dirty = True + + def get_knowledge(self) -> str: + texts = [Message(**m).content for m in self.knowledge] + return "\n".join(texts) + + @staticmethod + async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory": + redis = Redis(conf=redis_conf) + if not redis.is_valid() or not redis_key: + return BrainMemory(llm_type=CONFIG.LLM_TYPE) + v = await redis.get(key=redis_key) + logger.debug(f"REDIS GET {redis_key} {v}") + if v: + data = json.loads(v) + bm = BrainMemory(**data) + bm.is_dirty = False + return bm + return BrainMemory(llm_type=CONFIG.LLM_TYPE) + + async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None): + if not self.is_dirty: + return + redis = Redis(conf=redis_conf) + if not redis.is_valid() or not redis_key: + return False + v = self.json() + if self.cacheable: + await redis.set(key=redis_key, data=v, timeout_sec=timeout_sec) + logger.debug(f"REDIS SET {redis_key} {v}") + self.is_dirty = False + + @staticmethod + def to_redis_key(prefix: str, user_id: str, chat_id: str): + return f"{prefix}:{user_id}:{chat_id}" + + async def set_history_summary(self, history_summary, redis_key, redis_conf): + if self.historical_summary == history_summary: + if self.is_dirty: + await self.dumps(redis_key=redis_key, redis_conf=redis_conf) + self.is_dirty = False + return + + self.historical_summary = history_summary + self.history = [] + await self.dumps(redis_key=redis_key, redis_conf=redis_conf) + self.is_dirty = False + + def add_history(self, msg: Message): + if msg.id: + if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1): + return + self.history.append(msg.dict()) + self.last_history_id = str(msg.id) + self.is_dirty = True + + def exists(self, text) -> bool: + for m in reversed(self.history): + if m.get("content") == text: + return True + return False + + @staticmethod + def to_int(v, default_value): + try: + return int(v) + except: + return default_value + + def pop_last_talk(self): + v = self.last_talk + self.last_talk = None + return v + + async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs): + if self.llm_type == LLMType.METAGPT.value: + return await self._metagpt_summarize(llm=llm, max_words=max_words, keep_language=keep_language, **kwargs) + + return await self._openai_summarize( + llm=llm, max_words=max_words, keep_language=keep_language, limit=limit, **kwargs + ) + + async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs): + max_token_count = DEFAULT_MAX_TOKENS + max_count = 100 + texts = [self.historical_summary] + for i in self.history: + m = Message(**i) + texts.append(m.content) + text = "\n".join(texts) + text_length = len(text) + if limit > 0 and text_length < limit: + return text + summary = "" + while max_count > 0: + if text_length < max_token_count: + summary = await self._get_summary(text=text, llm=llm, max_words=max_words, keep_language=keep_language) + break + + padding_size = 20 if max_token_count > 20 else 0 + text_windows = self.split_texts(text, window_size=max_token_count - padding_size) + part_max_words = min(int(max_words / len(text_windows)) + 1, 100) + summaries = [] + for ws in text_windows: + response = await self._get_summary( + text=ws, llm=llm, max_words=part_max_words, keep_language=keep_language + ) + summaries.append(response) + if len(summaries) == 1: + summary = summaries[0] + break + + # Merged and retry + text = "\n".join(summaries) + text_length = len(text) + + max_count -= 1 # safeguard + if summary: + await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS) + return summary + raise openai.InvalidRequestError(message="text too long", param=None) + + async def _metagpt_summarize(self, max_words=200, **kwargs): + if not self.history: + return "" + + total_length = 0 + msgs = [] + for i in reversed(self.history): + m = Message(**i) + delta = len(m.content) + if total_length + delta > max_words: + left = max_words - total_length + if left == 0: + break + m.content = m.content[0:left] + msgs.append(m.dict()) + break + msgs.append(i) + total_length += delta + msgs.reverse() + self.history = msgs + self.is_dirty = True + await self.dumps(redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS_CONF) + self.is_dirty = False + + return BrainMemory.to_metagpt_history_format(self.history) + + @staticmethod + def to_metagpt_history_format(history) -> str: + mmsg = [] + for m in history: + msg = Message(**m) + r = RawMessage(role="user" if MessageType.Talk.value in msg.tags else "assistant", content=msg.content) + mmsg.append(r) + return json.dumps(mmsg) + + @staticmethod + async def _get_summary(text: str, llm, max_words=20, keep_language: bool = False): + """Generate text summary""" + if len(text) < max_words: + return text + if keep_language: + command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly." + else: + command = f"Translate the above content into a summary of less than {max_words} words." + msg = text + "\n\n" + command + logger.debug(f"summary ask:{msg}") + response = await llm.aask(msg=msg, system_msgs=[]) + logger.debug(f"summary rsp: {response}") + return response + + async def get_title(self, llm, max_words=5, **kwargs) -> str: + """Generate text title""" + if self.llm_type == LLMType.METAGPT.value: + return Message(**self.history[0]).content if self.history else "New" + + summary = await self.summarize(llm=llm, max_words=500) + + language = CONFIG.language or DEFAULT_LANGUAGE + command = f"Translate the above summary into a {language} title of less than {max_words} words." + summaries = [summary, command] + msg = "\n".join(summaries) + logger.debug(f"title ask:{msg}") + response = await llm.aask(msg=msg, system_msgs=[]) + logger.debug(f"title rsp: {response}") + return response + + async def is_related(self, text1, text2, llm): + if self.llm_type == LLMType.METAGPT.value: + return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm) + return await self._openai_is_related(text1=text1, text2=text2, llm=llm) + + @staticmethod + async def _metagpt_is_related(**kwargs): + return False + + @staticmethod + async def _openai_is_related(text1, text2, llm, **kwargs): + # command = f"{text1}\n{text2}\n\nIf the two sentences above are related, return [TRUE] brief and clear. Otherwise, return [FALSE]." + command = f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear." + rsp = await llm.aask(msg=command, system_msgs=[]) + result = True if "TRUE" in rsp else False + p2 = text2.replace("\n", "") + p1 = text1.replace("\n", "") + logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n") + return result + + async def rewrite(self, sentence: str, context: str, llm): + if self.llm_type == LLMType.METAGPT.value: + return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm) + return await self._openai_rewrite(sentence=sentence, context=context, llm=llm) + + async def _metagpt_rewrite(self, sentence: str, **kwargs): + return sentence + + async def _openai_rewrite(self, sentence: str, context: str, llm, **kwargs): + # command = ( + # f"{context}\n\nConsidering the content above, rewrite and return this sentence brief and clear:\n{sentence}" + # ) + command = f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly supplement or rewrite the following text in brief and clear:\n{sentence}" + rsp = await llm.aask(msg=command, system_msgs=[]) + logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n") + return rsp + + @staticmethod + def split_texts(text: str, window_size) -> List[str]: + """Splitting long text into sliding windows text""" + if window_size <= 0: + window_size = BrainMemory.DEFAULT_TOKEN_SIZE + total_len = len(text) + if total_len <= window_size: + return [text] + + padding_size = 20 if window_size > 20 else 0 + windows = [] + idx = 0 + data_len = window_size - padding_size + while idx < total_len: + if window_size + idx > total_len: # 不足一个滑窗 + windows.append(text[idx:]) + break + # 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....] + # window_size=3, padding_size=1: + # [1, 2, 3], [3, 4, 5], [5, 6, 7], .... + # idx=2, | idx=5 | idx=8 | ... + w = text[idx : idx + window_size] + windows.append(w) + idx += data_len + + return windows + + @staticmethod + def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"): + match = re.match(pattern, input_string) + if match: + return match.group(1), match.group(2) + else: + return None, input_string + + def set_llm_type(self, v): + if v and v != self.llm_type: + self.llm_type = v + self.is_dirty = True + + @property + def is_history_available(self): + return bool(self.history or self.historical_summary) + + @property + def history_text(self): + if self.llm_type == LLMType.METAGPT.value: + return self._get_metagpt_history_text() + return self._get_openai_history_text() + + def _get_metagpt_history_text(self): + return BrainMemory.to_metagpt_history_format(self.history) + + def _get_openai_history_text(self): + if len(self.history) == 0 and not self.historical_summary: + return "" + texts = [self.historical_summary] if self.historical_summary else [] + for m in self.history[:-1]: + if isinstance(m, Dict): + t = Message(**m).content + elif isinstance(m, Message): + t = m.content + else: + continue + texts.append(t) + + return "\n".join(texts) + + DEFAULT_TOKEN_SIZE = 500 diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 710074f81..1497b8910 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- """ @Desc : the implement of Long-term memory +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ from typing import Optional diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index e9891ed00..d964cc1dc 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -129,3 +129,11 @@ class Memory(BaseModel): continue rsp += self.index[action] return rsp + + def get_by_tags(self, tags: list) -> list[Message]: + """Return messages with specified tags""" + result = [] + for m in self.storage: + if m.is_contain_tags(tags): + result.append(m) + return result diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index fafb33568..3017c23ad 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -1,6 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : the implement of memory storage +""" +@Desc : the implement of memory storage +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. +""" from pathlib import Path from typing import List diff --git a/metagpt/provider/azure_openai_api.py b/metagpt/provider/azure_openai_api.py new file mode 100644 index 000000000..ec5eed3f6 --- /dev/null +++ b/metagpt/provider/azure_openai_api.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/5 23:08 +@Author : alexanderwu +@File : openai.py +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation; + Change cost control from global to company level. +@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout. +@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x. +""" + + +from openai import AsyncAzureOpenAI, AzureOpenAI +from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper + +from metagpt.config import CONFIG, Config, LLMProviderEnum +from metagpt.provider.llm_provider_registry import register_provider +from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter + + +@register_provider(LLMProviderEnum.AZURE_OPENAI) +class AzureOpenAIGPTAPI(OpenAIGPTAPI): + """ + Check https://platform.openai.com/examples for examples + """ + + def __init__(self): + self.config: Config = CONFIG + self.__init_openai() + self.auto_max_tokens = False + # https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix + self._client = AsyncAzureOpenAI( + api_key=CONFIG.openai_api_key, + api_version=CONFIG.openai_api_version, + azure_endpoint=CONFIG.openai_api_base, + ) + RateLimiter.__init__(self, rpm=self.rpm) + + def _make_client(self): + kwargs, async_kwargs = self._make_client_kwargs() + self.client = AzureOpenAI(**kwargs) + self.async_client = AsyncAzureOpenAI(**async_kwargs) + + def _make_client_kwargs(self) -> (dict, dict): + kwargs = dict( + api_key=self.config.openai_api_key, + api_version=self.config.openai_api_version, + azure_endpoint=self.config.openai_base_url, + ) + async_kwargs = kwargs.copy() + + # to use proxy, openai v1 needs http_client + proxy_params = self._get_proxy_params() + if proxy_params: + kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params) + async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) + + return kwargs, async_kwargs + + def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict: + kwargs = { + "messages": messages, + "max_tokens": self.get_max_tokens(messages), + "n": 1, + "stop": None, + "temperature": 0.3, + "model": CONFIG.deployment_id, + } + if configs: + kwargs.update(configs) + try: + default_timeout = int(CONFIG.TIMEOUT) if CONFIG.TIMEOUT else 0 + except ValueError: + default_timeout = 0 + kwargs["timeout"] = max(default_timeout, timeout) + + return kwargs diff --git a/metagpt/provider/base_chatbot.py b/metagpt/provider/base_chatbot.py index a6950f144..535130de7 100644 --- a/metagpt/provider/base_chatbot.py +++ b/metagpt/provider/base_chatbot.py @@ -4,6 +4,7 @@ @Time : 2023/5/5 23:00 @Author : alexanderwu @File : base_chatbot.py +@Modified By: mashenquan, 2023/11/21. Add `timeout`. """ from abc import ABC, abstractmethod from dataclasses import dataclass @@ -17,13 +18,13 @@ class BaseChatbot(ABC): use_system_prompt: bool = True @abstractmethod - def ask(self, msg: str) -> str: + def ask(self, msg: str, timeout=3) -> str: """Ask GPT a question and get an answer""" @abstractmethod - def ask_batch(self, msgs: list) -> str: + def ask_batch(self, msgs: list, timeout=3) -> str: """Ask GPT multiple questions and get a series of answers""" @abstractmethod - def ask_code(self, msgs: list) -> str: + def ask_code(self, msgs: list, timeout=3) -> str: """Ask GPT multiple questions and get a piece of code""" diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index c38576806..f650305e3 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -4,12 +4,12 @@ @Time : 2023/5/5 23:04 @Author : alexanderwu @File : base_gpt_api.py +@Desc : mashenquan, 2023/8/22. + try catch """ import json from abc import abstractmethod from typing import Optional -from metagpt.logs import logger from metagpt.provider.base_chatbot import BaseChatbot @@ -33,62 +33,66 @@ class BaseGPTAPI(BaseChatbot): def _default_system_msg(self): return self._system_msg(self.system_prompt) - def ask(self, msg: str) -> str: + def ask(self, msg: str, timeout=3) -> str: message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)] - rsp = self.completion(message) + rsp = self.completion(message, timeout=timeout) return self.get_choice_text(rsp) - async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream=True) -> str: + async def aask( + self, + msg: str, + system_msgs: Optional[list[str]] = None, + format_msgs: Optional[list[dict[str, str]]] = None, + generator: bool = False, + timeout=3, + stream=True, + ) -> str: if system_msgs: - message = ( - self._system_msgs(system_msgs) + [self._user_msg(msg)] - if self.use_system_prompt - else [self._user_msg(msg)] - ) + message = self._system_msgs(system_msgs) else: - message = ( - [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)] - ) - logger.debug(message) - rsp = await self.acompletion_text(message, stream=stream) + message = [self._default_system_msg()] + if format_msgs: + message.extend(format_msgs) + message.append(self._user_msg(msg)) + rsp = await self.acompletion_text(message, stream=stream, generator=generator, timeout=timeout) # logger.debug(rsp) return rsp def _extract_assistant_rsp(self, context): return "\n".join([i["content"] for i in context if i["role"] == "assistant"]) - def ask_batch(self, msgs: list) -> str: + def ask_batch(self, msgs: list, timeout=3) -> str: context = [] for msg in msgs: umsg = self._user_msg(msg) context.append(umsg) - rsp = self.completion(context) + rsp = self.completion(context, timeout=timeout) rsp_text = self.get_choice_text(rsp) context.append(self._assistant_msg(rsp_text)) return self._extract_assistant_rsp(context) - async def aask_batch(self, msgs: list) -> str: + async def aask_batch(self, msgs: list, timeout=3) -> str: """Sequential questioning""" context = [] for msg in msgs: umsg = self._user_msg(msg) context.append(umsg) - rsp_text = await self.acompletion_text(context) + rsp_text = await self.acompletion_text(context, timeout=timeout) context.append(self._assistant_msg(rsp_text)) return self._extract_assistant_rsp(context) - def ask_code(self, msgs: list[str]) -> str: + def ask_code(self, msgs: list[str], timeout=3) -> str: """FIXME: No code segment filtering has been done here, and all results are actually displayed""" - rsp_text = self.ask_batch(msgs) + rsp_text = self.ask_batch(msgs, timeout=timeout) return rsp_text - async def aask_code(self, msgs: list[str]) -> str: + async def aask_code(self, msgs: list[str], timeout=3) -> str: """FIXME: No code segment filtering has been done here, and all results are actually displayed""" - rsp_text = await self.aask_batch(msgs) + rsp_text = await self.aask_batch(msgs, timeout=timeout) return rsp_text @abstractmethod - def completion(self, messages: list[dict]): + def completion(self, messages: list[dict], timeout=3): """All GPTAPIs are required to provide the standard OpenAI completion interface [ {"role": "system", "content": "You are a helpful assistant."}, @@ -98,7 +102,7 @@ class BaseGPTAPI(BaseChatbot): """ @abstractmethod - async def acompletion(self, messages: list[dict]): + async def acompletion(self, messages: list[dict], timeout=3): """Asynchronous version of completion All GPTAPIs are required to provide the standard OpenAI completion interface [ @@ -109,7 +113,7 @@ class BaseGPTAPI(BaseChatbot): """ @abstractmethod - async def acompletion_text(self, messages: list[dict], stream=False) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: """Asynchronous version of completion. Return str. Support stream-print""" def get_choice_text(self, rsp: dict) -> str: @@ -145,7 +149,7 @@ class BaseGPTAPI(BaseChatbot): :return dict: return first function of choice, for exmaple, {'name': 'execute', 'arguments': '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}'} """ - return rsp.get("choices")[0]["message"]["tool_calls"][0]["function"].to_dict() + return rsp.get("choices")[0]["message"]["tool_calls"][0]["function"] def get_choice_function_arguments(self, rsp: dict) -> dict: """Required to provide the first function arguments of choice. @@ -163,3 +167,8 @@ class BaseGPTAPI(BaseChatbot): def messages_to_dict(self, messages): """objects to [{"role": "user", "content": msg}] etc.""" return [i.to_dict() for i in messages] + + @abstractmethod + async def close(self): + """Close connection""" + pass diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index a76151666..bfe85f490 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -6,7 +6,7 @@ import openai from metagpt.config import CONFIG, LLMProviderEnum from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter +from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter @register_provider(LLMProviderEnum.FIREWORKS) @@ -16,10 +16,11 @@ class FireWorksGPTAPI(OpenAIGPTAPI): self.llm = openai self.model = CONFIG.fireworks_api_model self.auto_max_tokens = False - self._cost_manager = CostManager() RateLimiter.__init__(self, rpm=self.rpm) def __init_fireworks(self, config: "Config"): - openai.api_key = config.fireworks_api_key - openai.api_base = config.fireworks_api_base + # TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you + # instantiate the client, e.g. 'OpenAI(api_base=config.fireworks_api_base)' + # openai.api_key = config.fireworks_api_key + # openai.api_base = config.fireworks_api_base self.rpm = int(config.get("RPM", 10)) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 682f7b507..eb91cc32b 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -23,7 +23,7 @@ from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import CostManager, log_and_reraise +from metagpt.provider.openai_api import log_and_reraise class GeminiGenerativeModel(GenerativeModel): @@ -53,7 +53,6 @@ class GeminiGPTAPI(BaseGPTAPI): self.__init_gemini(CONFIG) self.model = "gemini-pro" # so far only one model self.llm = GeminiGenerativeModel(model_name=self.model) - self._cost_manager = CostManager() def __init_gemini(self, config: CONFIG): genai.configure(api_key=config.gemini_api_key) @@ -76,7 +75,7 @@ class GeminiGPTAPI(BaseGPTAPI): try: prompt_tokens = int(usage.get("prompt_tokens", 0)) completion_tokens = int(usage.get("completion_tokens", 0)) - self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: logger.error(f"google gemini updats costs failed! exp: {e}") diff --git a/metagpt/provider/human_provider.py b/metagpt/provider/human_provider.py index c70a7f1a6..5850dd8dc 100644 --- a/metagpt/provider/human_provider.py +++ b/metagpt/provider/human_provider.py @@ -14,24 +14,35 @@ class HumanProvider(BaseGPTAPI): This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction """ - def ask(self, msg: str) -> str: + def ask(self, msg: str, timeout=3) -> str: logger.info("It's your turn, please type in your response. You may also refer to the context below") rsp = input(msg) if rsp in ["exit", "quit"]: exit() return rsp - async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str: - return self.ask(msg) + async def aask( + self, + msg: str, + system_msgs: Optional[list[str]] = None, + format_msgs: Optional[list[dict[str, str]]] = None, + generator: bool = False, + timeout=3, + ) -> str: + return self.ask(msg, timeout=timeout) - def completion(self, messages: list[dict]): + def completion(self, messages: list[dict], timeout=3): """dummy implementation of abstract method in base""" return [] - async def acompletion(self, messages: list[dict]): + async def acompletion(self, messages: list[dict], timeout=3): """dummy implementation of abstract method in base""" return [] - async def acompletion_text(self, messages: list[dict], stream=False) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """dummy implementation of abstract method in base""" - return [] + return "" + + async def close(self): + """Close connection""" + pass diff --git a/metagpt/provider/metagpt_api.py b/metagpt/provider/metagpt_api.py new file mode 100644 index 000000000..00a42ee2a --- /dev/null +++ b/metagpt/provider/metagpt_api.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/5 23:08 +@Author : alexanderwu +@File : metagpt_api.py +@Desc : MetaGPT LLM provider. +""" +from metagpt.config import LLMProviderEnum +from metagpt.provider import OpenAIGPTAPI +from metagpt.provider.llm_provider_registry import register_provider + + +@register_provider(LLMProviderEnum.METAGPT) +class METAGPTAPI(OpenAIGPTAPI): + def __init__(self): + super().__init__() diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index bada0e294..2e8c03ba1 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -2,47 +2,44 @@ # -*- coding: utf-8 -*- # @Desc : self-host open llm model with openai-compatible interface -import openai from metagpt.config import CONFIG, LLMProviderEnum -from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter +from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter - -class OpenLLMCostManager(CostManager): - """open llm model is self-host, it's free and without cost""" - - def update_cost(self, prompt_tokens, completion_tokens, model): - """ - Update the total cost, prompt tokens, and completion tokens. - - Args: - prompt_tokens (int): The number of tokens used in the prompt. - completion_tokens (int): The number of tokens used in the completion. - model (str): The model used for the API call. - """ - self.total_prompt_tokens += prompt_tokens - self.total_completion_tokens += completion_tokens - - logger.info( - f"Max budget: ${CONFIG.max_budget:.3f} | " - f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" - ) - CONFIG.total_cost = self.total_cost +# class OpenLLMCostManager(CostManager): +# """open llm model is self-host, it's free and without cost""" +# +# def update_cost(self, prompt_tokens, completion_tokens, model): +# """ +# Update the total cost, prompt tokens, and completion tokens. +# +# Args: +# prompt_tokens (int): The number of tokens used in the prompt. +# completion_tokens (int): The number of tokens used in the completion. +# model (str): The model used for the API call. +# """ +# self.total_prompt_tokens += prompt_tokens +# self.total_completion_tokens += completion_tokens +# +# logger.info( +# f"Max budget: ${CONFIG.max_budget:.3f} | " +# f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" +# ) +# CONFIG.total_cost = self.total_cost @register_provider(LLMProviderEnum.OPEN_LLM) class OpenLLMGPTAPI(OpenAIGPTAPI): def __init__(self): self.__init_openllm(CONFIG) - self.llm = openai self.model = CONFIG.open_llm_api_model self.auto_max_tokens = False - self._cost_manager = OpenLLMCostManager() RateLimiter.__init__(self, rpm=self.rpm) def __init_openllm(self, config: "Config"): - openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value - openai.api_base = config.open_llm_api_base + # TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you + # instantiate the client, e.g. 'OpenAI(api_base=config.open_llm_api_base)' + # openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value + # openai.api_base = config.open_llm_api_base self.rpm = int(config.get("RPM", 10)) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 8fd8959b2..ca130ce15 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -3,20 +3,19 @@ @Time : 2023/5/5 23:08 @Author : alexanderwu @File : openai.py +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation; + Change cost control from global to company level. +@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout. +@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x. """ + import asyncio import json import time -from typing import NamedTuple, Union +from typing import List, Union -from openai import ( - APIConnectionError, - AsyncAzureOpenAI, - AsyncOpenAI, - AsyncStream, - AzureOpenAI, - OpenAI, -) +import openai +from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI, RateLimitError from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionChunk @@ -29,15 +28,15 @@ from tenacity import ( ) from metagpt.config import CONFIG, Config, LLMProviderEnum +from metagpt.const import DEFAULT_MAX_TOKENS from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message +from metagpt.utils.cost_manager import Costs from metagpt.utils.exceptions import handle_exception -from metagpt.utils.singleton import Singleton from metagpt.utils.token_counter import ( - TOKEN_COSTS, count_message_tokens, count_string_tokens, get_max_completion_tokens, @@ -69,75 +68,6 @@ class RateLimiter: self.last_call_time = time.time() -class Costs(NamedTuple): - total_prompt_tokens: int - total_completion_tokens: int - total_cost: float - total_budget: float - - -class CostManager(metaclass=Singleton): - """计算使用接口的开销""" - - def __init__(self): - self.total_prompt_tokens = 0 - self.total_completion_tokens = 0 - self.total_cost = 0 - self.total_budget = 0 - - def update_cost(self, prompt_tokens, completion_tokens, model): - """ - Update the total cost, prompt tokens, and completion tokens. - - Args: - prompt_tokens (int): The number of tokens used in the prompt. - completion_tokens (int): The number of tokens used in the completion. - model (str): The model used for the API call. - """ - self.total_prompt_tokens += prompt_tokens - self.total_completion_tokens += completion_tokens - cost = ( - prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"] - ) / 1000 - self.total_cost += cost - logger.info( - f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | " - f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" - ) - CONFIG.total_cost = self.total_cost - - def get_total_prompt_tokens(self): - """ - Get the total number of prompt tokens. - - Returns: - int: The total number of prompt tokens. - """ - return self.total_prompt_tokens - - def get_total_completion_tokens(self): - """ - Get the total number of completion tokens. - - Returns: - int: The total number of completion tokens. - """ - return self.total_completion_tokens - - def get_total_cost(self): - """ - Get the total cost of API calls. - - Returns: - float: The total cost of API calls. - """ - return self.total_cost - - def get_costs(self) -> Costs: - """Get all costs""" - return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget) - - def log_and_reraise(retry_state): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") logger.warning( @@ -159,35 +89,29 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): self.config: Config = CONFIG self.__init_openai() self.auto_max_tokens = False - self._cost_manager = CostManager() + # https://github.com/openai/openai-python#async-usage + self._client = AsyncOpenAI(api_key=CONFIG.openai_api_key, base_url=CONFIG.openai_api_base) RateLimiter.__init__(self, rpm=self.rpm) + async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: + kwargs = self._cons_kwargs(messages, timeout=timeout) + response = await self._client.chat.completions.create(**kwargs, stream=True) + # iterate through the stream of events + async for chunk in response: + chunk_message = chunk.choices[0].delta.content or "" # extract the message + yield chunk_message + def __init_openai(self): - self.is_azure = self.config.openai_api_type == "azure" - self.model = self.config.deployment_name if self.is_azure else self.config.openai_api_model self.rpm = int(self.config.get("RPM", 10)) self._make_client() def _make_client(self): kwargs, async_kwargs = self._make_client_kwargs() - - if self.is_azure: - self.client = AzureOpenAI(**kwargs) - self.async_client = AsyncAzureOpenAI(**async_kwargs) - else: - self.client = OpenAI(**kwargs) - self.async_client = AsyncOpenAI(**async_kwargs) + self.client = OpenAI(**kwargs) + self.async_client = AsyncOpenAI(**async_kwargs) def _make_client_kwargs(self) -> (dict, dict): - if self.is_azure: - kwargs = dict( - api_key=self.config.openai_api_key, - api_version=self.config.openai_api_version, - azure_endpoint=self.config.openai_base_url, - ) - else: - kwargs = dict(api_key=self.config.openai_api_key, base_url=self.config.openai_base_url) - + kwargs = dict(api_key=self.config.openai_api_key, base_url=self.config.openai_base_url) async_kwargs = kwargs.copy() # to use proxy, openai v1 needs http_client @@ -230,36 +154,37 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): self._update_costs(usage) return full_reply_content - def _cons_kwargs(self, messages: list[dict], **configs) -> dict: + def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict: kwargs = { "messages": messages, "max_tokens": self.get_max_tokens(messages), "n": 1, "stop": None, "temperature": 0.3, - "timeout": 3, - "model": self.model, + "model": self.config.openai_api_model, } if configs: kwargs.update(configs) + try: + default_timeout = int(CONFIG.TIMEOUT) if CONFIG.TIMEOUT else 0 + except ValueError: + default_timeout = 0 + kwargs["timeout"] = max(default_timeout, timeout) return kwargs - async def _achat_completion(self, messages: list[dict]) -> ChatCompletion: - rsp: ChatCompletion = await self.async_client.chat.completions.create(**self._cons_kwargs(messages)) + async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion: + kwargs = self._cons_kwargs(messages, timeout=timeout) + rsp: ChatCompletion = await self._client.chat.completions.create(**kwargs) self._update_costs(rsp.usage) return rsp - def _chat_completion(self, messages: list[dict]) -> ChatCompletion: - rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages)) - self._update_costs(rsp.usage) - return rsp + def completion(self, messages: list[dict], timeout=3) -> ChatCompletion: + loop = self.get_event_loop() + return loop.run_until_complete(self.acompletion(messages, timeout=timeout)) - def completion(self, messages: list[dict]) -> ChatCompletion: - return self._chat_completion(messages) - - async def acompletion(self, messages: list[dict]) -> ChatCompletion: - return await self._achat_completion(messages) + async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion: + return await self._achat_completion(messages, timeout=timeout) @retry( wait=wait_random_exponential(min=1, max=60), @@ -268,14 +193,34 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False) -> str: + @retry( + wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(6), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(RateLimitError), + reraise=True, + ) + async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: """when streaming, print each token in place.""" if stream: - return await self._achat_completion_stream(messages) - rsp = await self._achat_completion(messages) + resp = self._achat_completion_stream(messages, timeout=timeout) + if generator: + return resp + + collected_messages = [] + async for i in resp: + print(i, end="") + collected_messages.append(i) + + full_reply_content = "".join(collected_messages) + usage = self._calc_usage(messages, full_reply_content) + self._update_costs(usage) + return full_reply_content + + rsp = await self._achat_completion(messages, timeout=timeout) return self.get_choice_text(rsp) - def _func_configs(self, messages: list[dict], **kwargs) -> dict: + def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict: """ Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create """ @@ -286,17 +231,15 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): } kwargs.update(configs) - return self._cons_kwargs(messages, **kwargs) + return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs) - def _chat_completion_function(self, messages: list[dict], **kwargs) -> ChatCompletion: - rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs)) - self._update_costs(rsp.usage) - return rsp + def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> ChatCompletion: + loop = self.get_event_loop() + return loop.run_until_complete(self._achat_completion_function(messages=messages, timeout=timeout, **kwargs)) - async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> ChatCompletion: - rsp: ChatCompletion = await self.async_client.chat.completions.create( - **self._func_configs(messages, **chat_configs) - ) + async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion: + kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs) + rsp: ChatCompletion = await self._client.chat.completions.create(**kwargs) self._update_costs(rsp.usage) return rsp @@ -349,8 +292,27 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): >>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} """ messages = self._process_message(messages) - rsp = await self._achat_completion_function(messages, **kwargs) - return self.get_choice_function_arguments(rsp) + try: + rsp = await self._achat_completion_function(messages, **kwargs) + return self.get_choice_function_arguments(rsp) + except openai.NotFoundError as e: + logger.error(f"API TYPE:{CONFIG.openai_api_type}, err:{e}") + raise e + + def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage: + if CONFIG.calc_usage: + try: + prompt_tokens = count_message_tokens(messages, self.model) + completion_tokens = count_string_tokens(rsp, self.model) + usage = CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + return usage + except Exception as e: + logger.error(f"{self.model} usage calculation failed!", e) + return CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict: """Required to provide the first function arguments of choice. @@ -380,7 +342,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return usage - async def acompletion_batch(self, batch: list[list[dict]]) -> list[ChatCompletion]: + async def acompletion_batch(self, batch: list[list[dict]], timeout=3) -> list[ChatCompletion]: """Return full JSON""" split_batches = self.split_batches(batch) all_results = [] @@ -389,16 +351,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): logger.info(small_batch) await self.wait_if_needed(len(small_batch)) - future = [self.acompletion(prompt) for prompt in small_batch] + future = [self.acompletion(prompt, timeout=timeout) for prompt in small_batch] results = await asyncio.gather(*future) logger.info(results) all_results.extend(results) return all_results - async def acompletion_batch_text(self, batch: list[list[dict]]) -> list[str]: + async def acompletion_batch_text(self, batch: list[list[dict]], timeout=3) -> list[str]: """Only return plain text""" - raw_results = await self.acompletion_batch(batch) + raw_results = await self.acompletion_batch(batch, timeout=timeout) results = [] for idx, raw_result in enumerate(raw_results, start=1): result = self.get_choice_text(raw_result) @@ -409,18 +371,112 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): def _update_costs(self, usage: CompletionUsage): if CONFIG.calc_usage and usage: try: - self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) + CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) except Exception as e: logger.error("updating costs failed!", e) def get_costs(self) -> Costs: - return self._cost_manager.get_costs() + return CONFIG.cost_manager.get_costs() def get_max_tokens(self, messages: list[dict]): if not self.auto_max_tokens: return CONFIG.max_tokens_rsp return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) + def moderation(self, content: Union[str, list[str]]): + loop = self.get_event_loop() + loop.run_until_complete(self.amoderation(content=content)) + @handle_exception async def amoderation(self, content: Union[str, list[str]]): - return await self.async_client.moderations.create(input=content) + return await self._client.moderations.create(input=content) + + async def close(self): + """Close connection""" + if not self._client: + return + await self._client.close() + self._client = None + + @staticmethod + def get_event_loop(): + try: + return asyncio.get_event_loop() + except RuntimeError as e: + if "There is no current event loop in thread" in str(e): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + else: + raise e + + async def summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs) -> str: + max_token_count = DEFAULT_MAX_TOKENS + max_count = 100 + text_length = len(text) + if limit > 0 and text_length < limit: + return text + summary = "" + while max_count > 0: + if text_length < max_token_count: + summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language) + break + + padding_size = 20 if max_token_count > 20 else 0 + text_windows = self.split_texts(text, window_size=max_token_count - padding_size) + part_max_words = min(int(max_words / len(text_windows)) + 1, 100) + summaries = [] + for ws in text_windows: + response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language) + summaries.append(response) + if len(summaries) == 1: + summary = summaries[0] + break + + # Merged and retry + text = "\n".join(summaries) + text_length = len(text) + + max_count -= 1 # safeguard + return summary + + async def _get_summary(self, text: str, max_words=20, keep_language: bool = False): + """Generate text summary""" + if len(text) < max_words: + return text + if keep_language: + command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly." + else: + command = f"Translate the above content into a summary of less than {max_words} words." + msg = text + "\n\n" + command + logger.debug(f"summary ask:{msg}") + response = await self.aask(msg=msg, system_msgs=[]) + logger.debug(f"summary rsp: {response}") + return response + + @staticmethod + def split_texts(text: str, window_size) -> List[str]: + """Splitting long text into sliding windows text""" + if window_size <= 0: + window_size = DEFAULT_TOKEN_SIZE + total_len = len(text) + if total_len <= window_size: + return [text] + + padding_size = 20 if window_size > 20 else 0 + windows = [] + idx = 0 + data_len = window_size - padding_size + while idx < total_len: + if window_size + idx > total_len: # 不足一个滑窗 + windows.append(text[idx:]) + break + # 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....] + # window_size=3, padding_size=1: + # [1, 2, 3], [3, 4, 5], [5, 6, 7], .... + # idx=2, | idx=5 | idx=8 | ... + w = text[idx : idx + window_size] + windows.append(w) + idx += data_len + + return windows diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 60d9a0777..54f0ddcbb 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -5,7 +5,6 @@ import json from enum import Enum -import openai import zhipuai from requests import ConnectionError from tenacity import ( @@ -20,7 +19,7 @@ from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import CostManager, log_and_reraise +from metagpt.provider.openai_api import log_and_reraise from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI @@ -44,12 +43,12 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): self.__init_zhipuai(CONFIG) self.llm = ZhiPuModelAPI self.model = "chatglm_turbo" # so far only one model, just use it - self._cost_manager = CostManager() def __init_zhipuai(self, config: CONFIG): assert config.zhipuai_api_key zhipuai.api_key = config.zhipuai_api_key - openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used. + # due to use openai sdk, set the api_key but it will't be used. + # openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used. def _const_kwargs(self, messages: list[dict]) -> dict: kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3} @@ -61,7 +60,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): try: prompt_tokens = int(usage.get("prompt_tokens", 0)) completion_tokens = int(usage.get("completion_tokens", 0)) - self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: logger.error(f"zhipuai updats costs failed! exp: {e}") diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py index 3524a5bce..9f3a1bac4 100644 --- a/metagpt/repo_parser.py +++ b/metagpt/repo_parser.py @@ -5,19 +5,49 @@ @Author : alexanderwu @File : repo_parser.py """ +from __future__ import annotations + import ast import json +import re +import subprocess from pathlib import Path from pprint import pformat +from typing import Dict, List, Optional, Tuple +import aiofiles import pandas as pd from pydantic import BaseModel, Field from metagpt.config import CONFIG from metagpt.logs import logger +from metagpt.utils.common import any_to_str from metagpt.utils.exceptions import handle_exception +class RepoFileInfo(BaseModel): + file: str + classes: List = Field(default_factory=list) + functions: List = Field(default_factory=list) + globals: List = Field(default_factory=list) + page_info: List = Field(default_factory=list) + + +class CodeBlockInfo(BaseModel): + lineno: int + end_lineno: int + type_name: str + tokens: List = Field(default_factory=list) + properties: Dict = Field(default_factory=dict) + + +class ClassInfo(BaseModel): + name: str + package: Optional[str] = None + attributes: Dict[str, str] = Field(default_factory=dict) + methods: Dict[str, str] = Field(default_factory=dict) + + class RepoParser(BaseModel): base_directory: Path = Field(default=None) @@ -27,31 +57,32 @@ class RepoParser(BaseModel): """Parse a Python file in the repository.""" return ast.parse(file_path.read_text()).body - def extract_class_and_function_info(self, tree, file_path): + def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo: """Extract class, function, and global variable information from the AST.""" - file_info = { - "file": str(file_path.relative_to(self.base_directory)), - "classes": [], - "functions": [], - "globals": [], - } - + file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory))) for node in tree: + info = RepoParser.node_to_str(node) + file_info.page_info.append(info) if isinstance(node, ast.ClassDef): class_methods = [m.name for m in node.body if is_func(m)] - file_info["classes"].append({"name": node.name, "methods": class_methods}) + file_info.classes.append({"name": node.name, "methods": class_methods}) elif is_func(node): - file_info["functions"].append(node.name) + file_info.functions.append(node.name) elif isinstance(node, (ast.Assign, ast.AnnAssign)): for target in node.targets if isinstance(node, ast.Assign) else [node.target]: if isinstance(target, ast.Name): - file_info["globals"].append(target.id) + file_info.globals.append(target.id) return file_info - def generate_symbols(self): + def generate_symbols(self) -> List[RepoFileInfo]: files_classes = [] directory = self.base_directory - for path in directory.rglob("*.py"): + + matching_files = [] + extensions = ["*.py", "*.js"] + for ext in extensions: + matching_files += directory.rglob(ext) + for path in matching_files: tree = self._parse_file(path) file_info = self.extract_class_and_function_info(tree, path) files_classes.append(file_info) @@ -79,6 +110,215 @@ class RepoParser(BaseModel): elif mode == "csv": self.generate_dataframe_structure(output_path) + @staticmethod + def node_to_str(node) -> (int, int, str, str | Tuple): + if any_to_str(node) == any_to_str(ast.Expr): + return CodeBlockInfo( + lineno=node.lineno, + end_lineno=node.end_lineno, + type_name=any_to_str(node), + tokens=RepoParser._parse_expr(node), + ) + mappings = { + any_to_str(ast.Import): lambda x: [RepoParser._parse_name(n) for n in x.names], + any_to_str(ast.Assign): RepoParser._parse_assign, + any_to_str(ast.ClassDef): lambda x: x.name, + any_to_str(ast.FunctionDef): lambda x: x.name, + any_to_str(ast.ImportFrom): lambda x: { + "module": x.module, + "names": [RepoParser._parse_name(n) for n in x.names], + }, + any_to_str(ast.If): RepoParser._parse_if, + any_to_str(ast.AsyncFunctionDef): lambda x: x.name, + } + func = mappings.get(any_to_str(node)) + if func: + code_block = CodeBlockInfo(lineno=node.lineno, end_lineno=node.end_lineno, type_name=any_to_str(node)) + val = func(node) + if isinstance(val, dict): + code_block.properties = val + elif isinstance(val, list): + code_block.tokens = val + elif isinstance(val, str): + code_block.tokens = [val] + else: + raise NotImplementedError(f"Not implement:{val}") + return code_block + raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}") + + @staticmethod + def _parse_expr(node) -> List: + funcs = { + any_to_str(ast.Constant): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)], + any_to_str(ast.Call): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value.func)], + } + func = funcs.get(any_to_str(node.value)) + if func: + return func(node) + raise NotImplementedError(f"Not implement: {node.value}") + + @staticmethod + def _parse_name(n): + if n.asname: + return f"{n.name} as {n.asname}" + return n.name + + @staticmethod + def _parse_if(n): + tokens = [RepoParser._parse_variable(n.test.left)] + for item in n.test.comparators: + tokens.append(RepoParser._parse_variable(item)) + return tokens + + @staticmethod + def _parse_variable(node): + funcs = { + any_to_str(ast.Constant): lambda x: x.value, + any_to_str(ast.Name): lambda x: x.id, + any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}", + } + func = funcs.get(any_to_str(node)) + if not func: + raise NotImplementedError(f"Not implement:{node}") + return func(node) + + @staticmethod + def _parse_assign(node): + return [RepoParser._parse_variable(t) for t in node.targets] + + async def rebuild_class_views(self, path: str | Path = None): + if not path: + path = self.base_directory + path = Path(path) + if not path.exists(): + return + command = f"pyreverse {str(path)} -o dot" + result = subprocess.run(command, shell=True, check=True, cwd=str(path)) + if result.returncode != 0: + raise ValueError(f"{result}") + class_view_pathname = path / "classes.dot" + class_views = await self._parse_classes(class_view_pathname) + packages_pathname = path / "packages.dot" + class_views = RepoParser._repair_namespaces(class_views=class_views, path=path) + class_view_pathname.unlink(missing_ok=True) + packages_pathname.unlink(missing_ok=True) + return class_views + + async def _parse_classes(self, class_view_pathname): + class_views = [] + if not class_view_pathname.exists(): + return class_views + async with aiofiles.open(str(class_view_pathname), mode="r") as reader: + lines = await reader.readlines() + for line in lines: + package_name, info = RepoParser._split_class_line(line) + if not package_name: + continue + class_name, members, functions = re.split(r"(?" + if begin_flag not in left or end_flag not in left: + return None, None + bix = left.find(begin_flag) + eix = left.rfind(end_flag) + info = left[bix + len(begin_flag) : eix] + info = re.sub(r"]*>", "\n", info) + return class_name, info + + @staticmethod + def _create_path_mapping(path: str | Path) -> Dict[str, str]: + mappings = { + str(path).replace("/", "."): str(path), + } + files = [] + try: + directory_path = Path(path) + if not directory_path.exists(): + return mappings + for file_path in directory_path.iterdir(): + if file_path.is_file(): + files.append(str(file_path)) + else: + subfolder_files = RepoParser._create_path_mapping(path=file_path) + mappings.update(subfolder_files) + except Exception as e: + logger.error(f"Error: {e}") + for f in files: + mappings[str(Path(f).with_suffix("")).replace("/", ".")] = str(f) + + return mappings + + @staticmethod + def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]: + if not class_views: + return [] + c = class_views[0] + full_key = str(path).lstrip("/").replace("/", ".") + root_namespace = RepoParser._find_root(full_key, c.package) + root_path = root_namespace.replace(".", "/") + + mappings = RepoParser._create_path_mapping(path=path) + new_mappings = {} + ix_root_namespace = len(root_namespace) + ix_root_path = len(root_path) + for k, v in mappings.items(): + nk = k[ix_root_namespace:] + nv = v[ix_root_path:] + new_mappings[nk] = nv + + for c in class_views: + c.package = RepoParser._repair_ns(c.package, new_mappings) + return class_views + + @staticmethod + def _repair_ns(package, mappings): + file_ns = package + while file_ns != "": + if file_ns not in mappings: + ix = file_ns.rfind(".") + file_ns = file_ns[0:ix] + continue + break + internal_ns = package[ix + 1 :] + ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":") + return ns + + @staticmethod + def _find_root(full_key, package) -> str: + left = full_key + while left != "": + if left in package: + break + if "." not in left: + break + ix = left.find(".") + left = left[ix + 1 :] + ix = full_key.rfind(left) + return "." + full_key[0:ix] + def is_func(node): return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) diff --git a/metagpt/roles/assistant.py b/metagpt/roles/assistant.py new file mode 100644 index 000000000..84ca07c9a --- /dev/null +++ b/metagpt/roles/assistant.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/7 +@Author : mashenquan +@File : assistant.py +@Desc : I am attempting to incorporate certain symbol concepts from UML into MetaGPT, enabling it to have the + ability to freely construct flows through symbol concatenation. Simultaneously, I am also striving to + make these symbols configurable and standardized, making the process of building flows more convenient. + For more about `fork` node in activity diagrams, see: `https://www.uml-diagrams.org/activity-diagrams.html` + This file defines a `fork` style meta role capable of generating arbitrary roles at runtime based on a + configuration file. +@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false + indicates that further reasoning cannot continue. + +""" +import asyncio +from pathlib import Path + +from metagpt.actions import ActionOutput +from metagpt.actions.skill_action import ArgumentsParingAction, SkillAction +from metagpt.actions.talk_action import TalkAction +from metagpt.config import CONFIG +from metagpt.learn.skill_loader import SkillLoader +from metagpt.logs import logger +from metagpt.memory.brain_memory import BrainMemory, MessageType +from metagpt.roles import Role +from metagpt.schema import Message + + +class Assistant(Role): + """Assistant for solving common issues.""" + + def __init__( + self, + name="Lily", + profile="An assistant", + goal="Help to solve problem", + constraints="Talk in {language}", + desc="", + *args, + **kwargs, + ): + super(Assistant, self).__init__( + name=name, profile=profile, goal=goal, constraints=constraints, desc=desc, *args, **kwargs + ) + brain_memory = CONFIG.BRAIN_MEMORY + self.memory = BrainMemory(**brain_memory) if brain_memory else BrainMemory(llm_type=CONFIG.LLM_TYPE) + skill_path = Path(CONFIG.SKILL_PATH) if CONFIG.SKILL_PATH else None + self.skills = SkillLoader(skill_yaml_file_name=skill_path) + + async def think(self) -> bool: + """Everything will be done part by part.""" + last_talk = await self.refine_memory() + if not last_talk: + return False + prompt = "" + skills = self.skills.get_skill_list() + for desc, name in skills.items(): + prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n" + prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n' + prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n" + rsp = await self._llm.aask(prompt, []) + logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n") + return await self._plan(rsp, last_talk=last_talk) + + async def act(self) -> ActionOutput: + result = await self._rc.todo.run(**CONFIG.options) + if not result: + return None + if isinstance(result, str): + msg = Message(content=result) + output = ActionOutput(content=result) + else: + msg = Message( + content=result.content, instruct_content=result.instruct_content, cause_by=type(self._rc.todo) + ) + output = result + self.memory.add_answer(msg) + return output + + async def talk(self, text): + self.memory.add_talk(Message(content=text)) + + async def _plan(self, rsp: str, **kwargs) -> bool: + skill, text = BrainMemory.extract_info(input_string=rsp) + handlers = { + MessageType.Talk.value: self.talk_handler, + MessageType.Skill.value: self.skill_handler, + } + handler = handlers.get(skill, self.talk_handler) + return await handler(text, **kwargs) + + async def talk_handler(self, text, **kwargs) -> bool: + history = self.memory.history_text + text = kwargs.get("last_talk") or text + action = TalkAction( + talk=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self._llm, **kwargs + ) + self.add_to_do(action) + return True + + async def skill_handler(self, text, **kwargs) -> bool: + last_talk = kwargs.get("last_talk") + skill = self.skills.get_skill(text) + if not skill: + logger.info(f"skill not found: {text}") + return await self.talk_handler(text=last_talk, **kwargs) + action = ArgumentsParingAction(skill=skill, llm=self._llm, **kwargs) + await action.run(**kwargs) + if action.args is None: + return await self.talk_handler(text=last_talk, **kwargs) + action = SkillAction(skill=skill, args=action.args, llm=self._llm, name=skill.name, desc=skill.description) + self.add_to_do(action) + return True + + async def refine_memory(self) -> str: + last_talk = self.memory.pop_last_talk() + if last_talk is None: # No user feedback, unsure if past conversation is finished. + return None + if not self.memory.is_history_available: + return last_talk + history_summary = await self.memory.summarize(max_words=800, keep_language=True, llm=self._llm) + if last_talk and await self.memory.is_related(text1=last_talk, text2=history_summary, llm=self._llm): + # Merge relevant content. + last_talk = await self.memory.rewrite(sentence=last_talk, context=history_summary, llm=self._llm) + return last_talk + + return last_talk + + def get_memory(self) -> str: + return self.memory.json() + + def load_memory(self, jsn): + try: + self.memory = BrainMemory(**jsn) + except Exception as e: + logger.exception(f"load error:{e}, data:{jsn}") + + +async def main(): + topic = "what's apple" + role = Assistant(language="Chinese") + await role.talk(topic) + while True: + has_action = await role.think() + if not has_action: + break + msg = await role.act() + logger.info(msg) + # Retrieve user terminal input. + logger.info("Enter prompt") + talk = input("You: ") + await role.talk(talk) + + +if __name__ == "__main__": + CONFIG.language = "Chinese" + asyncio.run(main()) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index e0234f378..12deaa5bb 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -43,7 +43,7 @@ from metagpt.schema import ( Documents, Message, ) -from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set IS_PASS_PROMPT = """ {context} @@ -85,6 +85,9 @@ class Engineer(Role): self._init_actions([WriteCode]) self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug]) + self.code_todos = [] + self.summarize_todos = [] + self._next_todo = any_to_name(WriteCode) @staticmethod def _parse_tasks(task_msg: Document) -> list[str]: @@ -128,8 +131,10 @@ class Engineer(Role): if self._rc.todo is None: return None if isinstance(self._rc.todo, WriteCode): + self._next_todo = any_to_name(SummarizeCode) return await self._act_write_code() if isinstance(self._rc.todo, SummarizeCode): + self._next_todo = any_to_name(WriteCode) return await self._act_summarize() return None @@ -301,3 +306,8 @@ class Engineer(Role): self.summarize_todos.append(SummarizeCode(context=ctx, llm=self._llm)) if self.summarize_todos: self._rc.todo = self.summarize_todos[0] + + @property + def todo(self) -> str: + """AgentStore uses this attribute to display to the user what actions the current role should take.""" + return self._next_todo diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index c794ad2eb..0f18c9cb2 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -12,6 +12,7 @@ from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.config import CONFIG from metagpt.roles.role import Role +from metagpt.utils.common import any_to_name class ProductManager(Role): @@ -35,6 +36,7 @@ class ProductManager(Role): self._init_actions([PrepareDocuments, WritePRD]) self._watch([UserRequirement, PrepareDocuments]) + self._todo = any_to_name(PrepareDocuments) async def _think(self) -> None: """Decide what to do""" @@ -42,7 +44,13 @@ class ProductManager(Role): self._set_state(1) else: self._set_state(0) + self._todo = any_to_name(WritePRD) return self._rc.todo async def _observe(self, ignore_memory=False) -> int: - return await super()._observe(ignore_memory=True) + return await super(ProductManager, self)._observe(ignore_memory=True) + + @property + def todo(self) -> str: + """AgentStore uses this attribute to display to the user what actions the current role should take.""" + return self._todo diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index 162d72b9b..fc6afa1fd 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -1,10 +1,10 @@ #!/usr/bin/env python """ +@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false indicates that further reasoning cannot continue. @Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, change the data type of the `cause_by` value in the `Message` to a string to support the new message distribution feature. """ - import asyncio from pydantic import BaseModel @@ -41,6 +41,17 @@ class Researcher(Role): if language not in ("en-us", "zh-cn"): logger.warning(f"The language `{language}` has not been tested, it may not work.") + async def _think(self) -> bool: + if self._rc.todo is None: + self._set_state(0) + return True + + if self._rc.state + 1 < len(self._states): + self._set_state(self._rc.state + 1) + else: + self._rc.todo = None + return False + async def _act(self) -> Message: logger.info(f"{self._setting}: ready to {self._rc.todo}") todo = self._rc.todo diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 9d0898b71..528e7d72d 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -4,6 +4,7 @@ @Time : 2023/5/11 14:42 @Author : alexanderwu @File : role.py +@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false indicates that further reasoning cannot continue. @Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116: 1. Merge the `recv` functionality into the `_observe` function. Future message reading operations will be consolidated within the `_observe` function. @@ -38,6 +39,7 @@ from metagpt.memory import Memory from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Message, MessageQueue from metagpt.utils.common import ( + any_to_name, any_to_str, import_class, read_json_file, @@ -118,7 +120,7 @@ class RoleContext(BaseModel): @property def important_memory(self) -> list[Message]: - """Get the information corresponding to the watched actions""" + """Retrieve information corresponding to the attention action.""" return self.memory.get_by_actions(self.watch) @property @@ -208,8 +210,7 @@ class Role(BaseModel): if "actions" in kwargs: self._init_actions(kwargs["actions"]) - if "watch" in kwargs: - self._watch(kwargs["watch"]) + self._watch(kwargs.get("watch") or [UserRequirement]) def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) @@ -315,6 +316,9 @@ class Role(BaseModel): # check RoleContext after adding watch actions self._rc.check(self._role_id) + def is_watch(self, caused_by: str): + return caused_by in self._rc.watch + def subscribe(self, tags: Set[str]): """Used to receive Messages with certain tags from the environment. Message will be put into personal message buffer to be further processed in _observe. By default, a Role subscribes Messages with a tag of its own name @@ -340,7 +344,12 @@ class Role(BaseModel): @property def subscription(self) -> Set: """The labels for messages to be consumed by the Role object.""" - return self._subscription + return set(self._subscription) + + @property + def action_count(self): + """Return number of action""" + return len(self._actions) def _get_prefix(self): """Get the role prefix""" @@ -357,16 +366,18 @@ class Role(BaseModel): prefix += env_desc return prefix - async def _think(self) -> None: - """Think about what to do and decide on the next action""" + async def _think(self) -> bool: + """Consider what to do and decide on the next course of action. Return false if nothing can be done.""" if len(self._actions) == 1: # If there is only one action, then only this one can be performed self._set_state(0) - return + + return True + if self.recovered and self._rc.state >= 0: self._set_state(self._rc.state) # action to run from recovered state self.recovered = False # avoid max_react_loop out of work - return + return True prompt = self._get_prefix() prompt += STATE_TEMPLATE.format( @@ -388,6 +399,7 @@ class Role(BaseModel): if next_state == -1: logger.info(f"End actions with {next_state=}") self._set_state(next_state) + return True async def _act(self) -> Message: logger.info(f"{self._setting}: ready to {self._rc.todo}") @@ -498,23 +510,6 @@ class Role(BaseModel): self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None return rsp - # # Replaced by run() - # def recv(self, message: Message) -> None: - # """add message to history.""" - # # self._history += f"\n{message}" - # # self._context = self._history - # if message in self._rc.memory.get(): - # return - # self._rc.memory.add(message) - - # # Replaced by run() - # async def handle(self, message: Message) -> Message: - # """Receive information and reply with actions""" - # # logger.debug(f"{self.name=}, {self.profile=}, {message.role=}") - # self.recv(message) - # - # return await self._react() - def get_memories(self, k=0) -> list[Message]: """A wrapper to return the most recent k memories of this role, return all when k=0""" return self._rc.memory.get(k=k) @@ -551,3 +546,20 @@ class Role(BaseModel): def is_idle(self) -> bool: """If true, all actions have been executed.""" return not self._rc.news and not self._rc.todo and self._rc.msg_buffer.empty() + + async def think(self) -> Action: + """The exported `think` function""" + await self._think() + return self._rc.todo + + async def act(self) -> ActionOutput: + """The exported `act` function""" + msg = await self._act() + return ActionOutput(content=msg.content, instruct_content=msg.instruct_content) + + @property + def todo(self) -> str: + """AgentStore uses this attribute to display to the user what actions the current role should take.""" + if self._actions: + return any_to_name(self._actions[0]) + return "" diff --git a/metagpt/roles/teacher.py b/metagpt/roles/teacher.py new file mode 100644 index 000000000..031ce94c9 --- /dev/null +++ b/metagpt/roles/teacher.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/7/27 +@Author : mashenquan +@File : teacher.py +@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false indicates that further reasoning cannot continue. + +""" + + +import re + +import aiofiles + +from metagpt.actions.write_teaching_plan import ( + TeachingPlanRequirement, + WriteTeachingPlanPart, +) +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.roles import Role +from metagpt.schema import Message + + +class Teacher(Role): + """Support configurable teacher roles, + with native and teaching languages being replaceable through configurations.""" + + def __init__( + self, + name="Lily", + profile="{teaching_language} Teacher", + goal="writing a {language} teaching plan part by part", + constraints="writing in {language}", + desc="", + *args, + **kwargs, + ): + super().__init__(name=name, profile=profile, goal=goal, constraints=constraints, desc=desc, *args, **kwargs) + actions = [] + for topic in WriteTeachingPlanPart.TOPICS: + act = WriteTeachingPlanPart(topic=topic, llm=self._llm) + actions.append(act) + self._init_actions(actions) + self._watch({TeachingPlanRequirement}) + + async def _think(self) -> bool: + """Everything will be done part by part.""" + if self._rc.todo is None: + self._set_state(0) + return True + + if self._rc.state + 1 < len(self._states): + self._set_state(self._rc.state + 1) + return True + + self._rc.todo = None + return False + + async def _react(self) -> Message: + ret = Message(content="") + while True: + await self._think() + if self._rc.todo is None: + break + logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}") + msg = await self._act() + if ret.content != "": + ret.content += "\n\n\n" + ret.content += msg.content + logger.info(ret.content) + await self.save(ret.content) + return ret + + async def save(self, content): + """Save teaching plan""" + filename = Teacher.new_file_name(self.course_title) + pathname = CONFIG.workspace / "teaching_plan" + pathname.mkdir(exist_ok=True) + pathname = pathname / filename + try: + async with aiofiles.open(str(pathname), mode="w", encoding="utf-8") as writer: + await writer.write(content) + except Exception as e: + logger.error(f"Save failed:{e}") + logger.info(f"Save to:{pathname}") + + @staticmethod + def new_file_name(lesson_title, ext=".md"): + """Create a related file name based on `lesson_title` and `ext`.""" + # Define the special characters that need to be replaced. + illegal_chars = r'[#@$%!*&\\/:*?"<>|\n\t \']' + # Replace the special characters with underscores. + filename = re.sub(illegal_chars, "_", lesson_title) + ext + return re.sub(r"_+", "_", filename) + + @property + def course_title(self): + """Return course title of teaching plan""" + default_title = "teaching_plan" + for act in self._actions: + if act.topic != WriteTeachingPlanPart.COURSE_TITLE: + continue + if act.rsp is None: + return default_title + title = act.rsp.lstrip("# \n") + if "\n" in title: + ix = title.index("\n") + title = title[0:ix] + return title + + return default_title diff --git a/metagpt/schema.py b/metagpt/schema.py index 51921763d..60b9a6998 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -162,8 +162,7 @@ class Message(BaseModel): # prefix = '-'.join([self.role, str(self.cause_by)]) if self.instruct_content: return f"{self.role}: {self.instruct_content.dict()}" - else: - return f"{self.role}: {self.content}" + return f"{self.role}: {self.content}" def __repr__(self): return self.__str__() @@ -180,8 +179,19 @@ class Message(BaseModel): @handle_exception(exception_type=JSONDecodeError, default_return=None) def load(val): """Convert the json string to object.""" - i = json.loads(val) - return Message(**i) + + try: + m = json.loads(val) + id = m.get("id") + if "id" in m: + del m["id"] + msg = Message(**m) + if id: + msg.id = id + return msg + except JSONDecodeError as err: + logger.error(f"parse json failed: {val}, error:{err}") + return None class UserMessage(Message): diff --git a/metagpt/team.py b/metagpt/team.py index 0b9f042df..625903e3e 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -90,9 +90,12 @@ class Team(BaseModel): CONFIG.max_budget = investment logger.info(f"Investment: ${investment}.") - def _check_balance(self): - if CONFIG.total_cost > CONFIG.max_budget: - raise NoMoneyException(CONFIG.total_cost, f"Insufficient funds: {CONFIG.max_budget}") + @staticmethod + def _check_balance(): + if CONFIG.cost_manager.total_cost > CONFIG.cost_manager.max_budget: + raise NoMoneyException( + CONFIG.cost_manager.total_cost, f"Insufficient funds: {CONFIG.cost_manager.max_budget}" + ) def run_project(self, idea, send_to: str = ""): """Run a project from publishing user requirement.""" @@ -100,7 +103,8 @@ class Team(BaseModel): # Human requirement. self.env.publish_message( - Message(role="Human", content=idea, cause_by=UserRequirement, send_to=send_to or MESSAGE_ROUTE_TO_ALL) + Message(role="Human", content=idea, cause_by=UserRequirement, send_to=send_to or MESSAGE_ROUTE_TO_ALL), + peekable=False, ) def start_project(self, idea, send_to: str = ""): @@ -120,7 +124,7 @@ class Team(BaseModel): logger.info(self.json(ensure_ascii=False)) @serialize_decorator - async def run(self, n_round=3, idea=""): + async def run(self, n_round=3, idea="", auto_archive=True): """Run company until target round or no money""" if idea: self.run_project(idea=idea) @@ -132,6 +136,5 @@ class Team(BaseModel): self._check_balance() await self.env.run() - if CONFIG.git_repo: - CONFIG.git_repo.archive() + self.env.archive(auto_archive) return self.env.history diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py index d98087e4b..aab8c990c 100644 --- a/metagpt/tools/__init__.py +++ b/metagpt/tools/__init__.py @@ -22,3 +22,8 @@ class WebBrowserEngineType(Enum): PLAYWRIGHT = "playwright" SELENIUM = "selenium" CUSTOM = "custom" + + @classmethod + def __missing__(cls, key): + """Default type conversion""" + return cls.CUSTOM diff --git a/metagpt/tools/azure_tts.py b/metagpt/tools/azure_tts.py index e59d98016..8fdb10c13 100644 --- a/metagpt/tools/azure_tts.py +++ b/metagpt/tools/azure_tts.py @@ -4,39 +4,110 @@ @Time : 2023/6/9 22:22 @Author : Leo Xiao @File : azure_tts.py +@Modified by: mashenquan, 2023/8/17. Azure TTS OAS3 api, which provides text-to-speech functionality """ +import asyncio +import base64 +from pathlib import Path +from uuid import uuid4 + +import aiofiles from azure.cognitiveservices.speech import AudioConfig, SpeechConfig, SpeechSynthesizer -from metagpt.config import CONFIG +from metagpt.config import CONFIG, Config +from metagpt.logs import logger class AzureTTS: - """https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles""" + """Azure Text-to-Speech""" - @classmethod - def synthesize_speech(cls, lang, voice, role, text, output_file): - subscription_key = CONFIG.get("AZURE_TTS_SUBSCRIPTION_KEY") - region = CONFIG.get("AZURE_TTS_REGION") - speech_config = SpeechConfig(subscription=subscription_key, region=region) + def __init__(self, subscription_key, region): + """ + :param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint` + :param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API. + """ + self.subscription_key = subscription_key if subscription_key else CONFIG.AZURE_TTS_SUBSCRIPTION_KEY + self.region = region if region else CONFIG.AZURE_TTS_REGION + # 参数参考:https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles + async def synthesize_speech(self, lang, voice, text, output_file): + speech_config = SpeechConfig(subscription=self.subscription_key, region=self.region) speech_config.speech_synthesis_voice_name = voice audio_config = AudioConfig(filename=output_file) synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config) - # if voice=="zh-CN-YunxiNeural": - ssml_string = f""" - - - - {text} - - - - """ + # More detail: https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-synthesis-markup-voice + ssml_string = ( + "" + f"{text}" + ) - synthesizer.speak_ssml_async(ssml_string).get() + return synthesizer.speak_ssml_async(ssml_string).get() + + @staticmethod + def role_style_text(role, style, text): + return f'{text}' + + @staticmethod + def role_text(role, text): + return f'{text}' + + @staticmethod + def style_text(style, text): + return f'{text}' + + +# Export +async def oas3_azsure_tts(text, lang="", voice="", style="", role="", subscription_key="", region=""): + """Text to speech + For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts` + + :param lang: The value can contain a language code such as en (English), or a locale such as en-US (English - United States). For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts` + :param voice: For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`, `https://speech.microsoft.com/portal/voicegallery` + :param style: Speaking style to express different emotions like cheerfulness, empathy, and calm. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts` + :param role: With roles, the same voice can act as a different age and gender. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts` + :param text: The text used for voice conversion. + :param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint` + :param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API. + :return: Returns the Base64-encoded .wav file data if successful, otherwise an empty string. + + """ + if not text: + return "" + + if not lang: + lang = "zh-CN" + if not voice: + voice = "zh-CN-XiaomoNeural" + if not role: + role = "Girl" + if not style: + style = "affectionate" + if not subscription_key: + subscription_key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY + if not region: + region = CONFIG.AZURE_TTS_REGION + + xml_value = AzureTTS.role_style_text(role=role, style=style, text=text) + tts = AzureTTS(subscription_key=subscription_key, region=region) + filename = Path(__file__).resolve().parent / (str(uuid4()).replace("-", "") + ".wav") + try: + await tts.synthesize_speech(lang=lang, voice=voice, text=xml_value, output_file=str(filename)) + async with aiofiles.open(filename, mode="rb") as reader: + data = await reader.read() + base64_string = base64.b64encode(data).decode("utf-8") + filename.unlink() + except Exception as e: + logger.error(f"text:{text}, error:{e}") + return "" + + return base64_string if __name__ == "__main__": - azure_tts = AzureTTS() - azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "Hello, I am Kaka", "output.wav") + Config() + loop = asyncio.new_event_loop() + v = loop.create_task(oas3_azsure_tts("测试,test")) + loop.run_until_complete(v) + print(v) diff --git a/metagpt/tools/hello.py b/metagpt/tools/hello.py new file mode 100644 index 000000000..8a21e1b4e --- /dev/null +++ b/metagpt/tools/hello.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/2 16:03 +@Author : mashenquan +@File : hello.py +@Desc : Implement the OpenAPI Specification 3.0 demo and use the following command to test the HTTP service: + + curl -X 'POST' \ + 'http://localhost:8080/openapi/greeting/dave' \ + -H 'accept: text/plain' \ + -H 'Content-Type: application/json' \ + -d '{}' +""" + +import connexion + + +# openapi implement +async def post_greeting(name: str) -> str: + return f"Hello {name}\n" + + +if __name__ == "__main__": + app = connexion.AioHttpApp(__name__, specification_dir="../../.well-known/") + app.add_api("openapi.yaml", arguments={"title": "Hello World Example"}) + app.run(port=8080) diff --git a/metagpt/tools/iflytek_tts.py b/metagpt/tools/iflytek_tts.py new file mode 100644 index 000000000..cb87d2e7f --- /dev/null +++ b/metagpt/tools/iflytek_tts.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/17 +@Author : mashenquan +@File : iflytek_tts.py +@Desc : iFLYTEK TTS OAS3 api, which provides text-to-speech functionality +""" +import asyncio +import base64 +import hashlib +import hmac +import json +import uuid +from datetime import datetime +from enum import Enum +from pathlib import Path +from time import mktime +from typing import Optional +from urllib.parse import urlencode +from wsgiref.handlers import format_date_time + +import aiofiles +import websockets as websockets +from pydantic import BaseModel + +from metagpt.config import CONFIG +from metagpt.logs import logger + + +class IFlyTekTTSStatus(Enum): + STATUS_FIRST_FRAME = 0 # The first frame + STATUS_CONTINUE_FRAME = 1 # The intermediate frame + STATUS_LAST_FRAME = 2 # The last frame + + +class AudioData(BaseModel): + audio: str + status: int + ced: str + + +class IFlyTekTTSResponse(BaseModel): + code: int + message: str + data: Optional[AudioData] = None + sid: str + + +DEFAULT_IFLYTEK_VOICE = "xiaoyan" + + +class IFlyTekTTS(object): + def __init__(self, app_id: str, api_key: str, api_secret: str): + """ + :param app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts` + :param api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts` + :param api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts` + """ + self.app_id = app_id or CONFIG.IFLYTEK_APP_ID + self.api_key = api_key or CONFIG.IFLYTEK_API_KEY + self.api_secret = api_secret or CONFIG.API_SECRET + + async def synthesize_speech(self, text, output_file: str, voice=DEFAULT_IFLYTEK_VOICE): + url = self._create_url() + data = { + "common": {"app_id": self.app_id}, + "business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": voice, "tte": "utf8"}, + "data": {"status": 2, "text": str(base64.b64encode(text.encode("utf-8")), "UTF8")}, + } + req = json.dumps(data) + async with websockets.connect(url) as websocket: + # send request + await websocket.send(req) + + # receive frames + async with aiofiles.open(str(output_file), "w") as writer: + while True: + v = await websocket.recv() + rsp = IFlyTekTTSResponse(**json.loads(v)) + if rsp.data: + await writer.write(rsp.data.audio) + if rsp.data.status != IFlyTekTTSStatus.STATUS_LAST_FRAME.value: + continue + break + + def _create_url(self): + """Create request url""" + url = "wss://tts-api.xfyun.cn/v2/tts" + # Generate a timestamp in RFC1123 format + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + + signature_origin = "host: " + "ws-api.xfyun.cn" + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" + # Perform HMAC-SHA256 encryption + signature_sha = hmac.new( + self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256 + ).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8") + + authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % ( + self.api_key, + "hmac-sha256", + "host date request-line", + signature_sha, + ) + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") + # Combine the authentication parameters of the request into a dictionary. + v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"} + # Concatenate the authentication parameters to generate the URL. + url = url + "?" + urlencode(v) + return url + + +# Export +async def oas3_iflytek_tts(text: str, voice: str = "", app_id: str = "", api_key: str = "", api_secret: str = ""): + """Text to speech + For more details, check out:`https://www.xfyun.cn/doc/tts/online_tts/API.html` + + :param voice: Default `xiaoyan`. For more details, checkout: `https://www.xfyun.cn/doc/tts/online_tts/API.html#%E6%8E%A5%E5%8F%A3%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B` + :param text: The text used for voice conversion. + :param app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts` + :param api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts` + :param api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts` + :return: Returns the Base64-encoded .mp3 file data if successful, otherwise an empty string. + + """ + if not app_id: + app_id = CONFIG.IFLYTEK_APP_ID + if not api_key: + api_key = CONFIG.IFLYTEK_API_KEY + if not api_secret: + api_secret = CONFIG.IFLYTEK_API_SECRET + if not voice: + voice = CONFIG.IFLYTEK_VOICE or DEFAULT_IFLYTEK_VOICE + + filename = Path(__file__).parent / (uuid.uuid4().hex + ".mp3") + try: + tts = IFlyTekTTS(app_id=app_id, api_key=api_key, api_secret=api_secret) + await tts.synthesize_speech(text=text, output_file=str(filename), voice=voice) + async with aiofiles.open(str(filename), mode="r") as reader: + base64_string = await reader.read() + except Exception as e: + logger.error(f"text:{text}, error:{e}") + base64_string = "" + finally: + filename.unlink() + + return base64_string + + +if __name__ == "__main__": + asyncio.get_event_loop().run_until_complete( + oas3_iflytek_tts( + text="你好,hello", + app_id="f7acef62", + api_key="fda72e3aa286042a492525816a5efa08", + api_secret="ZDk3NjdiMDBkODJlOWQ1NjRjMGI2NDY4", + ) + ) diff --git a/metagpt/tools/metagpt_oas3_api_svc.py b/metagpt/tools/metagpt_oas3_api_svc.py new file mode 100644 index 000000000..2ff4c8225 --- /dev/null +++ b/metagpt/tools/metagpt_oas3_api_svc.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/17 +@Author : mashenquan +@File : metagpt_oas3_api_svc.py +@Desc : MetaGPT OpenAPI Specification 3.0 REST API service +""" +import asyncio +import sys +from pathlib import Path + +import connexion + +sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) # fix-bug: No module named 'metagpt' + + +def oas_http_svc(): + """Start the OAS 3.0 OpenAPI HTTP service""" + app = connexion.AioHttpApp(__name__, specification_dir="../../.well-known/") + app.add_api("metagpt_oas3_api.yaml") + app.add_api("openapi.yaml") + app.run(port=8080) + + +async def async_main(): + """Start the OAS 3.0 OpenAPI HTTP service in the background.""" + loop = asyncio.get_event_loop() + loop.run_in_executor(None, oas_http_svc) + + # TODO: replace following codes: + while True: + await asyncio.sleep(1) + print("sleep") + + +def main(): + print("http://localhost:8080/oas3/ui/") + oas_http_svc() + + +if __name__ == "__main__": + # asyncio.run(async_main()) + main() diff --git a/metagpt/tools/metagpt_text_to_image.py b/metagpt/tools/metagpt_text_to_image.py new file mode 100644 index 000000000..50c0edcba --- /dev/null +++ b/metagpt/tools/metagpt_text_to_image.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/18 +@Author : mashenquan +@File : metagpt_text_to_image.py +@Desc : MetaGPT Text-to-Image OAS3 api, which provides text-to-image functionality. +""" +import asyncio +import base64 +from typing import Dict, List + +import aiohttp +import requests +from pydantic import BaseModel + +from metagpt.config import CONFIG, Config +from metagpt.logs import logger + + +class MetaGPTText2Image: + def __init__(self, model_url): + """ + :param model_url: Model reset api url + """ + self.model_url = model_url if model_url else CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL + + async def text_2_image(self, text, size_type="512x512"): + """Text to image + + :param text: The text used for image conversion. + :param size_type: One of ['512x512', '512x768'] + :return: The image data is returned in Base64 encoding. + """ + + headers = {"Content-Type": "application/json"} + dims = size_type.split("x") + data = { + "prompt": text, + "negative_prompt": "(easynegative:0.8),black, dark,Low resolution", + "override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"}, + "seed": -1, + "batch_size": 1, + "n_iter": 1, + "steps": 20, + "cfg_scale": 11, + "width": int(dims[0]), + "height": int(dims[1]), # 768, + "restore_faces": False, + "tiling": False, + "do_not_save_samples": False, + "do_not_save_grid": False, + "enable_hr": False, + "hr_scale": 2, + "hr_upscaler": "Latent", + "hr_second_pass_steps": 0, + "hr_resize_x": 0, + "hr_resize_y": 0, + "hr_upscale_to_x": 0, + "hr_upscale_to_y": 0, + "truncate_x": 0, + "truncate_y": 0, + "applied_old_hires_behavior_to": None, + "eta": None, + "sampler_index": "DPM++ SDE Karras", + "alwayson_scripts": {}, + } + + class ImageResult(BaseModel): + images: List + parameters: Dict + + try: + async with aiohttp.ClientSession() as session: + async with session.post(self.model_url, headers=headers, json=data) as response: + result = ImageResult(**await response.json()) + if len(result.images) == 0: + return "" + return result.images[0] + except requests.exceptions.RequestException as e: + logger.error(f"An error occurred:{e}") + return "" + + +# Export +async def oas3_metagpt_text_to_image(text, size_type: str = "512x512", model_url=""): + """Text to image + + :param text: The text used for image conversion. + :param model_url: Model reset api + :param size_type: One of ['512x512', '512x768'] + :return: The image data is returned in Base64 encoding. + """ + if not text: + return "" + if not model_url: + model_url = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL + return await MetaGPTText2Image(model_url).text_2_image(text, size_type=size_type) + + +if __name__ == "__main__": + Config() + loop = asyncio.new_event_loop() + task = loop.create_task(oas3_metagpt_text_to_image("Panda emoji")) + v = loop.run_until_complete(task) + print(v) + data = base64.b64decode(v) + with open("tmp.png", mode="wb") as writer: + writer.write(data) + print(v) diff --git a/metagpt/tools/openai_text_to_embedding.py b/metagpt/tools/openai_text_to_embedding.py new file mode 100644 index 000000000..fb6fbc653 --- /dev/null +++ b/metagpt/tools/openai_text_to_embedding.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/18 +@Author : mashenquan +@File : openai_text_to_embedding.py +@Desc : OpenAI Text-to-Embedding OAS3 api, which provides text-to-embedding functionality. + For more details, checkout: `https://platform.openai.com/docs/api-reference/embeddings/object` +""" +import asyncio +from typing import List + +import aiohttp +import requests +from pydantic import BaseModel + +from metagpt.config import CONFIG, Config +from metagpt.logs import logger + + +class Embedding(BaseModel): + """Represents an embedding vector returned by embedding endpoint.""" + + object: str # The object type, which is always "embedding". + embedding: List[ + float + ] # The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the embedding guide. + index: int # The index of the embedding in the list of embeddings. + + +class Usage(BaseModel): + prompt_tokens: int + total_tokens: int + + +class ResultEmbedding(BaseModel): + object: str + data: List[Embedding] + model: str + usage: Usage + + +class OpenAIText2Embedding: + def __init__(self, openai_api_key): + """ + :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` + """ + self.openai_api_key = openai_api_key if openai_api_key else CONFIG.OPENAI_API_KEY + + async def text_2_embedding(self, text, model="text-embedding-ada-002"): + """Text to embedding + + :param text: The text used for embedding. + :param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`. + :return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`. + """ + + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}"} + data = {"input": text, "model": model} + try: + async with aiohttp.ClientSession() as session: + async with session.post("https://api.openai.com/v1/embeddings", headers=headers, json=data) as response: + return await response.json() + except requests.exceptions.RequestException as e: + logger.error(f"An error occurred:{e}") + return {} + + +# Export +async def oas3_openai_text_to_embedding(text, model="text-embedding-ada-002", openai_api_key=""): + """Text to embedding + + :param text: The text used for embedding. + :param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`. + :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` + :return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`. + """ + if not text: + return "" + if not openai_api_key: + openai_api_key = CONFIG.OPENAI_API_KEY + return await OpenAIText2Embedding(openai_api_key).text_2_embedding(text, model=model) + + +if __name__ == "__main__": + Config() + loop = asyncio.new_event_loop() + task = loop.create_task(oas3_openai_text_to_embedding("Panda emoji")) + v = loop.run_until_complete(task) + print(v) diff --git a/metagpt/tools/openai_text_to_image.py b/metagpt/tools/openai_text_to_image.py new file mode 100644 index 000000000..80de04e45 --- /dev/null +++ b/metagpt/tools/openai_text_to_image.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/17 +@Author : mashenquan +@File : openai_text_to_image.py +@Desc : OpenAI Text-to-Image OAS3 api, which provides text-to-image functionality. +""" +import asyncio +import base64 + +import aiohttp +import requests +from openai import AsyncOpenAI + +from metagpt.config import CONFIG, Config +from metagpt.logs import logger + + +class OpenAIText2Image: + def __init__(self, openai_api_key): + """ + :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` + """ + self.openai_api_key = openai_api_key if openai_api_key else CONFIG.OPENAI_API_KEY + self._client = AsyncOpenAI(api_key=self.openai_api_key, base_url=CONFIG.openai_api_base) + + def __del__(self): + if self._client: + self._client.close() + + async def text_2_image(self, text, size_type="1024x1024"): + """Text to image + + :param text: The text used for image conversion. + :param size_type: One of ['256x256', '512x512', '1024x1024'] + :return: The image data is returned in Base64 encoding. + """ + try: + result = await self._client.images.generate(prompt=text, n=1, size=size_type) + except Exception as e: + logger.error(f"An error occurred:{e}") + return "" + if result and len(result.data) > 0: + return await OpenAIText2Image.get_image_data(result.data[0].url) + return "" + + @staticmethod + async def get_image_data(url): + """Fetch image data from a URL and encode it as Base64 + + :param url: Image url + :return: Base64-encoded image data. + """ + try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + response.raise_for_status() # 如果是 4xx 或 5xx 响应,会引发异常 + image_data = await response.read() + base64_image = base64.b64encode(image_data).decode("utf-8") + return base64_image + + except requests.exceptions.RequestException as e: + logger.error(f"An error occurred:{e}") + return "" + + +# Export +async def oas3_openai_text_to_image(text, size_type: str = "1024x1024", openai_api_key=""): + """Text to image + + :param text: The text used for image conversion. + :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` + :param size_type: One of ['256x256', '512x512', '1024x1024'] + :return: The image data is returned in Base64 encoding. + """ + if not text: + return "" + if not openai_api_key: + openai_api_key = CONFIG.OPENAI_API_KEY + return await OpenAIText2Image(openai_api_key).text_2_image(text, size_type=size_type) + + +if __name__ == "__main__": + Config() + loop = asyncio.new_event_loop() + task = loop.create_task(oas3_openai_text_to_image("Panda emoji")) + v = loop.run_until_complete(task) + print(v) diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py index a84812f7c..c4d9d2df4 100644 --- a/metagpt/tools/sd_engine.py +++ b/metagpt/tools/sd_engine.py @@ -6,7 +6,6 @@ import asyncio import base64 import io import json -import os from os.path import join from typing import List @@ -14,8 +13,7 @@ from aiohttp import ClientSession from PIL import Image, PngImagePlugin from metagpt.config import CONFIG - -# from metagpt.const import WORKSPACE_ROOT +from metagpt.const import SD_OUTPUT_FILE_REPO from metagpt.logs import logger payload = { @@ -79,10 +77,10 @@ class SDEngine: return self.payload def _save(self, imgs, save_name=""): - save_dir = CONFIG.workspace_path / "resources" / "SD_Output" - if not os.path.exists(save_dir): - os.makedirs(save_dir, exist_ok=True) - batch_decode_base64_to_image(imgs, save_dir, save_name=save_name) + save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO + if not save_dir.exists(): + save_dir.mkdir(parents=True, exist_ok=True) + batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name) async def run_t2i(self, prompts: List): # Asynchronously run the SD API for multiple prompts diff --git a/metagpt/tools/web_browser_engine.py b/metagpt/tools/web_browser_engine.py index 453d87f31..cda137cbd 100644 --- a/metagpt/tools/web_browser_engine.py +++ b/metagpt/tools/web_browser_engine.py @@ -1,9 +1,12 @@ #!/usr/bin/env python +""" +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. +""" from __future__ import annotations import importlib -from typing import Any, Callable, Coroutine, Literal, overload +from typing import Any, Callable, Coroutine, Dict, Literal, overload from metagpt.config import CONFIG from metagpt.tools import WebBrowserEngineType @@ -13,18 +16,21 @@ from metagpt.utils.parse_html import WebPage class WebBrowserEngine: def __init__( self, + options: Dict, engine: WebBrowserEngineType | None = None, run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None, ): engine = engine or CONFIG.web_browser_engine + if engine is None: + raise NotImplementedError - if engine == WebBrowserEngineType.PLAYWRIGHT: + if WebBrowserEngineType(engine) is WebBrowserEngineType.PLAYWRIGHT: module = "metagpt.tools.web_browser_engine_playwright" run_func = importlib.import_module(module).PlaywrightWrapper().run - elif engine == WebBrowserEngineType.SELENIUM: + elif WebBrowserEngineType(engine) is WebBrowserEngineType.SELENIUM: module = "metagpt.tools.web_browser_engine_selenium" run_func = importlib.import_module(module).SeleniumWrapper().run - elif engine == WebBrowserEngineType.CUSTOM: + elif WebBrowserEngineType(engine) is WebBrowserEngineType.CUSTOM: run_func = run_func else: raise NotImplementedError @@ -47,6 +53,6 @@ if __name__ == "__main__": import fire async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs): - return await WebBrowserEngine(WebBrowserEngineType(engine_type), **kwargs).run(url, *urls) + return await WebBrowserEngine(engine=WebBrowserEngineType(engine_type), **kwargs).run(url, *urls) fire.Fire(main) diff --git a/metagpt/tools/web_browser_engine_playwright.py b/metagpt/tools/web_browser_engine_playwright.py index 030e7701b..8eecc4f40 100644 --- a/metagpt/tools/web_browser_engine_playwright.py +++ b/metagpt/tools/web_browser_engine_playwright.py @@ -1,4 +1,8 @@ #!/usr/bin/env python +""" +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. +""" + from __future__ import annotations import asyncio @@ -144,6 +148,6 @@ if __name__ == "__main__": import fire async def main(url: str, *urls: str, browser_type: str = "chromium", **kwargs): - return await PlaywrightWrapper(browser_type, **kwargs).run(url, *urls) + return await PlaywrightWrapper(browser_type=browser_type, **kwargs).run(url, *urls) fire.Fire(main) diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index decab2b7d..628c8dea2 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -1,11 +1,15 @@ #!/usr/bin/env python +""" +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. +""" + from __future__ import annotations import asyncio import importlib from concurrent import futures from copy import deepcopy -from typing import Literal +from typing import Dict, Literal from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC @@ -29,6 +33,7 @@ class SeleniumWrapper: def __init__( self, + options: Dict, browser_type: Literal["chrome", "firefox", "edge", "ie"] | None = None, launch_kwargs: dict | None = None, *, @@ -120,6 +125,6 @@ if __name__ == "__main__": import fire async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs): - return await SeleniumWrapper(browser_type, **kwargs).run(url, *urls) + return await SeleniumWrapper(browser_type=browser_type, **kwargs).run(url, *urls) fire.Fire(main) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 8db7a80a1..382523083 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -23,7 +23,7 @@ import sys import traceback import typing from pathlib import Path -from typing import Any, List, Tuple, Union, get_args, get_origin +from typing import Any, Callable, List, Tuple, Union, get_args, get_origin import aiofiles import loguru @@ -367,7 +367,7 @@ def get_class_name(cls) -> str: return f"{cls.__module__}.{cls.__name__}" -def any_to_str(val: str | typing.Callable) -> str: +def any_to_str(val: str | Callable) -> str: """Return the class name or the class name of the object, or 'val' if it's a string type.""" if isinstance(val, str): return val @@ -406,6 +406,21 @@ def is_subscribed(message: "Message", tags: set): return False +def any_to_name(val): + """ + Convert a value to its name by extracting the last part of the dotted path. + + :param val: The value to convert. + + :return: The name of the value. + """ + return any_to_str(val).split(".")[-1] + + +def concat_namespace(*args) -> str: + return ":".join(str(value) for value in args) + + def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]: """ Generates a logging function to be used after a call is retried. @@ -520,3 +535,20 @@ async def aread(file_path: str) -> str: async with aiofiles.open(str(file_path), mode="r") as reader: content = await reader.read() return content + + +async def read_file_block(filename: str | Path, lineno: int, end_lineno: int): + if not Path(filename).exists(): + return "" + lines = [] + async with aiofiles.open(str(filename), mode="r") as reader: + ix = 0 + while ix < end_lineno: + ix += 1 + line = await reader.readline() + if ix < lineno: + continue + if ix > end_lineno: + break + lines.append(line) + return "".join(lines) diff --git a/metagpt/utils/cost_manager.py b/metagpt/utils/cost_manager.py new file mode 100644 index 000000000..ce53f2285 --- /dev/null +++ b/metagpt/utils/cost_manager.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/28 +@Author : mashenquan +@File : openai.py +@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting. +""" + +from typing import NamedTuple + +from pydantic import BaseModel + +from metagpt.logs import logger +from metagpt.utils.token_counter import TOKEN_COSTS + + +class Costs(NamedTuple): + total_prompt_tokens: int + total_completion_tokens: int + total_cost: float + total_budget: float + + +class CostManager(BaseModel): + """Calculate the overhead of using the interface.""" + + total_prompt_tokens: int = 0 + total_completion_tokens: int = 0 + total_budget: float = 0 + max_budget: float = 10.0 + total_cost: float = 0 + + def update_cost(self, prompt_tokens, completion_tokens, model): + """ + Update the total cost, prompt tokens, and completion tokens. + + Args: + prompt_tokens (int): The number of tokens used in the prompt. + completion_tokens (int): The number of tokens used in the completion. + model (str): The model used for the API call. + """ + self.total_prompt_tokens += prompt_tokens + self.total_completion_tokens += completion_tokens + cost = ( + prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"] + ) / 1000 + self.total_cost += cost + logger.info( + f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | " + f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" + ) + + def get_total_prompt_tokens(self): + """ + Get the total number of prompt tokens. + + Returns: + int: The total number of prompt tokens. + """ + return self.total_prompt_tokens + + def get_total_completion_tokens(self): + """ + Get the total number of completion tokens. + + Returns: + int: The total number of completion tokens. + """ + return self.total_completion_tokens + + def get_total_cost(self): + """ + Get the total cost of API calls. + + Returns: + float: The total cost of API calls. + """ + return self.total_cost + + def get_costs(self) -> Costs: + """Get all costs""" + return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget) diff --git a/metagpt/utils/di_graph_repository.py b/metagpt/utils/di_graph_repository.py new file mode 100644 index 000000000..08f4327fa --- /dev/null +++ b/metagpt/utils/di_graph_repository.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : di_graph_repository.py +@Desc : Graph repository based on DiGraph +""" +from __future__ import annotations + +import json +from pathlib import Path +from typing import List + +import aiofiles +import networkx + +from metagpt.utils.graph_repository import SPO, GraphRepository + + +class DiGraphRepository(GraphRepository): + def __init__(self, name: str, **kwargs): + super().__init__(name=name, **kwargs) + self._repo = networkx.DiGraph() + + async def insert(self, subject: str, predicate: str, object_: str): + self._repo.add_edge(subject, object_, predicate=predicate) + + async def upsert(self, subject: str, predicate: str, object_: str): + pass + + async def update(self, subject: str, predicate: str, object_: str): + pass + + async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]: + result = [] + for s, o, p in self._repo.edges(data="predicate"): + if subject and subject != s: + continue + if predicate and predicate != p: + continue + if object_ and object_ != o: + continue + result.append(SPO(subject=s, predicate=p, object_=o)) + return result + + def json(self) -> str: + m = networkx.node_link_data(self._repo) + data = json.dumps(m) + return data + + async def save(self, path: str | Path = None): + data = self.json() + path = path or self._kwargs.get("root") + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + pathname = Path(path) / self.name + async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer: + await writer.write(data) + + async def load(self, pathname: str | Path): + async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader: + data = await reader.read(-1) + m = json.loads(data) + self._repo = networkx.node_link_graph(m) + + @staticmethod + async def load_from(pathname: str | Path) -> GraphRepository: + pathname = Path(pathname) + name = pathname.with_suffix("").name + root = pathname.parent + graph = DiGraphRepository(name=name, root=root) + if pathname.exists(): + await graph.load(pathname=pathname) + return graph + + @property + def root(self) -> str: + return self._kwargs.get("root") + + @property + def pathname(self) -> Path: + p = Path(self.root) / self.name + return p.with_suffix(".json") diff --git a/metagpt/utils/graph_repository.py b/metagpt/utils/graph_repository.py new file mode 100644 index 000000000..37da3dee4 --- /dev/null +++ b/metagpt/utils/graph_repository.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : graph_repository.py +@Desc : Superclass for graph repository. +""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List + +from pydantic import BaseModel + +from metagpt.repo_parser import ClassInfo, RepoFileInfo +from metagpt.utils.common import concat_namespace + + +class GraphKeyword: + IS = "is" + CLASS = "class" + FUNCTION = "function" + SOURCE_CODE = "source_code" + NULL = "" + GLOBAL_VARIABLE = "global_variable" + CLASS_FUNCTION = "class_function" + CLASS_PROPERTY = "class_property" + HAS_CLASS = "has_class" + HAS_PAGE_INFO = "has_page_info" + HAS_CLASS_VIEW = "has_class_view" + HAS_SEQUENCE_VIEW = "has_sequence_view" + HAS_ARGS_DESC = "has_args_desc" + HAS_TYPE_DESC = "has_type_desc" + + +class SPO(BaseModel): + subject: str + predicate: str + object_: str + + +class GraphRepository(ABC): + def __init__(self, name: str, **kwargs): + self._repo_name = name + self._kwargs = kwargs + + @abstractmethod + async def insert(self, subject: str, predicate: str, object_: str): + pass + + @abstractmethod + async def upsert(self, subject: str, predicate: str, object_: str): + pass + + @abstractmethod + async def update(self, subject: str, predicate: str, object_: str): + pass + + @abstractmethod + async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]: + pass + + @property + def name(self) -> str: + return self._repo_name + + @staticmethod + async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo): + await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE) + file_types = {".py": "python", ".js": "javascript"} + file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL) + await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type) + for c in file_info.classes: + class_name = c.get("name", "") + await graph_db.insert( + subject=file_info.file, + predicate=GraphKeyword.HAS_CLASS, + object_=concat_namespace(file_info.file, class_name), + ) + await graph_db.insert( + subject=concat_namespace(file_info.file, class_name), + predicate=GraphKeyword.IS, + object_=GraphKeyword.CLASS, + ) + methods = c.get("methods", []) + for fn in methods: + await graph_db.insert( + subject=concat_namespace(file_info.file, class_name, fn), + predicate=GraphKeyword.IS, + object_=GraphKeyword.CLASS_FUNCTION, + ) + for f in file_info.functions: + await graph_db.insert( + subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION + ) + for g in file_info.globals: + await graph_db.insert( + subject=concat_namespace(file_info.file, g), + predicate=GraphKeyword.IS, + object_=GraphKeyword.GLOBAL_VARIABLE, + ) + for code_block in file_info.page_info: + if code_block.tokens: + await graph_db.insert( + subject=concat_namespace(file_info.file, *code_block.tokens), + predicate=GraphKeyword.HAS_PAGE_INFO, + object_=code_block.json(ensure_ascii=False), + ) + for k, v in code_block.properties.items(): + await graph_db.insert( + subject=concat_namespace(file_info.file, k, v), + predicate=GraphKeyword.HAS_PAGE_INFO, + object_=code_block.json(ensure_ascii=False), + ) + + @staticmethod + async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]): + for c in class_views: + filename, class_name = c.package.split(":", 1) + await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE) + file_types = {".py": "python", ".js": "javascript"} + file_type = file_types.get(Path(filename).suffix, GraphKeyword.NULL) + await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=file_type) + await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=class_name) + await graph_db.insert( + subject=c.package, + predicate=GraphKeyword.IS, + object_=GraphKeyword.CLASS, + ) + for vn, vt in c.attributes.items(): + await graph_db.insert( + subject=concat_namespace(c.package, vn), + predicate=GraphKeyword.IS, + object_=GraphKeyword.CLASS_PROPERTY, + ) + await graph_db.insert( + subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt + ) + for fn, desc in c.methods.items(): + await graph_db.insert( + subject=concat_namespace(c.package, fn), + predicate=GraphKeyword.IS, + object_=GraphKeyword.CLASS_FUNCTION, + ) + await graph_db.insert( + subject=concat_namespace(c.package, fn), + predicate=GraphKeyword.HAS_ARGS_DESC, + object_=desc, + ) diff --git a/metagpt/utils/mermaid.py b/metagpt/utils/mermaid.py index eb85a3f90..9aefeb5aa 100644 --- a/metagpt/utils/mermaid.py +++ b/metagpt/utils/mermaid.py @@ -4,11 +4,14 @@ @Time : 2023/7/4 10:53 @Author : alexanderwu alitrack @File : mermaid.py +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ import asyncio import os from pathlib import Path +import aiofiles + from metagpt.config import CONFIG from metagpt.const import METAGPT_ROOT from metagpt.logs import logger @@ -29,7 +32,9 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name) tmp = Path(f"{output_file_without_suffix}.mmd") - tmp.write_text(mermaid_code, encoding="utf-8") + async with aiofiles.open(tmp, "w", encoding="utf-8") as f: + await f.write(mermaid_code) + # tmp.write_text(mermaid_code, encoding="utf-8") engine = CONFIG.mermaid_engine.lower() if engine == "nodejs": @@ -88,7 +93,8 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, return 0 -MMC1 = """classDiagram +MMC1 = """ +classDiagram class Main { -SearchEngine search_engine +main() str @@ -118,9 +124,11 @@ MMC1 = """classDiagram SearchEngine --> Index SearchEngine --> Ranking SearchEngine --> Summary - Index --> KnowledgeBase""" + Index --> KnowledgeBase +""" -MMC2 = """sequenceDiagram +MMC2 = """ +sequenceDiagram participant M as Main participant SE as SearchEngine participant I as Index @@ -136,11 +144,11 @@ MMC2 = """sequenceDiagram R-->>SE: return ranked_results SE->>S: summarize_results(ranked_results) S-->>SE: return summary - SE-->>M: return summary""" - + SE-->>M: return summary +""" if __name__ == "__main__": loop = asyncio.new_event_loop() result = loop.run_until_complete(mermaid_to_file(MMC1, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1")) - result = loop.run_until_complete(mermaid_to_file(MMC2, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1")) + result = loop.run_until_complete(mermaid_to_file(MMC2, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/2")) loop.close() diff --git a/metagpt/utils/redis.py b/metagpt/utils/redis.py new file mode 100644 index 000000000..c344b67ac --- /dev/null +++ b/metagpt/utils/redis.py @@ -0,0 +1,219 @@ +# !/usr/bin/python3 +# -*- coding: utf-8 -*- +# @Author: Hui +# @Desc: { redis client } +# @Date: 2022/11/28 10:12 +import json +import traceback +from datetime import timedelta +from enum import Enum +from typing import Awaitable, Callable, Dict, Optional, Union + +from redis import asyncio as aioredis + +from metagpt.config import CONFIG +from metagpt.logs import logger + + +class RedisTypeEnum(Enum): + """Redis 数据类型""" + + String = "String" + List = "List" + Hash = "Hash" + Set = "Set" + ZSet = "ZSet" + + +def make_url( + dialect: str, + *, + user: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[Union[str, int]] = None, + name: Optional[Union[str, int]] = None, +) -> str: + url_parts = [f"{dialect}://"] + if user or password: + if user: + url_parts.append(user) + if password: + url_parts.append(f":{password}") + url_parts.append("@") + + if not host and not dialect.startswith("sqlite"): + host = "127.0.0.1" + + if host: + url_parts.append(f"{host}") + if port: + url_parts.append(f":{port}") + + # 比如redis可能传入0 + if name is not None: + url_parts.append(f"/{name}") + return "".join(url_parts) + + +class RedisAsyncClient(aioredis.Redis): + """异步的客户端 + 例子:: + + rdb = RedisAsyncClient() + print(rdb.url) + + Args: + host: 服务器地址 + port: 服务器端口 + user: 用户名 + db: 数据库 + password: 密码 + decode_responses: 字符串输入被编码成utf8存储在Redis里了,而取出来的时候还是被编码后的bytes,需要显示的decode才能变成字符串 + health_check_interval: 定时检测连接,防止出现ConnectionErrors (104, Connection reset by peer) + """ + + def __init__( + self, + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: str = None, + decode_responses=True, + health_check_interval=10, + socket_connect_timeout=5, + retry_on_timeout=True, + socket_keepalive=True, + **kwargs, + ): + super().__init__( + host=host, + port=port, + db=db, + password=password, + decode_responses=decode_responses, + health_check_interval=health_check_interval, + socket_connect_timeout=socket_connect_timeout, + retry_on_timeout=retry_on_timeout, + socket_keepalive=socket_keepalive, + **kwargs, + ) + self.url = make_url("redis", host=host, port=port, name=db, password=password) + + +class RedisCacheInfo(object): + """统一缓存信息类""" + + def __init__(self, key, timeout: Union[int, timedelta] = timedelta(seconds=60), data_type=RedisTypeEnum.String): + """ + 缓存信息类初始化 + Args: + key: 缓存的key + timeout: 缓存过期时间, 单位秒 + data_type: 缓存采用的数据结构 (不传并不影响,用于标记业务采用的是什么数据结构) + """ + self.key = key + self.timeout = timeout + self.data_type = data_type + + def __str__(self): + return f"cache key {self.key} timeout {self.timeout}s" + + +class RedisManager: + client: RedisAsyncClient = None + + @classmethod + def init_redis_conn(cls, host, port, password, db): + """初始化redis 连接""" + if cls.client is None: + cls.client = RedisAsyncClient(host=host, port=port, password=password, db=db) + + @classmethod + async def set_with_cache_info(cls, redis_cache_info: RedisCacheInfo, value): + """ + 根据 RedisCacheInfo 设置 Redis 缓存 + :param redis_cache_info: RedisCacheInfo缓存信息对象 + :param value: 缓存的值 + :return: + """ + await cls.client.setex(redis_cache_info.key, redis_cache_info.timeout, value) + + @classmethod + async def get_with_cache_info(cls, redis_cache_info: RedisCacheInfo): + """ + 根据 RedisCacheInfo 获取 Redis 缓存 + :param redis_cache_info: RedisCacheInfo 缓存信息对象 + :return: + """ + cache_info = await cls.client.get(redis_cache_info.key) + return cache_info + + @classmethod + async def del_with_cache_info(cls, redis_cache_info: RedisCacheInfo): + """ + 根据 RedisCacheInfo 删除 Redis 缓存 + :param redis_cache_info: RedisCacheInfo缓存信息对象 + :return: + """ + await cls.client.delete(redis_cache_info.key) + + @staticmethod + async def get_or_set_cache(cache_info: RedisCacheInfo, fetch_data_func: Callable[[], Awaitable[dict]]) -> dict: + """ + 获取缓存数据,如果缓存不存在,则从提供的函数中获取并设置缓存 + 当前版本仅支持 json 形式的 string 格式数据 + """ + + serialized_data = await RedisManager.get_with_cache_info(cache_info) + + if serialized_data: + return json.loads(serialized_data) + + data = await fetch_data_func() + try: + serialized_data = json.dumps(data) + await RedisManager.set_with_cache_info(cache_info, serialized_data) + except Exception as e: + logger.warning(f"数据 {data} 通过 json 进行序列化缓存失败:{e}") + + return data + + @classmethod + def is_valid(cls): + return cls.client is not None + + +class Redis: + def __init__(self, conf: Dict = None): + try: + host = CONFIG.REDIS_HOST + port = int(CONFIG.REDIS_PORT) + pwd = CONFIG.REDIS_PASSWORD + db = CONFIG.REDIS_DB + RedisManager.init_redis_conn(host=host, port=port, password=pwd, db=db) + except Exception as e: + logger.warning(f"Redis initialization has failed:{e}") + + def is_valid(self): + return RedisManager.is_valid() + + async def get(self, key: str) -> str: + if not self.is_valid() or not key: + return None + try: + v = await RedisManager.get_with_cache_info(redis_cache_info=RedisCacheInfo(key=key)) + return v + except Exception as e: + logger.exception(f"{e}, stack:{traceback.format_exc()}") + return None + + async def set(self, key: str, data: str, timeout_sec: int): + if not self.is_valid() or not key: + return + try: + await RedisManager.set_with_cache_info( + redis_cache_info=RedisCacheInfo(key=key, timeout=timeout_sec), value=data + ) + except Exception as e: + logger.exception(f"{e}, stack:{traceback.format_exc()}") diff --git a/metagpt/utils/s3.py b/metagpt/utils/s3.py new file mode 100644 index 000000000..9accfcade --- /dev/null +++ b/metagpt/utils/s3.py @@ -0,0 +1,170 @@ +import base64 +import os.path +import traceback +import uuid +from pathlib import Path +from typing import Optional + +import aioboto3 +import aiofiles + +from metagpt.config import CONFIG +from metagpt.const import BASE64_FORMAT +from metagpt.logs import logger + + +class S3: + """A class for interacting with Amazon S3 storage.""" + + def __init__(self): + self.session = aioboto3.Session() + self.auth_config = { + "service_name": "s3", + "aws_access_key_id": CONFIG.S3_ACCESS_KEY, + "aws_secret_access_key": CONFIG.S3_SECRET_KEY, + "endpoint_url": CONFIG.S3_ENDPOINT_URL, + } + + async def upload_file( + self, + bucket: str, + local_path: str, + object_name: str, + ) -> None: + """Upload a file from the local path to the specified path of the storage bucket specified in s3. + + Args: + bucket: The name of the S3 storage bucket. + local_path: The local file path, including the file name. + object_name: The complete path of the uploaded file to be stored in S3, including the file name. + + Raises: + Exception: If an error occurs during the upload process, an exception is raised. + """ + try: + async with self.session.client(**self.auth_config) as client: + async with aiofiles.open(local_path, mode="rb") as reader: + body = await reader.read() + await client.put_object(Body=body, Bucket=bucket, Key=object_name) + logger.info(f"Successfully uploaded the file to path {object_name} in bucket {bucket} of s3.") + except Exception as e: + logger.error(f"Failed to upload the file to path {object_name} in bucket {bucket} of s3: {e}") + raise e + + async def get_object_url( + self, + bucket: str, + object_name: str, + ) -> str: + """Get the URL for a downloadable or preview file stored in the specified S3 bucket. + + Args: + bucket: The name of the S3 storage bucket. + object_name: The complete path of the file stored in S3, including the file name. + + Returns: + The URL for the downloadable or preview file. + + Raises: + Exception: If an error occurs while retrieving the URL, an exception is raised. + """ + try: + async with self.session.client(**self.auth_config) as client: + file = await client.get_object(Bucket=bucket, Key=object_name) + return str(file["Body"].url) + except Exception as e: + logger.error(f"Failed to get the url for a downloadable or preview file: {e}") + raise e + + async def get_object( + self, + bucket: str, + object_name: str, + ) -> bytes: + """Get the binary data of a file stored in the specified S3 bucket. + + Args: + bucket: The name of the S3 storage bucket. + object_name: The complete path of the file stored in S3, including the file name. + + Returns: + The binary data of the requested file. + + Raises: + Exception: If an error occurs while retrieving the file data, an exception is raised. + """ + try: + async with self.session.client(**self.auth_config) as client: + s3_object = await client.get_object(Bucket=bucket, Key=object_name) + return await s3_object["Body"].read() + except Exception as e: + logger.error(f"Failed to get the binary data of the file: {e}") + raise e + + async def download_file( + self, bucket: str, object_name: str, local_path: str, chunk_size: Optional[int] = 128 * 1024 + ) -> None: + """Download an S3 object to a local file. + + Args: + bucket: The name of the S3 storage bucket. + object_name: The complete path of the file stored in S3, including the file name. + local_path: The local file path where the S3 object will be downloaded. + chunk_size: The size of data chunks to read and write at a time. Default is 128 KB. + + Raises: + Exception: If an error occurs during the download process, an exception is raised. + """ + try: + async with self.session.client(**self.auth_config) as client: + s3_object = await client.get_object(Bucket=bucket, Key=object_name) + stream = s3_object["Body"] + async with aiofiles.open(local_path, mode="wb") as writer: + while True: + file_data = await stream.read(chunk_size) + if not file_data: + break + await writer.write(file_data) + except Exception as e: + logger.error(f"Failed to download the file from S3: {e}") + raise e + + async def cache(self, data: str, file_ext: str, format: str = "") -> str: + """Save data to remote S3 and return url""" + object_name = uuid.uuid4().hex + file_ext + path = Path(__file__).parent + pathname = path / object_name + try: + async with aiofiles.open(str(pathname), mode="wb") as file: + if format == BASE64_FORMAT: + data = base64.b64decode(data) + await file.write(data) + + bucket = CONFIG.S3_BUCKET + object_pathname = CONFIG.S3_BUCKET or "system" + object_pathname += f"/{object_name}" + object_pathname = os.path.normpath(object_pathname) + await self.upload_file(bucket=bucket, local_path=str(pathname), object_name=object_pathname) + pathname.unlink(missing_ok=True) + + return await self.get_object_url(bucket=bucket, object_name=object_pathname) + except Exception as e: + logger.exception(f"{e}, stack:{traceback.format_exc()}") + pathname.unlink(missing_ok=True) + return None + + @property + def is_valid(self): + is_invalid = ( + not CONFIG.S3_ACCESS_KEY + or CONFIG.S3_ACCESS_KEY == "YOUR_S3_ACCESS_KEY" + or not CONFIG.S3_SECRET_KEY + or CONFIG.S3_SECRET_KEY == "YOUR_S3_SECRET_KEY" + or not CONFIG.S3_ENDPOINT_URL + or CONFIG.S3_ENDPOINT_URL == "YOUR_S3_ENDPOINT_URL" + or not CONFIG.S3_BUCKET + or CONFIG.S3_BUCKET == "YOUR_S3_BUCKET" + ) + if is_invalid: + logger.info("S3 is invalid") + return not is_invalid diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 000000000..0a34c35ea --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,3 @@ +-r requirements.txt +pytest +pytest-asyncio \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index eaff5c4b2..d221dc3c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,6 +36,7 @@ tqdm==4.64.0 # webdriver_manager<3.9 anthropic==0.3.6 typing-inspect==0.8.0 +aiofiles typing_extensions==4.7.0 libcst==1.0.1 qdrant-client==1.4.0 @@ -44,9 +45,18 @@ pytest-mock==3.11.1 ta==0.10.2 semantic-kernel==0.4.0.dev0 wrapt==1.15.0 -websocket-client==0.58.0 +#aiohttp_jinja2 +#azure-cognitiveservices-speech~=1.31.0 +#aioboto3~=11.3.0 +#redis==4.3.5 +websocket-client==1.6.2 aiofiles==23.2.1 gitpython==3.1.40 zhipuai==1.0.7 +socksio~=1.0.0 gitignore-parser==0.1.9 +# connexion[swagger-ui] +websockets~=12.0 +networkx~=3.2.1 +pylint~=3.0.3 google-generativeai==0.3.1 diff --git a/tests/conftest.py b/tests/conftest.py index b22e43e79..47e05e20e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ from unittest.mock import Mock import pytest -from metagpt.config import CONFIG +from metagpt.config import CONFIG, Config from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI @@ -96,3 +96,8 @@ def setup_and_teardown_git_repo(request): # Register the function for destroying the environment. request.addfinalizer(fin) + + +@pytest.fixture(scope="session", autouse=True) +def init_config(): + Config() diff --git a/tests/metagpt/actions/test_rebuild_class_view.py b/tests/metagpt/actions/test_rebuild_class_view.py new file mode 100644 index 000000000..955c6ae3b --- /dev/null +++ b/tests/metagpt/actions/test_rebuild_class_view.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/20 +@Author : mashenquan +@File : test_rebuild_class_view.py +@Desc : Unit tests for rebuild_class_view.py +""" +from pathlib import Path + +import pytest + +from metagpt.actions.rebuild_class_view import RebuildClassView +from metagpt.llm import LLM + + +@pytest.mark.asyncio +async def test_rebuild(): + action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM()) + await action.run() + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index 54229089c..73f3a6dcf 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -9,8 +9,8 @@ import pytest from metagpt.actions.write_code import WriteCode -from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.openai_api import OpenAIGPTAPI as LLM from metagpt.schema import CodingContext, Document from tests.metagpt.actions.mock import TASKS_2, WRITE_CODE_PROMPT_SAMPLE diff --git a/tests/metagpt/actions/test_write_teaching_plan.py b/tests/metagpt/actions/test_write_teaching_plan.py new file mode 100644 index 000000000..3f25b2167 --- /dev/null +++ b/tests/metagpt/actions/test_write_teaching_plan.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/7/28 17:25 +@Author : mashenquan +@File : test_write_teaching_plan.py +""" + +import asyncio +from typing import Optional + +from langchain.llms.base import LLM +from pydantic import BaseModel + +from metagpt.actions.write_teaching_plan import WriteTeachingPlanPart +from metagpt.config import Config +from metagpt.schema import Message + + +class MockWriteTeachingPlanPart(WriteTeachingPlanPart): + def __init__(self, options, name: str = "", context=None, llm: LLM = None, topic="", language="Chinese"): + super().__init__(options, name, context, llm, topic, language) + + async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: + return f"{WriteTeachingPlanPart.DATA_BEGIN_TAG}\nprompt\n{WriteTeachingPlanPart.DATA_END_TAG}" + + +async def mock_write_teaching_plan_part(): + class Inputs(BaseModel): + input: str + name: str + topic: str + language: str + + inputs = [ + {"input": "AABBCC", "name": "A", "topic": WriteTeachingPlanPart.COURSE_TITLE, "language": "C"}, + {"input": "DDEEFFF", "name": "A1", "topic": "B1", "language": "C1"}, + ] + + for i in inputs: + seed = Inputs(**i) + options = Config().runtime_options + act = MockWriteTeachingPlanPart(options=options, name=seed.name, topic=seed.topic, language=seed.language) + await act.run([Message(content="")]) + assert act.topic == seed.topic + assert str(act) == seed.topic + assert act.name == seed.name + assert act.rsp == "# prompt" if seed.topic == WriteTeachingPlanPart.COURSE_TITLE else "prompt" + + +def test_suite(): + loop = asyncio.get_event_loop() + task = loop.create_task(mock_write_teaching_plan_part()) + loop.run_until_complete(task) + + +if __name__ == "__main__": + test_suite() diff --git a/tests/metagpt/learn/__init__.py b/tests/metagpt/learn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/metagpt/learn/test_google_search.py b/tests/metagpt/learn/test_google_search.py new file mode 100644 index 000000000..da32e8923 --- /dev/null +++ b/tests/metagpt/learn/test_google_search.py @@ -0,0 +1,27 @@ +import asyncio + +from pydantic import BaseModel + +from metagpt.learn.google_search import google_search + + +async def mock_google_search(): + class Input(BaseModel): + input: str + + inputs = [{"input": "ai agent"}] + + for i in inputs: + seed = Input(**i) + result = await google_search(seed.input) + assert result != "" + + +def test_suite(): + loop = asyncio.get_event_loop() + task = loop.create_task(mock_google_search()) + loop.run_until_complete(task) + + +if __name__ == "__main__": + test_suite() diff --git a/tests/metagpt/learn/test_skill_loader.py b/tests/metagpt/learn/test_skill_loader.py new file mode 100644 index 000000000..5bc0e776f --- /dev/null +++ b/tests/metagpt/learn/test_skill_loader.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/19 +@Author : mashenquan +@File : test_skill_loader.py +@Desc : Unit tests. +""" + +from metagpt.config import CONFIG +from metagpt.learn.skill_loader import SkillLoader + + +def test_suite(): + CONFIG.agent_skills = [ + {"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True}, + {"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True}, + {"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True}, + {"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True}, + {"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True}, + {"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True}, + {"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True}, + ] + loader = SkillLoader() + skills = loader.get_skill_list() + assert skills + assert len(skills) >= 3 + for desc, name in skills.items(): + assert desc + assert name + + entity = loader.get_entity("Assistant") + assert entity + assert entity.skills + for sk in entity.skills: + assert sk + assert sk.arguments + + +if __name__ == "__main__": + test_suite() diff --git a/tests/metagpt/learn/test_text_to_embedding.py b/tests/metagpt/learn/test_text_to_embedding.py new file mode 100644 index 000000000..e3d20a759 --- /dev/null +++ b/tests/metagpt/learn/test_text_to_embedding.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/18 +@Author : mashenquan +@File : test_text_to_embedding.py +@Desc : Unit tests. +""" + +import asyncio + +from pydantic import BaseModel + +from metagpt.learn.text_to_embedding import text_to_embedding +from metagpt.tools.openai_text_to_embedding import ResultEmbedding + + +async def mock_text_to_embedding(): + class Input(BaseModel): + input: str + + inputs = [{"input": "Panda emoji"}] + + for i in inputs: + seed = Input(**i) + data = await text_to_embedding(seed.input) + v = ResultEmbedding(**data) + assert len(v.data) > 0 + + +def test_suite(): + loop = asyncio.get_event_loop() + task = loop.create_task(mock_text_to_embedding()) + loop.run_until_complete(task) + + +if __name__ == "__main__": + test_suite() diff --git a/tests/metagpt/learn/test_text_to_image.py b/tests/metagpt/learn/test_text_to_image.py new file mode 100644 index 000000000..a6cbc45bf --- /dev/null +++ b/tests/metagpt/learn/test_text_to_image.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/18 +@Author : mashenquan +@File : test_text_to_image.py +@Desc : Unit tests. +""" + +import base64 + +import pytest +from pydantic import BaseModel + +from metagpt.learn.text_to_image import text_to_image + + +@pytest.mark.asyncio +async def test(): + class Input(BaseModel): + input: str + size_type: str + + inputs = [{"input": "Panda emoji", "size_type": "512x512"}] + + for i in inputs: + seed = Input(**i) + base64_data = await text_to_image(seed.input) + assert base64_data != "" + print(f"{seed.input} -> {base64_data}") + flags = ";base64," + assert flags in base64_data + ix = base64_data.find(flags) + len(flags) + declaration = base64_data[0:ix] + assert declaration + data = base64_data[ix:] + assert data + assert base64.b64decode(data, validate=True) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/learn/test_text_to_speech.py b/tests/metagpt/learn/test_text_to_speech.py new file mode 100644 index 000000000..42b6839fa --- /dev/null +++ b/tests/metagpt/learn/test_text_to_speech.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/18 +@Author : mashenquan +@File : test_text_to_speech.py +@Desc : Unit tests. +""" +import asyncio +import base64 + +from pydantic import BaseModel + +from metagpt.learn.text_to_speech import text_to_speech + + +async def mock_text_to_speech(): + class Input(BaseModel): + input: str + + inputs = [{"input": "Panda emoji"}] + + for i in inputs: + seed = Input(**i) + base64_data = await text_to_speech(seed.input) + assert base64_data != "" + print(f"{seed.input} -> {base64_data}") + flags = ";base64," + assert flags in base64_data + ix = base64_data.find(flags) + len(flags) + declaration = base64_data[0:ix] + assert declaration + data = base64_data[ix:] + assert data + assert base64.b64decode(data, validate=True) + + +def test_suite(): + loop = asyncio.get_event_loop() + task = loop.create_task(mock_text_to_speech()) + loop.run_until_complete(task) + + +if __name__ == "__main__": + test_suite() diff --git a/tests/metagpt/memory/test_brain_memory.py b/tests/metagpt/memory/test_brain_memory.py new file mode 100644 index 000000000..2f2a984d8 --- /dev/null +++ b/tests/metagpt/memory/test_brain_memory.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/27 +@Author : mashenquan +@File : test_brain_memory.py +""" +import json +from typing import List + +import pydantic + +from metagpt.memory.brain_memory import BrainMemory +from metagpt.schema import Message + + +def test_json(): + class Input(pydantic.BaseModel): + history: List[str] + solution: List[str] + knowledge: List[str] + stack: List[str] + + inputs = [{"history": ["a", "b"], "solution": ["c"], "knowledge": ["d", "e"], "stack": ["f"]}] + + for i in inputs: + v = Input(**i) + bm = BrainMemory() + for h in v.history: + msg = Message(content=h) + bm.history.append(msg.dict()) + for h in v.solution: + msg = Message(content=h) + bm.solution.append(msg.dict()) + for h in v.knowledge: + msg = Message(content=h) + bm.knowledge.append(msg.dict()) + for h in v.stack: + msg = Message(content=h) + bm.stack.append(msg.dict()) + s = bm.json() + m = json.loads(s) + bm = BrainMemory(**m) + assert bm + for v in bm.history: + msg = Message(**v) + assert msg + + +if __name__ == "__main__": + test_json() diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index b6ae0ac79..1f07d74e3 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -2,11 +2,12 @@ # -*- coding: utf-8 -*- """ @Desc : unittest of `metagpt/memory/longterm_memory.py` +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ from metagpt.actions import UserRequirement from metagpt.config import CONFIG -from metagpt.memory import LongTermMemory +from metagpt.memory.longterm_memory import LongTermMemory from metagpt.roles.role import RoleContext from metagpt.schema import Message @@ -28,6 +29,7 @@ def test_ltm_search(): ltm.add(message) sim_idea = "Write a game of cli snake" + sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement) news = ltm.find_news([sim_message]) assert len(news) == 0 diff --git a/tests/metagpt/provider/test_metagpt_llm_api.py b/tests/metagpt/provider/test_metagpt_llm_api.py new file mode 100644 index 000000000..9c8356ca6 --- /dev/null +++ b/tests/metagpt/provider/test_metagpt_llm_api.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/30 +@Author : mashenquan +@File : test_metagpt_llm_api.py +""" +from metagpt.provider.metagpt_llm_api import MetaGPTLLMAPI + + +def test_metagpt(): + llm = MetaGPTLLMAPI() + assert llm + + +if __name__ == "__main__": + test_metagpt() diff --git a/tests/metagpt/roles/test_teacher.py b/tests/metagpt/roles/test_teacher.py new file mode 100644 index 000000000..82d6c7052 --- /dev/null +++ b/tests/metagpt/roles/test_teacher.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/7/27 13:25 +@Author : mashenquan +@File : test_teacher.py +""" + +from typing import Dict, Optional + +from pydantic import BaseModel + +from metagpt.roles.teacher import Teacher + + +def test_init(): + class Inputs(BaseModel): + name: str + profile: str + goal: str + constraints: str + desc: str + kwargs: Optional[Dict] = None + expect_name: str + expect_profile: str + expect_goal: str + expect_constraints: str + expect_desc: str + + inputs = [ + { + "name": "Lily{language}", + "expect_name": "LilyCN", + "profile": "X {teaching_language}", + "expect_profile": "X EN", + "goal": "Do {something_big}, {language}", + "expect_goal": "Do sleep, CN", + "constraints": "Do in {key1}, {language}", + "expect_constraints": "Do in HaHa, CN", + "kwargs": {"language": "CN", "key1": "HaHa", "something_big": "sleep", "teaching_language": "EN"}, + "desc": "aaa{language}", + "expect_desc": "aaaCN", + }, + { + "name": "Lily{language}", + "expect_name": "Lily{language}", + "profile": "X {teaching_language}", + "expect_profile": "X {teaching_language}", + "goal": "Do {something_big}, {language}", + "expect_goal": "Do {something_big}, {language}", + "constraints": "Do in {key1}, {language}", + "expect_constraints": "Do in {key1}, {language}", + "kwargs": {}, + "desc": "aaa{language}", + "expect_desc": "aaa{language}", + }, + ] + + for i in inputs: + seed = Inputs(**i) + teacher = Teacher( + name=seed.name, + profile=seed.profile, + goal=seed.goal, + constraints=seed.constraints, + desc=seed.desc, + **seed.kwargs + ) + assert teacher.name == seed.expect_name + assert teacher.desc == seed.expect_desc + assert teacher.profile == seed.expect_profile + assert teacher.goal == seed.expect_goal + assert teacher.constraints == seed.expect_constraints + assert teacher.course_title == "teaching_plan" + + +def test_new_file_name(): + class Inputs(BaseModel): + lesson_title: str + ext: str + expect: str + + inputs = [ + {"lesson_title": "# @344\n12", "ext": ".md", "expect": "_344_12.md"}, + {"lesson_title": "1#@$%!*&\\/:*?\"<>|\n\t '1", "ext": ".cc", "expect": "1_1.cc"}, + ] + for i in inputs: + seed = Inputs(**i) + result = Teacher.new_file_name(seed.lesson_title, seed.ext) + assert result == seed.expect + + +if __name__ == "__main__": + test_init() + test_new_file_name() diff --git a/tests/metagpt/roles/ui_role.py b/tests/metagpt/roles/ui_role.py index 0932efa1f..51b346821 100644 --- a/tests/metagpt/roles/ui_role.py +++ b/tests/metagpt/roles/ui_role.py @@ -8,8 +8,6 @@ from functools import wraps from importlib import import_module from metagpt.actions import Action, ActionOutput, WritePRD - -# from metagpt.const import WORKSPACE_ROOT from metagpt.actions.action_node import ActionNode from metagpt.config import CONFIG from metagpt.logs import logger diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index 56e2b4fc3..3a899d6ff 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -4,6 +4,8 @@ @Time : 2023/5/12 00:47 @Author : alexanderwu @File : test_environment.py +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. + """ from pathlib import Path @@ -11,9 +13,9 @@ from pathlib import Path import pytest from metagpt.actions import UserRequirement +from metagpt.config import CONFIG from metagpt.environment import Environment from metagpt.logs import logger -from metagpt.manager import Manager from metagpt.roles import Architect, ProductManager, Role from metagpt.schema import Message @@ -44,6 +46,10 @@ def test_get_roles(env: Environment): @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): + if CONFIG.git_repo: + CONFIG.git_repo.delete_repository() + CONFIG.git_repo = None + product_manager = ProductManager(name="Alice", profile="Product Manager", goal="做AI Native产品", constraints="资源有限") architect = Architect( name="Bob", profile="Architect", goal="设计一个可用、高效、较低成本的系统,包括数据结构与接口", constraints="资源有限,需要节省成本" @@ -51,9 +57,11 @@ async def test_publish_and_process_message(env: Environment): env.add_roles([product_manager, architect]) - env.set_manager(Manager()) env.publish_message(Message(role="User", content="需要一个基于LLM做总结的搜索引擎", cause_by=UserRequirement)) - await env.run(k=2) logger.info(f"{env.history=}") assert len(env.history) > 10 + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_gpt.py b/tests/metagpt/test_gpt.py index 431858d4c..daafeb708 100644 --- a/tests/metagpt/test_gpt.py +++ b/tests/metagpt/test_gpt.py @@ -5,9 +5,10 @@ @Author : alexanderwu @File : test_gpt.py """ - +import openai import pytest +from metagpt.config import CONFIG from metagpt.logs import logger @@ -18,14 +19,17 @@ class TestGPT: logger.info(answer) assert len(answer) > 0 - # def test_gptapi_ask_batch(self, llm_api): - # answer = llm_api.ask_batch(['请扮演一个Google Python专家工程师,如果理解,回复明白', '写一个hello world']) - # assert len(answer) > 0 + def test_gptapi_ask_batch(self, llm_api): + answer = llm_api.ask_batch(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"], timeout=60) + assert len(answer) > 0 def test_llm_api_ask_code(self, llm_api): - answer = llm_api.ask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"]) - logger.info(answer) - assert len(answer) > 0 + try: + answer = llm_api.ask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"]) + logger.info(answer) + assert len(answer) > 0 + except openai.NotFoundError: + assert CONFIG.openai_api_type == "azure" @pytest.mark.asyncio async def test_llm_api_aask(self, llm_api): @@ -35,9 +39,12 @@ class TestGPT: @pytest.mark.asyncio async def test_llm_api_aask_code(self, llm_api): - answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"]) - logger.info(answer) - assert len(answer) > 0 + try: + answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"]) + logger.info(answer) + assert len(answer) > 0 + except openai.NotFoundError: + assert CONFIG.openai_api_type == "azure" @pytest.mark.asyncio async def test_llm_api_costs(self, llm_api): @@ -47,5 +54,5 @@ class TestGPT: assert costs.total_cost > 0 -# if __name__ == "__main__": -# pytest.main([__file__, "-s"]) +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_llm.py b/tests/metagpt/test_llm.py index 408fd3162..d972e55c0 100644 --- a/tests/metagpt/test_llm.py +++ b/tests/metagpt/test_llm.py @@ -4,11 +4,12 @@ @Time : 2023/5/11 14:45 @Author : alexanderwu @File : test_llm.py +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ import pytest -from metagpt.llm import LLM +from metagpt.provider.openai_api import OpenAIGPTAPI as LLM @pytest.fixture() @@ -34,5 +35,5 @@ async def test_llm_acompletion(llm): assert len(await llm.acompletion_batch_text([hello_msg])) > 0 -# if __name__ == "__main__": -# pytest.main([__file__, "-s"]) +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_repo_parser.py b/tests/metagpt/test_repo_parser.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 1742757e8..897d203c7 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -10,8 +10,6 @@ import json -import pytest - from metagpt.actions import Action from metagpt.actions.action_node import ActionNode from metagpt.actions.write_code import WriteCode @@ -19,7 +17,6 @@ from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage from metagpt.utils.common import any_to_str -@pytest.mark.asyncio def test_messages(): test_content = "test_message" msgs = [ @@ -33,7 +30,6 @@ def test_messages(): assert all([i in text for i in roles]) -@pytest.mark.asyncio def test_message(): m = Message(content="a", role="v1") v = m.dump() @@ -64,7 +60,6 @@ def test_message(): assert m.content == "b" -@pytest.mark.asyncio def test_routes(): m = Message(content="a", role="b", cause_by="c", x="d", send_to="c") m.send_to = "b" diff --git a/tests/metagpt/test_startup.py b/tests/metagpt/test_startup.py index c34fd2c31..c8d4d5d29 100644 --- a/tests/metagpt/test_startup.py +++ b/tests/metagpt/test_startup.py @@ -26,3 +26,7 @@ async def test_team(): # def test_startup(): # args = ["Make a 2048 game"] # result = runner.invoke(app, args) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_subscription.py b/tests/metagpt/test_subscription.py index 75e06411c..b902d5416 100644 --- a/tests/metagpt/test_subscription.py +++ b/tests/metagpt/test_subscription.py @@ -100,3 +100,7 @@ async def test_subscription_run_error(loguru_caplog): logs = "".join(loguru_caplog.messages) assert "run error" in logs assert "has completed" in logs + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_azure_tts.py b/tests/metagpt/tools/test_azure_tts.py new file mode 100644 index 000000000..b7f94a19c --- /dev/null +++ b/tests/metagpt/tools/test_azure_tts.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/7/1 22:50 +@Author : alexanderwu +@File : test_azure_tts.py +@Modified By: mashenquan, 2023-8-9, add more text formatting options +@Modified By: mashenquan, 2023-8-17, move to `tools` folder. +""" +import asyncio + +from metagpt.config import CONFIG +from metagpt.tools.azure_tts import AzureTTS + + +def test_azure_tts(): + azure_tts = AzureTTS(subscription_key="", region="") + text = """ + 女儿看见父亲走了进来,问道: + + “您来的挺快的,怎么过来的?” + + 父亲放下手提包,说: + + “Writing a binary file in Python is similar to writing a regular text file, but you'll work with bytes instead of strings.” + + """ + path = CONFIG.workspace / "tts" + path.mkdir(exist_ok=True, parents=True) + filename = path / "girl.wav" + loop = asyncio.new_event_loop() + v = loop.create_task( + azure_tts.synthesize_speech(lang="zh-CN", voice="zh-CN-XiaomoNeural", text=text, output_file=str(filename)) + ) + result = loop.run_until_complete(v) + + print(result) + + # 运行需要先配置 SUBSCRIPTION_KEY + # TODO: 这里如果要检验,还要额外加上对应的asr,才能确保前后生成是接近一致的,但现在还没有 + + +if __name__ == "__main__": + test_azure_tts() diff --git a/tests/metagpt/tools/test_web_browser_engine.py b/tests/metagpt/tools/test_web_browser_engine.py index 28dd0e15c..1e4e956f2 100644 --- a/tests/metagpt/tools/test_web_browser_engine.py +++ b/tests/metagpt/tools/test_web_browser_engine.py @@ -1,5 +1,10 @@ +""" +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. +""" + import pytest +from metagpt.config import Config from metagpt.tools import WebBrowserEngineType, web_browser_engine @@ -13,7 +18,8 @@ from metagpt.tools import WebBrowserEngineType, web_browser_engine ids=["playwright", "selenium"], ) async def test_scrape_web_page(browser_type, url, urls): - browser = web_browser_engine.WebBrowserEngine(browser_type) + conf = Config() + browser = web_browser_engine.WebBrowserEngine(options=conf.runtime_options, engine=browser_type) result = await browser.run(url) assert isinstance(result, str) assert "深度赋智" in result diff --git a/tests/metagpt/tools/test_web_browser_engine_playwright.py b/tests/metagpt/tools/test_web_browser_engine_playwright.py index e9ea80b10..cc6c09925 100644 --- a/tests/metagpt/tools/test_web_browser_engine_playwright.py +++ b/tests/metagpt/tools/test_web_browser_engine_playwright.py @@ -1,6 +1,10 @@ +""" +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. +""" + import pytest -from metagpt.config import CONFIG +from metagpt.config import Config from metagpt.tools import web_browser_engine_playwright @@ -15,22 +19,25 @@ from metagpt.tools import web_browser_engine_playwright ids=["chromium-normal", "firefox-normal", "webkit-normal"], ) async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd): + conf = Config() + global_proxy = conf.global_proxy try: - global_proxy = CONFIG.global_proxy if use_proxy: - CONFIG.global_proxy = proxy - browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type, **kwagrs) + conf.global_proxy = proxy + browser = web_browser_engine_playwright.PlaywrightWrapper( + options=conf.runtime_options, browser_type=browser_type, **kwagrs + ) result = await browser.run(url) result = result.inner_text assert isinstance(result, str) - assert "Deepwisdom" in result + assert "DeepWisdom" in result if urls: results = await browser.run(url, *urls) assert isinstance(results, list) assert len(results) == len(urls) + 1 - assert all(("Deepwisdom" in i) for i in results) + assert all(("DeepWisdom" in i) for i in results) if use_proxy: assert "Proxy:" in capfd.readouterr().out finally: - CONFIG.global_proxy = global_proxy + conf.global_proxy = global_proxy diff --git a/tests/metagpt/tools/test_web_browser_engine_selenium.py b/tests/metagpt/tools/test_web_browser_engine_selenium.py index ac6eafee7..77f4d8592 100644 --- a/tests/metagpt/tools/test_web_browser_engine_selenium.py +++ b/tests/metagpt/tools/test_web_browser_engine_selenium.py @@ -1,6 +1,10 @@ +""" +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. +""" + import pytest -from metagpt.config import CONFIG +from metagpt.config import Config from metagpt.tools import web_browser_engine_selenium @@ -15,11 +19,12 @@ from metagpt.tools import web_browser_engine_selenium ids=["chrome-normal", "firefox-normal", "edge-normal"], ) async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd): + conf = Config() + global_proxy = conf.global_proxy try: - global_proxy = CONFIG.global_proxy if use_proxy: - CONFIG.global_proxy = proxy - browser = web_browser_engine_selenium.SeleniumWrapper(browser_type) + conf.global_proxy = proxy + browser = web_browser_engine_selenium.SeleniumWrapper(options=conf.runtime_options, browser_type=browser_type) result = await browser.run(url) result = result.inner_text assert isinstance(result, str) @@ -33,4 +38,4 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd) if use_proxy: assert "Proxy:" in capfd.readouterr().out finally: - CONFIG.global_proxy = global_proxy + conf.global_proxy = global_proxy diff --git a/tests/metagpt/utils/test_config.py b/tests/metagpt/utils/test_config.py index b68a535f9..bd89f0ed3 100644 --- a/tests/metagpt/utils/test_config.py +++ b/tests/metagpt/utils/test_config.py @@ -4,19 +4,15 @@ @Time : 2023/5/1 11:19 @Author : alexanderwu @File : test_config.py +@Modified By: mashenquan, 2013/8/20, Add `test_options`; remove global configuration `CONFIG`, enable configuration support for business isolation. """ +from pathlib import Path import pytest from metagpt.config import Config -def test_config_class_is_singleton(): - config_1 = Config() - config_2 = Config() - assert config_1 == config_2 - - def test_config_class_get_key_exception(): with pytest.raises(Exception) as exc_info: config = Config() @@ -28,4 +24,14 @@ def test_config_yaml_file_not_exists(): config = Config("wtf.yaml") with pytest.raises(Exception) as exc_info: config.get("OPENAI_BASE_URL") - assert str(exc_info.value) == "Key 'OPENAI_BASE_URL' not found in environment variables or in the YAML file" + assert str(exc_info.value) == "Set OPENAI_API_KEY or Anthropic_API_KEY first" + + +def test_options(): + filename = Path(__file__).resolve().parent.parent.parent.parent / "config/config.yaml" + config = Config(filename) + assert config.options + + +if __name__ == "__main__": + test_options() diff --git a/tests/metagpt/utils/test_di_graph_repository.py b/tests/metagpt/utils/test_di_graph_repository.py new file mode 100644 index 000000000..0a8011e51 --- /dev/null +++ b/tests/metagpt/utils/test_di_graph_repository.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : test_di_graph_repository.py +@Desc : Unit tests for di_graph_repository.py +""" + +from pathlib import Path + +import pytest +from pydantic import BaseModel + +from metagpt.const import DEFAULT_WORKSPACE_ROOT +from metagpt.repo_parser import RepoParser +from metagpt.utils.di_graph_repository import DiGraphRepository +from metagpt.utils.graph_repository import GraphRepository + + +@pytest.mark.asyncio +async def test_di_graph_repository(): + class Input(BaseModel): + s: str + p: str + o: str + + inputs = [ + {"s": "main.py:Game:draw", "p": "method:hasDescription", "o": "Draw image"}, + {"s": "main.py:Game:draw", "p": "method:hasDescription", "o": "Show image"}, + ] + path = Path(__file__).parent + graph = DiGraphRepository(name="test", root=path) + for i in inputs: + data = Input(**i) + await graph.insert(subject=data.s, predicate=data.p, object_=data.o) + v = graph.json() + assert v + await graph.save() + assert graph.pathname.exists() + graph.pathname.unlink() + + +@pytest.mark.asyncio +async def test_js_parser(): + class Input(BaseModel): + path: str + + inputs = [ + {"path": str(Path(__file__).parent / "../../data/code")}, + ] + path = Path(__file__).parent + graph = DiGraphRepository(name="test", root=path) + for i in inputs: + data = Input(**i) + repo_parser = RepoParser(base_directory=data.path) + symbols = repo_parser.generate_symbols() + for s in symbols: + await GraphRepository.update_graph_db(graph_db=graph, file_info=s) + data = graph.json() + assert data + + +@pytest.mark.asyncio +async def test_codes(): + path = DEFAULT_WORKSPACE_ROOT / "snake_game" + repo_parser = RepoParser(base_directory=path) + + graph = DiGraphRepository(name="test", root=path) + symbols = repo_parser.generate_symbols() + for file_info in symbols: + for code_block in file_info.page_info: + try: + val = code_block.json(ensure_ascii=False) + assert val + except TypeError as e: + assert not e + await GraphRepository.update_graph_db(graph_db=graph, file_info=file_info) + data = graph.json() + assert data + print(data) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_token_counter.py b/tests/metagpt/utils/test_token_counter.py index 479ccc22d..acb99d717 100644 --- a/tests/metagpt/utils/test_token_counter.py +++ b/tests/metagpt/utils/test_token_counter.py @@ -15,7 +15,7 @@ def test_count_message_tokens(): {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ] - assert count_message_tokens(messages) == 17 + assert count_message_tokens(messages) == 15 def test_count_message_tokens_with_name(): @@ -67,3 +67,7 @@ def test_count_string_tokens_gpt_4(): string = "Hello, world!" assert count_string_tokens(string, model_name="gpt-4-0314") == 4 + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])