diff --git a/metagpt/config.py b/metagpt/config.py index 49d2fe36f..8af137808 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ -Provide configuration as a singleton. +Provide configuration, singleton """ import os import openai @@ -28,7 +28,7 @@ class NotConfiguredException(Exception): class Config(metaclass=Singleton): """ - Typical usage: + Common usage: config = Config("config.yaml") secret_key = config.get_key("MY_SECRET_KEY") print("Secret key:", secret_key) @@ -40,7 +40,7 @@ class Config(metaclass=Singleton): def __init__(self, yaml_file=default_yaml_file): self._configs = {} - self._initialize_with_config_files_and_environment(self._configs, yaml_file) + self._init_with_config_files_and_env(self._configs, yaml_file) logger.info("Config loading done.") self.global_proxy = self._get("GLOBAL_PROXY") self.openai_api_key = self._get("OPENAI_API_KEY") @@ -67,26 +67,26 @@ class Config(metaclass=Singleton): self.google_api_key = self._get("GOOGLE_API_KEY") self.google_cse_id = self._get("GOOGLE_CSE_ID") self.search_engine = self._get("SEARCH_ENGINE", SearchEngineType.SERPAPI_GOOGLE) - + self.web_browser_engine = WebBrowserEngineType(self._get("WEB_BROWSER_ENGINE", "playwright")) self.playwright_browser_type = self._get("PLAYWRIGHT_BROWSER_TYPE", "chromium") self.selenium_browser_type = self._get("SELENIUM_BROWSER_TYPE", "chrome") - + self.long_term_memory = self._get('LONG_TERM_MEMORY', False) if self.long_term_memory: logger.warning("LONG_TERM_MEMORY is True") self.max_budget = self._get("MAX_BUDGET", 10.0) self.total_cost = 0.0 - def _initialize_with_config_files_and_environment(self, configs: dict, yaml_file): - """Load configurations from config/key.yaml, config/config.yaml, and the environment, in decreasing order of priority.""" + def _init_with_config_files_and_env(self, configs: dict, yaml_file): + """Load from config/key.yaml / config/config.yaml / env in decreasing priority""" configs.update(os.environ) for _yaml_file in [yaml_file, self.key_yaml_file]: if not _yaml_file.exists(): continue - # Load local YAML files. + # Load local YAML file with open(_yaml_file, "r", encoding="utf-8") as file: yaml_data = yaml.safe_load(file) if not yaml_data: @@ -98,7 +98,7 @@ class Config(metaclass=Singleton): return self._configs.get(*args, **kwargs) def get(self, key, *args, **kwargs): - """Fetch a value from config/key.yaml, config/config.yaml, or the environment. Raises an error if not found.""" + """Find values from config/key.yaml / config/config.yaml / env, report an error if not found""" value = self._get(key, *args, **kwargs) if value is None: raise ValueError(f"Key '{key}' not found in environment variables or in the YAML file") diff --git a/metagpt/const.py b/metagpt/const.py index 861da7903..c8ce80279 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -9,7 +9,7 @@ from pathlib import Path def get_project_root(): - """Search upwards level by level for the project root directory.""" + """Search upwards to find the project root directory.""" current_path = Path.cwd() while True: if (current_path / '.git').exists() or \ diff --git a/metagpt/document_store/faiss_store.py b/metagpt/document_store/faiss_store.py index baa10ba1e..906963aa1 100644 --- a/metagpt/document_store/faiss_store.py +++ b/metagpt/document_store/faiss_store.py @@ -28,7 +28,7 @@ class FaissStore(LocalStore): def _load(self) -> Optional["FaissStore"]: index_file, store_file = self._get_index_and_store_fname() if not (index_file.exists() and store_file.exists()): - logger.info("Missing at least one of index_file/store_file, load failed and return None") + logger.info("At least one of the index_file/store_file is missing. Loading failed and returns None.") return None index = faiss.read_index(str(index_file)) with open(str(store_file), "rb") as f: diff --git a/metagpt/document_store/milvus_store.py b/metagpt/document_store/milvus_store.py index 0a8ed78d4..175c04d13 100644 --- a/metagpt/document_store/milvus_store.py +++ b/metagpt/document_store/milvus_store.py @@ -19,9 +19,8 @@ type_mapping = { np.ndarray: DataType.FLOAT_VECTOR } - def columns_to_milvus_schema(columns: dict, primary_col_name: str = "", desc: str = ""): - """Assuming the structure of columns is str: regular type.""" + """Assuming the structure of columns is str: standard type""" fields = [] for col, ctype in columns.items(): if ctype == str: @@ -34,13 +33,11 @@ def columns_to_milvus_schema(columns: dict, primary_col_name: str = "", desc: st schema = CollectionSchema(fields, description=desc) return schema - class MilvusConnection(TypedDict): alias: str host: str port: str - class MilvusStore(BaseStore): """ FIXME: ADD TESTS @@ -79,8 +76,8 @@ class MilvusStore(BaseStore): """ FIXME: ADD TESTS https://milvus.io/docs/v2.0.x/search.md - All search and query operations within Milvus are executed in memory. Load the collection to memory before conducting a vector similarity search. - Noting the above description, is this logic serious? The time taken for this should be long, right? + All search and query operations within Milvus are executed in memory. Load the collection into memory before conducting a vector similarity search. + Noting the above description, is this logic serious? This should be time-consuming, right? """ search_params = {"metric_type": "L2", "params": {"nprobe": 10}} results = self.collection.search( @@ -91,7 +88,7 @@ class MilvusStore(BaseStore): expr=None, consistency_level="Strong" ) - # FIXME: results contains ids, but to get the actual values from the ids, the query interface still needs to be called. + # FIXME: results contain an id, but to get the actual value for the id, you still need to call the query interface return results def write(self, name, schema, *args, **kwargs): diff --git a/metagpt/prompts/summarize.py b/metagpt/prompts/summarize.py index a187314f4..348debf07 100644 --- a/metagpt/prompts/summarize.py +++ b/metagpt/prompts/summarize.py @@ -20,71 +20,72 @@ summary. Pick a suitable emoji for every bullet point. Your response should be i a YouTube video, use the following text: {{CONTENT}}. """ -# From GCP-VertexAI-Text Summarization +# From GCP-VertexAI-Text Summarization (SUMMARIZE_PROMPT_2-5 are all from this) # https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/examples/prompt-design/text_summarization.ipynb -# For longer documents, a map-reduce process is needed, see the following notebook +# Long documents need a map-reduce process, see the following notebook # https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/examples/document-summarization/summarization_large_documents.ipynb SUMMARIZE_PROMPT_2 = """ Provide a very short summary, no more than three sentences, for the following article: -Our quantum computers work by manipulating qubits in a manner we call quantum algorithms. -The challenge is that qubits are extremely sensitive, to the extent that even stray light can introduce calculation errors — a problem that intensifies as quantum computers scale. -This has notable ramifications since the most effective quantum algorithms we know for executing valuable applications necessitate that our qubits' error rates be significantly lower than current levels. -To address this discrepancy, quantum error correction is essential. -Quantum error correction safeguards information by distributing it over several physical qubits, forming a “logical qubit.” This is believed to be the sole method to create a large-scale quantum computer with sufficiently low error rates for practical calculations. -Rather than computing on individual qubits, we will utilize logical qubits. By transforming a greater number of physical qubits on our quantum processor into a single logical qubit, we aim to reduce error rates, enabling viable quantum algorithms. +Our quantum computers work by manipulating qubits in an orchestrated fashion that we call quantum algorithms. +The challenge is that qubits are so sensitive that even stray light can cause calculation errors — and the problem worsens as quantum computers grow. +This has significant consequences, since the best quantum algorithms that we know for running useful applications require the error rates of our qubits to be far lower than we have today. +To bridge this gap, we will need quantum error correction. +Quantum error correction protects information by encoding it across multiple physical qubits to form a “logical qubit,” and is believed to be the only way to produce a large-scale quantum computer with error rates low enough for useful calculations. +Instead of computing on the individual qubits themselves, we will then compute on logical qubits. By encoding larger numbers of physical qubits on our quantum processor into one logical qubit, we hope to reduce the error rates to enable useful quantum algorithms. Summary: """ + SUMMARIZE_PROMPT_3 = """ Provide a TL;DR for the following article: -Our quantum computers operate by controlling qubits in a method termed quantum algorithms. -The problem is that qubits are incredibly delicate, so much so that even minimal light interference can introduce computational errors — and this issue becomes more pronounced as quantum computers expand. -This is consequential because the most potent quantum algorithms we are aware of, for practical applications, demand that our qubits' error rates be substantially below current standards. -To mitigate this, quantum error correction is pivotal. -Quantum error correction secures data by distributing it across numerous physical qubits, generating a “logical qubit.” It's believed to be the exclusive approach to develop a large-scale quantum computer with error rates low enough for practical operations. -Instead of operations on individual qubits, we'll focus on logical qubits. By encoding a greater number of physical qubits on our quantum device into a single logical qubit, we aspire to diminish error rates and enable efficient quantum algorithms. +Our quantum computers work by manipulating qubits in an orchestrated fashion that we call quantum algorithms. +The challenge is that qubits are so sensitive that even stray light can cause calculation errors — and the problem worsens as quantum computers grow. +This has significant consequences, since the best quantum algorithms that we know for running useful applications require the error rates of our qubits to be far lower than we have today. +To bridge this gap, we will need quantum error correction. +Quantum error correction protects information by encoding it across multiple physical qubits to form a “logical qubit,” and is believed to be the only way to produce a large-scale quantum computer with error rates low enough for useful calculations. +Instead of computing on the individual qubits themselves, we will then compute on logical qubits. By encoding larger numbers of physical qubits on our quantum processor into one logical qubit, we hope to reduce the error rates to enable useful quantum algorithms. TL;DR: """ + SUMMARIZE_PROMPT_4 = """ Provide a very short summary in four bullet points for the following article: -Our quantum computers function by manipulating qubits through a method known as quantum algorithms. -The dilemma is that qubits are exceedingly fragile, so much so that even minimal light can lead to computational inaccuracies — and this problem amplifies as quantum computers become larger. -This is significant because the most proficient quantum algorithms known to us, suitable for real-world applications, necessitate that our qubits' error rates be significantly below what we currently observe. -To bridge this disparity, quantum error correction becomes indispensable. -Quantum error correction secures data by spreading it across multiple physical qubits, resulting in a “logical qubit.” It's perceived as the only technique to manufacture a large-scale quantum computer with error rates sufficiently low for practical tasks. -Instead of operating on individual qubits directly, we'll be utilizing logical qubits. By converting more physical qubits on our quantum machine into a single logical qubit, we intend to lower error rates, facilitating effective quantum algorithms. +Our quantum computers work by manipulating qubits in an orchestrated fashion that we call quantum algorithms. +The challenge is that qubits are so sensitive that even stray light can cause calculation errors — and the problem worsens as quantum computers grow. +This has significant consequences, since the best quantum algorithms that we know for running useful applications require the error rates of our qubits to be far lower than we have today. +To bridge this gap, we will need quantum error correction. +Quantum error correction protects information by encoding it across multiple physical qubits to form a “logical qubit,” and is believed to be the only way to produce a large-scale quantum computer with error rates low enough for useful calculations. +Instead of computing on the individual qubits themselves, we will then compute on logical qubits. By encoding larger numbers of physical qubits on our quantum processor into one logical qubit, we hope to reduce the error rates to enable useful quantum algorithms. Bulletpoints: """ + SUMMARIZE_PROMPT_5 = """ Please generate a summary of the following conversation and at the end summarize the to-do's for the support Agent: Customer: Hi, I'm Larry, and I received the wrong item. -Support Agent: Hi, Larry. How would you like this issue to be resolved? +Support Agent: Hi, Larry. How would you like to see this resolved? -Customer: That's alright. I'd like to return the item and receive a refund, please. +Customer: That's alright. I want to return the item and get a refund, please. -Support Agent: Certainly. I can process the refund for you right now. Could I have your order number, please? +Support Agent: Of course. I can process the refund for you now. Can I have your order number, please? Customer: It's [ORDER NUMBER]. -Support Agent: Thanks. I've processed the refund, and you should receive your funds within 14 days. +Support Agent: Thank you. I've processed the refund, and you will receive your money back within 14 days. -Customer: I appreciate it. +Customer: Thank you very much. -Support Agent: You're welcome, Larry. Have a great day! +Support Agent: You're welcome, Larry. Have a good day! Summary: -""" - -# - def summarize(doc: str) -> str # Input a document and receive a summary. +""" \ No newline at end of file diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index a48f4fc9d..4b171917a 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -1,243 +1,187 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ -@Time : 2023/5/5 23:08 +@Time : 2023/5/11 14:43 @Author : alexanderwu -@File : openai.py +@File : engineer.py """ import asyncio -import time -from functools import wraps -from typing import NamedTuple +import shutil +from collections import OrderedDict +from pathlib import Path -import openai - -from metagpt.config import CONFIG +from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.utils.singleton import Singleton -from metagpt.utils.token_counter import ( - TOKEN_COSTS, - count_message_tokens, - count_string_tokens, -) +from metagpt.roles import Role +from metagpt.actions import WriteCode, WriteCodeReview, WriteTasks, WriteDesign +from metagpt.schema import Message +from metagpt.utils.common import CodeParser -def retry(max_retries): - def decorator(f): - @wraps(f) - async def wrapper(*args, **kwargs): - for i in range(max_retries): +async def gather_ordered_k(coros, k) -> list: + tasks = OrderedDict() + results = [None] * len(coros) + done_queue = asyncio.Queue() + + for i, coro in enumerate(coros): + if len(tasks) >= k: + done, _ = await asyncio.wait(tasks.keys(), return_when=asyncio.FIRST_COMPLETED) + for task in done: + index = tasks.pop(task) + await done_queue.put((index, task.result())) + task = asyncio.create_task(coro) + tasks[task] = i + + if tasks: + done, _ = await asyncio.wait(tasks.keys()) + for task in done: + index = tasks[task] + await done_queue.put((index, task.result())) + + while not done_queue.empty(): + index, result = await done_queue.get() + results[index] = result + + return results + + +class Engineer(Role): + def __init__(self, name="Alex", profile="Engineer", goal="Write elegant, readable, extensible, efficient code", + constraints="The code you write should conform to code standard like PEP8, be modular, easy to read and maintain", + n_borg=1, use_code_review=False): + super().__init__(name, profile, goal, constraints) + self._init_actions([WriteCode]) + self.use_code_review = use_code_review + if self.use_code_review: + self._init_actions([WriteCode, WriteCodeReview]) + self._watch([WriteTasks]) + self.todos = [] + self.n_borg = n_borg + + @classmethod + def parse_tasks(cls, task_msg: Message) -> list[str]: + if not task_msg.instruct_content: + return task_msg.instruct_content.dict().get("Task list") + return CodeParser.parse_file_list(block="Task list", text=task_msg.content) + + @classmethod + def parse_code(cls, code_text: str) -> str: + return CodeParser.parse_code(block="", text=code_text) + + @classmethod + def parse_workspace(cls, system_design_msg: Message) -> str: + if not system_design_msg.instruct_content: + return system_design_msg.instruct_content.dict().get("Python package name") + return CodeParser.parse_str(block="Python package name", text=system_design_msg.content) + + def get_workspace(self) -> Path: + msg = self._rc.memory.get_by_action(WriteDesign)[-1] + if not msg: + return WORKSPACE_ROOT / 'src' + workspace = self.parse_workspace(msg) + # Codes are written in workspace/{package_name}/{package_name} + return WORKSPACE_ROOT / workspace / workspace + + def recreate_workspace(self): + workspace = self.get_workspace() + try: + shutil.rmtree(workspace) + except FileNotFoundError: + pass # Directory does not exist, but we don't care + workspace.mkdir(parents=True, exist_ok=True) + + def write_file(self, filename: str, code: str): + workspace = self.get_workspace() + file = workspace / filename + file.parent.mkdir(parents=True, exist_ok=True) + file.write_text(code) + + def recv(self, message: Message) -> None: + self._rc.memory.add(message) + if message in self._rc.important_memory: + self.todos = self.parse_tasks(message) + + async def _act_mp(self) -> Message: + # self.recreate_workspace() + todo_coros = [] + for todo in self.todos: + todo_coro = WriteCode().run( + context=self._rc.memory.get_by_actions([WriteTasks, WriteDesign]), + filename=todo + ) + todo_coros.append(todo_coro) + + rsps = await gather_ordered_k(todo_coros, self.n_borg) + for todo, code_rsp in zip(self.todos, rsps): + _ = self.parse_code(code_rsp) + logger.info(todo) + logger.info(code_rsp) + # self.write_file(todo, code) + msg = Message(content=code_rsp, role=self.profile, cause_by=type(self._rc.todo)) + self._rc.memory.add(msg) + del self.todos[0] + + logger.info(f'Done {self.get_workspace()} generating.') + msg = Message(content="all done.", role=self.profile, cause_by=type(self._rc.todo)) + return msg + + async def _act_sp(self) -> Message: + for todo in self.todos: + code_rsp = await WriteCode().run( + context=self._rc.history, + filename=todo + ) + # logger.info(todo) + # logger.info(code_rsp) + # code = self.parse_code(code_rsp) + self.write_file(todo, code_rsp) + msg = Message(content=code_rsp, role=self.profile, cause_by=type(self._rc.todo)) + self._rc.memory.add(msg) + + logger.info(f'Done {self.get_workspace()} generating.') + msg = Message(content="all done.", role=self.profile, cause_by=type(self._rc.todo)) + return msg + + async def _act_sp_precision(self) -> Message: + for todo in self.todos: + """ + # Select essential information from historical information to reduce prompt length (summarized from human experience) + 1. All from Architect + 2. All from ProjectManager + 3. Do you need other codes (currently needed)? + TODO: The goal is not to need it. Once tasks are split clearly, according to the design idea, the code can be written clearly for each file without other codes. If it can't, it means that it still needs to be defined more clearly, this is the key to write long code. + """ + context = [] + msg = self._rc.memory.get_by_actions([WriteDesign, WriteTasks, WriteCode]) + for m in msg: + context.append(m.content) + context_str = "\n".join(context) + # Writing code + code = await WriteCode().run( + context=context_str, + filename=todo + ) + # Code review + if self.use_code_review: try: - return await f(*args, **kwargs) - except Exception: - if i == max_retries - 1: - raise - await asyncio.sleep(2 ** i) - return wrapper - return decorator + rewrite_code = await WriteCodeReview().run( + context=context_str, + code=code, + filename=todo + ) + code = rewrite_code + except Exception as e: + logger.error("code review failed!", e) + pass + self.write_file(todo, code) + msg = Message(content=code, role=self.profile, cause_by=WriteCode) + self._rc.memory.add(msg) + logger.info(f'Done {self.get_workspace()} generating.') + msg = Message(content="all done.", role=self.profile, cause_by=WriteCode) + return msg -class RateLimiter: - """Rate control class, each call goes through wait_if_needed, sleep if rate control is needed.""" - def __init__(self, rpm): - self.last_call_time = 0 - self.interval = 1.1 * 60 / rpm # Using 1.1 since strict adherence to time can still lead to QoS issues; consider simple error retry later. - self.rpm = rpm - - def split_batches(self, batch): - return [batch[i:i + self.rpm] for i in range(0, len(batch), self.rpm)] - - async def wait_if_needed(self, num_requests): - current_time = time.time() - elapsed_time = current_time - self.last_call_time - - if elapsed_time < self.interval * num_requests: - remaining_time = self.interval * num_requests - elapsed_time - logger.info(f"sleep {remaining_time}") - await asyncio.sleep(remaining_time) - - self.last_call_time = time.time() - - -class Costs(NamedTuple): - total_prompt_tokens: int - total_completion_tokens: int - total_cost: float - total_budget: float - - -class CostManager(metaclass=Singleton): - """Calculate the cost of using the API.""" - def __init__(self): - self.total_prompt_tokens = 0 - self.total_completion_tokens = 0 - self.total_cost = 0 - self.total_budget = 0 - - def update_cost(self, prompt_tokens, completion_tokens, model): - """ - Update the total cost, prompt tokens, and completion tokens. - - Args: - prompt_tokens (int): The number of tokens used in the prompt. - completion_tokens (int): The number of tokens used in the completion. - model (str): The model used for the API call. - """ - self.total_prompt_tokens += prompt_tokens - self.total_completion_tokens += completion_tokens - cost = ( - prompt_tokens * TOKEN_COSTS[model]["prompt"] - + completion_tokens * TOKEN_COSTS[model]["completion"] - ) / 1000 - self.total_cost += cost - logger.info(f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | " - f"Current cost: ${cost:.3f}, {prompt_tokens=}, {completion_tokens=}") - CONFIG.total_cost = self.total_cost - - def get_total_prompt_tokens(self): - """Get the total number of prompt tokens.""" - return self.total_prompt_tokens - - def get_total_completion_tokens(self): - """Get the total number of completion tokens.""" - return self.total_completion_tokens - - def get_total_cost(self): - """Get the total cost of API calls.""" - return self.total_cost - - def get_costs(self) -> Costs: - """Get all costs.""" - return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget) - - -class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): - """ - Check https://platform.openai.com/examples for examples. - """ - def __init__(self): - self.__init_openai(CONFIG) - self.llm = openai - self.model = CONFIG.openai_api_model - self._cost_manager = CostManager() - RateLimiter.__init__(self, rpm=self.rpm) - - def __init_openai(self, config): - openai.api_key = config.openai_api_key - if config.openai_api_base: - openai.api_base = config.openai_api_base - if config.openai_api_type: - openai.api_type = config.openai_api_type - openai.api_version = config.openai_api_version - self.rpm = int(config.get("RPM", 10)) - - async def _achat_completion_stream(self, messages: list[dict]) -> str: - response = await openai.ChatCompletion.acreate( - **self._cons_kwargs(messages), - stream=True - ) - - # create variables to collect the stream of chunks - collected_chunks = [] - collected_messages = [] - # iterate through the stream of events - async for chunk in response: - collected_chunks.append(chunk) # save the event response - chunk_message = chunk['choices'][0]['delta'] # extract the message - collected_messages.append(chunk_message) # save the message - if "content" in chunk_message: - print(chunk_message["content"], end="") - print() - - full_reply_content = ''.join([m.get('content', '') for m in collected_messages]) - usage = self._calc_usage(messages, full_reply_content) - self._update_costs(usage) - return full_reply_content - - def _cons_kwargs(self, messages: list[dict]) -> dict: - if CONFIG.openai_api_type == 'azure': - kwargs = { - "deployment_id": CONFIG.deployment_id, - "messages": messages, - "max_tokens": CONFIG.max_tokens_rsp, - "n": 1, - "stop": None, - "temperature": 0.3 - } - else: - kwargs = { - "model": self.model, - "messages": messages, - "max_tokens": CONFIG.max_tokens_rsp, - "n": 1, - "stop": None, - "temperature": 0.3 - } - return kwargs - - async def _achat_completion(self, messages: list[dict]) -> dict: - rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages)) - self._update_costs(rsp.get('usage')) - return rsp - - def _chat_completion(self, messages: list[dict]) -> dict: - rsp = self.llm.ChatCompletion.create(**self._cons_kwargs(messages)) - self._update_costs(rsp) - return rsp - - def completion(self, messages: list[dict]) -> dict: - return self._chat_completion(messages) - - async def acompletion(self, messages: list[dict]) -> dict: - return await self._achat_completion(messages) - - @retry(max_retries=6) - async def acompletion_text(self, messages: list[dict], stream=False) -> str: - """When streaming, print each token in place.""" - if stream: - return await self._achat_completion_stream(messages) - rsp = await self._achat_completion(messages) - return self.get_choice_text(rsp) - - def _calc_usage(self, messages: list[dict], rsp: str) -> dict: - usage = {} - prompt_tokens = count_message_tokens(messages, self.model) - completion_tokens = count_string_tokens(rsp, self.model) - usage['prompt_tokens'] = prompt_tokens - usage['completion_tokens'] = completion_tokens - return usage - - async def acompletion_batch(self, batch: list[list[dict]]) -> list[dict]: - """Returns the full JSON.""" - split_batches = self.split_batches(batch) - all_results = [] - - for small_batch in split_batches: - logger.info(small_batch) - await self.wait_if_needed(len(small_batch)) - - future = [self.acompletion(prompt) for prompt in small_batch] - results = await asyncio.gather(*future) - logger.info(results) - all_results.extend(results) - - return all_results - - async def acompletion_batch_text(self, batch: list[list[dict]]) -> list[str]: - """Returns only plain text.""" - raw_results = await self.acompletion_batch(batch) - results = [] - for idx, raw_result in enumerate(raw_results, start=1): - result = self.get_choice_text(raw_result) - results.append(result) - logger.info(f"Result of task {idx}: {result}") - return results - - def _update_costs(self, usage: dict): - prompt_tokens = int(usage['prompt_tokens']) - completion_tokens = int(usage['completion_tokens']) - self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) - - def get_costs(self) -> Costs: - return self._cost_manager.get_costs() + async def _act(self) -> Message: + if self.use_code_review: + return await self._act_sp_precision() + return await self._act_sp() diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py index aa776f662..606952b99 100644 --- a/metagpt/tools/sd_engine.py +++ b/metagpt/tools/sd_engine.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # @Date : 2023/7/19 16:28 # @Author : stellahong (stellahong@fuzhi.ai) -# @Desc : +# @Description : + import os import asyncio from os.path import join @@ -55,6 +56,7 @@ payload = { default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution" + class SDEngine: def __init__(self): # Initialize the SDEngine with configuration @@ -65,7 +67,7 @@ class SDEngine: self.payload = payload logger.info(self.sd_t2i_url) - def construct_payload(self, prompt, negtive_prompt=default_negative_prompt, width=512, height=512, + def construct_payload(self, prompt, negative_prompt=default_negative_prompt, width=512, height=512, sd_model="galaxytimemachinesGTM_photoV20"): # Configure the payload with provided inputs self.payload["prompt"] = prompt @@ -101,11 +103,11 @@ class SDEngine: return imgs async def run_i2i(self): - # TODO: Add image-to-image API call + # TODO: Add a method to call the image-to-image interface raise NotImplementedError async def run_sam(self): - # TODO: Add SAM API call + # TODO: Add a method to call the SAM interface raise NotImplementedError def decode_base64_to_image(img, save_name): diff --git a/metagpt/tools/ut_writer.py b/metagpt/tools/ut_writer.py index aca335246..23604ac54 100644 --- a/metagpt/tools/ut_writer.py +++ b/metagpt/tools/ut_writer.py @@ -6,18 +6,18 @@ from pathlib import Path from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI -ICL_SAMPLE = '''API Definition: +ICL_SAMPLE = '''Interface definition: ```text -API Name: Tag Elements -API Path: /projects/{project_key}/node-tags -Method: POST +Interface Name: Tag Elements +Interface Path: /projects/{project_key}/node-tags +Method: POST Request Parameters: Path Parameters: project_key Body Parameters: -Name Type Required Default Value Description +Name Type Required Default Value Description nodes array Yes Nodes node_key string No Node key tags array No Original node tag list @@ -26,11 +26,11 @@ operations array Yes tags array No Operation tag list mode string No Operation type ADD / DELETE -Response Data: -Name Type Required Default Value Description +Return Data: +Name Type Required Default Value Description code integer Yes Status code msg string Yes Message -data object Yes Response data +data object Yes Return data list array No Node list true / false node_type string No Node type DATASET / RECIPE node_key string No Node key @@ -43,7 +43,7 @@ Unit Test: [ ("project_key", [{"node_key": "dataset_001", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "success"), ("project_key", [{"node_key": "dataset_002", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["tag1"], "mode": "DELETE"}], "success"), -("", [{"node_key": "dataset_001", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "Missing required parameter project_key"), +("", [{"node_key": "dataset_001", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "Missing necessary parameter project_key"), (123, [{"node_key": "dataset_001", "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "Incorrect parameter type"), ("project_key", [{"node_key": "a"*201, "tags": ["tag1", "tag2"], "node_type": "DATASET"}], [{"tags": ["new_tag1"], "mode": "ADD"}], "Request parameter exceeds field boundary") ] @@ -51,64 +51,67 @@ Unit Test: def test_node_tags(project_key, nodes, operations, expected_msg): pass ``` -Above is an example of an API definition and a unit test sample. -Next, please play the role of a test manager from Google with 20 years of experience. After I provide the API definition, reply with the unit test. There are a few requirements: -1. Only output one '@pytest.mark.parametrize' and its corresponding 'test_' function (with only a 'pass' statement inside, no implementation). --- The function parameters should include 'expected_msg' for result validation. -2. The generated test cases should use shorter text or numbers and be as compact as possible. +The above is an example of interface definition and unit test. +Next, please act as an expert test manager with 20 years of experience at Google. +After I provide the interface definition, please reply with the unit test. +There are a few requirements: +1. Only output one `@pytest.mark.parametrize` and the corresponding test_ function + (with a pass inside, not implemented). + -- The function parameters should include expected_msg for result validation. +2. The generated test cases should use shorter text or numbers and be as concise as possible. 3. If comments are needed, use Chinese. -If you understand, please wait for me to provide the API definition and only reply with "Understood" to save tokens. +If you understand, please wait for me to provide the interface definition +and only reply with "Understood" to save tokens. ''' -ACT_PROMPT_PREFIX = '''Reference test types: such as missing request parameters, field boundary validation, incorrect field type. +ACT_PROMPT_PREFIX = '''Reference test types: such as missing request parameters, field boundary checks, incorrect field types. Please output 10 test cases within a `@pytest.mark.parametrize` scope. ```text ''' -YFT_PROMPT_PREFIX = '''Reference test types: such as SQL injection, cross-site scripting (XSS), illegal access and unauthorized access, authentication and authorization, parameter validation, exception handling, file upload and download. +YFT_PROMPT_PREFIX = '''Reference test types: such as SQL injection, cross-site scripting (XSS), illegal access and unauthorized access, authentication and authorization, parameter verification, exception handling, file upload and download. Please output 10 test cases within a `@pytest.mark.parametrize` scope. ```text ''' - OCR_API_DOC = '''```text -API Name: OCR Recognition -API Path: /api/v1/contract/treaty/task/ocr -Method: POST +API Name: OCR Recognition +API Path: /api/v1/contract/treaty/task/ocr +Method: POST Request Parameters: Path Parameters: Body Parameters: -Name Type Required Default Value Remarks -file_id string Yes -box array Yes -contract_id number Yes Contract ID -start_time string No yyyy-mm-dd -end_time string No yyyy-mm-dd -extract_type number No Recognition type 1-During import 2-After import, default is 1 +Name Type Mandatory Default Value Remarks +file_id string Yes +box array Yes +contract_id number Yes Contract ID +start_time string No yyyy-mm-dd +end_time string No yyyy-mm-dd +extract_type number No Recognition Type 1- During Import 2- After Import, Default is 1 -Response Data: -Name Type Required Default Value Remarks -code integer Yes -message string Yes -data object Yes +Return Data: +Name Type Mandatory Default Value Remarks +code integer Yes +message string Yes +data object Yes + ''' - class UTGenerator: - """UT Generator: Construct UT through API documentation.""" + """UT Generator: Constructs UT from API documentation.""" def __init__(self, swagger_file: str, ut_py_path: str, questions_path: str, chatgpt_method: str = "API", template_prefix=YFT_PROMPT_PREFIX) -> None: - """Initialize the UT Generator. + """Initializes the UT generator. Args: - swagger_file: Path to the swagger file. + swagger_file: Path to the swagger. ut_py_path: Path to store test cases. - questions_path: Path to store templates for future investigation. - chatgpt_method: API - template_prefix: Use template, default is YFT_UT_PROMPT. + questions_path: Path to store templates for further investigation. + chatgpt_method: API. + template_prefix: Use template, defaults to YFT_UT_PROMPT. """ self.swagger_file = swagger_file self.ut_py_path = ut_py_path @@ -116,56 +119,56 @@ class UTGenerator: assert chatgpt_method in ["API"], "Invalid chatgpt_method" self.chatgpt_method = chatgpt_method - # ICL: In-Context Learning. Provide an example here for GPT to mimic. + # ICL: In-Context Learning. Here, an example is provided for GPT to follow. self.icl_sample = ICL_SAMPLE self.template_prefix = template_prefix def get_swagger_json(self) -> dict: - """Load Swagger JSON from a local file.""" + """Loads Swagger JSON from a local file.""" with open(self.swagger_file, "r", encoding="utf-8") as file: swagger_json = json.load(file) return swagger_json - def __parameter_to_string(self, prop, required, name=""): + def __para_to_str(self, prop, required, name=""): name = name or prop["name"] ptype = prop["type"] title = prop.get("title", "") desc = prop.get("description", "") return f'{name}\t{ptype}\t{"Yes" if required else "No"}\t{title}\t{desc}' - def _parameter_to_string(self, prop): + def _para_to_str(self, prop): required = prop.get("required", False) - return self.__parameter_to_string(prop, required) + return self.__para_to_str(prop, required) - def parameter_to_string(self, name, prop, prop_object_required): + def para_to_str(self, name, prop, prop_object_required): required = name in prop_object_required - return self.__parameter_to_string(prop, required, name) + return self.__para_to_str(prop, required, name) def build_object_properties(self, node, prop_object_required, level: int = 0) -> str: - """Recursively output properties of object and array[object] types. + """Recursively outputs properties of object and array[object] types. Args: - node (_type_): Value of the sub-item. - prop_object_required (_type_): Indicates if it's a required item. + node: Value of the sub-item. + prop_object_required: Whether it's a required item. level: Current recursion depth. """ doc = "" def dive_into_object(node): - """If it's an object type, recursively output its properties.""" + """If it's an object type, recursively outputs its properties.""" if node.get("type") == "object": sub_properties = node.get("properties", {}) return self.build_object_properties(sub_properties, prop_object_required, level=level + 1) return "" if node.get("in", "") in ["query", "header", "formData"]: - doc += f'{"\t" * level}{self._parameter_to_string(node)}\n' + doc += f'{"\t" * level}{self._para_to_str(node)}\n' doc += dive_into_object(node) return doc for name, prop in node.items(): - doc += f'{"\t" * level}{self.parameter_to_string(name, prop, prop_object_required)}\n' + doc += f'{"\t" * level}{self.para_to_str(name, prop, prop_object_required)}\n' doc += dive_into_object(prop) if prop["type"] == "array": items = prop.get("items", {}) @@ -173,10 +176,10 @@ class UTGenerator: return doc def get_tags_mapping(self) -> dict: - """Process tag and path. + """Handles tag and path mapping. Returns: - Dictionary: Correspondence of tag to path. + Dict: Mapping of tag to path. """ swagger_data = self.get_swagger_json() paths = swagger_data["paths"] @@ -194,7 +197,7 @@ class UTGenerator: return tags def generate_ut(self, include_tags) -> bool: - """Generate test case files.""" + """Generates test case files.""" tags = self.get_tags_mapping() for tag, paths in tags.items(): if include_tags is None or tag in include_tags: @@ -204,19 +207,19 @@ class UTGenerator: def build_api_doc(self, node: dict, path: str, method: str) -> str: summary = node["summary"] - doc = f"API Name: {summary}\nAPI Path: {path}\nMethod: {method.upper()}\n" - doc += "\nRequest Parameters:\n" + doc = f"Interface name: {summary}\nInterface path: {path}\nMethod: {method.upper()}\n" + doc += "\nRequest parameters:\n" if "parameters" in node: parameters = node["parameters"] - doc += "Path Parameters:\n" + doc += "Path parameters:\n" # param["in"]: path / formData / body / query / header for param in parameters: if param["in"] == "path": - doc += f'{param["name"]}\n' + doc += f'{param["name"]} \n' - doc += "\nBody Parameters:\n" - doc += "Name\tType\tRequired\tDefault Value\tRemarks\n" + doc += "\nBody parameters:\n" + doc += "Name\tType\tRequired\tDefault\tNotes\n" for param in parameters: if param["in"] == "body": schema = param.get("schema", {}) @@ -227,8 +230,8 @@ class UTGenerator: doc += self.build_object_properties(param, []) # Output return data information - doc += "\nReturn Data:\n" - doc += "Name\tType\tRequired\tDefault Value\tRemarks\n" + doc += "\nReturn data:\n" + doc += "Name\tType\tRequired\tDefault\tNotes\n" responses = node["responses"] response = responses.get("200", {}) schema = response.get("schema", {}) @@ -242,12 +245,13 @@ class UTGenerator: return doc def _store(self, data, base, folder, fname): + """Store data in a file.""" file_path = self.get_file_path(Path(base) / folder, fname) with open(file_path, "w", encoding="utf-8") as file: file.write(data) def ask_gpt_and_save(self, question: str, tag: str, fname: str): - """Generate a question and store both question and answer.""" + """Generate a question and save both the question and answer.""" messages = [self.icl_sample, question] result = self.gpt_msgs_to_code(messages=messages) @@ -255,11 +259,11 @@ class UTGenerator: self._store(result, self.ut_py_path, tag, f"{fname}.py") def _generate_ut(self, tag, paths): - """Handle structure under the data path. + """Process the structure under the data path. Args: - tag (_type_): Module name. - paths (_type_): Path Object. + tag: Module name. + paths: Path Object. """ for path, path_obj in paths.items(): for method, node in path_obj.items(): @@ -269,7 +273,7 @@ class UTGenerator: self.ask_gpt_and_save(question, tag, summary) def gpt_msgs_to_code(self, messages: list) -> str: - """Choose based on different invocation methods.""" + """Choose based on different call methods.""" result = '' if self.chatgpt_method == "API": result = GPTAPI().ask_code(msgs=messages) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index fb91d2c57..aa2f5bb98 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -14,10 +14,11 @@ from typing import List, Tuple from metagpt.logs import logger -def check_command_exists(command) -> int: - """ Check if a command exists. +def check_cmd_exists(command) -> int: + """Check if a command exists. + :param command: Command to check. - :return: Returns 0 if the command exists, else returns non-zero. + :return: Returns 0 if the command exists, otherwise returns a non-zero value. """ check_command = 'command -v ' + command + ' >/dev/null 2>&1 || { echo >&2 "no mermaid"; exit 1; }' result = os.system(check_command) @@ -28,19 +29,19 @@ class OutputParser: @classmethod def parse_blocks(cls, text: str): - # Firstly, split the text into different blocks based on "##". + # First, split the text into different blocks based on "##". blocks = text.split("##") # Create a dictionary to store the title and content of each block. block_dict = {} - # Loop through all blocks. + # Iterate through all blocks. for block in blocks: - # If block is not empty, continue processing. + # If the block is not empty, continue processing. if block.strip() != "": - # Split block's title and content and trim them. + # Split the block's title and content and trim whitespace. block_title, block_content = block.split("\n", 1) - # There may be errors in LLM, correct it here. + # LLM might make mistakes; correct it here. if block_title[-1] == ":": block_title = block_title[:-1] block_dict[block_title.strip()] = block_content.strip() @@ -84,7 +85,7 @@ class OutputParser: block_dict = cls.parse_blocks(data) parsed_data = {} for block, content in block_dict.items(): - # Try to remove code markers. + # Try to remove the code marker. try: content = cls.parse_code(text=content) except Exception: @@ -103,7 +104,7 @@ class OutputParser: block_dict = cls.parse_blocks(data) parsed_data = {} for block, content in block_dict.items(): - # Try to remove code markers. + # Try to remove the code marker. try: content = cls.parse_code(text=content) except Exception: @@ -135,17 +136,17 @@ class CodeParser: @classmethod def parse_blocks(cls, text: str): - # Firstly, split the text into different blocks based on "##". + # First, split the text into different blocks based on "##". blocks = text.split("##") # Create a dictionary to store the title and content of each block. block_dict = {} - # Loop through all blocks. + # Iterate through all blocks. for block in blocks: - # If block is not empty, continue processing. + # If the block is not empty, continue processing. if block.strip() != "": - # Split block's title and content and trim them. + # Split the block's title and content and trim whitespace. block_title, block_content = block.split("\n", 1) block_dict[block_title.strip()] = block_content.strip() @@ -160,7 +161,7 @@ class CodeParser: if match: code = match.group(1) else: - logger.error(f"{pattern} did not match the following text:") + logger.error(f"{pattern} not match following text:") logger.error(text) raise Exception return code @@ -213,14 +214,4 @@ def print_members(module, indent=0): prefix = ' ' * indent for name, obj in inspect.getmembers(module): print(name, obj) - if inspect.isclass(obj): - print(f'{prefix}Class: {name}') - # print the methods within the class - if name in ['__class__', '__base__']: - continue - print_members(obj, indent + 2) - elif inspect.isfunction(obj): - print(f'{prefix}Function: {name}') - elif inspect.ismethod(obj): - print(f'{prefix}Method: {name}') - + if inspect