add instance filter

This commit is contained in:
stellahsr 2024-03-19 20:07:18 +08:00
parent d2b34f5897
commit fc23e8f27e

View file

@ -15,6 +15,7 @@ from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.utils import count_string_tokens
from metagpt.utils.recovery_util import save_history
from data.inference.const import SCIKIT_LEARN_IDS
# Replace with your own
MAX_TOKEN = 128000
@ -70,6 +71,7 @@ async def openai_inference(
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
di = DataInterpreter(use_reflection=use_reflection)
instance_id = datum["instance_id"]
if instance_id in existing_ids:
continue
output_dict = {"instance_id": instance_id}
@ -124,12 +126,19 @@ async def main(
dataset = dataset[split]
lens = np.array(list(map(len, dataset["text"])))
dataset = dataset.select(np.argsort(lens))
if len(existing_ids) > 0:
dataset = dataset.filter(
lambda x: x["instance_id"] not in existing_ids,
desc="Filtering out existing ids",
load_from_cache_file=False,
)
if len(SCIKIT_LEARN_IDS) > 0:
dataset = dataset.filter(
lambda x: x["instance_id"] in SCIKIT_LEARN_IDS,
desc="Filtering out subset_instance_ids",
load_from_cache_file=False,
)
inference_args = {
"test_dataset": dataset,
"model_name_or_path": model_name_or_path,