mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-02 04:12:45 +02:00
Merge branch 'geekan:main' into main
This commit is contained in:
commit
9e4e32e7c7
33 changed files with 321 additions and 140 deletions
|
|
@ -3,7 +3,7 @@
|
|||
{
|
||||
"name": "Python 3",
|
||||
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
|
||||
"image": "mcr.microsoft.com/devcontainers/python:0-3.11",
|
||||
"image": "metagpt/metagpt:latest",
|
||||
|
||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||
// "features": {},
|
||||
|
|
@ -18,7 +18,7 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
// Use 'postCreateCommand' to run commands after the container is created.
|
||||
"postCreateCommand": "./.devcontainer/postCreateCommand.sh"
|
||||
|
||||
|
|
|
|||
|
|
@ -85,8 +85,8 @@ # Check https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html
|
|||
|
||||
```yaml
|
||||
llm:
|
||||
api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options
|
||||
model: "gpt-4-turbo" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
|
||||
api_type: "openai" # or azure / ollama / groq etc. Check LLMType for more options
|
||||
model: "gpt-4-turbo" # or gpt-3.5-turbo
|
||||
base_url: "https://api.openai.com/v1" # or forward url / other llm url
|
||||
api_key: "YOUR_API_KEY"
|
||||
```
|
||||
|
|
@ -107,7 +107,7 @@ ### Usage
|
|||
print(repo) # it will print the repo structure with files
|
||||
```
|
||||
|
||||
You can also use its [Data Interpreter](https://github.com/geekan/MetaGPT/tree/main/examples/di)
|
||||
You can also use [Data Interpreter](https://github.com/geekan/MetaGPT/tree/main/examples/di) to write code:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
|
|
|
|||
|
|
@ -1,27 +1,23 @@
|
|||
llm:
|
||||
api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options
|
||||
api_type: "openai" # or azure / ollama / groq etc.
|
||||
base_url: "YOUR_BASE_URL"
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
|
||||
model: "gpt-4-turbo" # or gpt-3.5-turbo
|
||||
proxy: "YOUR_PROXY" # for LLM API requests
|
||||
# timeout: 600 # Optional. If set to 0, default value is 300.
|
||||
pricing_plan: "" # Optional. If invalid, it will be automatically filled in with the value of the `model`.
|
||||
# Azure-exclusive pricing plan mappings:
|
||||
# - gpt-3.5-turbo 4k: "gpt-3.5-turbo-1106"
|
||||
# - gpt-4-turbo: "gpt-4-turbo-preview"
|
||||
# - gpt-4-turbo-vision: "gpt-4-vision-preview"
|
||||
# - gpt-4 8k: "gpt-4"
|
||||
# See for more: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
|
||||
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
|
||||
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's
|
||||
|
||||
|
||||
# RAG Embedding.
|
||||
# For backward compatibility, if the embedding is not set and the llm's api_type is either openai or azure, the llm's config will be used.
|
||||
embedding:
|
||||
api_type: "" # openai / azure / gemini / ollama etc. Check EmbeddingType for more options.
|
||||
base_url: ""
|
||||
api_key: ""
|
||||
model: ""
|
||||
api_version: ""
|
||||
embed_batch_size: 100
|
||||
api_type: "" # openai / azure / gemini / ollama etc. Check EmbeddingType for more options.
|
||||
base_url: ""
|
||||
api_key: ""
|
||||
model: ""
|
||||
api_version: ""
|
||||
embed_batch_size: 100
|
||||
|
||||
repair_llm_output: true # when the output is not a valid json, try to repair it
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py
|
||||
# Config Docs: https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html
|
||||
llm:
|
||||
api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options
|
||||
model: "gpt-4-turbo" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
|
||||
api_type: "openai" # or azure / ollama / groq etc.
|
||||
model: "gpt-4-turbo" # or gpt-3.5-turbo
|
||||
base_url: "https://api.openai.com/v1" # or forward url / other llm url
|
||||
api_key: "YOUR_API_KEY"
|
||||
5
config/examples/anthropic-claude-3-opus.yaml
Normal file
5
config/examples/anthropic-claude-3-opus.yaml
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
llm:
|
||||
api_type: 'claude' # or anthropic
|
||||
base_url: 'https://api.anthropic.com'
|
||||
api_key: 'YOUR_API_KEY'
|
||||
model: 'claude-3-opus-20240229'
|
||||
4
config/examples/google-gemini.yaml
Normal file
4
config/examples/google-gemini.yaml
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
llm:
|
||||
api_type: 'gemini'
|
||||
api_key: 'YOUR_API_KEY'
|
||||
model: 'gemini-pro'
|
||||
5
config/examples/groq-llama3-70b.yaml
Normal file
5
config/examples/groq-llama3-70b.yaml
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
llm:
|
||||
# Visit https://console.groq.com/keys to create api key
|
||||
base_url: "https://api.groq.com/openai/v1"
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: "llama3-70b-8192" # llama3-8b-8192,llama3-70b-8192,llama2-70b-4096 ,mixtral-8x7b-32768,gemma-7b-it
|
||||
5
config/examples/openai-gpt-3.5-turbo.yaml
Normal file
5
config/examples/openai-gpt-3.5-turbo.yaml
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
llm:
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: "gpt-3.5-turbo"
|
||||
#proxy: "http://<ip>:<port>"
|
||||
#base_url: "https://<forward_url>/v1"
|
||||
6
config/examples/openai-gpt-4-turbo.yaml
Normal file
6
config/examples/openai-gpt-4-turbo.yaml
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
llm:
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: "gpt-4-turbo"
|
||||
#proxy: "http://<ip>:<port>"
|
||||
#base_url: "https://<forward_url>/v1"
|
||||
|
||||
5
config/examples/openrouter-llama3-70b-instruct.yaml
Normal file
5
config/examples/openrouter-llama3-70b-instruct.yaml
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
llm:
|
||||
api_type: openrouter
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: meta-llama/llama-3-70b-instruct
|
||||
|
|
@ -38,9 +38,9 @@ ### Chief Evangelist (Monthly Rotation)
|
|||
### FAQ
|
||||
|
||||
1. Code truncation/ Parsing failure:
|
||||
1. Check if it's due to exceeding length. Consider using the gpt-4-turbo-preview or other long token versions.
|
||||
1. Check if it's due to exceeding length. Consider using the gpt-4-turbo or other long token versions.
|
||||
2. Success rate:
|
||||
1. There hasn't been a quantitative analysis yet, but the success rate of code generated by gpt-4-turbo-preview is significantly higher than that of gpt-3.5-turbo.
|
||||
1. There hasn't been a quantitative analysis yet, but the success rate of code generated by gpt-4-turbo is significantly higher than that of gpt-3.5-turbo.
|
||||
3. Support for incremental, differential updates (if you wish to continue a half-done task):
|
||||
1. There is now an experimental version. Specify `--inc --project-path "<path>"` or `--inc --project-name "<name>"` on the command line and enter the corresponding requirements to try it.
|
||||
4. Can existing code be loaded?
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ from metagpt.roles import Role
|
|||
from metagpt.team import Team
|
||||
|
||||
gpt35 = Config.default()
|
||||
gpt35.llm.model = "gpt-3.5-turbo-1106"
|
||||
gpt35.llm.model = "gpt-3.5-turbo"
|
||||
gpt4 = Config.default()
|
||||
gpt4.llm.model = "gpt-4-1106-preview"
|
||||
gpt4.llm.model = "gpt-4-turbo"
|
||||
action1 = Action(config=gpt4, name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
action2 = Action(config=gpt35, name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2])
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
@Time : 2023/5/6 14:13
|
||||
@Author : alexanderwu
|
||||
@File : llm_hello_world.py
|
||||
@File : hello_world.py
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
|
|
@ -11,20 +11,15 @@ from metagpt.llm import LLM
|
|||
from metagpt.logs import logger
|
||||
|
||||
|
||||
async def main():
|
||||
llm = LLM()
|
||||
# llm type check
|
||||
question = "what's your name"
|
||||
logger.info(f"{question}: ")
|
||||
logger.info(await llm.aask(question))
|
||||
logger.info("\n\n")
|
||||
async def ask_and_print(question: str, llm: LLM, system_prompt) -> str:
|
||||
logger.info(f"Q: {question}")
|
||||
rsp = await llm.aask(question, system_msgs=[system_prompt])
|
||||
logger.info(f"A: {rsp}")
|
||||
return rsp
|
||||
|
||||
logger.info(
|
||||
await llm.aask(
|
||||
"who are you", system_msgs=["act as a robot, just answer 'I'am robot' if the question is 'who are you'"]
|
||||
)
|
||||
)
|
||||
|
||||
async def lowlevel_api_example(llm: LLM):
|
||||
logger.info("low level api example")
|
||||
logger.info(await llm.aask_batch(["hi", "write python hello world."]))
|
||||
|
||||
hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}]
|
||||
|
|
@ -39,5 +34,12 @@ async def main():
|
|||
logger.info(llm.completion(hello_msg))
|
||||
|
||||
|
||||
async def main():
|
||||
llm = LLM()
|
||||
await ask_and_print("what's your name?", llm, "I'm a helpful AI assistant.")
|
||||
await ask_and_print("who are you?", llm, "just answer 'I am a robot' if the question is 'who are you'")
|
||||
await lowlevel_api_example(llm)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
29
examples/ping.py
Normal file
29
examples/ping.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/4/22 14:28
|
||||
@Author : alexanderwu
|
||||
@File : ping.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
async def ask_and_print(question: str, llm: LLM, system_prompt) -> str:
|
||||
logger.info(f"Q: {question}")
|
||||
rsp = await llm.aask(question, system_msgs=[system_prompt])
|
||||
logger.info(f"A: {rsp}")
|
||||
logger.info("\n")
|
||||
return rsp
|
||||
|
||||
|
||||
async def main():
|
||||
llm = LLM()
|
||||
await ask_and_print("ping?", llm, "Just answer pong when ping.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -40,7 +40,10 @@ class Player(BaseModel):
|
|||
|
||||
|
||||
class RAGExample:
|
||||
"""Show how to use RAG."""
|
||||
"""Show how to use RAG.
|
||||
|
||||
Default engine use LLM Reranker, if the answer from the LLM is incorrect, may encounter `IndexError: list index out of range`.
|
||||
"""
|
||||
|
||||
def __init__(self, engine: SimpleEngine = None):
|
||||
self._engine = engine
|
||||
|
|
@ -59,6 +62,7 @@ class RAGExample:
|
|||
def engine(self, value: SimpleEngine):
|
||||
self._engine = value
|
||||
|
||||
@handle_exception
|
||||
async def run_pipeline(self, question=QUESTION, print_title=True):
|
||||
"""This example run rag pipeline, use faiss retriever and llm ranker, will print something like:
|
||||
|
||||
|
|
@ -79,6 +83,7 @@ class RAGExample:
|
|||
answer = await self.engine.aquery(question)
|
||||
self._print_query_result(answer)
|
||||
|
||||
@handle_exception
|
||||
async def add_docs(self):
|
||||
"""This example show how to add docs.
|
||||
|
||||
|
|
@ -148,6 +153,7 @@ class RAGExample:
|
|||
except Exception as e:
|
||||
logger.error(f"nodes is empty, llm don't answer correctly, exception: {e}")
|
||||
|
||||
@handle_exception
|
||||
async def init_objects(self):
|
||||
"""This example show how to from objs, will print something like:
|
||||
|
||||
|
|
@ -160,6 +166,7 @@ class RAGExample:
|
|||
await self.add_objects(print_title=False)
|
||||
self.engine = pre_engine
|
||||
|
||||
@handle_exception
|
||||
async def init_and_query_chromadb(self):
|
||||
"""This example show how to use chromadb. how to save and load index. will print something like:
|
||||
|
||||
|
|
@ -233,7 +240,7 @@ class RAGExample:
|
|||
|
||||
|
||||
async def main():
|
||||
"""RAG pipeline"""
|
||||
"""RAG pipeline."""
|
||||
e = RAGExample()
|
||||
await e.run_pipeline()
|
||||
await e.add_docs()
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ class WriteCode(Action):
|
|||
if not task_doc.content:
|
||||
task_doc = project_repo.docs.task.get(filename=task_doc.filename)
|
||||
m = json.loads(task_doc.content)
|
||||
code_filenames = m.get(TASK_LIST.key, []) if use_inc else m.get(REFINED_TASK_LIST.key, [])
|
||||
code_filenames = m.get(TASK_LIST.key, []) if not use_inc else m.get(REFINED_TASK_LIST.key, [])
|
||||
codes = []
|
||||
src_file_repo = project_repo.srcs
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
# Context
|
||||
{context}
|
||||
|
||||
-----
|
||||
|
||||
## Code to be Reviewed: {filename}
|
||||
```Code
|
||||
{code}
|
||||
|
|
@ -38,7 +40,8 @@ EXAMPLE_AND_INSTRUCTION = """
|
|||
{format_example}
|
||||
|
||||
|
||||
# Instruction: Based on the actual code situation, follow one of the "Format example". Return only 1 file under review.
|
||||
# Instruction: Based on the actual code, follow one of the "Code Review Format example".
|
||||
- Note the code filename should be `{filename}`. Return the only ONE file `{filename}` under review.
|
||||
|
||||
## Code Review: Ordered List. Based on the "Code to be Reviewed", provide key, clear, concise, and specific answer. If any answer is no, explain how to fix it step by step.
|
||||
1. Is the code implemented as per the requirements? If not, how to achieve it? Analyse it step by step.
|
||||
|
|
@ -56,7 +59,9 @@ LGTM/LBTM
|
|||
"""
|
||||
|
||||
FORMAT_EXAMPLE = """
|
||||
# Format example 1
|
||||
-----
|
||||
|
||||
# Code Review Format example 1
|
||||
## Code Review: {filename}
|
||||
1. No, we should fix the logic of class A due to ...
|
||||
2. ...
|
||||
|
|
@ -92,7 +97,9 @@ FORMAT_EXAMPLE = """
|
|||
## Code Review Result
|
||||
LBTM
|
||||
|
||||
# Format example 2
|
||||
-----
|
||||
|
||||
# Code Review Format example 2
|
||||
## Code Review: {filename}
|
||||
1. Yes.
|
||||
2. Yes.
|
||||
|
|
@ -106,10 +113,12 @@ pass
|
|||
|
||||
## Code Review Result
|
||||
LGTM
|
||||
|
||||
-----
|
||||
"""
|
||||
|
||||
REWRITE_CODE_TEMPLATE = """
|
||||
# Instruction: rewrite code based on the Code Review and Actions
|
||||
# Instruction: rewrite the `{filename}` based on the Code Review and Actions
|
||||
## Rewrite Code: CodeBlock. If it still has some bugs, rewrite {filename} with triple quotes. Do your utmost to optimize THIS SINGLE FILE. Return all completed codes and prohibit the return of unfinished codes.
|
||||
```Code
|
||||
## {filename}
|
||||
|
|
@ -169,6 +178,7 @@ class WriteCodeReview(Action):
|
|||
)
|
||||
cr_prompt = EXAMPLE_AND_INSTRUCTION.format(
|
||||
format_example=format_example,
|
||||
filename=self.i_context.code_doc.filename,
|
||||
)
|
||||
len1 = len(iterative_code) if iterative_code else 0
|
||||
len2 = len(self.i_context.code_doc.content) if self.i_context.code_doc.content else 0
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import json
|
|||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
|
|
@ -63,7 +63,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
response_synthesizer: Optional[BaseSynthesizer] = None,
|
||||
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
index: Optional[BaseIndex] = None,
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
retriever=retriever,
|
||||
|
|
@ -71,7 +71,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
node_postprocessors=node_postprocessors,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
self.index = index
|
||||
self._transformations = transformations or self._default_transformations()
|
||||
|
||||
@classmethod
|
||||
def from_docs(
|
||||
|
|
@ -103,12 +103,17 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
cls._fix_document_metadata(documents)
|
||||
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
transformations = transformations or cls._default_transformations()
|
||||
nodes = run_transformations(documents, transformations=transformations)
|
||||
|
||||
return cls._from_nodes(
|
||||
nodes=nodes,
|
||||
transformations=transformations,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_objs(
|
||||
|
|
@ -137,12 +142,15 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
|
||||
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
index = VectorStoreIndex(
|
||||
|
||||
return cls._from_nodes(
|
||||
nodes=nodes,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
transformations=transformations,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_index(
|
||||
|
|
@ -183,7 +191,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
self._fix_document_metadata(documents)
|
||||
|
||||
nodes = run_transformations(documents, transformations=self.index._transformations)
|
||||
nodes = run_transformations(documents, transformations=self._transformations)
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def add_objs(self, objs: list[RAGObject]):
|
||||
|
|
@ -199,6 +207,29 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
|
||||
self._persist(str(persist_dir), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _from_nodes(
|
||||
cls,
|
||||
nodes: list[BaseNode],
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
|
||||
llm = llm or get_rag_llm()
|
||||
|
||||
retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
|
||||
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
transformations=transformations,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_index(
|
||||
cls,
|
||||
|
|
@ -208,6 +239,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
llm = llm or get_rag_llm()
|
||||
|
||||
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
|
||||
|
||||
|
|
@ -215,7 +247,6 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
index=index,
|
||||
)
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
|
|
@ -266,3 +297,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
return embed_model or get_rag_embedding()
|
||||
|
||||
@staticmethod
|
||||
def _default_transformations():
|
||||
return [SentenceSplitter()]
|
||||
|
|
|
|||
|
|
@ -36,19 +36,26 @@ class ConfigBasedFactory(GenericFactory):
|
|||
"""Designed to get objects based on object type."""
|
||||
|
||||
def get_instance(self, key: Any, **kwargs) -> Any:
|
||||
"""Key is config, such as a pydantic model.
|
||||
"""Get instance by the type of key.
|
||||
|
||||
Call func by the type of key, and the key will be passed to func.
|
||||
Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func.
|
||||
Raise Exception if key not found.
|
||||
"""
|
||||
creator = self._creators.get(type(key))
|
||||
if creator:
|
||||
return creator(key, **kwargs)
|
||||
|
||||
self._raise_for_key(key)
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"Unknown config: `{type(key)}`, {key}")
|
||||
|
||||
@staticmethod
|
||||
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
|
||||
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
|
||||
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.
|
||||
|
||||
Return None if not found.
|
||||
"""
|
||||
if config is not None and hasattr(config, key):
|
||||
val = getattr(config, key)
|
||||
if val is not None:
|
||||
|
|
@ -57,6 +64,4 @@ class ConfigBasedFactory(GenericFactory):
|
|||
if key in kwargs:
|
||||
return kwargs[key]
|
||||
|
||||
raise KeyError(
|
||||
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
"""RAG Retriever Factory."""
|
||||
|
||||
import copy
|
||||
|
||||
from functools import wraps
|
||||
|
||||
import chromadb
|
||||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
|
|
@ -24,10 +27,25 @@ from metagpt.rag.schema import (
|
|||
ElasticsearchKeywordRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
IndexRetrieverConfig,
|
||||
)
|
||||
|
||||
|
||||
def get_or_build_index(build_index_func):
|
||||
"""Decorator to get or build an index.
|
||||
|
||||
Get index using `_extract_index` method, if not found, using build_index_func.
|
||||
"""
|
||||
|
||||
@wraps(build_index_func)
|
||||
def wrapper(self, config, **kwargs):
|
||||
index = self._extract_index(config, **kwargs)
|
||||
if index is not None:
|
||||
return index
|
||||
return build_index_func(self, config, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RetrieverFactory(ConfigBasedFactory):
|
||||
"""Modify creators for dynamically instance implementation."""
|
||||
|
||||
|
|
@ -54,48 +72,79 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
|
||||
|
||||
def _create_default(self, **kwargs) -> RAGRetriever:
|
||||
return self._extract_index(**kwargs).as_retriever()
|
||||
index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs)
|
||||
|
||||
return index.as_retriever()
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
config.index = self._build_faiss_index(config, **kwargs)
|
||||
|
||||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
|
||||
index = self._extract_index(config, **kwargs)
|
||||
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
|
||||
|
||||
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
|
||||
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
config.index = self._build_chroma_index(config, **kwargs)
|
||||
|
||||
return ChromaRetriever(**config.model_dump())
|
||||
|
||||
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
config.index = self._build_es_index(config, **kwargs)
|
||||
|
||||
return ElasticsearchRetriever(**config.model_dump())
|
||||
|
||||
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
|
||||
return self._val_from_config_or_kwargs("index", config, **kwargs)
|
||||
|
||||
def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]:
|
||||
return self._val_from_config_or_kwargs("nodes", config, **kwargs)
|
||||
|
||||
def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding:
|
||||
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
|
||||
|
||||
def _build_default_index(self, **kwargs) -> VectorStoreIndex:
|
||||
index = VectorStoreIndex(
|
||||
nodes=self._extract_nodes(**kwargs),
|
||||
embed_model=self._extract_embed_model(**kwargs),
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
@get_or_build_index
|
||||
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
def _build_index_from_vector_store(
|
||||
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
|
||||
self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
old_index = self._extract_index(config, **kwargs)
|
||||
new_index = VectorStoreIndex(
|
||||
nodes=list(old_index.docstore.docs.values()),
|
||||
index = VectorStoreIndex(
|
||||
nodes=self._extract_nodes(config, **kwargs),
|
||||
storage_context=storage_context,
|
||||
embed_model=old_index._embed_model,
|
||||
embed_model=self._extract_embed_model(config, **kwargs),
|
||||
)
|
||||
return new_index
|
||||
|
||||
return index
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
|
|
|
|||
|
|
@ -406,7 +406,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
elif isinstance(response, Message):
|
||||
msg = response
|
||||
else:
|
||||
msg = Message(content=response, role=self.profile, cause_by=self.rc.todo, sent_from=self)
|
||||
msg = Message(content=response or "", role=self.profile, cause_by=self.rc.todo, sent_from=self)
|
||||
self.rc.memory.add(msg)
|
||||
|
||||
return msg
|
||||
|
|
|
|||
|
|
@ -123,9 +123,10 @@ def startup(
|
|||
|
||||
DEFAULT_CONFIG = """# Full Example: https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml
|
||||
# Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py
|
||||
# Config Docs: https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html
|
||||
llm:
|
||||
api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options
|
||||
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
|
||||
api_type: "openai" # or azure / ollama / groq etc.
|
||||
model: "gpt-4-turbo" # or gpt-3.5-turbo
|
||||
base_url: "https://api.openai.com/v1" # or forward url / other llm url
|
||||
api_key: "YOUR_API_KEY"
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class GitRepository:
|
|||
self._repository = Repo.init(path=Path(local_path))
|
||||
|
||||
gitignore_filename = Path(local_path) / ".gitignore"
|
||||
ignores = ["__pycache__", "*.pyc"]
|
||||
ignores = ["__pycache__", "*.pyc", ".vs"]
|
||||
with open(str(gitignore_filename), mode="w") as writer:
|
||||
writer.write("\n".join(ignores))
|
||||
self._repository.index.add([".gitignore"])
|
||||
|
|
|
|||
|
|
@ -32,9 +32,9 @@ TOKEN_COSTS = {
|
|||
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator
|
||||
"gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
|
|
@ -57,6 +57,8 @@ TOKEN_COSTS = {
|
|||
"yi-34b-chat-0205": {"prompt": 0.0003, "completion": 0.0003},
|
||||
"yi-34b-chat-200k": {"prompt": 0.0017, "completion": 0.0017},
|
||||
"microsoft/wizardlm-2-8x22b": {"prompt": 0.00108, "completion": 0.00108}, # for openrouter, start
|
||||
"meta-llama/llama-3-70b-instruct": {"prompt": 0.008, "completion": 0.008},
|
||||
"llama3-70b-8192": {"prompt": 0.0059, "completion": 0.0079},
|
||||
"openai/gpt-3.5-turbo-0125": {"prompt": 0.0005, "completion": 0.0015},
|
||||
"openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
}
|
||||
|
|
@ -155,8 +157,8 @@ FIREWORKS_GRADE_TOKEN_COSTS = {
|
|||
TOKEN_MAX = {
|
||||
"gpt-4-0125-preview": 128000,
|
||||
"gpt-4-turbo-preview": 128000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4-vision-preview": 128000,
|
||||
"gpt-4-1106-vision-preview": 128000,
|
||||
"gpt-4": 8192,
|
||||
|
|
@ -190,6 +192,8 @@ TOKEN_MAX = {
|
|||
"yi-34b-chat-0205": 4000,
|
||||
"yi-34b-chat-200k": 200000,
|
||||
"microsoft/wizardlm-2-8x22b": 65536,
|
||||
"meta-llama/llama-3-70b-instruct": 8192,
|
||||
"llama3-70b-8192": 8192,
|
||||
"openai/gpt-3.5-turbo-0125": 16385,
|
||||
"openai/gpt-4-turbo-preview": 128000,
|
||||
}
|
||||
|
|
@ -217,7 +221,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
|
|||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4-1106-vision-preview",
|
||||
}:
|
||||
|
|
|
|||
|
|
@ -12,12 +12,12 @@ typer==0.9.0
|
|||
lancedb==0.4.0
|
||||
loguru==0.6.0
|
||||
meilisearch==0.21.0
|
||||
numpy==1.24.3
|
||||
openai==1.6.1
|
||||
numpy>=1.24.3
|
||||
openai>=1.6.1
|
||||
openpyxl
|
||||
beautifulsoup4==4.12.3
|
||||
pandas==2.1.1
|
||||
pydantic==2.5.3
|
||||
pydantic>=2.5.3
|
||||
#pygame==2.1.3
|
||||
#pymilvus==2.2.8
|
||||
# pytest==7.2.2 # test extras require
|
||||
|
|
@ -58,7 +58,7 @@ typing-extensions==4.9.0
|
|||
socksio~=1.0.0
|
||||
gitignore-parser==0.1.9
|
||||
# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py
|
||||
websockets~=11.0
|
||||
websockets>=10.0,<12.0
|
||||
networkx~=3.2.1
|
||||
google-generativeai==0.4.1
|
||||
playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -74,7 +74,7 @@ extras_require["dev"] = (["pylint~=3.0.3", "black~=23.3.0", "isort~=5.12.0", "pr
|
|||
|
||||
setup(
|
||||
name="metagpt",
|
||||
version="0.8.0",
|
||||
version="0.8.1",
|
||||
description="The Multi-Agent Framework",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
llm:
|
||||
base_url: "https://api.openai.com/v1"
|
||||
api_key: "sk-xxx"
|
||||
model: "gpt-3.5-turbo-1106"
|
||||
model: "gpt-3.5-turbo"
|
||||
|
||||
search:
|
||||
api_type: "serpapi"
|
||||
|
|
|
|||
|
|
@ -25,10 +25,6 @@ class TestSimpleEngine:
|
|||
def mock_simple_directory_reader(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_retriever(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
|
||||
|
|
@ -45,7 +41,6 @@ class TestSimpleEngine:
|
|||
self,
|
||||
mocker,
|
||||
mock_simple_directory_reader,
|
||||
mock_vector_store_index,
|
||||
mock_get_retriever,
|
||||
mock_get_rankers,
|
||||
mock_get_response_synthesizer,
|
||||
|
|
@ -81,11 +76,8 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_vector_store_index.assert_called_once()
|
||||
mock_get_retriever.assert_called_once_with(
|
||||
configs=retriever_configs, index=mock_vector_store_index.return_value
|
||||
)
|
||||
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
|
||||
mock_get_retriever.assert_called_once()
|
||||
mock_get_rankers.assert_called_once()
|
||||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
|
|
@ -119,7 +111,7 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is not None
|
||||
assert engine._transformations is not None
|
||||
|
||||
def test_from_objs_with_bm25_config(self):
|
||||
# Setup
|
||||
|
|
@ -137,6 +129,7 @@ class TestSimpleEngine:
|
|||
def test_from_index(self, mocker, mock_llm, mock_embedding):
|
||||
# Mock
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index.as_retriever.return_value = "retriever"
|
||||
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
|
||||
mock_get_index.return_value = mock_index
|
||||
|
||||
|
|
@ -149,7 +142,7 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is mock_index
|
||||
assert engine._retriever == "retriever"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, mocker):
|
||||
|
|
@ -200,14 +193,11 @@ class TestSimpleEngine:
|
|||
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index._transformations = mocker.MagicMock()
|
||||
|
||||
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
|
||||
mock_run_transformations.return_value = ["node1", "node2"]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
|
||||
# Exec
|
||||
|
|
@ -230,7 +220,7 @@ class TestSimpleEngine:
|
|||
return ""
|
||||
|
||||
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
|
||||
# Exec
|
||||
engine.add_objs(objs=objs)
|
||||
|
|
|
|||
|
|
@ -97,6 +97,5 @@ class TestConfigBasedFactory:
|
|||
def test_val_from_config_or_kwargs_key_error(self):
|
||||
# Test KeyError when the key is not found in both config object and kwargs
|
||||
config = DummyConfig(name=None)
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
|
||||
val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert val is None
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import faiss
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.embeddings import MockEmbedding
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
|
||||
|
|
@ -43,6 +45,14 @@ class TestRetrieverFactory:
|
|||
def mock_es_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ElasticsearchStore)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_nodes(self, mocker):
|
||||
return [TextNode(text="msg")]
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding(self):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
|
||||
|
|
@ -52,42 +62,40 @@ class TestRetrieverFactory:
|
|||
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_nodes):
|
||||
mock_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes)
|
||||
|
||||
assert isinstance(retriever, DynamicBM25Retriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=1)
|
||||
mock_bm25_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
retriever = self.retriever_factory.get_retriever(
|
||||
configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding
|
||||
)
|
||||
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding):
|
||||
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
|
||||
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
|
||||
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
|
||||
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
|
||||
|
||||
assert isinstance(retriever, ChromaRetriever)
|
||||
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
|
||||
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
|
||||
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
|
||||
|
||||
assert isinstance(retriever, ElasticsearchRetriever)
|
||||
|
||||
|
|
@ -111,3 +119,19 @@ class TestRetrieverFactory:
|
|||
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
||||
def test_get_or_build_when_get(self, mocker):
|
||||
want = "existing_index"
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want)
|
||||
|
||||
got = self.retriever_factory._build_es_index(None)
|
||||
|
||||
assert got == want
|
||||
|
||||
def test_get_or_build_when_build(self, mocker):
|
||||
want = "call_build_es_index"
|
||||
mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want)
|
||||
|
||||
got = self.retriever_factory._build_es_index(None)
|
||||
|
||||
assert got == want
|
||||
|
|
|
|||
|
|
@ -105,11 +105,11 @@ def test_config_mixin_4_multi_inheritance_override_config():
|
|||
async def test_config_priority():
|
||||
"""If action's config is set, then its llm will be set, otherwise, it will use the role's llm"""
|
||||
home_dir = Path.home() / CONFIG_ROOT
|
||||
gpt4t = Config.from_home("gpt-4-1106-preview.yaml")
|
||||
gpt4t = Config.from_home("gpt-4-turbo.yaml")
|
||||
if not home_dir.exists():
|
||||
assert gpt4t is None
|
||||
gpt35 = Config.default()
|
||||
gpt35.llm.model = "gpt-3.5-turbo-1106"
|
||||
gpt35.llm.model = "gpt-4-turbo"
|
||||
gpt4 = Config.default()
|
||||
gpt4.llm.model = "gpt-4-0613"
|
||||
|
||||
|
|
@ -127,8 +127,8 @@ async def test_config_priority():
|
|||
env = Environment(desc="US election live broadcast")
|
||||
Team(investment=10.0, env=env, roles=[A, B, C])
|
||||
|
||||
assert a1.llm.model == "gpt-4-1106-preview" if Path(home_dir / "gpt-4-1106-preview.yaml").exists() else "gpt-4-0613"
|
||||
assert a1.llm.model == "gpt-4-turbo" if Path(home_dir / "gpt-4-turbo.yaml").exists() else "gpt-4-0613"
|
||||
assert a2.llm.model == "gpt-4-0613"
|
||||
assert a3.llm.model == "gpt-3.5-turbo-1106"
|
||||
assert a3.llm.model == "gpt-4-turbo"
|
||||
|
||||
# history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="a1", n_round=3)
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class TestUTWriter:
|
|||
)
|
||||
],
|
||||
created=1706710532,
|
||||
model="gpt-3.5-turbo-1106",
|
||||
model="gpt-4-turbo",
|
||||
object="chat.completion",
|
||||
system_fingerprint="fp_04f9a1eebf",
|
||||
usage=CompletionUsage(completion_tokens=35, prompt_tokens=1982, total_tokens=2017),
|
||||
|
|
|
|||
|
|
@ -12,11 +12,11 @@ from metagpt.utils.cost_manager import CostManager
|
|||
|
||||
def test_cost_manager():
|
||||
cm = CostManager(total_budget=20)
|
||||
cm.update_cost(prompt_tokens=1000, completion_tokens=100, model="gpt-4-1106-preview")
|
||||
cm.update_cost(prompt_tokens=1000, completion_tokens=100, model="gpt-4-turbo")
|
||||
assert cm.get_total_prompt_tokens() == 1000
|
||||
assert cm.get_total_completion_tokens() == 100
|
||||
assert cm.get_total_cost() == 0.013
|
||||
cm.update_cost(prompt_tokens=100, completion_tokens=10, model="gpt-4-1106-preview")
|
||||
cm.update_cost(prompt_tokens=100, completion_tokens=10, model="gpt-4-turbo")
|
||||
assert cm.get_total_prompt_tokens() == 1100
|
||||
assert cm.get_total_completion_tokens() == 110
|
||||
assert cm.get_total_cost() == 0.0143
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue