mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-28 10:26:32 +02:00
add instance filter
This commit is contained in:
parent
d2b34f5897
commit
fc23e8f27e
1 changed files with 9 additions and 0 deletions
|
|
@ -15,6 +15,7 @@ from metagpt.logs import logger
|
|||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
from metagpt.utils import count_string_tokens
|
||||
from metagpt.utils.recovery_util import save_history
|
||||
from data.inference.const import SCIKIT_LEARN_IDS
|
||||
|
||||
# Replace with your own
|
||||
MAX_TOKEN = 128000
|
||||
|
|
@ -70,6 +71,7 @@ async def openai_inference(
|
|||
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
|
||||
di = DataInterpreter(use_reflection=use_reflection)
|
||||
instance_id = datum["instance_id"]
|
||||
|
||||
if instance_id in existing_ids:
|
||||
continue
|
||||
output_dict = {"instance_id": instance_id}
|
||||
|
|
@ -124,12 +126,19 @@ async def main(
|
|||
dataset = dataset[split]
|
||||
lens = np.array(list(map(len, dataset["text"])))
|
||||
dataset = dataset.select(np.argsort(lens))
|
||||
|
||||
if len(existing_ids) > 0:
|
||||
dataset = dataset.filter(
|
||||
lambda x: x["instance_id"] not in existing_ids,
|
||||
desc="Filtering out existing ids",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
if len(SCIKIT_LEARN_IDS) > 0:
|
||||
dataset = dataset.filter(
|
||||
lambda x: x["instance_id"] in SCIKIT_LEARN_IDS,
|
||||
desc="Filtering out subset_instance_ids",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
inference_args = {
|
||||
"test_dataset": dataset,
|
||||
"model_name_or_path": model_name_or_path,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue