Merge pull request #1175 from better629/feat_werewolf

Feat add werewolf game
This commit is contained in:
Alexander Wu 2024-04-11 10:11:04 +08:00 committed by kithib
commit 5b37954b4b
82 changed files with 11456 additions and 172 deletions

View file

@ -0,0 +1,218 @@
"""
Filename: MetaGPT/examples/werewolf_game/evals/eval.py
Created Date: Oct 18, 2023
Updated Date: Oct 24, 2023
Author: [Aria](https://github.com/ariafyy)
Info: eval the Voting Accuracy Rate of non_werewolves and Vote Difficulity
"""
import glob
import os
import re
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
from utils import Utils
from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT
from metagpt.environment.werewolf.const import RoleType
class Vote:
"""Vote Evaluation"""
def __init__(self):
self.OUT_PATH = DEFAULT_WORKSPACE_ROOT / "outputs"
os.makedirs(self.OUT_PATH, exist_ok=True)
self.SUB_FOLDER_LIST = ["01-10", "11-20", "21-30"]
def _get_log_fileslist(self, IN_PATH) -> list[str]:
files_list = []
for SUB_FOLDER in self.SUB_FOLDER_LIST:
files_list.extend(glob.glob(str(IN_PATH / SUB_FOLDER / "*.txt")))
return files_list
def extract_votes_from_logs(self, files_list: list):
for in_logfile in tqdm(files_list):
SUB_FOLDER = (Path(in_logfile).parent).stem
out_txtfile = self.OUT_PATH / "# {0}_{1}.txt".format(SUB_FOLDER, Path(in_logfile).stem)
Utils().pick_vote_log(in_logfile, out_txtfile)
votefiles_list = Utils().get_file_list(self.OUT_PATH)
return votefiles_list
@staticmethod
def parse_vote_text2chunks(text: str):
"""
parse each game vote log into text chunks
one chunk example:
['Player1', 'Player2', 'Player3', 'Player5', 'Player6']. Say ONLY: I vote to eliminate ...
Player1(Witch): 49 | I vote to eliminate Player5
Player2(Villager): 49 | I vote to eliminate Player5
Player3(Villager): 49 | I vote to eliminate Player5
Player5(Werewolf): 49 | I vote to eliminate Player6
Player6(Seer): 49 | I vote to eliminate Player5
"""
pattern = re.compile(r"""\[([^\]]+)\]. Say ONLY: I vote to eliminate ...""")
chunks = {}
chunk_id = 0
last_end = 0
for match in pattern.finditer(text):
start = match.start()
chunk = text[last_end:start]
chunks[f"vote_{chunk_id}"] = chunk.strip()
last_end = match.end()
chunk_id += 1
final_chunk = text[last_end:].strip()
if final_chunk:
chunks[f"vote_{chunk_id}"] = final_chunk
return chunks
def _vote_rate_players(self, text: str):
"""
# calculate the rate of goodteam vote werewolves
:example:
input:
['Player1', 'Player2', 'Player3', 'Player5', 'Player6']. Say ONLY: I vote to eliminate ...
Player1(Witch): 49 | I vote to eliminate Player5
Player2(Villager): 49 | I vote to eliminate Player5
Player3(Villager): 49 | I vote to eliminate Player5
Player5(Werewolf): 49 | I vote to eliminate Player6
Player6(Seer): 49 | I vote to eliminate Player5
output:
werewolves: ['Player5']
non_werewolves: ['Player1', 'Player2', 'Player3', 'Player6']
as you can see :Player2(Villager) and Player3(Villager) vote to eliminate Player5(Werewolf)
:return goodteam vote rateability: 100.00%
"""
pattern = re.compile(r"(\w+)\(([^\)]+)\): \d+ \| I vote to eliminate (\w+)")
# find all werewolves
werewolves = []
for match in pattern.finditer(text):
if match.group(2) == RoleType.WEREWOLF.value:
werewolves.append(match.group(1))
# find all non_werewolves
non_werewolves = []
for match in pattern.finditer(text):
if match.group(2) != RoleType.WEREWOLF.value:
non_werewolves.append(match.group(1))
num_non_werewolves = len(non_werewolves)
# count players other than werewolves made the correct votes
correct_votes = 0
for match in pattern.finditer(text):
if match.group(2) != RoleType.WEREWOLF.value and match.group(3) in werewolves:
correct_votes += 1
# cal the rateability of non_werewolves
rate = correct_votes / num_non_werewolves
good_vote_rate = round(rate, 2)
return {"good_vote_rate": good_vote_rate, "werewolves": werewolves, "non_werewolves": non_werewolves}
def get_goodteam_vote_rate(self, text: str) -> float:
goodteam_vote_rate = self._vote_rate_players(text)["good_vote_rate"]
return goodteam_vote_rate
def get_werewolves(self, text: str) -> list:
werewolves_list = self._vote_rate_players(text)["werewolves"]
return werewolves_list
def get_non_werewolves(self, text: str) -> list:
non_werewolves_list = self._vote_rate_players(text)["non_werewolves"]
return non_werewolves_list
def get_votewolf_difficulty(self, werewolves: list, non_werewolves: list) -> str:
num_living_wolfs = len(werewolves)
num_living_players = len(werewolves) + len(non_werewolves)
votewolf_difficulty = "_{0} / {1}".format(num_living_wolfs, num_living_players)
return votewolf_difficulty
def get_result_df(self, out_txtfile: str) -> pd.DataFrame:
"""
folder: sub folders for evals
file: evaluation file, each file represents one game
votes: the number of votes, eg. vote_1 represents the first vote of this game,
good_vote_rate:the rateability of a good person voting against a werewolf,
correct_votes / the total number of players other than werewolves
total_votes:the total number of votes cast
"""
with open(out_txtfile, "r") as out_file:
text = out_file.read()
chunks = self.parse_vote_text2chunks(text)
res = []
for k, v in chunks.items():
if v != "":
chunks_list = list(chunks.keys())
total_votes = len(chunks_list) - 1
werewolves = self.get_werewolves(v)
non_werewolves = self.get_non_werewolves(v)
good_vote_rate = self.get_goodteam_vote_rate(v)
votewolf_difficulty = self.get_votewolf_difficulty(werewolves, non_werewolves)
folder = Utils().filename_to_foldername(out_txtfile)
result = {
"folder": folder,
"file": Path(out_txtfile).stem + ".txt",
"vote_round": k,
"good_vote_rate": good_vote_rate,
"total_votes": total_votes,
"votewolf_difficulty": votewolf_difficulty,
}
res.append(result)
df = pd.DataFrame(res)
return df
def calc_avg_rate(self, IN_PATH) -> pd.DataFrame:
"""
get avg_rate for each game
avg_rate : the good_rate/total number of votes in the game
vote1_rate: First Round Voting Accuracy Rate
"""
infiles_list = self._get_log_fileslist(IN_PATH)
votefiles_list = self.extract_votes_from_logs(infiles_list)
df_list = [self._load_df_from_file(file) for file in votefiles_list]
combined_df = pd.concat(df_list, ignore_index=True)
# calculate the average good_vote_rate for each file
mean_rates = self._calculate_mean_rates(combined_df)
combined_df["avg_rate"] = combined_df["file"].map(mean_rates)
# calculate vote1 rate
vote1_rates = self._calc_vote1_rates(combined_df)
combined_df["vote1_rate"] = combined_df["folder"].map(vote1_rates.set_index("folder")["good_vote_rate"])
combined_df.loc[combined_df["vote_round"] != "vote_1", "vote1_rate"] = np.nan
combined_df["vote1_rate"] = combined_df["vote1_rate"].apply(self._format_rates)
combined_df["good_vote_rate"] = combined_df["good_vote_rate"].apply(self._format_rates)
combined_df["avg_rate"] = combined_df["avg_rate"].apply(self._format_rates)
combined_df.sort_values(["file"], ascending=True, inplace=True)
return combined_df
def _calc_vote1_rates(self, df):
df_vote1 = df[df["vote_round"] == "vote_1"]
vote1_rates = df_vote1.groupby("folder")["good_vote_rate"].mean().reset_index()
return vote1_rates
def _load_df_from_file(self, file):
return self.get_result_df(file)
def _calculate_mean_rates(self, df):
return df.groupby("file")["good_vote_rate"].mean()
def _format_rates(self, s):
return Utils().float_to_percent(s)
def get_eval_csv(self, IN_PATH, EVAL_RESULT):
"""
IN_PATH : parent folder of ["01-10", "11-20", "21-30"]
EVAL_RESULT : output csv file path
"""
combined_df = self.calc_avg_rate(IN_PATH)
combined_df.to_csv(EVAL_RESULT, index=False)
if __name__ == "__main__":
IN_PATH = METAGPT_ROOT / "examples/werewolf_game/evals"
EVAL_RESULT = DEFAULT_WORKSPACE_ROOT / "outputs" / "goodteam_vote_rate.csv"
Vote().get_eval_csv(IN_PATH, EVAL_RESULT)

View file

@ -0,0 +1,134 @@
"""
Filename: MetaGPT/examples/werewolf_game/evals/utils.py
Created Date: Oct 11, 2023
Revised Date: Oct 20, 2023
Author: [Aria](https://github.com/ariafyy)
"""
import glob
import os
import re
from pathlib import Path
from metagpt.const import METAGPT_ROOT
class Utils:
"""Utils: utils of logs"""
@staticmethod
def polish_log(in_logfile, out_txtfile):
"""polish logs for evaluation"""
pattern_text = r"(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3}) \| (\w+) +\| ([\w\.]+:\w+:\d+) - (.*\S)"
pattern_player = r"(Player(\d{1}): \w+)"
pattern_start = False
json_start = False
with open(in_logfile, "r") as f, open(out_txtfile, "w") as out:
for line in f.readlines():
matches = re.match(pattern_text, line)
if matches:
message = matches.group(4).strip()
pattern_start = True
json_start = False
if (
"Moderator(Moderator) ready to InstructSpeak" not in message
and "Moderator(Moderator) ready to ParseSpeak" not in message
and "Total running cost:" not in message
):
out.write("- " + message + "\n")
else:
out.write("\n")
elif pattern_start and not matches:
if "gpt-4 may update over time" in line:
line = ""
out.write(line)
elif line.strip().startswith("{"):
out.write(line.strip())
json_start = True
elif json_start and not line.strip().endswith("}"):
out.write(line.strip())
elif json_start and line.strip().endswith("}"):
out.write(line.strip())
json_start = False
elif (
line.startswith("(User):") or line.startswith("********** STEP:") or re.search(pattern_player, line)
):
out.write(line)
else:
out.write("\n")
@staticmethod
def pick_vote_log(in_logfile, out_txtfile):
"""
pick the vote log from the log file.
ready to AnnounceGameResult serves as the 'HINT_TEXT ' which indicates the end of the game.
based on bservation and reflection, then discuss is not in vote session.
"""
pattern_vote = r"(Player\d+)\(([A-Za-z]+)\): (\d+) \| (I vote to eliminate Player\d+)"
ignore_text = """reflection"""
HINT_TEXT = r"ready to AnnounceGameResult"
pattern_moderator = r"\[([^\]]+)\]\. Say ONLY: I vote to eliminate ..."
in_valid_block = False
with open(in_logfile, "r") as f:
lines = f.read()
split_lines = lines.split(HINT_TEXT)
if len(split_lines) < 2:
print(f"Key text :{HINT_TEXT} not found in {in_logfile}")
return
relevant_lines = split_lines[1].split("\n")
with open(out_txtfile, "w") as out:
for line in relevant_lines:
if re.search(pattern_moderator, line):
in_valid_block = True
out.write(line.lstrip() + "\n")
elif in_valid_block and re.search(pattern_vote, line):
out.write(line + "\n")
elif ignore_text in line:
in_valid_block = False
@staticmethod
def get_file_list(path: str) -> list:
file_pattern = os.path.join(path, "*.txt")
files_list = glob.glob(file_pattern)
return files_list
@staticmethod
def filename_to_foldername(out_txtfile: str):
"""
convert filename into its parent folder name
input:"....../# 01-10_10132100.txt"
output:# 01-10
"""
s = Path(out_txtfile).stem
pattern_folder = r"([^_]*)_"
match = re.match(pattern_folder, s)
if match:
folder = match.group(1)
return folder
@staticmethod
def float_to_percent(decimal: float) -> str:
"""
input: 1.00
output: 100.00%
"""
percent = decimal * 100
return f"{percent:.2f}%"
if __name__ == "__main__":
in_logfile = METAGPT_ROOT / "logs/log.txt"
out_txtfile = "input your wish path"
# Utils().polish_log(in_logfile, out_txtfile)
Utils().pick_vote_log(in_logfile, out_txtfile)

View file

@ -0,0 +1,66 @@
import asyncio
import fire
from metagpt.ext.werewolf.roles import Guard, Moderator, Seer, Villager, Werewolf, Witch
from metagpt.ext.werewolf.werewolf_game import WerewolfGame
from metagpt.logs import logger
async def start_game(
investment: float = 3.0,
n_round: int = 5,
shuffle: bool = True,
add_human: bool = False,
use_reflection: bool = True,
use_experience: bool = False,
use_memory_selection: bool = False,
new_experience_version: str = "",
):
game = WerewolfGame()
game_setup, players = game.env.init_game_setup(
role_uniq_objs=[Villager, Werewolf, Guard, Seer, Witch],
num_werewolf=2,
num_villager=2,
shuffle=shuffle,
add_human=add_human,
use_reflection=use_reflection,
use_experience=use_experience,
use_memory_selection=use_memory_selection,
new_experience_version=new_experience_version,
)
logger.info(f"{game_setup}")
players = [Moderator()] + players
game.hire(players)
game.invest(investment)
game.run_project(game_setup)
await game.run(n_round=n_round)
def main(
investment: float = 20.0,
n_round: int = 100,
shuffle: bool = True,
add_human: bool = False,
use_reflection: bool = True,
use_experience: bool = False,
use_memory_selection: bool = False,
new_experience_version: str = "",
):
asyncio.run(
start_game(
investment,
n_round,
shuffle,
add_human,
use_reflection,
use_experience,
use_memory_selection,
new_experience_version,
)
)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -6,7 +6,11 @@ import subprocess
from pathlib import Path
from typing import Any, Optional
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from PIL import Image
from pydantic import Field
from text_icon_localization import ocr
from metagpt.environment.android.const import ADB_EXEC_FAIL
from metagpt.environment.android.env_space import (
@ -17,6 +21,7 @@ from metagpt.environment.android.env_space import (
EnvObsValType,
)
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.logs import logger
class AndroidExtEnv(ExtEnv):
@ -228,3 +233,39 @@ class AndroidExtEnv(ExtEnv):
adb_cmd = f"{self.adb_prefix_si} swipe {start[0]} {start[1]} {end[0]} {end[1]} {duration}"
swipe_res = self.execute_adb_with_cmd(adb_cmd)
return swipe_res
@mark_as_writeable
def user_exit(self):
adb_cmd = "adb shell am start -a android.intent.action.MAIN -c android.intent.category.HOME"
exit_res = self.execute_adb_with_cmd(adb_cmd)
return exit_res
@mark_as_writeable
def user_openApp(self, app_name: str):
# openApp without xml
screenshot_path = self.get_screenshot("screenshot", "./screenshot")
image = screenshot_path
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo")
iw, ih = Image.open(image).size
x, y = self.device_shape
if iw > ih:
x, y = y, x
iw, ih = ih, iw
in_coordinate, out_coordinate = ocr(image, app_name, ocr_detection, ocr_recognition, iw, ih)
if len(in_coordinate) == 0:
logger.info(f"No App named {app_name}.")
return "no"
else:
tap_coordinate = [
(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2,
]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, (tap_coordinate[1] - round(50 / y, 2)) * y)
@mark_as_writeable
def user_stop(self):
logger.info("Successful execution of tasks")
# todo user_clickIcon

View file

@ -0,0 +1,43 @@
batch_size = 1
modelname = "groundingdino"
backbone = "swin_T_224_1k"
position_embedding = "sine"
pe_temperatureH = 20
pe_temperatureW = 20
return_interm_indices = [1, 2, 3]
backbone_freeze_keywords = None
enc_layers = 6
dec_layers = 6
pre_norm = False
dim_feedforward = 2048
hidden_dim = 256
dropout = 0.0
nheads = 8
num_queries = 900
query_dim = 4
num_patterns = 0
num_feature_levels = 4
enc_n_points = 4
dec_n_points = 4
two_stage_type = "standard"
two_stage_bbox_embed_share = False
two_stage_class_embed_share = False
transformer_activation = "relu"
dec_pred_bbox_embed_share = True
dn_box_noise_scale = 1.0
dn_label_noise_ratio = 0.5
dn_label_coef = 1.0
dn_bbox_coef = 1.0
embed_init_tgt = True
dn_labelbook_size = 2000
max_text_len = 256
text_encoder_type = "bert-base-uncased"
use_text_enhancer = True
use_fusion_layer = True
use_checkpoint = True
use_transformer_ckpt = True
use_text_cross_attention = True
text_dropout = 0.0
fusion_dropout = 0.0
fusion_droppath = 0.1
sub_sentence_present = True

View file

@ -0,0 +1,311 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Transforms and data augmentation for both image + bbox.
"""
import os
import random
import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from groundingdino.util.box_ops import box_xyxy_to_cxcywh
from groundingdino.util.misc import interpolate
def crop(image, target, region):
cropped_image = F.crop(image, *region)
target = target.copy()
i, j, h, w = region
# should we do something wrt the original size?
target["size"] = torch.tensor([h, w])
fields = ["labels", "area", "iscrowd", "positive_map"]
if "boxes" in target:
boxes = target["boxes"]
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
target["boxes"] = cropped_boxes.reshape(-1, 4)
target["area"] = area
fields.append("boxes")
if "masks" in target:
# FIXME should we update the area here if there are no boxes?
target["masks"] = target["masks"][:, i : i + h, j : j + w]
fields.append("masks")
# remove elements for which the boxes or masks that have zero area
if "boxes" in target or "masks" in target:
# favor boxes selection when defining which elements to keep
# this is compatible with previous implementation
if "boxes" in target:
cropped_boxes = target["boxes"].reshape(-1, 2, 2)
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
else:
keep = target["masks"].flatten(1).any(1)
for field in fields:
if field in target:
target[field] = target[field][keep]
if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
# for debug and visualization only.
if "strings_positive" in target:
target["strings_positive"] = [
_i for _i, _j in zip(target["strings_positive"], keep) if _j
]
return cropped_image, target
def hflip(image, target):
flipped_image = F.hflip(image)
w, h = image.size
target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
[w, 0, w, 0]
)
target["boxes"] = boxes
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
return flipped_image, target
def resize(image, target, size, max_size=None):
# size can be min_size (scalar) or (w, h) tuple
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (oh, ow)
def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
size = get_size(image.size, size, max_size)
rescaled_image = F.resize(image, size)
if target is None:
return rescaled_image, None
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
ratio_width, ratio_height = ratios
target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height]
)
target["boxes"] = scaled_boxes
if "area" in target:
area = target["area"]
scaled_area = area * (ratio_width * ratio_height)
target["area"] = scaled_area
h, w = size
target["size"] = torch.tensor([h, w])
if "masks" in target:
target["masks"] = (
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
)
return rescaled_image, target
def pad(image, target, padding):
# assumes that we only pad on the bottom right corners
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
if target is None:
return padded_image, None
target = target.copy()
# should we do something wrt the original size?
target["size"] = torch.tensor(padded_image.size[::-1])
if "masks" in target:
target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
return padded_image, target
class ResizeDebug(object):
def __init__(self, size):
self.size = size
def __call__(self, img, target):
return resize(img, target, self.size)
class RandomCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, img, target):
region = T.RandomCrop.get_params(img, self.size)
return crop(img, target, region)
class RandomSizeCrop(object):
def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
# respect_boxes: True to keep all boxes
# False to tolerence box filter
self.min_size = min_size
self.max_size = max_size
self.respect_boxes = respect_boxes
def __call__(self, img: PIL.Image.Image, target: dict):
init_boxes = len(target["boxes"])
max_patience = 10
for i in range(max_patience):
w = random.randint(self.min_size, min(img.width, self.max_size))
h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, [h, w])
result_img, result_target = crop(img, target, region)
if (
not self.respect_boxes
or len(result_target["boxes"]) == init_boxes
or i == max_patience - 1
):
return result_img, result_target
return result_img, result_target
class CenterCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, img, target):
image_width, image_height = img.size
crop_height, crop_width = self.size
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, img, target):
if random.random() < self.p:
return hflip(img, target)
return img, target
class RandomResize(object):
def __init__(self, sizes, max_size=None):
assert isinstance(sizes, (list, tuple))
self.sizes = sizes
self.max_size = max_size
def __call__(self, img, target=None):
size = random.choice(self.sizes)
return resize(img, target, size, self.max_size)
class RandomPad(object):
def __init__(self, max_pad):
self.max_pad = max_pad
def __call__(self, img, target):
pad_x = random.randint(0, self.max_pad)
pad_y = random.randint(0, self.max_pad)
return pad(img, target, (pad_x, pad_y))
class RandomSelect(object):
"""
Randomly selects between transforms1 and transforms2,
with probability p for transforms1 and (1 - p) for transforms2
"""
def __init__(self, transforms1, transforms2, p=0.5):
self.transforms1 = transforms1
self.transforms2 = transforms2
self.p = p
def __call__(self, img, target):
if random.random() < self.p:
return self.transforms1(img, target)
return self.transforms2(img, target)
class ToTensor(object):
def __call__(self, img, target):
return F.to_tensor(img), target
class RandomErasing(object):
def __init__(self, *args, **kwargs):
self.eraser = T.RandomErasing(*args, **kwargs)
def __call__(self, img, target):
return self.eraser(img), target
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target=None):
image = F.normalize(image, mean=self.mean, std=self.std)
if target is None:
return image, None
target = target.copy()
h, w = image.shape[-2:]
if "boxes" in target:
boxes = target["boxes"]
boxes = box_xyxy_to_cxcywh(boxes)
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
target["boxes"] = boxes
return image, target
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string

View file

@ -0,0 +1,15 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copied from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
from .groundingdino import build_groundingdino

View file

@ -0,0 +1 @@
from .backbone import build_backbone

View file

@ -0,0 +1,221 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copied from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
"""
Backbone modules.
"""
from typing import Dict, List
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
from .position_encoding import build_position_encoding
from .swin_transformer import build_swin_transformer
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(
self,
backbone: nn.Module,
train_backbone: bool,
num_channels: int,
return_interm_indices: list,
):
super().__init__()
for name, parameter in backbone.named_parameters():
if (
not train_backbone
or "layer2" not in name
and "layer3" not in name
and "layer4" not in name
):
parameter.requires_grad_(False)
return_layers = {}
for idx, layer_index in enumerate(return_interm_indices):
return_layers.update(
{"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
)
# if len:
# if use_stage1_feature:
# return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
# else:
# return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
# else:
# return_layers = {'layer4': "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
# import ipdb; ipdb.set_trace()
return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(
self,
name: str,
train_backbone: bool,
dilation: bool,
return_interm_indices: list,
batch_norm=FrozenBatchNorm2d,
):
if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(),
norm_layer=batch_norm,
)
else:
raise NotImplementedError("Why you can get here with name {}".format(name))
# num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
num_channels_all = [256, 512, 1024, 2048]
num_channels = num_channels_all[4 - len(return_interm_indices) :]
super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.tensors.dtype))
return out, pos
def build_backbone(args):
"""
Useful args:
- backbone: backbone name
- lr_backbone:
- dilation
- return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
- backbone_freeze_keywords:
- use_checkpoint: for swin only for now
"""
position_embedding = build_position_encoding(args)
train_backbone = True
if not train_backbone:
raise ValueError("Please set lr_backbone > 0")
return_interm_indices = args.return_interm_indices
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
args.backbone_freeze_keywords
use_checkpoint = getattr(args, "use_checkpoint", False)
if args.backbone in ["resnet50", "resnet101"]:
backbone = Backbone(
args.backbone,
train_backbone,
args.dilation,
return_interm_indices,
batch_norm=FrozenBatchNorm2d,
)
bb_num_channels = backbone.num_channels
elif args.backbone in [
"swin_T_224_1k",
"swin_B_224_22k",
"swin_B_384_22k",
"swin_L_224_22k",
"swin_L_384_22k",
]:
pretrain_img_size = int(args.backbone.split("_")[-2])
backbone = build_swin_transformer(
args.backbone,
pretrain_img_size=pretrain_img_size,
out_indices=tuple(return_interm_indices),
dilation=False,
use_checkpoint=use_checkpoint,
)
bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
else:
raise NotImplementedError("Unknown backbone {}".format(args.backbone))
assert len(bb_num_channels) == len(
return_interm_indices
), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
model = Joiner(backbone, position_embedding)
model.num_channels = bb_num_channels
assert isinstance(
bb_num_channels, List
), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
# import ipdb; ipdb.set_trace()
return model

View file

@ -0,0 +1,186 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# DINO
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copied from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from groundingdino.util.misc import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
# if os.environ.get("SHILONG_AMP", None) == '1':
# eps = 1e-4
# else:
# eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingSineHW(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperatureH = temperatureH
self.temperatureW = temperatureW
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
# import ipdb; ipdb.set_trace()
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# import ipdb; ipdb.set_trace()
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
torch.cat(
[
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
],
dim=-1,
)
.permute(2, 0, 1)
.unsqueeze(0)
.repeat(x.shape[0], 1, 1, 1)
)
return pos
def build_position_encoding(args):
N_steps = args.hidden_dim // 2
if args.position_embedding in ("v2", "sine"):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSineHW(
N_steps,
temperatureH=args.pe_temperatureH,
temperatureW=args.pe_temperatureW,
normalize=True,
)
elif args.position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding

View file

@ -0,0 +1,802 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# DINO
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# --------------------------------------------------------
# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from groundingdino.util.misc import NestedTensor
class Mlp(nn.Module):
"""Multilayer perceptron."""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
mask_matrix: Attention mask for cyclic shift.
"""
B, L, C = x.shape
H, W = self.H, self.W
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
"""Patch Merging Layer
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of feature channels
depth (int): Depths of this stage.
num_heads (int): Number of attention head.
window_size (int): Local window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
"""Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
"""Swin Transformer backbone.
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
pretrain_img_size (int): Input image size for training the pretrained model,
used in absolute postion embedding. Default 224.
patch_size (int | tuple(int)): Patch size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
num_heads (tuple[int]): Number of attention head of each stage.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
"""
def __init__(
self,
pretrain_img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
dilation=False,
use_checkpoint=False,
):
super().__init__()
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.dilation = dilation
# if use_checkpoint:
# print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
# absolute position embedding
if self.ape:
pretrain_img_size = to_2tuple(pretrain_img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
pretrain_img_size[0] // patch_size[0],
pretrain_img_size[1] // patch_size[1],
]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
)
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
# prepare downsample list
downsamplelist = [PatchMerging for i in range(self.num_layers)]
downsamplelist[-1] = None
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
if self.dilation:
downsamplelist[-2] = None
num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
for i_layer in range(self.num_layers):
layer = BasicLayer(
# dim=int(embed_dim * 2 ** i_layer),
dim=num_features[i_layer],
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
# downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
downsample=downsamplelist[i_layer],
use_checkpoint=use_checkpoint,
)
self.layers.append(layer)
# num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
# add a norm layer for each output
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f"norm{i_layer}"
self.add_module(layer_name, layer)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1 and self.ape:
self.absolute_pos_embed.requires_grad = False
if self.frozen_stages >= 2:
self.pos_drop.eval()
for i in range(0, self.frozen_stages - 1):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
# def init_weights(self, pretrained=None):
# """Initialize the weights in backbone.
# Args:
# pretrained (str, optional): Path to pre-trained weights.
# Defaults to None.
# """
# def _init_weights(m):
# if isinstance(m, nn.Linear):
# trunc_normal_(m.weight, std=.02)
# if isinstance(m, nn.Linear) and m.bias is not None:
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, nn.LayerNorm):
# nn.init.constant_(m.bias, 0)
# nn.init.constant_(m.weight, 1.0)
# if isinstance(pretrained, str):
# self.apply(_init_weights)
# logger = get_root_logger()
# load_checkpoint(self, pretrained, strict=False, logger=logger)
# elif pretrained is None:
# self.apply(_init_weights)
# else:
# raise TypeError('pretrained must be a str or None')
def forward_raw(self, x):
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
)
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
# import ipdb; ipdb.set_trace()
if i in self.out_indices:
norm_layer = getattr(self, f"norm{i}")
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
# in:
# torch.Size([2, 3, 1024, 1024])
# outs:
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
return tuple(outs)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
)
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f"norm{i}")
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
# in:
# torch.Size([2, 3, 1024, 1024])
# out:
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
# collect for nesttensors
outs_dict = {}
for idx, out_i in enumerate(outs):
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
outs_dict[idx] = NestedTensor(out_i, mask)
return outs_dict
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()
def build_swin_transformer(modelname, pretrain_img_size, **kw):
assert modelname in [
"swin_T_224_1k",
"swin_B_224_22k",
"swin_B_384_22k",
"swin_L_224_22k",
"swin_L_384_22k",
]
model_para_dict = {
"swin_T_224_1k": dict(
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
),
"swin_B_224_22k": dict(
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
),
"swin_B_384_22k": dict(
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
),
"swin_L_224_22k": dict(
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
),
"swin_L_384_22k": dict(
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
),
}
kw_cgf = model_para_dict[modelname]
kw_cgf.update(kw)
model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
return model
if __name__ == "__main__":
model = build_swin_transformer("swin_L_384_22k", 384, dilation=True)
x = torch.rand(2, 3, 1024, 1024)
y = model.forward_raw(x)
import ipdb
ipdb.set_trace()
x = torch.rand(2, 3, 384, 384)
y = model.forward_raw(x)

View file

@ -0,0 +1,273 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch import Tensor, nn
from torchvision.ops.boxes import nms
from transformers import BertConfig, BertModel, BertPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
class BertModelWarper(nn.Module):
def __init__(self, bert_model):
super().__init__()
# self.bert = bert_modelc
self.config = bert_model.config
self.embeddings = bert_model.embeddings
self.encoder = bert_model.encoder
self.pooler = bert_model.pooler
self.get_extended_attention_mask = bert_model.get_extended_attention_mask
self.invert_attention_mask = bert_model.invert_attention_mask
self.get_head_mask = bert_model.get_head_mask
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
if attention_mask is None:
attention_mask = torch.ones(
((batch_size, seq_length + past_key_values_length)), device=device
)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, device
)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class TextEncoderShell(nn.Module):
def __init__(self, text_encoder):
super().__init__()
self.text_encoder = text_encoder
self.config = self.text_encoder.config
def forward(self, **kw):
# feed into text encoder
return self.text_encoder(**kw)
def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer):
"""Generate attention mask between each pair of special tokens
Args:
input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
special_tokens_mask (list): special tokens mask.
Returns:
torch.Tensor: attention mask between each special tokens.
"""
input_ids = tokenized["input_ids"]
bs, num_token = input_ids.shape
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
for special_token in special_tokens_list:
special_tokens_mask |= input_ids == special_token
# idxs: each row is a list of indices of special tokens
idxs = torch.nonzero(special_tokens_mask)
# generate attention mask and positional ids
attention_mask = (
torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
)
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
previous_col = 0
for i in range(idxs.shape[0]):
row, col = idxs[i]
if (col == 0) or (col == num_token - 1):
attention_mask[row, col, col] = True
position_ids[row, col] = 0
else:
attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
position_ids[row, previous_col + 1 : col + 1] = torch.arange(
0, col - previous_col, device=input_ids.device
)
previous_col = col
# # padding mask
# padding_mask = tokenized['attention_mask']
# attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
return attention_mask, position_ids.to(torch.long)
def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_tokens_list, tokenizer):
"""Generate attention mask between each pair of special tokens
Args:
input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
special_tokens_mask (list): special tokens mask.
Returns:
torch.Tensor: attention mask between each special tokens.
"""
input_ids = tokenized["input_ids"]
bs, num_token = input_ids.shape
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
for special_token in special_tokens_list:
special_tokens_mask |= input_ids == special_token
# idxs: each row is a list of indices of special tokens
idxs = torch.nonzero(special_tokens_mask)
# generate attention mask and positional ids
attention_mask = (
torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
)
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
cate_to_token_mask_list = [[] for _ in range(bs)]
previous_col = 0
for i in range(idxs.shape[0]):
row, col = idxs[i]
if (col == 0) or (col == num_token - 1):
attention_mask[row, col, col] = True
position_ids[row, col] = 0
else:
attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
position_ids[row, previous_col + 1 : col + 1] = torch.arange(
0, col - previous_col, device=input_ids.device
)
c2t_maski = torch.zeros((num_token), device=input_ids.device).bool()
c2t_maski[previous_col + 1 : col] = True
cate_to_token_mask_list[row].append(c2t_maski)
previous_col = col
cate_to_token_mask_list = [
torch.stack(cate_to_token_mask_listi, dim=0)
for cate_to_token_mask_listi in cate_to_token_mask_list
]
# # padding mask
# padding_mask = tokenized['attention_mask']
# attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
return attention_mask, position_ids.to(torch.long), cate_to_token_mask_list

View file

@ -0,0 +1,64 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include "ms_deform_attn_cpu.h"
#ifdef WITH_CUDA
#include "ms_deform_attn_cuda.h"
#endif
namespace groundingdino {
at::Tensor
ms_deform_attn_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
std::vector<at::Tensor>
ms_deform_attn_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
} // namespace groundingdino

View file

@ -0,0 +1,43 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
namespace groundingdino {
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
} // namespace groundingdino

View file

@ -0,0 +1,35 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
namespace groundingdino {
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
} // namespace groundingdino

View file

@ -0,0 +1,156 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include "ms_deform_im2col_cuda.cuh"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
namespace groundingdino {
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads*channels});
return output;
}
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}
return {
grad_value, grad_sampling_loc, grad_attn_weight
};
}
} // namespace groundingdino

View file

@ -0,0 +1,33 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
namespace groundingdino {
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
} // namespace groundingdino

View file

@ -0,0 +1,7 @@
#include <cuda_runtime_api.h>
namespace groundingdino {
int get_cudart_version() {
return CUDART_VERSION;
}
} // namespace groundingdino

View file

@ -0,0 +1,58 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "MsDeformAttn/ms_deform_attn.h"
namespace groundingdino {
#ifdef WITH_CUDA
extern int get_cudart_version();
#endif
std::string get_cuda_version() {
#ifdef WITH_CUDA
std::ostringstream oss;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
auto printCudaStyleVersion = [&](int v) {
oss << (v / 1000) << "." << (v / 10 % 100);
if (v % 10 != 0) {
oss << "." << (v % 10);
}
};
printCudaStyleVersion(get_cudart_version());
return oss.str();
#else
return std::string("not available");
#endif
}
// similar to
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
std::string get_compiler_version() {
std::ostringstream ss;
#if defined(__GNUC__)
#ifndef __clang__
{ ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
#endif
#endif
#if defined(__clang_major__)
{
ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
<< __clang_patchlevel__;
}
#endif
#if defined(_MSC_VER)
{ ss << "MSVC " << _MSC_FULL_VER; }
#endif
return ss.str();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}
} // namespace groundingdino

View file

@ -0,0 +1,297 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
class FeatureResizer(nn.Module):
"""
This class takes as input a set of embeddings of dimension C1 and outputs a set of
embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
"""
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
super().__init__()
self.do_ln = do_ln
# Object feature encoding
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
self.dropout = nn.Dropout(dropout)
def forward(self, encoder_features):
x = self.fc(encoder_features)
if self.do_ln:
x = self.layer_norm(x)
output = self.dropout(x)
return output
def l1norm(X, dim, eps=1e-8):
"""L1-normalize columns of X"""
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
X = torch.div(X, norm)
return X
def l2norm(X, dim, eps=1e-8):
"""L2-normalize columns of X"""
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
X = torch.div(X, norm)
return X
def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
"""
query: (n_context, queryL, d)
context: (n_context, sourceL, d)
"""
batch_size_q, queryL = query.size(0), query.size(1)
batch_size, sourceL = context.size(0), context.size(1)
# Get attention
# --> (batch, d, queryL)
queryT = torch.transpose(query, 1, 2)
# (batch, sourceL, d)(batch, d, queryL)
# --> (batch, sourceL, queryL)
attn = torch.bmm(context, queryT)
if raw_feature_norm == "softmax":
# --> (batch*sourceL, queryL)
attn = attn.view(batch_size * sourceL, queryL)
attn = nn.Softmax()(attn)
# --> (batch, sourceL, queryL)
attn = attn.view(batch_size, sourceL, queryL)
elif raw_feature_norm == "l2norm":
attn = l2norm(attn, 2)
elif raw_feature_norm == "clipped_l2norm":
attn = nn.LeakyReLU(0.1)(attn)
attn = l2norm(attn, 2)
else:
raise ValueError("unknown first norm type:", raw_feature_norm)
# --> (batch, queryL, sourceL)
attn = torch.transpose(attn, 1, 2).contiguous()
# --> (batch*queryL, sourceL)
attn = attn.view(batch_size * queryL, sourceL)
attn = nn.Softmax()(attn * smooth)
# --> (batch, queryL, sourceL)
attn = attn.view(batch_size, queryL, sourceL)
# --> (batch, sourceL, queryL)
attnT = torch.transpose(attn, 1, 2).contiguous()
# --> (batch, d, sourceL)
contextT = torch.transpose(context, 1, 2)
# (batch x d x sourceL)(batch x sourceL x queryL)
# --> (batch, d, queryL)
weightedContext = torch.bmm(contextT, attnT)
# --> (batch, queryL, d)
weightedContext = torch.transpose(weightedContext, 1, 2)
return weightedContext, attnT
class BiMultiHeadAttention(nn.Module):
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
super(BiMultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.v_dim = v_dim
self.l_dim = l_dim
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
self.scale = self.head_dim ** (-0.5)
self.dropout = dropout
self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
self.stable_softmax_2d = True
self.clamp_min_for_underflow = True
self.clamp_max_for_overflow = True
self._reset_parameters()
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.v_proj.weight)
self.v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.l_proj.weight)
self.l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_v_proj.weight)
self.values_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_l_proj.weight)
self.values_l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_v_proj.weight)
self.out_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_l_proj.weight)
self.out_l_proj.bias.data.fill_(0)
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
"""_summary_
Args:
v (_type_): bs, n_img, dim
l (_type_): bs, n_text, dim
attention_mask_v (_type_, optional): _description_. bs, n_img
attention_mask_l (_type_, optional): _description_. bs, n_text
Returns:
_type_: _description_
"""
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
bsz, tgt_len, _ = v.size()
query_states = self.v_proj(v) * self.scale
key_states = self._shape(self.l_proj(l), -1, bsz)
value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_v_states = value_v_states.view(*proj_shape)
value_l_states = value_l_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
)
if self.stable_softmax_2d:
attn_weights = attn_weights - attn_weights.max()
if self.clamp_min_for_underflow:
attn_weights = torch.clamp(
attn_weights, min=-50000
) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights = torch.clamp(
attn_weights, max=50000
) # Do not increase 50000, data type half has quite limited range
attn_weights_T = attn_weights.transpose(1, 2)
attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]
if self.clamp_min_for_underflow:
attn_weights_l = torch.clamp(
attn_weights_l, min=-50000
) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights_l = torch.clamp(
attn_weights_l, max=50000
) # Do not increase 50000, data type half has quite limited range
# mask vison for language
if attention_mask_v is not None:
attention_mask_v = (
attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
)
attn_weights_l.masked_fill_(attention_mask_v, float("-inf"))
attn_weights_l = attn_weights_l.softmax(dim=-1)
# mask language for vision
if attention_mask_l is not None:
attention_mask_l = (
attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
)
attn_weights.masked_fill_(attention_mask_l, float("-inf"))
attn_weights_v = attn_weights.softmax(dim=-1)
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
attn_output_v = torch.bmm(attn_probs_v, value_l_states)
attn_output_l = torch.bmm(attn_probs_l, value_v_states)
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
)
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
raise ValueError(
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
)
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output_v = attn_output_v.transpose(1, 2)
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
attn_output_l = attn_output_l.transpose(1, 2)
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
attn_output_v = self.out_v_proj(attn_output_v)
attn_output_l = self.out_l_proj(attn_output_l)
return attn_output_v, attn_output_l
# Bi-Direction MHA (text->image, image->text)
class BiAttentionBlock(nn.Module):
def __init__(
self,
v_dim,
l_dim,
embed_dim,
num_heads,
dropout=0.1,
drop_path=0.0,
init_values=1e-4,
cfg=None,
):
"""
Inputs:
embed_dim - Dimensionality of input and attention feature vectors
hidden_dim - Dimensionality of hidden layer in feed-forward network
(usually 2-4x larger than embed_dim)
num_heads - Number of heads to use in the Multi-Head Attention block
dropout - Amount of dropout to apply in the feed-forward network
"""
super(BiAttentionBlock, self).__init__()
# pre layer norm
self.layer_norm_v = nn.LayerNorm(v_dim)
self.layer_norm_l = nn.LayerNorm(l_dim)
self.attn = BiMultiHeadAttention(
v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout
)
# add layer scale for training stability
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
v = self.layer_norm_v(v)
l = self.layer_norm_l(l)
delta_v, delta_l = self.attn(
v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l
)
# v, l = v + delta_v, l + delta_l
v = v + self.drop_path(self.gamma_v * delta_v)
l = l + self.drop_path(self.gamma_l * delta_l)
return v, l
# def forward(self, v:List[torch.Tensor], l, attention_mask_v=None, attention_mask_l=None)

View file

@ -0,0 +1,395 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR model and criterion classes.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
import copy
from typing import List
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.ops.boxes import nms
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
from groundingdino.util import box_ops, get_tokenlizer
from groundingdino.util.misc import (
NestedTensor,
accuracy,
get_world_size,
interpolate,
inverse_sigmoid,
is_dist_avail_and_initialized,
nested_tensor_from_tensor_list,
)
from groundingdino.util.utils import get_phrases_from_posmap
from groundingdino.util.visualizer import COCOVisualizer
from groundingdino.util.vl_utils import create_positive_map_from_span
from ..registry import MODULE_BUILD_FUNCS
from .backbone import build_backbone
from .bertwarper import (
BertModelWarper,
generate_masks_with_special_tokens,
generate_masks_with_special_tokens_and_transfer_map,
)
from .transformer import build_transformer
from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss
class GroundingDINO(nn.Module):
"""This is the Cross-Attention Detector module that performs object detection"""
def __init__(
self,
backbone,
transformer,
num_queries,
aux_loss=False,
iter_update=False,
query_dim=2,
num_feature_levels=1,
nheads=8,
# two stage
two_stage_type="no", # ['no', 'standard']
dec_pred_bbox_embed_share=True,
two_stage_class_embed_share=True,
two_stage_bbox_embed_share=True,
num_patterns=0,
dn_number=100,
dn_box_noise_scale=0.4,
dn_label_noise_ratio=0.5,
dn_labelbook_size=100,
text_encoder_type="bert-base-uncased",
sub_sentence_present=True,
max_text_len=256,
):
"""Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
self.hidden_dim = hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.nheads = nheads
self.max_text_len = 256
self.sub_sentence_present = sub_sentence_present
# setting query dim
self.query_dim = query_dim
assert query_dim == 4
# for dn training
self.num_patterns = num_patterns
self.dn_number = dn_number
self.dn_box_noise_scale = dn_box_noise_scale
self.dn_label_noise_ratio = dn_label_noise_ratio
self.dn_labelbook_size = dn_labelbook_size
# bert
self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
self.bert.pooler.dense.weight.requires_grad_(False)
self.bert.pooler.dense.bias.requires_grad_(False)
self.bert = BertModelWarper(bert_model=self.bert)
self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
nn.init.constant_(self.feat_map.bias.data, 0)
nn.init.xavier_uniform_(self.feat_map.weight.data)
# freeze
# special tokens
self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
# prepare input projection layers
if num_feature_levels > 1:
num_backbone_outs = len(backbone.num_channels)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
)
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(32, hidden_dim),
)
)
in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
self.input_proj = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
]
)
self.backbone = backbone
self.aux_loss = aux_loss
self.box_pred_damping = box_pred_damping = None
self.iter_update = iter_update
assert iter_update, "Why not iter_update?"
# prepare pred layers
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
# prepare class & box embed
_class_embed = ContrastiveEmbed()
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
else:
box_embed_layerlist = [
copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)
]
class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
self.bbox_embed = nn.ModuleList(box_embed_layerlist)
self.class_embed = nn.ModuleList(class_embed_layerlist)
self.transformer.decoder.bbox_embed = self.bbox_embed
self.transformer.decoder.class_embed = self.class_embed
# two stage
self.two_stage_type = two_stage_type
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
two_stage_type
)
if two_stage_type != "no":
if two_stage_bbox_embed_share:
assert dec_pred_bbox_embed_share
self.transformer.enc_out_bbox_embed = _bbox_embed
else:
self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
if two_stage_class_embed_share:
assert dec_pred_bbox_embed_share
self.transformer.enc_out_class_embed = _class_embed
else:
self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
self.refpoint_embed = None
self._reset_parameters()
def _reset_parameters(self):
# init input_proj
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
def forward(self, samples: NestedTensor, targets: List = None, **kw):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if targets is None:
captions = kw["captions"]
else:
captions = [t["caption"] for t in targets]
len(captions)
# encoder texts
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
samples.device
)
(
text_self_attention_masks,
position_ids,
cate_to_token_mask_list,
) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, self.specical_tokens, self.tokenizer
)
if text_self_attention_masks.shape[1] > self.max_text_len:
text_self_attention_masks = text_self_attention_masks[
:, : self.max_text_len, : self.max_text_len
]
position_ids = position_ids[:, : self.max_text_len]
tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
# extract text embeddings
if self.sub_sentence_present:
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
tokenized_for_encoder["position_ids"] = position_ids
else:
# import ipdb; ipdb.set_trace()
tokenized_for_encoder = tokenized
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
text_token_mask = tokenized.attention_mask.bool() # bs, 195
# text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask
if encoded_text.shape[1] > self.max_text_len:
encoded_text = encoded_text[:, : self.max_text_len, :]
text_token_mask = text_token_mask[:, : self.max_text_len]
position_ids = position_ids[:, : self.max_text_len]
text_self_attention_masks = text_self_attention_masks[
:, : self.max_text_len, : self.max_text_len
]
text_dict = {
"encoded_text": encoded_text, # bs, 195, d_model
"text_token_mask": text_token_mask, # bs, 195
"position_ids": position_ids, # bs, 195
"text_self_attention_masks": text_self_attention_masks, # bs, 195,195
}
# import ipdb; ipdb.set_trace()
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
features, poss = self.backbone(samples)
srcs = []
masks = []
for l, feat in enumerate(features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
poss.append(pos_l)
input_query_bbox = input_query_label = attn_mask = dn_meta = None
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict
)
# deformable-detr-like anchor update
outputs_coord_list = []
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
zip(reference[:-1], self.bbox_embed, hs)
):
layer_delta_unsig = layer_bbox_embed(layer_hs)
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
outputs_coord_list.append(layer_outputs_unsig)
outputs_coord_list = torch.stack(outputs_coord_list)
# output
outputs_class = torch.stack(
[
layer_cls_embed(layer_hs, text_dict)
for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
]
)
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
# # for intermediate outputs
# if self.aux_loss:
# out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
# # for encoder output
# if hs_enc is not None:
# # prepare intermediate outputs
# interm_coord = ref_enc[-1]
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [
{"pred_logits": a, "pred_boxes": b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
]
@MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
def build_groundingdino(args):
backbone = build_backbone(args)
transformer = build_transformer(args)
dn_labelbook_size = args.dn_labelbook_size
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
sub_sentence_present = args.sub_sentence_present
model = GroundingDINO(
backbone,
transformer,
num_queries=args.num_queries,
aux_loss=True,
iter_update=True,
query_dim=4,
num_feature_levels=args.num_feature_levels,
nheads=args.nheads,
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
two_stage_type=args.two_stage_type,
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
two_stage_class_embed_share=args.two_stage_class_embed_share,
num_patterns=args.num_patterns,
dn_number=0,
dn_box_noise_scale=args.dn_box_noise_scale,
dn_label_noise_ratio=args.dn_label_noise_ratio,
dn_labelbook_size=dn_labelbook_size,
text_encoder_type=args.text_encoder_type,
sub_sentence_present=sub_sentence_present,
max_text_len=args.max_text_len,
)
return model

View file

@ -0,0 +1,413 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py
# ------------------------------------------------------------------------------------------------
import math
import warnings
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.init import constant_, xavier_uniform_
try:
from groundingdino import _C
except:
warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
# helpers
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n - 1) == 0) and n != 0
class MultiScaleDeformableAttnFunction(Function):
@staticmethod
def forward(
ctx,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
ctx.im2col_step = im2col_step
output = _C.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
ctx.im2col_step,
)
ctx.save_for_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = _C.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
ctx.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
def multi_scale_deformable_attn_pytorch(
value: torch.Tensor,
value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
) -> torch.Tensor:
bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = (
value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(bs, num_heads * embed_dims, num_queries)
)
return output.transpose(1, 2).contiguous()
class MultiScaleDeformableAttention(nn.Module):
"""Multi-Scale Deformable Attention Module used in Deformable-DETR
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dim (int): The embedding dimension of Attention. Default: 256.
num_heads (int): The number of attention heads. Default: 8.
num_levels (int): The number of feature map used in Attention. Default: 4.
num_points (int): The number of sampling points for each query
in each head. Default: 4.
img2col_steps (int): The step used in image_to_column. Defualt: 64.
dropout (float): Dropout layer used in output. Default: 0.1.
batch_first (bool): if ``True``, then the input and output tensor will be
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
"""
def __init__(
self,
embed_dim: int = 256,
num_heads: int = 8,
num_levels: int = 4,
num_points: int = 4,
img2col_step: int = 64,
batch_first: bool = False,
):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
"embed_dim must be divisible by num_heads, but got {} and {}".format(
embed_dim, num_heads
)
)
head_dim = embed_dim // num_heads
self.batch_first = batch_first
if not _is_power_of_2(head_dim):
warnings.warn(
"""
You'd better set d_model in MSDeformAttn to make sure that
each dim of the attention head a power of 2, which is more efficient.
"""
)
self.im2col_step = img2col_step
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_levels = num_levels
self.num_points = num_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self.init_weights()
def _reset_parameters(self):
return self.init_weights()
def init_weights(self):
"""
Default initialization for Parameters of Module.
"""
constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
2.0 * math.pi / self.num_heads
)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(self.num_heads, 1, 1, 2)
.repeat(1, self.num_levels, self.num_points, 1)
)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.0)
constant_(self.attention_weights.bias.data, 0.0)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.0)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.0)
def freeze_sampling_offsets(self):
print("Freeze sampling offsets")
self.sampling_offsets.weight.requires_grad = False
self.sampling_offsets.bias.requires_grad = False
def freeze_attention_weights(self):
print("Freeze attention weights")
self.attention_weights.weight.requires_grad = False
self.attention_weights.bias.requires_grad = False
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
query_pos: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
reference_points: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.Tensor] = None,
level_start_index: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
"""Forward Function of MultiScaleDeformableAttention
Args:
query (torch.Tensor): Query embeddings with shape
`(num_query, bs, embed_dim)`
key (torch.Tensor): Key embeddings with shape
`(num_key, bs, embed_dim)`
value (torch.Tensor): Value embeddings with shape
`(num_key, bs, embed_dim)`
query_pos (torch.Tensor): The position embedding for `query`. Default: None.
key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`,
indicating which elements within `key` to be ignored in attention.
reference_points (torch.Tensor): The normalized reference points
with shape `(bs, num_query, num_levels, 2)`,
all elements is range in [0, 1], top-left (0, 0),
bottom-right (1, 1), including padding are.
or `(N, Length_{query}, num_levels, 4)`, add additional
two dimensions `(h, w)` to form reference boxes.
spatial_shapes (torch.Tensor): Spatial shape of features in different levels.
With shape `(num_levels, 2)`, last dimension represents `(h, w)`.
level_start_index (torch.Tensor): The start index of each level. A tensor with
shape `(num_levels, )` which can be represented as
`[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`.
Returns:
torch.Tensor: forward results with shape `(num_query, bs, embed_dim)`
"""
if value is None:
value = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], float(0))
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points
)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(
bs,
num_query,
self.num_heads,
self.num_levels,
self.num_points,
)
# bs, num_query, num_heads, num_levels, num_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
+ sampling_offsets
/ self.num_points
* reference_points[:, :, None, :, None, 2:]
* 0.5
)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(
reference_points.shape[-1]
)
)
if torch.cuda.is_available() and value.is_cuda:
halffloat = False
if value.dtype == torch.float16:
halffloat = True
value = value.float()
sampling_locations = sampling_locations.float()
attention_weights = attention_weights.float()
output = MultiScaleDeformableAttnFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
if halffloat:
output = output.half()
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights
)
output = self.output_proj(output)
if not self.batch_first:
output = output.permute(1, 0, 2)
return output
def create_dummy_class(klass, dependency, message=""):
"""
When a dependency of a class is not available, create a dummy class which throws ImportError
when used.
Args:
klass (str): name of the class.
dependency (str): name of the dependency.
message: extra message to print
Returns:
class: a class object
"""
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
if message:
err = err + " " + message
class _DummyMetaClass(type):
# throw error on class attribute access
def __getattr__(_, __): # noqa: B902
raise ImportError(err)
class _Dummy(object, metaclass=_DummyMetaClass):
# throw error on constructor
def __init__(self, *args, **kwargs):
raise ImportError(err)
return _Dummy
def create_dummy_func(func, dependency, message=""):
"""
When a dependency of a function is not available, create a dummy function which throws
ImportError when used.
Args:
func (str): name of the function.
dependency (str or list[str]): name(s) of the dependency.
message: extra message to print
Returns:
function: a function object
"""
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
if message:
err = err + " " + message
if isinstance(dependency, (list, tuple)):
dependency = ",".join(dependency)
def _dummy(*args, **kwargs):
raise ImportError(err)
return _dummy

View file

@ -0,0 +1,959 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# DINO
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR Transformer class.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
from typing import Optional
import torch
import torch.utils.checkpoint as checkpoint
from torch import Tensor, nn
from groundingdino.util.misc import inverse_sigmoid
from .fuse_modules import BiAttentionBlock
from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
from .transformer_vanilla import TransformerEncoderLayer
from .utils import (
MLP,
_get_activation_fn,
_get_clones,
gen_encoder_output_proposals,
gen_sineembed_for_position,
get_sine_pos_embed,
)
class Transformer(nn.Module):
def __init__(
self,
d_model=256,
nhead=8,
num_queries=300,
num_encoder_layers=6,
num_unicoder_layers=0,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.0,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
query_dim=4,
num_patterns=0,
# for deformable encoder
num_feature_levels=1,
enc_n_points=4,
dec_n_points=4,
# init query
learnable_tgt_init=False,
# two stage
two_stage_type="no", # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
embed_init_tgt=False,
# for text
use_text_enhancer=False,
use_fusion_layer=False,
use_checkpoint=False,
use_transformer_ckpt=False,
use_text_cross_attention=False,
text_dropout=0.1,
fusion_dropout=0.1,
fusion_droppath=0.0,
):
super().__init__()
self.num_feature_levels = num_feature_levels
self.num_encoder_layers = num_encoder_layers
self.num_unicoder_layers = num_unicoder_layers
self.num_decoder_layers = num_decoder_layers
self.num_queries = num_queries
assert query_dim == 4
# choose encoder layer type
encoder_layer = DeformableTransformerEncoderLayer(
d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points
)
if use_text_enhancer:
text_enhance_layer = TransformerEncoderLayer(
d_model=d_model,
nhead=nhead // 2,
dim_feedforward=dim_feedforward // 2,
dropout=text_dropout,
)
else:
text_enhance_layer = None
if use_fusion_layer:
feature_fusion_layer = BiAttentionBlock(
v_dim=d_model,
l_dim=d_model,
embed_dim=dim_feedforward // 2,
num_heads=nhead // 2,
dropout=fusion_dropout,
drop_path=fusion_droppath,
)
else:
feature_fusion_layer = None
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
assert encoder_norm is None
self.encoder = TransformerEncoder(
encoder_layer,
num_encoder_layers,
d_model=d_model,
num_queries=num_queries,
text_enhance_layer=text_enhance_layer,
feature_fusion_layer=feature_fusion_layer,
use_checkpoint=use_checkpoint,
use_transformer_ckpt=use_transformer_ckpt,
)
# choose decoder layer type
decoder_layer = DeformableTransformerDecoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
dec_n_points,
use_text_cross_attention=use_text_cross_attention,
)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(
decoder_layer,
num_decoder_layers,
decoder_norm,
return_intermediate=return_intermediate_dec,
d_model=d_model,
query_dim=query_dim,
num_feature_levels=num_feature_levels,
)
self.d_model = d_model
self.nhead = nhead
self.dec_layers = num_decoder_layers
self.num_queries = num_queries # useful for single stage model only
self.num_patterns = num_patterns
if not isinstance(num_patterns, int):
Warning("num_patterns should be int but {}".format(type(num_patterns)))
self.num_patterns = 0
if num_feature_levels > 1:
if self.num_encoder_layers > 0:
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
else:
self.level_embed = None
self.learnable_tgt_init = learnable_tgt_init
assert learnable_tgt_init, "why not learnable_tgt_init"
self.embed_init_tgt = embed_init_tgt
if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"):
self.tgt_embed = nn.Embedding(self.num_queries, d_model)
nn.init.normal_(self.tgt_embed.weight.data)
else:
self.tgt_embed = None
# for two stage
self.two_stage_type = two_stage_type
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
two_stage_type
)
if two_stage_type == "standard":
# anchor selection at the output of encoder
self.enc_output = nn.Linear(d_model, d_model)
self.enc_output_norm = nn.LayerNorm(d_model)
self.two_stage_wh_embedding = None
if two_stage_type == "no":
self.init_ref_points(num_queries) # init self.refpoint_embed
self.enc_out_class_embed = None
self.enc_out_bbox_embed = None
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
if self.num_feature_levels > 1 and self.level_embed is not None:
nn.init.normal_(self.level_embed)
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, 4)
def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
"""
Input:
- srcs: List of multi features [bs, ci, hi, wi]
- masks: List of multi masks [bs, hi, wi]
- refpoint_embed: [bs, num_dn, 4]. None in infer
- pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
- tgt: [bs, num_dn, d_model]. None in infer
"""
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c
mask = mask.flatten(1) # bs, hw
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
if self.num_feature_levels > 1 and self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
else:
lvl_pos_embed = pos_embed
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=src_flatten.device
)
level_start_index = torch.cat(
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
)
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# two stage
enc_topk_proposals = enc_refpoint_embed = None
#########################################################
# Begin Encoder
#########################################################
memory, memory_text = self.encoder(
src_flatten,
pos=lvl_pos_embed_flatten,
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
key_padding_mask=mask_flatten,
memory_text=text_dict["encoded_text"],
text_attention_mask=~text_dict["text_token_mask"],
# we ~ the mask . False means use the token; True means pad the token
position_ids=text_dict["position_ids"],
text_self_attention_masks=text_dict["text_self_attention_masks"],
)
#########################################################
# End Encoder
# - memory: bs, \sum{hw}, c
# - mask_flatten: bs, \sum{hw}
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
#########################################################
text_dict["encoded_text"] = memory_text
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# if memory.isnan().any() | memory.isinf().any():
# import ipdb; ipdb.set_trace()
if self.two_stage_type == "standard":
output_memory, output_proposals = gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes
)
output_memory = self.enc_output_norm(self.enc_output(output_memory))
if text_dict is not None:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict)
else:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
topk_logits = enc_outputs_class_unselected.max(-1)[0]
enc_outputs_coord_unselected = (
self.enc_out_bbox_embed(output_memory) + output_proposals
) # (bs, \sum{hw}, 4) unsigmoid
topk = self.num_queries
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
# gather boxes
refpoint_embed_undetach = torch.gather(
enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
) # unsigmoid
refpoint_embed_ = refpoint_embed_undetach.detach()
init_box_proposal = torch.gather(
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
).sigmoid() # sigmoid
# gather tgt
tgt_undetach = torch.gather(
output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
)
if self.embed_init_tgt:
tgt_ = (
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
) # nq, bs, d_model
else:
tgt_ = tgt_undetach.detach()
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
elif self.two_stage_type == "no":
tgt_ = (
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
) # nq, bs, d_model
refpoint_embed_ = (
self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
) # nq, bs, 4
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
if self.num_patterns > 0:
tgt_embed = tgt.repeat(1, self.num_patterns, 1)
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
self.num_queries, 1
) # 1, n_q*n_pat, d_model
tgt = tgt_embed + tgt_pat
init_box_proposal = refpoint_embed_.sigmoid()
else:
raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
#########################################################
# End preparing tgt
# - tgt: bs, NQ, d_model
# - refpoint_embed(unsigmoid): bs, NQ, d_model
#########################################################
#########################################################
# Begin Decoder
#########################################################
hs, references = self.decoder(
tgt=tgt.transpose(0, 1),
memory=memory.transpose(0, 1),
memory_key_padding_mask=mask_flatten,
pos=lvl_pos_embed_flatten.transpose(0, 1),
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
tgt_mask=attn_mask,
memory_text=text_dict["encoded_text"],
text_attention_mask=~text_dict["text_token_mask"],
# we ~ the mask . False means use the token; True means pad the token
)
#########################################################
# End Decoder
# hs: n_dec, bs, nq, d_model
# references: n_dec+1, bs, nq, query_dim
#########################################################
#########################################################
# Begin postprocess
#########################################################
if self.two_stage_type == "standard":
hs_enc = tgt_undetach.unsqueeze(0)
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
else:
hs_enc = ref_enc = None
#########################################################
# End postprocess
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
# ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
#########################################################
return hs, references, hs_enc, ref_enc, init_box_proposal
# hs: (n_dec, bs, nq, d_model)
# references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
# ref_enc: sigmoid coordinates. \
# (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
class TransformerEncoder(nn.Module):
def __init__(
self,
encoder_layer,
num_layers,
d_model=256,
num_queries=300,
enc_layer_share=False,
text_enhance_layer=None,
feature_fusion_layer=None,
use_checkpoint=False,
use_transformer_ckpt=False,
):
"""_summary_
Args:
encoder_layer (_type_): _description_
num_layers (_type_): _description_
norm (_type_, optional): _description_. Defaults to None.
d_model (int, optional): _description_. Defaults to 256.
num_queries (int, optional): _description_. Defaults to 300.
enc_layer_share (bool, optional): _description_. Defaults to False.
"""
super().__init__()
# prepare layers
self.layers = []
self.text_layers = []
self.fusion_layers = []
if num_layers > 0:
self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
if text_enhance_layer is not None:
self.text_layers = _get_clones(
text_enhance_layer, num_layers, layer_share=enc_layer_share
)
if feature_fusion_layer is not None:
self.fusion_layers = _get_clones(
feature_fusion_layer, num_layers, layer_share=enc_layer_share
)
else:
self.layers = []
del encoder_layer
if text_enhance_layer is not None:
self.text_layers = []
del text_enhance_layer
if feature_fusion_layer is not None:
self.fusion_layers = []
del feature_fusion_layer
self.query_scale = None
self.num_queries = num_queries
self.num_layers = num_layers
self.d_model = d_model
self.use_checkpoint = use_checkpoint
self.use_transformer_ckpt = use_transformer_ckpt
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(
self,
# for images
src: Tensor,
pos: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
key_padding_mask: Tensor,
# for texts
memory_text: Tensor = None,
text_attention_mask: Tensor = None,
pos_text: Tensor = None,
text_self_attention_masks: Tensor = None,
position_ids: Tensor = None,
):
"""
Input:
- src: [bs, sum(hi*wi), 256]
- pos: pos embed for src. [bs, sum(hi*wi), 256]
- spatial_shapes: h,w of each level [num_level, 2]
- level_start_index: [num_level] start point of level in sum(hi*wi).
- valid_ratios: [bs, num_level, 2]
- key_padding_mask: [bs, sum(hi*wi)]
- memory_text: bs, n_text, 256
- text_attention_mask: bs, n_text
False for no padding; True for padding
- pos_text: bs, n_text, 256
- position_ids: bs, n_text
Intermedia:
- reference_points: [bs, sum(hi*wi), num_level, 2]
Outpus:
- output: [bs, sum(hi*wi), 256]
"""
output = src
# preparation and reshape
if self.num_layers > 0:
reference_points = self.get_reference_points(
spatial_shapes, valid_ratios, device=src.device
)
if self.text_layers:
# generate pos_text
bs, n_text, text_dim = memory_text.shape
if pos_text is None and position_ids is None:
pos_text = (
torch.arange(n_text, device=memory_text.device)
.float()
.unsqueeze(0)
.unsqueeze(-1)
.repeat(bs, 1, 1)
)
pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False)
if position_ids is not None:
pos_text = get_sine_pos_embed(
position_ids[..., None], num_pos_feats=256, exchange_xy=False
)
# main process
for layer_id, layer in enumerate(self.layers):
# if output.isnan().any() or memory_text.isnan().any():
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
if self.fusion_layers:
if self.use_checkpoint:
output, memory_text = checkpoint.checkpoint(
self.fusion_layers[layer_id],
output,
memory_text,
key_padding_mask,
text_attention_mask,
)
else:
output, memory_text = self.fusion_layers[layer_id](
v=output,
l=memory_text,
attention_mask_v=key_padding_mask,
attention_mask_l=text_attention_mask,
)
if self.text_layers:
memory_text = self.text_layers[layer_id](
src=memory_text.transpose(0, 1),
src_mask=~text_self_attention_masks, # note we use ~ for mask here
src_key_padding_mask=text_attention_mask,
pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
).transpose(0, 1)
# main process
if self.use_transformer_ckpt:
output = checkpoint.checkpoint(
layer,
output,
pos,
reference_points,
spatial_shapes,
level_start_index,
key_padding_mask,
)
else:
output = layer(
src=output,
pos=pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,
)
return output, memory_text
class TransformerDecoder(nn.Module):
def __init__(
self,
decoder_layer,
num_layers,
norm=None,
return_intermediate=False,
d_model=256,
query_dim=4,
num_feature_levels=1,
):
super().__init__()
if num_layers > 0:
self.layers = _get_clones(decoder_layer, num_layers)
else:
self.layers = []
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
assert return_intermediate, "support return_intermediate only"
self.query_dim = query_dim
assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
self.num_feature_levels = num_feature_levels
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
self.query_pos_sine_scale = None
self.query_scale = None
self.bbox_embed = None
self.class_embed = None
self.d_model = d_model
self.ref_anchor_head = None
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
# for memory
level_start_index: Optional[Tensor] = None, # num_levels
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
valid_ratios: Optional[Tensor] = None,
# for text
memory_text: Optional[Tensor] = None,
text_attention_mask: Optional[Tensor] = None,
):
"""
Input:
- tgt: nq, bs, d_model
- memory: hw, bs, d_model
- pos: hw, bs, d_model
- refpoints_unsigmoid: nq, bs, 2/4
- valid_ratios/spatial_shapes: bs, nlevel, 2
"""
output = tgt
intermediate = []
reference_points = refpoints_unsigmoid.sigmoid()
ref_points = [reference_points]
for layer_id, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = (
reference_points[:, :, None]
* torch.cat([valid_ratios, valid_ratios], -1)[None, :]
) # nq, bs, nlevel, 4
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
query_sine_embed = gen_sineembed_for_position(
reference_points_input[:, :, 0, :]
) # nq, bs, 256*2
# conditional query
raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
pos_scale = self.query_scale(output) if self.query_scale is not None else 1
query_pos = pos_scale * raw_query_pos
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# if query_pos.isnan().any() | query_pos.isinf().any():
# import ipdb; ipdb.set_trace()
# main process
output = layer(
tgt=output,
tgt_query_pos=query_pos,
tgt_query_sine_embed=query_sine_embed,
tgt_key_padding_mask=tgt_key_padding_mask,
tgt_reference_points=reference_points_input,
memory_text=memory_text,
text_attention_mask=text_attention_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
memory_level_start_index=level_start_index,
memory_spatial_shapes=spatial_shapes,
memory_pos=pos,
self_attn_mask=tgt_mask,
cross_attn_mask=memory_mask,
)
if output.isnan().any() | output.isinf().any():
print(f"output layer_id {layer_id} is nan")
try:
num_nan = output.isnan().sum().item()
num_inf = output.isinf().sum().item()
print(f"num_nan {num_nan}, num_inf {num_inf}")
except Exception as e:
print(e)
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# import ipdb; ipdb.set_trace()
# iter update
if self.bbox_embed is not None:
# box_holder = self.bbox_embed(output)
# box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
# new_reference_points = box_holder[..., :self.query_dim].sigmoid()
reference_before_sigmoid = inverse_sigmoid(reference_points)
delta_unsig = self.bbox_embed[layer_id](output)
outputs_unsig = delta_unsig + reference_before_sigmoid
new_reference_points = outputs_unsig.sigmoid()
reference_points = new_reference_points.detach()
# if layer_id != self.num_layers - 1:
ref_points.append(new_reference_points)
intermediate.append(self.norm(output))
return [
[itm_out.transpose(0, 1) for itm_out in intermediate],
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
]
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4,
):
super().__init__()
# self attention
self.self_attn = MSDeformAttn(
embed_dim=d_model,
num_levels=n_levels,
num_heads=n_heads,
num_points=n_points,
batch_first=True,
)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation, d_model=d_ffn)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(
self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None
):
# self attention
# import ipdb; ipdb.set_trace()
src2 = self.self_attn(
query=self.with_pos_embed(src, pos),
reference_points=reference_points,
value=src,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,
)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src
class DeformableTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4,
use_text_feat_guide=False,
use_text_cross_attention=False,
):
super().__init__()
# cross attention
self.cross_attn = MSDeformAttn(
embed_dim=d_model,
num_levels=n_levels,
num_heads=n_heads,
num_points=n_points,
batch_first=True,
)
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm1 = nn.LayerNorm(d_model)
# cross attention text
if use_text_cross_attention:
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.catext_norm = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm3 = nn.LayerNorm(d_model)
self.key_aware_proj = None
self.use_text_feat_guide = use_text_feat_guide
assert not use_text_feat_guide
self.use_text_cross_attention = use_text_cross_attention
def rm_self_attn_modules(self):
self.self_attn = None
self.dropout2 = None
self.norm2 = None
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
with torch.cuda.amp.autocast(enabled=False):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward(
self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
memory_text: Optional[Tensor] = None, # bs, num_token, d_model
text_attention_mask: Optional[Tensor] = None, # bs, num_token
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
):
"""
Input:
- tgt/tgt_query_pos: nq, bs, d_model
-
"""
assert cross_attn_mask is None
# self attention
if self.self_attn is not None:
# import ipdb; ipdb.set_trace()
q = k = self.with_pos_embed(tgt, tgt_query_pos)
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
if self.use_text_cross_attention:
tgt2 = self.ca_text(
self.with_pos_embed(tgt, tgt_query_pos),
memory_text.transpose(0, 1),
memory_text.transpose(0, 1),
key_padding_mask=text_attention_mask,
)[0]
tgt = tgt + self.catext_dropout(tgt2)
tgt = self.catext_norm(tgt)
tgt2 = self.cross_attn(
query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
reference_points=tgt_reference_points.transpose(0, 1).contiguous(),
value=memory.transpose(0, 1),
spatial_shapes=memory_spatial_shapes,
level_start_index=memory_level_start_index,
key_padding_mask=memory_key_padding_mask,
).transpose(0, 1)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ffn
tgt = self.forward_ffn(tgt)
return tgt
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
num_queries=args.num_queries,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
query_dim=args.query_dim,
activation=args.transformer_activation,
num_patterns=args.num_patterns,
num_feature_levels=args.num_feature_levels,
enc_n_points=args.enc_n_points,
dec_n_points=args.dec_n_points,
learnable_tgt_init=True,
# two stage
two_stage_type=args.two_stage_type, # ['no', 'standard', 'early']
embed_init_tgt=args.embed_init_tgt,
use_text_enhancer=args.use_text_enhancer,
use_fusion_layer=args.use_fusion_layer,
use_checkpoint=args.use_checkpoint,
use_transformer_ckpt=args.use_transformer_ckpt,
use_text_cross_attention=args.use_text_cross_attention,
text_dropout=args.text_dropout,
fusion_dropout=args.fusion_dropout,
fusion_droppath=args.fusion_droppath,
)

View file

@ -0,0 +1,123 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from .utils import (
MLP,
_get_activation_fn,
_get_clones,
gen_encoder_output_proposals,
gen_sineembed_for_position,
sigmoid_focal_loss,
)
class TextTransformer(nn.Module):
def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
self.nheads = nheads
self.dim_feedforward = dim_feedforward
self.norm = None
single_encoder_layer = TransformerEncoderLayer(
d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout
)
self.layers = _get_clones(single_encoder_layer, num_layers)
def forward(self, memory_text: torch.Tensor, text_attention_mask: torch.Tensor):
"""
Args:
text_attention_mask: bs, num_token
memory_text: bs, num_token, d_model
Raises:
RuntimeError: _description_
Returns:
output: bs, num_token, d_model
"""
output = memory_text.transpose(0, 1)
for layer in self.layers:
output = layer(output, src_key_padding_mask=text_attention_mask)
if self.norm is not None:
output = self.norm(output)
return output.transpose(0, 1)
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self.nhead = nhead
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
# repeat attn mask
if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
# bs, num_q, num_k
src_mask = src_mask.repeat(self.nhead, 1, 1)
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
# src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src

View file

@ -0,0 +1,268 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
import copy
import math
import torch
import torch.nn.functional as F
from torch import Tensor, nn
def _get_clones(module, N, layer_share=False):
# import ipdb; ipdb.set_trace()
if layer_share:
return nn.ModuleList([module for i in range(N)])
else:
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def get_sine_pos_embed(
pos_tensor: torch.Tensor,
num_pos_feats: int = 128,
temperature: int = 10000,
exchange_xy: bool = True,
):
"""generate sine position embedding from a position tensor
Args:
pos_tensor (torch.Tensor): shape: [..., n].
num_pos_feats (int): projected shape for each float in the tensor.
temperature (int): temperature in the sine/cosine function.
exchange_xy (bool, optional): exchange pos x and pos y. \
For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
Returns:
pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
"""
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
def sine_func(x: torch.Tensor):
sin_x = x * scale / dim_t
sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2)
return sin_x
pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)]
if exchange_xy:
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
pos_res = torch.cat(pos_res, dim=-1)
return pos_res
def gen_encoder_output_proposals(
memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None
):
"""
Input:
- memory: bs, \sum{hw}, d_model
- memory_padding_mask: bs, \sum{hw}
- spatial_shapes: nlevel, 2
- learnedwh: 2
Output:
- output_memory: bs, \sum{hw}, d_model
- output_proposals: bs, \sum{hw}, 4
"""
N_, S_, C_ = memory.shape
proposals = []
_cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(N_, H_, W_, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
# import ipdb; ipdb.set_trace()
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
if learnedwh is not None:
# import ipdb; ipdb.set_trace()
wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
else:
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
# scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
# grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
# wh = torch.ones_like(grid) / scale
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal)
_cur += H_ * W_
# import ipdb; ipdb.set_trace()
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(
-1, keepdim=True
)
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
# output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
# output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
return output_memory, output_proposals
class RandomBoxPerturber:
def __init__(
self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2
) -> None:
self.noise_scale = torch.Tensor(
[x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale]
)
def __call__(self, refanchors: Tensor) -> Tensor:
nq, bs, query_dim = refanchors.shape
device = refanchors.device
noise_raw = torch.rand_like(refanchors)
noise_scale = self.noise_scale.to(device)[:query_dim]
new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale)
return new_refanchors.clamp_(0, 1)
def sigmoid_focal_loss(
inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False
):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if no_reduction:
return loss
return loss.mean(1).sum() / num_boxes
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def _get_activation_fn(activation, d_model=256, batch_dim=0):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
if activation == "prelu":
return nn.PReLU()
if activation == "selu":
return F.selu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def gen_sineembed_for_position(pos_tensor):
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
return pos
class ContrastiveEmbed(nn.Module):
def __init__(self, max_text_len=256):
"""
Args:
max_text_len: max length of text.
"""
super().__init__()
self.max_text_len = max_text_len
def forward(self, x, text_dict):
"""_summary_
Args:
x (_type_): _description_
text_dict (_type_): _description_
{
'encoded_text': encoded_text, # bs, 195, d_model
'text_token_mask': text_token_mask, # bs, 195
# True for used tokens. False for padding tokens
}
Returns:
_type_: _description_
"""
assert isinstance(text_dict, dict)
y = text_dict["encoded_text"]
text_token_mask = text_dict["text_token_mask"]
res = x @ y.transpose(-1, -2)
res.masked_fill_(~text_token_mask[:, None, :], float("-inf"))
# padding to max_text_len
new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device)
new_res[..., : res.shape[-1]] = res
return new_res

View file

@ -0,0 +1,18 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .GroundingDINO import build_groundingdino
def build_model(args):
# we use register to maintain models from catdet6 on.
from .registry import MODULE_BUILD_FUNCS
assert args.modelname in MODULE_BUILD_FUNCS._module_dict
build_func = MODULE_BUILD_FUNCS.get(args.modelname)
model = build_func(args)
return model

View file

@ -0,0 +1,66 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# -*- coding: utf-8 -*-
# @Author: Yihao Chen
# @Date: 2021-08-16 16:03:17
# @Last Modified by: Shilong Liu
# @Last Modified time: 2022-01-23 15:26
# modified from mmcv
import inspect
from functools import partial
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __repr__(self):
format_str = self.__class__.__name__ + "(name={}, items={})".format(
self._name, list(self._module_dict.keys())
)
return format_str
def __len__(self):
return len(self._module_dict)
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
return self._module_dict.get(key, None)
def registe_with_name(self, module_name=None, force=False):
return partial(self.register, module_name=module_name, force=force)
def register(self, module_build_function, module_name=None, force=False):
"""Register a module build function.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isfunction(module_build_function):
raise TypeError(
"module_build_function must be a function, but got {}".format(
type(module_build_function)
)
)
if module_name is None:
module_name = module_build_function.__name__
if not force and module_name in self._module_dict:
raise KeyError("{} is already registered in {}".format(module_name, self.name))
self._module_dict[module_name] = module_build_function
return module_build_function
MODULE_BUILD_FUNCS = Registry("model build functions")

View file

@ -0,0 +1,14 @@
torch
torchvision
transformers
modelscope
tensorflow-macos==2.9
keras==2.9.0
opencv-python
matplotlib
pycocotools
SentencePiece
tf_slim
tf_keras
pyclipper
shapely

View file

@ -0,0 +1 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

View file

@ -0,0 +1,140 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
import torch
from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
# import ipdb; ipdb.set_trace()
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / (union + 1e-6)
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
# except:
# import ipdb; ipdb.set_trace()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / (area + 1e-6)
# modified from torchvision to also return the union
def box_iou_pairwise(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
wh = (rb - lt).clamp(min=0) # [N,2]
inter = wh[:, 0] * wh[:, 1] # [N]
union = area1 + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou_pairwise(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
Input:
- boxes1, boxes2: N,4
Output:
- giou: N, 4
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
assert boxes1.shape == boxes2.shape
iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
lt = torch.min(boxes1[:, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,2]
area = wh[:, 0] * wh[:, 1]
return iou - (area - union) / area
def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
x_mask = masks * x.unsqueeze(0)
x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = masks * y.unsqueeze(0)
y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
return torch.stack([x_min, y_min, x_max, y_max], 1)
if __name__ == "__main__":
x = torch.rand(5, 4)
y = torch.rand(3, 4)
iou, union = box_iou(x, y)
import ipdb
ipdb.set_trace()

View file

@ -0,0 +1,26 @@
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
def get_tokenlizer(text_encoder_type):
if not isinstance(text_encoder_type, str):
# print("text_encoder_type is not a str")
if hasattr(text_encoder_type, "text_encoder_type"):
text_encoder_type = text_encoder_type.text_encoder_type
elif text_encoder_type.get("text_encoder_type", False):
text_encoder_type = text_encoder_type.get("text_encoder_type")
else:
raise ValueError(
"Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
)
print("final text_encoder_type: {}".format(text_encoder_type))
tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)
return tokenizer
def get_pretrained_language_model(text_encoder_type):
if text_encoder_type == "bert-base-uncased":
return BertModel.from_pretrained(text_encoder_type)
if text_encoder_type == "roberta-base":
return RobertaModel.from_pretrained(text_encoder_type)
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))

View file

@ -0,0 +1,98 @@
from typing import Tuple, List
import cv2
import numpy as np
import supervision as sv
import torch
from PIL import Image
from torchvision.ops import box_convert
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util.misc import clean_state_dict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import get_phrases_from_posmap
def preprocess_caption(caption: str) -> str:
result = caption.lower().strip()
if result.endswith("."):
return result
return result + "."
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.eval()
return model
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def predict(
model,
image: torch.Tensor,
caption: str,
box_threshold: float,
text_threshold: float,
device: str = "cuda"
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
caption = preprocess_caption(caption=caption)
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
mask = prediction_logits.max(dim=1)[0] > box_threshold
logits = prediction_logits[mask] # logits.shape = (n, 256)
boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
for logit
in logits
]
return boxes, logits.max(dim=1)[0], phrases
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
detections = sv.Detections(xyxy=xyxy)
labels = [
f"{phrase} {logit:.2f}"
for phrase, logit
in zip(phrases, logits)
]
box_annotator = sv.BoxAnnotator()
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
return annotated_frame

View file

@ -0,0 +1,93 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools
import logging
import os
import sys
from termcolor import colored
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
self._abbrev_name = kwargs.pop("abbrev_name", "")
if len(self._abbrev_name):
self._abbrev_name = self._abbrev_name + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
record.name = record.name.replace(self._root_name, self._abbrev_name)
log = super(_ColorfulFormatter, self).formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
# so that calling setup_logger multiple times won't add many handlers
@functools.lru_cache()
def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None):
"""
Initialize the detectron2 logger and set its verbosity level to "INFO".
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
Returns:
logging.Logger: a logger
"""
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
if abbrev_name is None:
abbrev_name = name
plain_formatter = logging.Formatter(
"[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S"
)
# stdout logging: master only
if distributed_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S",
root_name=name,
abbrev_name=str(abbrev_name),
)
else:
formatter = plain_formatter
ch.setFormatter(formatter)
logger.addHandler(ch)
# file logging: all workers
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.txt")
if distributed_rank > 0:
filename = filename + f".rank{distributed_rank}"
os.makedirs(os.path.dirname(filename), exist_ok=True)
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
return logger
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
return open(filename, "a")

View file

@ -0,0 +1,717 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import colorsys
import datetime
import functools
import io
import json
import os
import pickle
import subprocess
import time
from collections import OrderedDict, defaultdict, deque
from typing import List, Optional
import numpy as np
import torch
import torch.distributed as dist
# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
from torch import Tensor
__torchvision_need_compat_flag = float(torchvision.__version__.split(".")[1]) < 7
if __torchvision_need_compat_flag:
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
if d.shape[0] == 0:
return 0
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
if os.environ.get("SHILONG_AMP", None) == "1":
eps = 1e-4
else:
eps = 1e-6
return self.total / (self.count + eps)
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value,
)
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if dist.get_backend() == "nccl":
return dist.new_group(backend="gloo")
return dist.group.WORLD
def all_gather_cpu(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
cpu_group = _get_global_gloo_group()
buffer = io.BytesIO()
torch.save(data, buffer)
data_view = buffer.getbuffer()
device = "cuda" if cpu_group is None else "cpu"
tensor = torch.ByteTensor(data_view).to(device)
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)]
if cpu_group is None:
dist.all_gather(size_list, local_size)
else:
print("gathering on cpu")
dist.all_gather(size_list, local_size, group=cpu_group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
assert isinstance(local_size.item(), int)
local_size = int(local_size.item())
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device)
tensor = torch.cat((tensor, padding), dim=0)
if cpu_group is None:
dist.all_gather(tensor_list, tensor)
else:
dist.all_gather(tensor_list, tensor, group=cpu_group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
buffer = io.BytesIO(tensor.cpu().numpy())
obj = torch.load(buffer)
data_list.append(obj)
return data_list
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
if os.getenv("CPU_REDUCE") == "1":
return all_gather_cpu(data)
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
# print(name, str(meter))
# import ipdb;ipdb.set_trace()
if meter.count > 0:
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None, logger=None):
if logger is None:
print_func = print
else:
print_func = logger.info
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
# import ipdb; ipdb.set_trace()
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print_func(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print_func(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print_func(
"{} Total time: {} ({:.4f} s / it)".format(
header, total_time_str, total_time / len(iterable)
)
)
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
sha = "N/A"
diff = "clean"
branch = "N/A"
try:
sha = _run(["git", "rev-parse", "HEAD"])
subprocess.check_output(["git", "diff"], cwd=cwd)
diff = _run(["git", "diff-index", "HEAD"])
diff = "has uncommited changes" if diff else "clean"
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(batch):
# import ipdb; ipdb.set_trace()
batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
if mask == "auto":
self.mask = torch.zeros_like(tensors).to(tensors.device)
if self.mask.dim() == 3:
self.mask = self.mask.sum(0).to(bool)
elif self.mask.dim() == 4:
self.mask = self.mask.sum(1).to(bool)
else:
raise ValueError(
"tensors dim must be 3 or 4 but {}({})".format(
self.tensors.dim(), self.tensors.shape
)
)
def imgsize(self):
res = []
for i in range(self.tensors.shape[0]):
mask = self.mask[i]
maxH = (~mask).sum(0).max()
maxW = (~mask).sum(1).max()
res.append(torch.Tensor([maxH, maxW]))
return res
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def to_img_list_single(self, tensor, mask):
assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim())
maxH = (~mask).sum(0).max()
maxW = (~mask).sum(1).max()
img = tensor[:, :maxH, :maxW]
return img
def to_img_list(self):
"""remove the padding and convert to img list
Returns:
[type]: [description]
"""
if self.tensors.dim() == 3:
return self.to_img_list_single(self.tensors, self.mask)
else:
res = []
for i in range(self.tensors.shape[0]):
tensor_i = self.tensors[i]
mask_i = self.mask[i]
res.append(self.to_img_list_single(tensor_i, mask_i))
return res
@property
def device(self):
return self.tensors.device
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
@property
def shape(self):
return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape}
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
else:
raise ValueError("not supported")
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if "WORLD_SIZE" in os.environ and os.environ["WORLD_SIZE"] != "": # 'RANK' in os.environ and
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = args.local_rank = int(os.environ["LOCAL_RANK"])
# launch by torch.distributed.launch
# Single node
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ...
# Multi nodes
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
# args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))
# local_world_size = int(os.environ['GPU_PER_NODE_COUNT'])
# args.world_size = args.world_size * local_world_size
# args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
# args.rank = args.rank * local_world_size + args.local_rank
print(
"world size: {}, rank: {}, local rank: {}".format(
args.world_size, args.rank, args.local_rank
)
)
print(json.dumps(dict(os.environ), indent=2))
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.local_rank = int(os.environ["SLURM_LOCALID"])
args.world_size = int(os.environ["SLURM_NPROCS"])
print(
"world size: {}, world rank: {}, local rank: {}, device_count: {}".format(
args.world_size, args.rank, args.local_rank, torch.cuda.device_count()
)
)
else:
print("Not using distributed mode")
args.distributed = False
args.world_size = 1
args.rank = 0
args.local_rank = 0
return
print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank))
args.distributed = True
torch.cuda.set_device(args.local_rank)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend,
world_size=args.world_size,
rank=args.rank,
init_method=args.dist_url,
)
print("Before torch.distributed.barrier()")
torch.distributed.barrier()
print("End torch.distributed.barrier()")
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
@torch.no_grad()
def accuracy_onehot(pred, gt):
"""_summary_
Args:
pred (_type_): n, c
gt (_type_): n, c
"""
tp = ((pred - gt).abs().sum(-1) < 1e-4).float().sum()
acc = tp / gt.shape[0] * 100
return acc
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if __torchvision_need_compat_flag < 0.7:
if input.numel() > 0:
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
class color_sys:
def __init__(self, num_colors) -> None:
self.num_colors = num_colors
colors = []
for i in np.arange(0.0, 360.0, 360.0 / num_colors):
hue = i / 360.0
lightness = (50 + np.random.rand() * 10) / 100.0
saturation = (90 + np.random.rand() * 10) / 100.0
colors.append(
tuple([int(j * 255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)])
)
self.colors = colors
def __call__(self, idx):
return self.colors[idx]
def inverse_sigmoid(x, eps=1e-3):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
def clean_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k[:7] == "module.":
k = k[7:] # remove `module.`
new_state_dict[k] = v
return new_state_dict

View file

@ -0,0 +1,427 @@
# ==========================================================
# Modified from mmcv
# ==========================================================
import ast
import os.path as osp
import shutil
import sys
import tempfile
from argparse import Action
from importlib import import_module
import platform
from addict import Dict
from yapf.yapflib.yapf_api import FormatCode
BASE_KEY = "_base_"
DELETE_KEY = "_delete_"
RESERVED_KEYS = ["filename", "text", "pretty_text", "get", "dump", "merge_from_dict"]
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
if not osp.isfile(filename):
raise FileNotFoundError(msg_tmpl.format(filename))
class ConfigDict(Dict):
def __missing__(self, name):
raise KeyError(name)
def __getattr__(self, name):
try:
value = super(ConfigDict, self).__getattr__(name)
except KeyError:
ex = AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{name}'")
except Exception as e:
ex = e
else:
return value
raise ex
class SLConfig(object):
"""
config files.
only support .py file as config now.
ref: mmcv.utils.config
Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@staticmethod
def _validate_py_syntax(filename):
with open(filename) as f:
content = f.read()
try:
ast.parse(content)
except SyntaxError:
raise SyntaxError("There are syntax errors in config " f"file {filename}")
@staticmethod
def _file2dict(filename):
filename = osp.abspath(osp.expanduser(filename))
check_file_exist(filename)
if filename.lower().endswith(".py"):
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=".py")
temp_config_name = osp.basename(temp_config_file.name)
if platform.system() == 'Windows':
temp_config_file.close()
shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name))
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)
SLConfig._validate_py_syntax(filename)
mod = import_module(temp_module_name)
sys.path.pop(0)
cfg_dict = {
name: value for name, value in mod.__dict__.items() if not name.startswith("__")
}
# delete imported module
del sys.modules[temp_module_name]
# close temp file
temp_config_file.close()
elif filename.lower().endswith((".yml", ".yaml", ".json")):
from .slio import slload
cfg_dict = slload(filename)
else:
raise IOError("Only py/yml/yaml/json type are supported now!")
cfg_text = filename + "\n"
with open(filename, "r") as f:
cfg_text += f.read()
# parse the base file
if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = base_filename if isinstance(base_filename, list) else [base_filename]
cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
_cfg_dict, _cfg_text = SLConfig._file2dict(osp.join(cfg_dir, f))
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)
base_cfg_dict = dict()
for c in cfg_dict_list:
if len(base_cfg_dict.keys() & c.keys()) > 0:
raise KeyError("Duplicate key is not allowed among bases")
# TODO Allow the duplicate key while warnning user
base_cfg_dict.update(c)
base_cfg_dict = SLConfig._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = "\n".join(cfg_text_list)
return cfg_dict, cfg_text
@staticmethod
def _merge_a_into_b(a, b):
"""merge dict `a` into dict `b` (non-inplace).
values in `a` will overwrite `b`.
copy first to avoid inplace modification
Args:
a ([type]): [description]
b ([type]): [description]
Returns:
[dict]: [description]
"""
# import ipdb; ipdb.set_trace()
if not isinstance(a, dict):
return a
b = b.copy()
for k, v in a.items():
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
if not isinstance(b[k], dict) and not isinstance(b[k], list):
# if :
# import ipdb; ipdb.set_trace()
raise TypeError(
f"{k}={v} in child config cannot inherit from base "
f"because {k} is a dict in the child config but is of "
f"type {type(b[k])} in base config. You may set "
f"`{DELETE_KEY}=True` to ignore the base config"
)
b[k] = SLConfig._merge_a_into_b(v, b[k])
elif isinstance(b, list):
try:
_ = int(k)
except:
raise TypeError(
f"b is a list, " f"index {k} should be an int when input but {type(k)}"
)
b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)])
else:
b[k] = v
return b
@staticmethod
def fromfile(filename):
cfg_dict, cfg_text = SLConfig._file2dict(filename)
return SLConfig(cfg_dict, cfg_text=cfg_text, filename=filename)
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
for key in cfg_dict:
if key in RESERVED_KEYS:
raise KeyError(f"{key} is reserved for config file")
super(SLConfig, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
super(SLConfig, self).__setattr__("_filename", filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename, "r") as f:
text = f.read()
else:
text = ""
super(SLConfig, self).__setattr__("_text", text)
@property
def filename(self):
return self._filename
@property
def text(self):
return self._text
@property
def pretty_text(self):
indent = 4
def _indent(s_, num_spaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
def _format_basic_types(k, v, use_mapping=False):
if isinstance(v, str):
v_str = f"'{v}'"
else:
v_str = str(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f"{k_str}: {v_str}"
else:
attr_str = f"{str(k)}={v_str}"
attr_str = _indent(attr_str, indent)
return attr_str
def _format_list(k, v, use_mapping=False):
# check if all items in the list are dict
if all(isinstance(_, dict) for _ in v):
v_str = "[\n"
v_str += "\n".join(
f"dict({_indent(_format_dict(v_), indent)})," for v_ in v
).rstrip(",")
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f"{k_str}: {v_str}"
else:
attr_str = f"{str(k)}={v_str}"
attr_str = _indent(attr_str, indent) + "]"
else:
attr_str = _format_basic_types(k, v, use_mapping)
return attr_str
def _contain_invalid_identifier(dict_str):
contain_invalid_identifier = False
for key_name in dict_str:
contain_invalid_identifier |= not str(key_name).isidentifier()
return contain_invalid_identifier
def _format_dict(input_dict, outest_level=False):
r = ""
s = []
use_mapping = _contain_invalid_identifier(input_dict)
if use_mapping:
r += "{"
for idx, (k, v) in enumerate(input_dict.items()):
is_last = idx >= len(input_dict) - 1
end = "" if outest_level or is_last else ","
if isinstance(v, dict):
v_str = "\n" + _format_dict(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f"{k_str}: dict({v_str}"
else:
attr_str = f"{str(k)}=dict({v_str}"
attr_str = _indent(attr_str, indent) + ")" + end
elif isinstance(v, list):
attr_str = _format_list(k, v, use_mapping) + end
else:
attr_str = _format_basic_types(k, v, use_mapping) + end
s.append(attr_str)
r += "\n".join(s)
if use_mapping:
r += "}"
return r
cfg_dict = self._cfg_dict.to_dict()
text = _format_dict(cfg_dict, outest_level=True)
# copied from setup.cfg
yapf_style = dict(
based_on_style="pep8",
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True,
)
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
return text
def __repr__(self):
return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
def __len__(self):
return len(self._cfg_dict)
def __getattr__(self, name):
# # debug
# print('+'*15)
# print('name=%s' % name)
# print("addr:", id(self))
# # print('type(self):', type(self))
# print(self.__dict__)
# print('+'*15)
# if self.__dict__ == {}:
# raise ValueError
return getattr(self._cfg_dict, name)
def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)
def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)
def __iter__(self):
return iter(self._cfg_dict)
def dump(self, file=None):
# import ipdb; ipdb.set_trace()
if file is None:
return self.pretty_text
else:
with open(file, "w") as f:
f.write(self.pretty_text)
def merge_from_dict(self, options):
"""Merge list into cfg_dict
Merge the dict parsed by MultipleKVAction into this cfg.
Examples:
>>> options = {'model.backbone.depth': 50,
... 'model.backbone.with_cp':True}
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... model=dict(backbone=dict(depth=50, with_cp=True)))
Args:
options (dict): dict of configs to merge from.
"""
option_cfg_dict = {}
for full_key, v in options.items():
d = option_cfg_dict
key_list = full_key.split(".")
for subkey in key_list[:-1]:
d.setdefault(subkey, ConfigDict())
d = d[subkey]
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super(SLConfig, self).__getattribute__("_cfg_dict")
super(SLConfig, self).__setattr__(
"_cfg_dict", SLConfig._merge_a_into_b(option_cfg_dict, cfg_dict)
)
# for multiprocess
def __setstate__(self, state):
self.__init__(state)
def copy(self):
return SLConfig(self._cfg_dict.copy())
def deepcopy(self):
return SLConfig(self._cfg_dict.deepcopy())
class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options should
be passed as comma separated values, i.e KEY=V1,V2,V3
"""
@staticmethod
def _parse_int_float_bool(val):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ["true", "false"]:
return True if val.lower() == "true" else False
if val.lower() in ["none", "null"]:
return None
return val
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split("=", maxsplit=1)
val = [self._parse_int_float_bool(v) for v in val.split(",")]
if len(val) == 1:
val = val[0]
options[key] = val
setattr(namespace, self.dest, options)

View file

@ -0,0 +1,177 @@
# ==========================================================
# Modified from mmcv
# ==========================================================
import json
import pickle
from abc import ABCMeta, abstractmethod
from pathlib import Path
import yaml
try:
from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
from yaml import Loader, Dumper
# ===========================
# Rigister handler
# ===========================
class BaseFileHandler(metaclass=ABCMeta):
@abstractmethod
def load_from_fileobj(self, file, **kwargs):
pass
@abstractmethod
def dump_to_fileobj(self, obj, file, **kwargs):
pass
@abstractmethod
def dump_to_str(self, obj, **kwargs):
pass
def load_from_path(self, filepath, mode="r", **kwargs):
with open(filepath, mode) as f:
return self.load_from_fileobj(f, **kwargs)
def dump_to_path(self, obj, filepath, mode="w", **kwargs):
with open(filepath, mode) as f:
self.dump_to_fileobj(obj, f, **kwargs)
class JsonHandler(BaseFileHandler):
def load_from_fileobj(self, file):
return json.load(file)
def dump_to_fileobj(self, obj, file, **kwargs):
json.dump(obj, file, **kwargs)
def dump_to_str(self, obj, **kwargs):
return json.dumps(obj, **kwargs)
class PickleHandler(BaseFileHandler):
def load_from_fileobj(self, file, **kwargs):
return pickle.load(file, **kwargs)
def load_from_path(self, filepath, **kwargs):
return super(PickleHandler, self).load_from_path(filepath, mode="rb", **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault("protocol", 2)
return pickle.dumps(obj, **kwargs)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault("protocol", 2)
pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path(obj, filepath, mode="wb", **kwargs)
class YamlHandler(BaseFileHandler):
def load_from_fileobj(self, file, **kwargs):
kwargs.setdefault("Loader", Loader)
return yaml.load(file, **kwargs)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault("Dumper", Dumper)
yaml.dump(obj, file, **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault("Dumper", Dumper)
return yaml.dump(obj, **kwargs)
file_handlers = {
"json": JsonHandler(),
"yaml": YamlHandler(),
"yml": YamlHandler(),
"pickle": PickleHandler(),
"pkl": PickleHandler(),
}
# ===========================
# load and dump
# ===========================
def is_str(x):
"""Whether the input is an string instance.
Note: This method is deprecated since python 2 is no longer supported.
"""
return isinstance(x, str)
def slload(file, file_format=None, **kwargs):
"""Load data from json/yaml/pickle files.
This method provides a unified api for loading data from serialized files.
Args:
file (str or :obj:`Path` or file-like object): Filename or a file-like
object.
file_format (str, optional): If not specified, the file format will be
inferred from the file extension, otherwise use the specified one.
Currently supported formats include "json", "yaml/yml" and
"pickle/pkl".
Returns:
The content from the file.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None and is_str(file):
file_format = file.split(".")[-1]
if file_format not in file_handlers:
raise TypeError(f"Unsupported format: {file_format}")
handler = file_handlers[file_format]
if is_str(file):
obj = handler.load_from_path(file, **kwargs)
elif hasattr(file, "read"):
obj = handler.load_from_fileobj(file, **kwargs)
else:
raise TypeError('"file" must be a filepath str or a file-object')
return obj
def sldump(obj, file=None, file_format=None, **kwargs):
"""Dump data to json/yaml/pickle strings or files.
This method provides a unified api for dumping data as strings or to files,
and also supports custom arguments for each file format.
Args:
obj (any): The python object to be dumped.
file (str or :obj:`Path` or file-like object, optional): If not
specified, then the object is dump to a str, otherwise to a file
specified by the filename or file-like object.
file_format (str, optional): Same as :func:`load`.
Returns:
bool: True for success, False otherwise.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None:
if is_str(file):
file_format = file.split(".")[-1]
elif file is None:
raise ValueError("file_format must be specified since file is None")
if file_format not in file_handlers:
raise TypeError(f"Unsupported format: {file_format}")
handler = file_handlers[file_format]
if file is None:
return handler.dump_to_str(obj, **kwargs)
elif is_str(file):
handler.dump_to_path(obj, file, **kwargs)
elif hasattr(file, "write"):
handler.dump_to_fileobj(obj, file, **kwargs)
else:
raise TypeError('"file" must be a filename str or a file-object')

View file

@ -0,0 +1,62 @@
import json
import time
class TimeCounter:
def __init__(self) -> None:
pass
def clear(self):
self.timedict = {}
self.basetime = time.perf_counter()
def timeit(self, name):
nowtime = time.perf_counter() - self.basetime
self.timedict[name] = nowtime
self.basetime = time.perf_counter()
class TimeHolder:
def __init__(self) -> None:
self.timedict = {}
def update(self, _timedict: dict):
for k, v in _timedict.items():
if k not in self.timedict:
self.timedict[k] = AverageMeter(name=k, val_only=True)
self.timedict[k].update(val=v)
def final_res(self):
return {k: v.avg for k, v in self.timedict.items()}
def __str__(self):
return json.dumps(self.final_res(), indent=2)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":f", val_only=False):
self.name = name
self.fmt = fmt
self.val_only = val_only
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
if self.val_only:
fmtstr = "{name} {val" + self.fmt + "}"
else:
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)

View file

@ -0,0 +1,608 @@
import argparse
import json
import warnings
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Dict, List
import numpy as np
import torch
from transformers import AutoTokenizer
from groundingdino.util.slconfig import SLConfig
def slprint(x, name="x"):
if isinstance(x, (torch.Tensor, np.ndarray)):
print(f"{name}.shape:", x.shape)
elif isinstance(x, (tuple, list)):
print("type x:", type(x))
for i in range(min(10, len(x))):
slprint(x[i], f"{name}[{i}]")
elif isinstance(x, dict):
for k, v in x.items():
slprint(v, f"{name}[{k}]")
else:
print(f"{name}.type:", type(x))
def clean_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k[:7] == "module.":
k = k[7:] # remove `module.`
new_state_dict[k] = v
return new_state_dict
def renorm(
img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
) -> torch.FloatTensor:
# img: tensor(3,H,W) or tensor(B,3,H,W)
# return: same as img
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
if img.dim() == 3:
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
img.size(0),
str(img.size()),
)
img_perm = img.permute(1, 2, 0)
mean = torch.Tensor(mean)
std = torch.Tensor(std)
img_res = img_perm * std + mean
return img_res.permute(2, 0, 1)
else: # img.dim() == 4
assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
img.size(1),
str(img.size()),
)
img_perm = img.permute(0, 2, 3, 1)
mean = torch.Tensor(mean)
std = torch.Tensor(std)
img_res = img_perm * std + mean
return img_res.permute(0, 3, 1, 2)
class CocoClassMapper:
def __init__(self) -> None:
self.category_map_str = {
"1": 1,
"2": 2,
"3": 3,
"4": 4,
"5": 5,
"6": 6,
"7": 7,
"8": 8,
"9": 9,
"10": 10,
"11": 11,
"13": 12,
"14": 13,
"15": 14,
"16": 15,
"17": 16,
"18": 17,
"19": 18,
"20": 19,
"21": 20,
"22": 21,
"23": 22,
"24": 23,
"25": 24,
"27": 25,
"28": 26,
"31": 27,
"32": 28,
"33": 29,
"34": 30,
"35": 31,
"36": 32,
"37": 33,
"38": 34,
"39": 35,
"40": 36,
"41": 37,
"42": 38,
"43": 39,
"44": 40,
"46": 41,
"47": 42,
"48": 43,
"49": 44,
"50": 45,
"51": 46,
"52": 47,
"53": 48,
"54": 49,
"55": 50,
"56": 51,
"57": 52,
"58": 53,
"59": 54,
"60": 55,
"61": 56,
"62": 57,
"63": 58,
"64": 59,
"65": 60,
"67": 61,
"70": 62,
"72": 63,
"73": 64,
"74": 65,
"75": 66,
"76": 67,
"77": 68,
"78": 69,
"79": 70,
"80": 71,
"81": 72,
"82": 73,
"84": 74,
"85": 75,
"86": 76,
"87": 77,
"88": 78,
"89": 79,
"90": 80,
}
self.origin2compact_mapper = {int(k): v - 1 for k, v in self.category_map_str.items()}
self.compact2origin_mapper = {int(v - 1): int(k) for k, v in self.category_map_str.items()}
def origin2compact(self, idx):
return self.origin2compact_mapper[int(idx)]
def compact2origin(self, idx):
return self.compact2origin_mapper[int(idx)]
def to_device(item, device):
if isinstance(item, torch.Tensor):
return item.to(device)
elif isinstance(item, list):
return [to_device(i, device) for i in item]
elif isinstance(item, dict):
return {k: to_device(v, device) for k, v in item.items()}
else:
raise NotImplementedError(
"Call Shilong if you use other containers! type: {}".format(type(item))
)
#
def get_gaussian_mean(x, axis, other_axis, softmax=True):
"""
Args:
x (float): Input images(BxCxHxW)
axis (int): The index for weighted mean
other_axis (int): The other index
Returns: weighted index for axis, BxC
"""
mat2line = torch.sum(x, axis=other_axis)
# mat2line = mat2line / mat2line.mean() * 10
if softmax:
u = torch.softmax(mat2line, axis=2)
else:
u = mat2line / (mat2line.sum(2, keepdim=True) + 1e-6)
size = x.shape[axis]
ind = torch.linspace(0, 1, size).to(x.device)
batch = x.shape[0]
channel = x.shape[1]
index = ind.repeat([batch, channel, 1])
mean_position = torch.sum(index * u, dim=2)
return mean_position
def get_expected_points_from_map(hm, softmax=True):
"""get_gaussian_map_from_points
B,C,H,W -> B,N,2 float(0, 1) float(0, 1)
softargmax function
Args:
hm (float): Input images(BxCxHxW)
Returns:
weighted index for axis, BxCx2. float between 0 and 1.
"""
# hm = 10*hm
B, C, H, W = hm.shape
y_mean = get_gaussian_mean(hm, 2, 3, softmax=softmax) # B,C
x_mean = get_gaussian_mean(hm, 3, 2, softmax=softmax) # B,C
# return torch.cat((x_mean.unsqueeze(-1), y_mean.unsqueeze(-1)), 2)
return torch.stack([x_mean, y_mean], dim=2)
# Positional encoding (section 5.1)
# borrow from nerf
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs["input_dims"]
out_dim = 0
if self.kwargs["include_input"]:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs["max_freq_log2"]
N_freqs = self.kwargs["num_freqs"]
if self.kwargs["log_sampling"]:
freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs["periodic_fns"]:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, i=0):
import torch.nn as nn
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
"include_input": True,
"input_dims": 3,
"max_freq_log2": multires - 1,
"num_freqs": multires,
"log_sampling": True,
"periodic_fns": [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj: eo.embed(x)
return embed, embedder_obj.out_dim
class APOPMeter:
def __init__(self) -> None:
self.tp = 0
self.fp = 0
self.tn = 0
self.fn = 0
def update(self, pred, gt):
"""
Input:
pred, gt: Tensor()
"""
assert pred.shape == gt.shape
self.tp += torch.logical_and(pred == 1, gt == 1).sum().item()
self.fp += torch.logical_and(pred == 1, gt == 0).sum().item()
self.tn += torch.logical_and(pred == 0, gt == 0).sum().item()
self.tn += torch.logical_and(pred == 1, gt == 0).sum().item()
def update_cm(self, tp, fp, tn, fn):
self.tp += tp
self.fp += fp
self.tn += tn
self.tn += fn
def inverse_sigmoid(x, eps=1e-5):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
def get_raw_dict(args):
"""
return the dicf contained in args.
e.g:
>>> with open(path, 'w') as f:
json.dump(get_raw_dict(args), f, indent=2)
"""
if isinstance(args, argparse.Namespace):
return vars(args)
elif isinstance(args, dict):
return args
elif isinstance(args, SLConfig):
return args._cfg_dict
else:
raise NotImplementedError("Unknown type {}".format(type(args)))
def stat_tensors(tensor):
assert tensor.dim() == 1
tensor_sm = tensor.softmax(0)
entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum()
return {
"max": tensor.max(),
"min": tensor.min(),
"mean": tensor.mean(),
"var": tensor.var(),
"std": tensor.var() ** 0.5,
"entropy": entropy,
}
class NiceRepr:
"""Inherit from this class and define ``__nice__`` to "nicely" print your
objects.
Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
If the inheriting class has a ``__len__``, method then the default
``__nice__`` method will return its length.
Example:
>>> class Foo(NiceRepr):
... def __nice__(self):
... return 'info'
>>> foo = Foo()
>>> assert str(foo) == '<Foo(info)>'
>>> assert repr(foo).startswith('<Foo(info) at ')
Example:
>>> class Bar(NiceRepr):
... pass
>>> bar = Bar()
>>> import pytest
>>> with pytest.warns(None) as record:
>>> assert 'object at' in str(bar)
>>> assert 'object at' in repr(bar)
Example:
>>> class Baz(NiceRepr):
... def __len__(self):
... return 5
>>> baz = Baz()
>>> assert str(baz) == '<Baz(5)>'
"""
def __nice__(self):
"""str: a "nice" summary string describing this module"""
if hasattr(self, "__len__"):
# It is a common pattern for objects to use __len__ in __nice__
# As a convenience we define a default __nice__ for these objects
return str(len(self))
else:
# In all other cases force the subclass to overload __nice__
raise NotImplementedError(f"Define the __nice__ method for {self.__class__!r}")
def __repr__(self):
"""str: the string of the module"""
try:
nice = self.__nice__()
classname = self.__class__.__name__
return f"<{classname}({nice}) at {hex(id(self))}>"
except NotImplementedError as ex:
warnings.warn(str(ex), category=RuntimeWarning)
return object.__repr__(self)
def __str__(self):
"""str: the string of the module"""
try:
classname = self.__class__.__name__
nice = self.__nice__()
return f"<{classname}({nice})>"
except NotImplementedError as ex:
warnings.warn(str(ex), category=RuntimeWarning)
return object.__repr__(self)
def ensure_rng(rng=None):
"""Coerces input into a random number generator.
If the input is None, then a global random state is returned.
If the input is a numeric value, then that is used as a seed to construct a
random state. Otherwise the input is returned as-is.
Adapted from [1]_.
Args:
rng (int | numpy.random.RandomState | None):
if None, then defaults to the global rng. Otherwise this can be an
integer or a RandomState class
Returns:
(numpy.random.RandomState) : rng -
a numpy random number generator
References:
.. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
"""
if rng is None:
rng = np.random.mtrand._rand
elif isinstance(rng, int):
rng = np.random.RandomState(rng)
else:
rng = rng
return rng
def random_boxes(num=1, scale=1, rng=None):
"""Simple version of ``kwimage.Boxes.random``
Returns:
Tensor: shape (n, 4) in x1, y1, x2, y2 format.
References:
https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
Example:
>>> num = 3
>>> scale = 512
>>> rng = 0
>>> boxes = random_boxes(num, scale, rng)
>>> print(boxes)
tensor([[280.9925, 278.9802, 308.6148, 366.1769],
[216.9113, 330.6978, 224.0446, 456.5878],
[405.3632, 196.3221, 493.3953, 270.7942]])
"""
rng = ensure_rng(rng)
tlbr = rng.rand(num, 4).astype(np.float32)
tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
tlbr[:, 0] = tl_x * scale
tlbr[:, 1] = tl_y * scale
tlbr[:, 2] = br_x * scale
tlbr[:, 3] = br_y * scale
boxes = torch.from_numpy(tlbr)
return boxes
class ModelEma(torch.nn.Module):
def __init__(self, model, decay=0.9997, device=None):
super(ModelEma, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
# import ipdb; ipdb.set_trace()
self.decay = decay
self.device = device # perform ema on different device from model if set
if self.device is not None:
self.module.to(device=device)
def _update(self, model, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(
self.module.state_dict().values(), model.state_dict().values()
):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
class BestMetricSingle:
def __init__(self, init_res=0.0, better="large") -> None:
self.init_res = init_res
self.best_res = init_res
self.best_ep = -1
self.better = better
assert better in ["large", "small"]
def isbetter(self, new_res, old_res):
if self.better == "large":
return new_res > old_res
if self.better == "small":
return new_res < old_res
def update(self, new_res, ep):
if self.isbetter(new_res, self.best_res):
self.best_res = new_res
self.best_ep = ep
return True
return False
def __str__(self) -> str:
return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep)
def __repr__(self) -> str:
return self.__str__()
def summary(self) -> dict:
return {
"best_res": self.best_res,
"best_ep": self.best_ep,
}
class BestMetricHolder:
def __init__(self, init_res=0.0, better="large", use_ema=False) -> None:
self.best_all = BestMetricSingle(init_res, better)
self.use_ema = use_ema
if use_ema:
self.best_ema = BestMetricSingle(init_res, better)
self.best_regular = BestMetricSingle(init_res, better)
def update(self, new_res, epoch, is_ema=False):
"""
return if the results is the best.
"""
if not self.use_ema:
return self.best_all.update(new_res, epoch)
else:
if is_ema:
self.best_ema.update(new_res, epoch)
return self.best_all.update(new_res, epoch)
else:
self.best_regular.update(new_res, epoch)
return self.best_all.update(new_res, epoch)
def summary(self):
if not self.use_ema:
return self.best_all.summary()
res = {}
res.update({f"all_{k}": v for k, v in self.best_all.summary().items()})
res.update({f"regular_{k}": v for k, v in self.best_regular.summary().items()})
res.update({f"ema_{k}": v for k, v in self.best_ema.summary().items()})
return res
def __repr__(self) -> str:
return json.dumps(self.summary(), indent=2)
def __str__(self) -> str:
return self.__repr__()
def targets_to(targets: List[Dict[str, Any]], device):
"""Moves the target dicts to the given device."""
excluded_keys = [
"questionId",
"tokens_positive",
"strings_positive",
"tokens",
"dataset_name",
"sentence_id",
"original_img_id",
"nb_eval",
"task_id",
"original_id",
"token_span",
"caption",
"dataset_type",
]
return [
{k: v.to(device) if k not in excluded_keys else v for k, v in t.items()} for t in targets
]
def get_phrases_from_posmap(
posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer
):
assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
if posmap.dim() == 1:
non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
return tokenizer.decode(token_ids)
else:
raise NotImplementedError("posmap must be 1-dim")

View file

@ -0,0 +1,318 @@
# -*- coding: utf-8 -*-
"""
@File : visualizer.py
@Time : 2022/04/05 11:39:33
@Author : Shilong Liu
@Contact : slongliu86@gmail.com
"""
import datetime
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import transforms
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
from pycocotools import mask as maskUtils
def renorm(
img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
) -> torch.FloatTensor:
# img: tensor(3,H,W) or tensor(B,3,H,W)
# return: same as img
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
if img.dim() == 3:
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
img.size(0),
str(img.size()),
)
img_perm = img.permute(1, 2, 0)
mean = torch.Tensor(mean)
std = torch.Tensor(std)
img_res = img_perm * std + mean
return img_res.permute(2, 0, 1)
else: # img.dim() == 4
assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
img.size(1),
str(img.size()),
)
img_perm = img.permute(0, 2, 3, 1)
mean = torch.Tensor(mean)
std = torch.Tensor(std)
img_res = img_perm * std + mean
return img_res.permute(0, 3, 1, 2)
class ColorMap:
def __init__(self, basergb=[255, 255, 0]):
self.basergb = np.array(basergb)
def __call__(self, attnmap):
# attnmap: h, w. np.uint8.
# return: h, w, 4. np.uint8.
assert attnmap.dtype == np.uint8
h, w = attnmap.shape
res = self.basergb.copy()
res = res[None][None].repeat(h, 0).repeat(w, 1) # h, w, 3
attn1 = attnmap.copy()[..., None] # h, w, 1
res = np.concatenate((res, attn1), axis=-1).astype(np.uint8)
return res
def rainbow_text(x, y, ls, lc, **kw):
"""
Take a list of strings ``ls`` and colors ``lc`` and place them next to each
other, with text ls[i] being shown in color lc[i].
This example shows how to do both vertical and horizontal text, and will
pass all keyword arguments to plt.text, so you can set the font size,
family, etc.
"""
t = plt.gca().transData
fig = plt.gcf()
plt.show()
# horizontal version
for s, c in zip(ls, lc):
text = plt.text(x, y, " " + s + " ", color=c, transform=t, **kw)
text.draw(fig.canvas.get_renderer())
ex = text.get_window_extent()
t = transforms.offset_copy(text._transform, x=ex.width, units="dots")
# #vertical version
# for s,c in zip(ls,lc):
# text = plt.text(x,y," "+s+" ",color=c, transform=t,
# rotation=90,va='bottom',ha='center',**kw)
# text.draw(fig.canvas.get_renderer())
# ex = text.get_window_extent()
# t = transforms.offset_copy(text._transform, y=ex.height, units='dots')
class COCOVisualizer:
def __init__(self, coco=None, tokenlizer=None) -> None:
self.coco = coco
def visualize(self, img, tgt, caption=None, dpi=180, savedir="vis"):
"""
img: tensor(3, H, W)
tgt: make sure they are all on cpu.
must have items: 'image_id', 'boxes', 'size'
"""
plt.figure(dpi=dpi)
plt.rcParams["font.size"] = "5"
ax = plt.gca()
img = renorm(img).permute(1, 2, 0)
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
ax.imshow(img)
self.addtgt(tgt)
if tgt is None:
image_id = 0
elif "image_id" not in tgt:
image_id = 0
else:
image_id = tgt["image_id"]
if caption is None:
savename = "{}/{}-{}.png".format(
savedir, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
)
else:
savename = "{}/{}-{}-{}.png".format(
savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
)
print("savename: {}".format(savename))
os.makedirs(os.path.dirname(savename), exist_ok=True)
plt.savefig(savename)
plt.close()
def addtgt(self, tgt):
""" """
if tgt is None or not "boxes" in tgt:
ax = plt.gca()
if "caption" in tgt:
ax.set_title(tgt["caption"], wrap=True)
ax.set_axis_off()
return
ax = plt.gca()
H, W = tgt["size"]
numbox = tgt["boxes"].shape[0]
color = []
polygons = []
boxes = []
for box in tgt["boxes"].cpu():
unnormbbox = box * torch.Tensor([W, H, W, H])
unnormbbox[:2] -= unnormbbox[2:] / 2
[bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist()
boxes.append([bbox_x, bbox_y, bbox_w, bbox_h])
poly = [
[bbox_x, bbox_y],
[bbox_x, bbox_y + bbox_h],
[bbox_x + bbox_w, bbox_y + bbox_h],
[bbox_x + bbox_w, bbox_y],
]
np_poly = np.array(poly).reshape((4, 2))
polygons.append(Polygon(np_poly))
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
color.append(c)
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1)
ax.add_collection(p)
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
ax.add_collection(p)
if "strings_positive" in tgt and len(tgt["strings_positive"]) > 0:
assert (
len(tgt["strings_positive"]) == numbox
), f"{len(tgt['strings_positive'])} = {numbox}, "
for idx, strlist in enumerate(tgt["strings_positive"]):
cate_id = int(tgt["labels"][idx])
_string = str(cate_id) + ":" + " ".join(strlist)
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
# ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
ax.text(
bbox_x,
bbox_y,
_string,
color="black",
bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
)
if "box_label" in tgt:
assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, "
for idx, bl in enumerate(tgt["box_label"]):
_string = str(bl)
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
# ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
ax.text(
bbox_x,
bbox_y,
_string,
color="black",
bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
)
if "caption" in tgt:
ax.set_title(tgt["caption"], wrap=True)
# plt.figure()
# rainbow_text(0.0,0.0,"all unicorns poop rainbows ! ! !".split(),
# ['red', 'orange', 'brown', 'green', 'blue', 'purple', 'black'])
if "attn" in tgt:
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
if isinstance(tgt["attn"], tuple):
tgt["attn"] = [tgt["attn"]]
for item in tgt["attn"]:
attn_map, basergb = item
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3)
attn_map = (attn_map * 255).astype(np.uint8)
cm = ColorMap(basergb)
heatmap = cm(attn_map)
ax.imshow(heatmap)
ax.set_axis_off()
def showAnns(self, anns, draw_bbox=False):
"""
Display the specified annotations.
:param anns (array of object): annotations to display
:return: None
"""
if len(anns) == 0:
return 0
if "segmentation" in anns[0] or "keypoints" in anns[0]:
datasetType = "instances"
elif "caption" in anns[0]:
datasetType = "captions"
else:
raise Exception("datasetType not supported")
if datasetType == "instances":
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in anns:
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
if "segmentation" in ann:
if type(ann["segmentation"]) == list:
# polygon
for seg in ann["segmentation"]:
poly = np.array(seg).reshape((int(len(seg) / 2), 2))
polygons.append(Polygon(poly))
color.append(c)
else:
# mask
t = self.imgs[ann["image_id"]]
if type(ann["segmentation"]["counts"]) == list:
rle = maskUtils.frPyObjects(
[ann["segmentation"]], t["height"], t["width"]
)
else:
rle = [ann["segmentation"]]
m = maskUtils.decode(rle)
img = np.ones((m.shape[0], m.shape[1], 3))
if ann["iscrowd"] == 1:
color_mask = np.array([2.0, 166.0, 101.0]) / 255
if ann["iscrowd"] == 0:
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:, :, i] = color_mask[i]
ax.imshow(np.dstack((img, m * 0.5)))
if "keypoints" in ann and type(ann["keypoints"]) == list:
# turn skeleton into zero-based index
sks = np.array(self.loadCats(ann["category_id"])[0]["skeleton"]) - 1
kp = np.array(ann["keypoints"])
x = kp[0::3]
y = kp[1::3]
v = kp[2::3]
for sk in sks:
if np.all(v[sk] > 0):
plt.plot(x[sk], y[sk], linewidth=3, color=c)
plt.plot(
x[v > 0],
y[v > 0],
"o",
markersize=8,
markerfacecolor=c,
markeredgecolor="k",
markeredgewidth=2,
)
plt.plot(
x[v > 1],
y[v > 1],
"o",
markersize=8,
markerfacecolor=c,
markeredgecolor=c,
markeredgewidth=2,
)
if draw_bbox:
[bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"]
poly = [
[bbox_x, bbox_y],
[bbox_x, bbox_y + bbox_h],
[bbox_x + bbox_w, bbox_y + bbox_h],
[bbox_x + bbox_w, bbox_y],
]
np_poly = np.array(poly).reshape((4, 2))
polygons.append(Polygon(np_poly))
color.append(c)
# p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
# ax.add_collection(p)
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
ax.add_collection(p)
elif datasetType == "captions":
for ann in anns:
print(ann["caption"])

View file

@ -0,0 +1,100 @@
import os
import random
from typing import List
import torch
def create_positive_map_from_span(tokenized, token_span, max_text_len=256):
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j
Input:
- tokenized:
- input_ids: Tensor[1, ntokens]
- attention_mask: Tensor[1, ntokens]
- token_span: list with length num_boxes.
- each item: [start_idx, end_idx]
"""
positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float)
for j, tok_list in enumerate(token_span):
for (beg, end) in tok_list:
beg_pos = tokenized.char_to_token(beg)
end_pos = tokenized.char_to_token(end - 1)
if beg_pos is None:
try:
beg_pos = tokenized.char_to_token(beg + 1)
if beg_pos is None:
beg_pos = tokenized.char_to_token(beg + 2)
except:
beg_pos = None
if end_pos is None:
try:
end_pos = tokenized.char_to_token(end - 2)
if end_pos is None:
end_pos = tokenized.char_to_token(end - 3)
except:
end_pos = None
if beg_pos is None or end_pos is None:
continue
assert beg_pos is not None and end_pos is not None
if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE":
positive_map[j, beg_pos] = 1
break
else:
positive_map[j, beg_pos : end_pos + 1].fill_(1)
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
def build_captions_and_token_span(cat_list, force_lowercase):
"""
Return:
captions: str
cat2tokenspan: dict
{
'dog': [[0, 2]],
...
}
"""
cat2tokenspan = {}
captions = ""
for catname in cat_list:
class_name = catname
if force_lowercase:
class_name = class_name.lower()
if "/" in class_name:
class_name_list: List = class_name.strip().split("/")
class_name_list.append(class_name)
class_name: str = random.choice(class_name_list)
tokens_positive_i = []
subnamelist = [i.strip() for i in class_name.strip().split(" ")]
for subname in subnamelist:
if len(subname) == 0:
continue
if len(captions) > 0:
captions = captions + " "
strat_idx = len(captions)
end_idx = strat_idx + len(subname)
tokens_positive_i.append([strat_idx, end_idx])
captions = captions + subname
if len(tokens_positive_i) > 0:
captions = captions + " ."
cat2tokenspan[class_name] = tokens_positive_i
return captions, cat2tokenspan
def build_id2posspan_and_caption(category_dict: dict):
"""Build id2pos_span and caption from category_dict
Args:
category_dict (dict): category_dict
"""
cat_list = [item["name"].lower() for item in category_dict]
id2catname = {item["id"]: item["name"].lower() for item in category_dict}
caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True)
id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()}
return id2posspan, caption

View file

@ -0,0 +1 @@
__version__ = '0.1.0'

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

View file

@ -0,0 +1,401 @@
import math
import clip
import cv2
import groundingdino.datasets.transforms as T
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from PIL import Image, ImageDraw
################################## text_localization using ocr #######################
def crop_image(img, position):
def distance(x1, y1, x2, y2):
return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2))
position = position.tolist()
for i in range(4):
for j in range(i + 1, 4):
if position[i][0] > position[j][0]:
tmp = position[j]
position[j] = position[i]
position[i] = tmp
if position[0][1] > position[1][1]:
tmp = position[0]
position[0] = position[1]
position[1] = tmp
if position[2][1] > position[3][1]:
tmp = position[2]
position[2] = position[3]
position[3] = tmp
x1, y1 = position[0][0], position[0][1]
x2, y2 = position[2][0], position[2][1]
x3, y3 = position[3][0], position[3][1]
x4, y4 = position[1][0], position[1][1]
corners = np.zeros((4, 2), np.float32)
corners[0] = [x1, y1]
corners[1] = [x2, y2]
corners[2] = [x4, y4]
corners[3] = [x3, y3]
img_width = distance((x1 + x4) / 2, (y1 + y4) / 2, (x2 + x3) / 2, (y2 + y3) / 2)
img_height = distance((x1 + x2) / 2, (y1 + y2) / 2, (x4 + x3) / 2, (y4 + y3) / 2)
corners_trans = np.zeros((4, 2), np.float32)
corners_trans[0] = [0, 0]
corners_trans[1] = [img_width - 1, 0]
corners_trans[2] = [0, img_height - 1]
corners_trans[3] = [img_width - 1, img_height - 1]
transform = cv2.getPerspectiveTransform(corners, corners_trans)
dst = cv2.warpPerspective(img, transform, (int(img_width), int(img_height)))
return dst
def calculate_size(box):
return (box[2] - box[0]) * (box[3] - box[1])
def order_point(coor):
arr = np.array(coor).reshape([4, 2])
sum_ = np.sum(arr, 0)
centroid = sum_ / arr.shape[0]
theta = np.arctan2(arr[:, 1] - centroid[1], arr[:, 0] - centroid[0])
sort_points = arr[np.argsort(theta)]
sort_points = sort_points.reshape([4, -1])
if sort_points[0][0] > centroid[0]:
sort_points = np.concatenate([sort_points[3:], sort_points[:3]])
sort_points = sort_points.reshape([4, 2]).astype("float32")
return sort_points
def longest_common_substring_length(str1, str2):
m = len(str1)
n = len(str2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1):
for j in range(1, n + 1):
if str1[i - 1] == str2[j - 1]:
dp[i][j] = dp[i - 1][j - 1] + 1
else:
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
return dp[m][n]
def ocr(image_path, prompt, ocr_detection, ocr_recognition, x, y):
text_data = []
coordinate = []
image = Image.open(image_path)
iw, ih = image.size
image_full = cv2.imread(image_path)
det_result = ocr_detection(image_full)
det_result = det_result["polygons"]
for i in range(det_result.shape[0]):
pts = order_point(det_result[i])
image_crop = crop_image(image_full, pts)
result = ocr_recognition(image_crop)["text"][0]
if result == prompt:
box = [int(e) for e in list(pts.reshape(-1))]
box = [box[0], box[1], box[4], box[5]]
if calculate_size(box) > 0.05 * iw * ih:
continue
text_data.append(
[
int(max(0, box[0] - 10) * x / iw),
int(max(0, box[1] - 10) * y / ih),
int(min(box[2] + 10, iw) * x / iw),
int(min(box[3] + 10, ih) * y / ih),
]
)
coordinate.append(
[
int(max(0, box[0] - 300) * x / iw),
int(max(0, box[1] - 400) * y / ih),
int(min(box[2] + 300, iw) * x / iw),
int(min(box[3] + 400, ih) * y / ih),
]
)
max_length = 0
if len(text_data) == 0:
for i in range(det_result.shape[0]):
pts = order_point(det_result[i])
image_crop = crop_image(image_full, pts)
result = ocr_recognition(image_crop)["text"][0]
if len(result) < 0.3 * len(prompt):
continue
if result in prompt:
now_length = len(result)
else:
now_length = longest_common_substring_length(result, prompt)
if now_length > max_length:
max_length = now_length
box = [int(e) for e in list(pts.reshape(-1))]
box = [box[0], box[1], box[4], box[5]]
text_data = [
[
int(max(0, box[0] - 10) * x / iw),
int(max(0, box[1] - 10) * y / ih),
int(min(box[2] + 10, iw) * x / iw),
int(min(box[3] + 10, ih) * y / ih),
]
]
coordinate = [
[
int(max(0, box[0] - 300) * x / iw),
int(max(0, box[1] - 400) * y / ih),
int(min(box[2] + 300, iw) * x / iw),
int(min(box[3] + 400, ih) * y / ih),
]
]
if len(prompt) <= 10:
if max_length >= 0.8 * len(prompt):
return text_data, coordinate
else:
return [], []
elif (len(prompt) > 10) and (len(prompt) <= 20):
if max_length >= 0.5 * len(prompt):
return text_data, coordinate
else:
return [], []
else:
if max_length >= 0.4 * len(prompt):
return text_data, coordinate
else:
return [], []
else:
return text_data, coordinate
################################## icon_localization using clip #######################
def calculate_iou(box1, box2):
xA = max(box1[0], box2[0])
yA = max(box1[1], box2[1])
xB = min(box1[2], box2[2])
yB = min(box1[3], box2[3])
interArea = max(0, xB - xA) * max(0, yB - yA)
box1Area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2Area = (box2[2] - box2[0]) * (box2[3] - box2[1])
unionArea = box1Area + box2Area - interArea
iou = interArea / unionArea
return iou
def crop(image, box, i, text_data=None):
image = Image.open(image)
if text_data:
draw = ImageDraw.Draw(image)
draw.rectangle(((text_data[0], text_data[1]), (text_data[2], text_data[3])), outline="red", width=5)
# font_size = int((text_data[3] - text_data[1])*0.75)
# font = ImageFont.truetype("arial.ttf", font_size)
# draw.text((text_data[0]+5, text_data[1]+5), str(i), font=font, fill="red")
cropped_image = image.crop(box)
cropped_image.save(f"./temp/{i}.jpg")
def in_box(box, target):
if (box[0] > target[0]) and (box[1] > target[1]) and (box[2] < target[2]) and (box[3] < target[3]):
return True
else:
return False
def crop_for_clip(image, box, i, position):
image = Image.open(image)
w, h = image.size
if position == "left":
bound = [0, 0, w / 2, h]
elif position == "right":
bound = [w / 2, 0, w, h]
elif position == "top":
bound = [0, 0, w, h / 2]
elif position == "bottom":
bound = [0, h / 2, w, h]
elif position == "top left":
bound = [0, 0, w / 2, h / 2]
elif position == "top right":
bound = [w / 2, 0, w, h / 2]
elif position == "bottom left":
bound = [0, h / 2, w / 2, h]
elif position == "bottom right":
bound = [w / 2, h / 2, w, h]
else:
bound = [0, 0, w, h]
if in_box(box, bound):
cropped_image = image.crop(box)
cropped_image.save(f"./temp/{i}.jpg")
return True
else:
return False
def clip_for_icon(clip_model, clip_preprocess, images, prompt):
image_features = []
for image_file in images:
image = clip_preprocess(Image.open(image_file)).unsqueeze(0).to(next(clip_model.parameters()).device)
image_feature = clip_model.encode_image(image)
image_features.append(image_feature)
image_features = torch.cat(image_features)
text = clip.tokenize([prompt]).to(next(clip_model.parameters()).device)
text_features = clip_model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=0).squeeze(0)
_, max_pos = torch.max(similarity, dim=0)
pos = max_pos.item()
return pos
def transform_image(image_pil):
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image
def load_model(model_config_path, model_checkpoint_path, device):
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
print(load_res)
_ = model.eval()
return model
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
logits.shape[0]
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
logits_filt.shape[0]
tokenlizer = model.tokenizer
tokenized = tokenlizer(caption)
pred_phrases = []
scores = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
scores.append(logit.max().item())
return boxes_filt, torch.Tensor(scores), pred_phrases
def remove_boxes(boxes_filt, size, iou_threshold=0.5):
boxes_to_remove = set()
for i in range(len(boxes_filt)):
if calculate_size(boxes_filt[i]) > 0.05 * size[0] * size[1]:
boxes_to_remove.add(i)
for j in range(len(boxes_filt)):
if calculate_size(boxes_filt[j]) > 0.05 * size[0] * size[1]:
boxes_to_remove.add(j)
if i == j:
continue
if i in boxes_to_remove or j in boxes_to_remove:
continue
iou = calculate_iou(boxes_filt[i], boxes_filt[j])
if iou >= iou_threshold:
boxes_to_remove.add(j)
boxes_filt = [box for idx, box in enumerate(boxes_filt) if idx not in boxes_to_remove]
return boxes_filt
def det(input_image, text_prompt, groundingdino_model, box_threshold=0.05, text_threshold=0.5):
image = Image.open(input_image)
size = image.size
image_pil = image.convert("RGB")
image = np.array(image_pil)
transformed_image = transform_image(image_pil)
boxes_filt, scores, pred_phrases = get_grounding_output(
groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold
)
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
boxes_filt = boxes_filt.cpu().int().tolist()
filtered_boxes = remove_boxes(boxes_filt, size) # [:9]
coordinate = []
image_data = []
for box in filtered_boxes:
image_data.append(
[max(0, box[0] - 10), max(0, box[1] - 10), min(box[2] + 10, size[0]), min(box[3] + 10, size[1])]
)
coordinate.append(
[max(0, box[0] - 25), max(0, box[1] - 25), min(box[2] + 25, size[0]), min(box[3] + 25, size[1])]
)
return image_data, coordinate
if __name__ == "__main__":
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
image_ori = "./screenshot/screenshot.png"
image = "./screenshot/screenshot.png"
parameter = "抖音"
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo")
iw, ih = Image.open(image).size
if iw > ih:
iw, ih = ih, iw
in_coordinate, out_coordinate = ocr(image_ori, parameter, ocr_detection, ocr_recognition, iw, ih)

View file

@ -19,7 +19,7 @@ from metagpt.logs import logger
from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file
class MinecraftEnv(Environment, MinecraftExtEnv):
class MinecraftEnv(MinecraftExtEnv, Environment):
"""MinecraftEnv, including shared memory of cache and information between roles"""
model_config = ConfigDict(arbitrary_types_allowed=True)

View file

@ -0,0 +1,121 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from enum import Enum
from metagpt.const import MESSAGE_ROUTE_TO_ALL
class RoleType(Enum):
VILLAGER = "Villager"
WEREWOLF = "Werewolf"
GUARD = "Guard"
SEER = "Seer"
WITCH = "Witch"
MODERATOR = "Moderator"
class RoleState(Enum):
ALIVE = "alive" # the role is alive
DEAD = "dead" # killed or poisoned
KILLED = "killed" # killed by werewolf or voting
POISONED = "poisoned" # killed by poison
SAVED = "saved" # saved by antidote
PROTECTED = "projected" # projected by guard
class RoleActionRes(Enum):
SAVE = "save"
PASS = "pass" # ignore current action output
empty_set = set()
# the ordered rules by the moderator to announce to everyone each step
STEP_INSTRUCTIONS = {
0: {
"content": "Its dark, everyone close your eyes. I will talk with you/your team secretly at night.",
"send_to": {RoleType.MODERATOR.value}, # for moderator to continue speaking
"restricted_to": empty_set,
},
1: {
"content": "Guard, please open your eyes!",
"send_to": {RoleType.MODERATOR.value}, # for moderator to continue speaking
"restricted_to": empty_set,
},
2: {
"content": """Guard, now tell me who you protect tonight?
You only choose one from the following living options please: {living_players}.
Or you can pass. For example: Protect ...""",
"send_to": {RoleType.GUARD.value},
"restricted_to": {RoleType.MODERATOR.value, RoleType.GUARD.value},
},
3: {"content": "Guard, close your eyes", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set},
4: {
"content": "Werewolves, please open your eyes!",
"send_to": {RoleType.MODERATOR.value},
"restricted_to": empty_set,
},
5: {
"content": """Werewolves, I secretly tell you that {werewolf_players} are
all of the {werewolf_num} werewolves! Keep in mind you are teammates. The rest players are not werewolves.
choose one from the following living options please:
{living_players}. For example: Kill ...""",
"send_to": {RoleType.WEREWOLF.value},
"restricted_to": {RoleType.MODERATOR.value, RoleType.WEREWOLF.value},
},
6: {"content": "Werewolves, close your eyes", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set},
7: {"content": "Witch, please open your eyes!", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set},
8: {
"content": """Witch, tonight {player_hunted} has been killed by the werewolves.
You have a bottle of antidote, would you like to save him/her? If so, say "Save", else, say "Pass".""",
"send_to": {RoleType.WITCH.value},
"restricted_to": {RoleType.MODERATOR.value, RoleType.WITCH.value},
}, # 要先判断女巫是否有解药,再去询问女巫是否使用解药救人
9: {
"content": """Witch, you also have a bottle of poison, would you like to use it to kill one of the living players?
Choose one from the following living options: {living_players}.
If so, say ONLY "Poison PlayerX", replace PlayerX with the actual player name, else, say "Pass".""",
"send_to": {RoleType.WITCH.value},
"restricted_to": {RoleType.MODERATOR.value, RoleType.WITCH.value},
}, #
10: {"content": "Witch, close your eyes", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set},
11: {"content": "Seer, please open your eyes!", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set},
12: {
"content": """Seer, you can check one player's identity. Who are you going to verify its identity tonight?
Choose only one from the following living options:{living_players}.""",
"send_to": {RoleType.SEER.value},
"restricted_to": {RoleType.MODERATOR.value, RoleType.SEER.value},
},
13: {"content": "Seer, close your eyes", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set},
# The 1-st daytime
14: {
"content": """It's daytime. Everyone woke up except those who had been killed.""",
"send_to": {RoleType.MODERATOR.value},
"restricted_to": empty_set,
},
15: {
"content": "{player_current_dead} was killed last night!",
"send_to": {RoleType.MODERATOR.value},
"restricted_to": empty_set,
},
16: {
"content": """Living players: {living_players}, now freely talk about the current situation based on your observation and
reflection with a few sentences. Decide whether to reveal your identity based on your reflection.""",
"send_to": {MESSAGE_ROUTE_TO_ALL}, # send to all to speak in daytime
"restricted_to": empty_set,
},
17: {
"content": """Now vote and tell me who you think is the werewolf. Dont mention your role.
You only choose one from the following living options please:
{living_players}. Say ONLY: I vote to eliminate ...""",
"send_to": {MESSAGE_ROUTE_TO_ALL},
"restricted_to": empty_set,
},
18: {
"content": """{player_current_dead} was eliminated.""",
"send_to": {RoleType.MODERATOR.value},
"restricted_to": empty_set,
},
}

View file

@ -0,0 +1,60 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : werewolf observation/action space and its action definition
from gymnasium import spaces
from pydantic import ConfigDict, Field
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvActionType
from metagpt.environment.werewolf.const import STEP_INSTRUCTIONS
class EnvActionType(BaseEnvActionType):
NONE = 0 # no action to run, just get observation
WOLF_KILL = 1 # wolf kill someone
VOTE_KILL = 2 # vote kill someone
WITCH_POISON = 3 # witch poison someone
WITCH_SAVE = 4 # witch save someone
GUARD_PROTECT = 5 # guard protect someone
PROGRESS_STEP = 6 # step increment
class EnvAction(BaseEnvAction):
model_config = ConfigDict(arbitrary_types_allowed=True)
action_type: int = Field(default=EnvActionType.NONE, description="action type")
player_name: str = Field(default="", description="the name of the player to do the action")
target_player_name: str = Field(default="", description="the name of the player who take the action")
def get_observation_space() -> spaces.Dict:
space = spaces.Dict(
{
"game_setup": spaces.Text(256),
"step_idx": spaces.Discrete(len(STEP_INSTRUCTIONS)),
"living_players": spaces.Tuple(
(spaces.Text(16), spaces.Text(16))
), # TODO should be tuple of variable length
"werewolf_players": spaces.Tuple(
(spaces.Text(16), spaces.Text(16))
), # TODO should be tuple of variable length
"player_hunted": spaces.Text(16),
"player_current_dead": spaces.Tuple((spaces.Text(16))), # TODO should be tuple of variable length
"witch_poison_left": spaces.Discrete(2),
"witch_antidote_left": spaces.Discrete(2),
"winner": spaces.Text(16),
"win_reason": spaces.Text(64),
}
)
return space
def get_action_space() -> spaces.Dict:
space = spaces.Dict(
{
"action_type": spaces.Discrete(len(EnvActionType)),
"player_name": spaces.Text(16), # the player to do the action
"target_player_name": spaces.Text(16), # the target player who take the action
}
)
return space

View file

@ -2,30 +2,40 @@
# -*- coding: utf-8 -*-
# @Desc : MG Werewolf Env
from typing import Iterable
from pydantic import Field
from metagpt.environment.base_env import Environment
from metagpt.environment.werewolf.werewolf_ext_env import WerewolfExtEnv
from metagpt.logs import logger
from metagpt.schema import Message
class WerewolfEnv(Environment, WerewolfExtEnv):
timestamp: int = Field(default=0)
class WerewolfEnv(WerewolfExtEnv, Environment):
round_cnt: int = Field(default=0)
def add_roles(self, roles: Iterable["Role"]):
"""增加一批在当前环境的角色
Add a batch of characters in the current environment
"""
for role in roles:
self.roles[role.name] = role # use name as key here, due to multi-player can have same profile
for role in roles: # setup system message with roles
role.context = self.context
role.set_env(self)
def publish_message(self, message: Message, add_timestamp: bool = True):
"""Post information to the current environment"""
logger.debug(f"publish_message: {message.dump()}")
if add_timestamp:
# Because the content of the message may be repeated, for example, killing the same person in two nights
# Therefore, a unique timestamp prefix needs to be added so that the same message will not be automatically deduplicated when added to the memory.
message.content = f"{self.timestamp} | " + message.content
self.memory.add(message)
self.history += f"\n{message}"
# Therefore, a unique round_cnt prefix needs to be added so that the same message will not be automatically deduplicated when added to the memory.
message.content = f"{self.round_cnt} | " + message.content
super().publish_message(message)
async def run(self, k=1):
"""Process all Role runs by order"""
for _ in range(k):
for role in self.roles.values():
await role.run()
self.timestamp += 1
self.round_cnt += 1

View file

@ -4,110 +4,27 @@
import random
from collections import Counter
from enum import Enum
from typing import Any, Callable, Optional
from pydantic import ConfigDict, Field
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.environment.base_env_space import BaseEnvObsParams
from metagpt.environment.werewolf.const import STEP_INSTRUCTIONS, RoleState, RoleType
from metagpt.environment.werewolf.env_space import EnvAction, EnvActionType
from metagpt.logs import logger
class RoleState(Enum):
ALIVE = "alive" # the role is alive
KILLED = "killed" # the role is killed by werewolf or voting
POISONED = "poisoned" # the role is killed by posion
SAVED = "saved" # the role is saved by antidote
# the ordered rules by the moderator to announce to everyone each step
STEP_INSTRUCTIONS = {
0: {
"content": "Its dark, everyone close your eyes. I will talk with you/your team secretly at night.",
"send_to": "Moderator", # for moderator to continuen speaking
"restricted_to": "",
},
1: {
"content": "Guard, please open your eyes!",
"send_to": "Moderator", # for moderator to continuen speaking
"restricted_to": "",
},
2: {
"content": """Guard, now tell me who you protect tonight?
You only choose one from the following living options please: {living_players}.
Or you can pass. For example: Protect ...""",
"send_to": "Guard",
"restricted_to": "Moderator,Guard",
},
3: {"content": "Guard, close your eyes", "send_to": "Moderator", "restricted_to": ""},
4: {"content": "Werewolves, please open your eyes!", "send_to": "Moderator", "restricted_to": ""},
5: {
"content": """Werewolves, I secretly tell you that {werewolf_players} are
all of the 2 werewolves! Keep in mind you are teammates. The rest players are not werewolves.
choose one from the following living options please:
{living_players}. For example: Kill ...""",
"send_to": "Werewolf",
"restricted_to": "Moderator,Werewolf",
},
6: {"content": "Werewolves, close your eyes", "send_to": "Moderator", "restricted_to": ""},
7: {"content": "Witch, please open your eyes!", "send_to": "Moderator", "restricted_to": ""},
8: {
"content": """Witch, tonight {player_hunted} has been killed by the werewolves.
You have a bottle of antidote, would you like to save him/her? If so, say "Save", else, say "Pass".""",
"send_to": "Witch",
"restricted_to": "Moderator,Witch",
}, # 要先判断女巫是否有解药,再去询问女巫是否使用解药救人
9: {
"content": """Witch, you also have a bottle of poison, would you like to use it to kill one of the living players?
Choose one from the following living options: {living_players}.
If so, say ONLY "Poison PlayerX", replace PlayerX with the actual player name, else, say "Pass".""",
"send_to": "Witch",
"restricted_to": "Moderator,Witch",
}, #
10: {"content": "Witch, close your eyes", "send_to": "Moderator", "restricted_to": ""},
11: {"content": "Seer, please open your eyes!", "send_to": "Moderator", "restricted_to": ""},
12: {
"content": """Seer, you can check one player's identity. Who are you going to verify its identity tonight?
Choose only one from the following living options:{living_players}.""",
"send_to": "Seer",
"restricted_to": "Moderator,Seer",
},
13: {"content": "Seer, close your eyes", "send_to": "Moderator", "restricted_to": ""},
# The 1-st daytime
14: {
"content": """It's daytime. Everyone woke up except those who had been killed.""",
"send_to": "Moderator",
"restricted_to": "",
},
15: {"content": "{player_current_dead} was killed last night!", "send_to": "Moderator", "restricted_to": ""},
16: {
"content": """Living players: {living_players}, now freely talk about the current situation based on your observation and
reflection with a few sentences. Decide whether to reveal your identity based on your reflection.""",
"send_to": "", # send to all to speak in daytime
"restricted_to": "",
},
17: {
"content": """Now vote and tell me who you think is the werewolf. Dont mention your role.
You only choose one from the following living options please:
{living_players}. Say ONLY: I vote to eliminate ...""",
"send_to": "",
"restricted_to": "",
},
18: {"content": """{player_current_dead} was eliminated.""", "send_to": "Moderator", "restricted_to": ""},
}
class WerewolfExtEnv(ExtEnv):
model_config = ConfigDict(arbitrary_types_allowed=True)
players_state: dict[str, tuple[str, RoleState]] = Field(
default=dict(), description="the player's role type and state by player_name"
default_factory=dict, description="the player's role type and state by player_name"
)
round_idx: int = Field(default=0) # the current round
step_idx: int = Field(default=0) # the current step of current round
eval_step_idx: int = Field(default=0)
eval_step_idx: list[int] = Field(default=[])
per_round_steps: int = Field(default=len(STEP_INSTRUCTIONS))
# game global states
@ -115,13 +32,13 @@ class WerewolfExtEnv(ExtEnv):
special_role_players: list[str] = Field(default=[])
winner: Optional[str] = Field(default=None)
win_reason: Optional[str] = Field(default=None)
witch_poison_left: int = Field(default=1)
witch_antidote_left: int = Field(default=1)
witch_poison_left: int = Field(default=1, description="should be 1 or 0")
witch_antidote_left: int = Field(default=1, description="should be 1 or 0")
# game current round states, a round is from closing your eyes to the next time you close your eyes
round_hunts: dict[str, str] = Field(default=dict(), description="nighttime wolf hunt result")
round_hunts: dict[str, str] = Field(default_factory=dict, description="nighttime wolf hunt result")
round_votes: dict[str, str] = Field(
default=dict(), description="daytime all players vote result, key=voteer, value=voted one"
default_factory=dict, description="daytime all players vote result, key=voter, value=voted one"
)
player_hunted: Optional[str] = Field(default=None)
player_protected: Optional[str] = Field(default=None)
@ -135,13 +52,69 @@ class WerewolfExtEnv(ExtEnv):
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""currently unused"""
pass
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
"""currently unused"""
pass
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
pass
def _get_obs(self):
return {
"game_setup": self.game_setup,
"step_idx": self.step_idx,
"living_players": self.living_players,
"werewolf_players": self.werewolf_players, # currently, lack observation isolation
"player_hunted": self.player_hunted,
"player_current_dead": self.player_current_dead,
"witch_poison_left": self.witch_poison_left,
"witch_antidote_left": self.witch_antidote_left,
"winner": self.winner,
"win_reason": self.win_reason,
}
def step(self, action: EnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
action_type = action.action_type
player_name = action.player_name
target_player_name = action.target_player_name
if action_type == EnvActionType.WOLF_KILL:
self.wolf_kill_someone(wolf_name=player_name, player_name=target_player_name)
elif action_type == EnvActionType.VOTE_KILL:
self.vote_kill_someone(voter_name=player_name, player_name=target_player_name)
elif action_type == EnvActionType.WITCH_POISON:
self.witch_poison_someone(witch_name=player_name, player_name=target_player_name)
elif action_type == EnvActionType.WITCH_SAVE:
self.witch_save_someone(witch_name=player_name, player_name=target_player_name)
elif action_type == EnvActionType.GUARD_PROTECT:
self.guard_protect_someone(guard_name=player_name, player_name=target_player_name)
elif action_type == EnvActionType.PROGRESS_STEP:
self.progress_step()
elif action_type == EnvActionType.NONE:
pass
else:
raise ValueError(f"not supported action_type: {action_type}")
self.update_game_states()
terminated = self._check_game_finish()
obs = self._get_obs()
return obs, 1.0, terminated, False, {}
def _check_game_finish(self) -> bool:
"""return True if game finished else False"""
# game's termination condition
terminated = False
living_werewolf = [p for p in self.werewolf_players if p in self.living_players]
living_villagers = [p for p in self.villager_players if p in self.living_players]
living_special_roles = [p for p in self.special_role_players if p in self.living_players]
if not living_werewolf:
self.winner = "good guys"
self.win_reason = "werewolves all dead"
terminated = True
elif not living_villagers or not living_special_roles:
self.winner = "werewolf"
self.win_reason = "villagers all dead" if not living_villagers else "special roles all dead"
terminated = True
return terminated
@property
def living_players(self) -> list[str]:
@ -161,12 +134,12 @@ class WerewolfExtEnv(ExtEnv):
@property
def werewolf_players(self) -> list[str]:
player_names = self._role_type_players(role_type="Werewolf")
player_names = self._role_type_players(role_type=RoleType.WEREWOLF.value)
return player_names
@property
def villager_players(self) -> list[str]:
player_names = self._role_type_players(role_type="Villager")
player_names = self._role_type_players(role_type=RoleType.VILLAGER.value)
return player_names
def _init_players_state(self, players: list["Role"]):
@ -193,14 +166,14 @@ class WerewolfExtEnv(ExtEnv):
"""init players using different roles' num"""
role_objs = []
for role_obj in role_uniq_objs:
if str(role_obj) == "Villager":
if RoleType.VILLAGER.value in str(role_obj):
role_objs.extend([role_obj] * num_villager)
elif str(role_obj) == "Werewolf":
elif RoleType.WEREWOLF.value in str(role_obj):
role_objs.extend([role_obj] * num_werewolf)
else:
role_objs.append(role_obj)
if shuffle:
random.shuffle(len(role_objs))
random.shuffle(role_objs)
if add_human:
assigned_role_idx = random.randint(0, len(role_objs) - 1)
assigned_role = role_objs[assigned_role_idx]
@ -233,10 +206,12 @@ class WerewolfExtEnv(ExtEnv):
roletype_state = self.players_state[player_name]
self.players_state[player_name] = (roletype_state[0], state)
def _check_valid_role(self, player: "Role", role_type: str) -> bool:
return True if role_type in str(player) else False
def _check_valid_role(self, player_name: str, role_type: str) -> bool:
roletype_state = self.players_state.get(player_name)
return True if roletype_state and role_type in roletype_state[0] else False
def _check_player_continue(self, player_name: str, particular_step: int = -1) -> bool:
"""to check if can do the operation to the player"""
step_idx = self.step_idx % self.per_round_steps
if particular_step > 0 and step_idx != particular_step: # step no
# particular_step = 18, not daytime vote time, ignore
@ -253,6 +228,10 @@ class WerewolfExtEnv(ExtEnv):
self.step_idx += 1
return instruction
@mark_as_writeable
def progress_step(self):
self.step_idx += 1
@mark_as_readable
def get_players_state(self, player_names: list[str]) -> dict[str, RoleState]:
players_state = {
@ -263,57 +242,72 @@ class WerewolfExtEnv(ExtEnv):
return players_state
@mark_as_writeable
def vote_kill_someone(self, voteer: "Role", player_name: str = None):
def vote_kill_someone(self, voter_name: str, player_name: str = None):
"""player vote result at daytime
player_name: if it's None, regard as abstaining from voting
"""
if not self._check_player_continue(voteer.name, particular_step=18): # 18=step no
if not self._check_player_continue(voter_name, particular_step=18): # 18=step no
return
self.round_votes[voteer.name] = player_name
self.round_votes[voter_name] = player_name
# check if all living players finish voting, then get the dead one
if list(self.round_votes.keys()) == self.living_players:
voted_all = list(self.round_votes.values()) # TODO in case of tie vote, check who was voted first
voted_all = [item for item in voted_all if item]
self.player_current_dead = Counter(voted_all).most_common()[0][0]
self._update_players_state([self.player_current_dead])
self.player_current_dead = [Counter(voted_all).most_common()[0][0]]
self._update_players_state(self.player_current_dead)
@mark_as_writeable
def wolf_kill_someone(self, wolf: "Role", player_name: str):
if not self._check_valid_role(wolf, "Werewolf"):
def wolf_kill_someone(self, wolf_name: str, player_name: str):
if not self._check_valid_role(wolf_name, RoleType.WEREWOLF.value):
return
if not self._check_player_continue(wolf.name, particular_step=5): # 5=step no
if not self._check_player_continue(wolf_name, particular_step=6): # 5=step no
return
self.round_hunts[wolf.name] = player_name
living_werewolf = [p for p in self.werewolf_players if p in self.living_players]
self.round_hunts[wolf_name] = player_name
# living_werewolf = [p for p in self.werewolf_players if p in self.living_players]
# check if all living wolfs finish hunting, then get the hunted one
if list(self.round_hunts.keys()) == living_werewolf:
hunted_all = list(self.round_hunts.values())
self.player_hunted = Counter(hunted_all).most_common()[0][0]
# if list(self.round_hunts.keys()) == living_werewolf:
# hunted_all = list(self.round_hunts.values())
# self.player_hunted = Counter(hunted_all).most_common()[0][0]
self.player_hunted = player_name
@mark_as_writeable
def witch_poison_someone(self, witch: "Role", player_name: str = None):
if not self._check_valid_role(witch, "Witch"):
def _witch_poison_or_save_someone(
self, witch_name: str, player_name: str = None, state: RoleState = RoleState.POISONED
):
if not self._check_valid_role(witch_name, RoleType.WITCH.value):
return
if not self._check_player_continue(player_name):
return
self._update_players_state([player_name], RoleState.POISONED)
self.player_poisoned = player_name
assert state in [RoleState.POISONED, RoleState.SAVED]
self._update_players_state([player_name], state)
if state == RoleState.POISONED:
self.player_poisoned = player_name
self.witch_poison_left -= 1
else:
# self.player_protected = player_name
self.is_hunted_player_saved = True
self.witch_antidote_left -= 1
@mark_as_writeable
def witch_save_someone(self, witch: "Role", player_name: str = None):
if not self._check_valid_role(witch, "Witch"):
def witch_poison_someone(self, witch_name: str, player_name: str = None):
self._witch_poison_or_save_someone(witch_name, player_name, RoleState.POISONED)
@mark_as_writeable
def witch_save_someone(self, witch_name: str, player_name: str = None):
self._witch_poison_or_save_someone(witch_name, player_name, RoleState.SAVED)
@mark_as_writeable
def guard_protect_someone(self, guard_name: str, player_name: str = None):
if not self._check_valid_role(guard_name, RoleType.GUARD.value):
return
if not self._check_player_continue(player_name):
return
self._update_players_state([player_name], RoleState.SAVED)
self.player_protected = player_name
@mark_as_writeable
def update_game_states(self, memories: list):
def update_game_states(self):
step_idx = self.step_idx % self.per_round_steps
if step_idx not in [15, 18] or self.step_idx in self.eval_step_idx:
return
@ -329,22 +323,12 @@ class WerewolfExtEnv(ExtEnv):
if self.player_poisoned:
self.player_current_dead.append(self.player_poisoned)
self._update_players_state([self.player_current_dead])
self._update_players_state(self.player_current_dead)
# reset
self.player_hunted = None
self.player_protected = None
self.is_hunted_player_saved = False
self.player_poisoned = None
# game's termination condition
living_werewolf = [p for p in self.werewolf_players if p in self.living_players]
living_villagers = [p for p in self.villager_players if p in self.living_players]
living_special_roles = [p for p in self.special_role_players if p in self.living_players]
if not living_werewolf:
self.winner = "good guys"
self.win_reason = "werewolves all dead"
elif not living_villagers or not living_special_roles:
self.winner = "werewolf"
self.win_reason = "villagers all dead" if not living_villagers else "special roles all dead"
if self.winner is not None:
self._record_all_experiences() # TODO
elif step_idx == 18:
# updated use vote_kill_someone
pass

View file

@ -49,6 +49,7 @@ def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
def get_embedding(text, model: str = "text-embedding-ada-002"):
text = text.replace("\n", " ")
embedding = None
if not text:
text = "this is blank"
for idx in range(3):
@ -56,7 +57,8 @@ def get_embedding(text, model: str = "text-embedding-ada-002"):
embedding = (
OpenAI(api_key=config.llm.api_key).embeddings.create(input=[text], model=model).data[0].embedding
)
except Exception:
except Exception as exp:
logger.info(f"get_embedding failed, exp: {exp}, will retry.")
time.sleep(5)
if not embedding:
raise ValueError("get_embedding failed")

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from metagpt.ext.werewolf.actions.werewolf_actions import Hunt, Impersonate
from metagpt.ext.werewolf.actions.guard_actions import Protect
from metagpt.ext.werewolf.actions.seer_actions import Verify
from metagpt.ext.werewolf.actions.witch_actions import Save, Poison
from metagpt.ext.werewolf.actions.common_actions import Speak, NighttimeWhispers, Reflect
from metagpt.ext.werewolf.actions.experience_operation import AddNewExperiences, RetrieveExperiences
from metagpt.ext.werewolf.actions.moderator_actions import InstructSpeak
ACTIONS = {
"Speak": Speak,
"Hunt": Hunt,
"Protect": Protect,
"Verify": Verify,
"Save": Save,
"Poison": Poison,
"Impersonate": Impersonate,
}
__all__ = ["NighttimeWhispers", "Reflect", "AddNewExperiences", "RetrieveExperiences", "InstructSpeak"]

View file

@ -0,0 +1,240 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import json
from tenacity import retry, stop_after_attempt, wait_fixed
from metagpt.actions import Action
from metagpt.logs import logger
from metagpt.utils.common import parse_json_code_block
def log_and_parse_json(name: str, rsp: str) -> dict:
rsp = rsp.replace("\n", " ")
logger.debug(f"{name} result: {rsp}")
json_blocks = parse_json_code_block(rsp)
rsp_json = json.loads(json_blocks[0])
return rsp_json
class Speak(Action):
"""Action: Any speak action in a game"""
PROMPT_TEMPLATE: str = """
{
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__. Note that villager, seer, guard and witch are all in villager side, they have the same objective. Werewolves can collectively hunt ONE player at night."
,"HISTORY": "You have knowledge to the following conversation: __context__"
,"ATTENTION": "You can NOT VOTE a player who is NOT ALIVE now!"
,"REFLECTION": "__reflection__"
,"STRATEGY": __strategy__
,"PAST_EXPERIENCES": "__experiences__"
,"MODERATOR_INSTRUCTION": __latest_instruction__,
,"RULE": "Please follow the moderator's latest instruction, figure out if you need to speak your opinion or directly to vote:
1. If the instruction is to SPEAK, speak in 200 words. Remember the goal of your role and try to achieve it using your speech;
2. If the instruction is to VOTE, you MUST vote and ONLY say 'I vote to eliminate PlayerX', replace PlayerX with the actual player name, DO NOT include any other words."
,"OUTPUT_FORMAT":
{
"ROLE": "Your role, in this case, __profile__"
,"PLAYER_NAME": "Your name, in this case, __name__"
,"LIVING_PLAYERS": "List living players based on MODERATOR_INSTRUCTION. Return a json LIST datatype."
,"THOUGHTS": "Based on `MODERATOR_INSTRUCTION` and `RULE`, carefully think about what to say or vote so that your chance of win as __profile__ maximizes.
If you find similar situation in `PAST_EXPERIENCES`, you may draw lessons from them to refine your strategy, take better vote action, or improve your speech.
Give your step-by-step thought process, you should think no more than 3 steps. For example: My step-by-step thought process:..."
,"RESPONSE": "Based on `MODERATOR_INSTRUCTION`, `RULE`, and the 'THOUGHTS' you had, express your opinion or cast a vote."
}
}
"""
STRATEGY: str = """
Decide whether to reveal your identity based on benefits vs. risks, provide useful information, and vote to eliminate the most suspicious.
If you have special abilities, pay attention to those who falsely claims your role, for they are probably werewolves.
"""
name: str = "Speak"
@retry(stop=stop_after_attempt(2), wait=wait_fixed(1))
async def run(
self,
profile: str,
name: str,
context: str,
latest_instruction: str,
reflection: str = "",
experiences: str = "",
):
prompt = (
self.PROMPT_TEMPLATE.replace("__context__", context)
.replace("__profile__", profile)
.replace("__name__", name)
.replace("__latest_instruction__", latest_instruction)
.replace("__strategy__", self.STRATEGY)
.replace("__reflection__", reflection)
.replace("__experiences__", experiences)
)
rsp = await self._aask(prompt)
rsp_json = log_and_parse_json(self.name, rsp)
return rsp_json["RESPONSE"]
class NighttimeWhispers(Action):
"""
Action: nighttime whispers with thinking processes
Usage Example:
class Hunt(NighttimeWhispers):
def __init__(self, name="Hunt", context=None, llm=None):
super().__init__(name, context, llm)
class Protect(NighttimeWhispers):
def __init__(self, name="Protect", context=None, llm=None):
super().__init__(name, context, llm)
class Verify(NighttimeWhispers):
def __init__(self, name="Verify", context=None, llm=None):
super().__init__(name, context, llm)
class Save(NighttimeWhispers):
def __init__(self, name="Save", context=None, llm=None):
super().__init__(name, context, llm)
def _update_prompt_json(self, prompt_json: dict, profile: str, name: str, context: str, **kwargs):
del prompt_json['ACTION']
del prompt_json['ATTENTION']
prompt_json["OUTPUT_FORMAT"]["THOUGHTS"] = "It is night time. Return the thinking steps of your decision of whether to save the player JUST be killed at this night."
prompt_json["OUTPUT_FORMAT"]["RESPONSE"] = "Follow the Moderator's instruction, decide whether you want to save that person or not. Return SAVE or PASS."
return prompt_json
class Poison(NighttimeWhispers):
def __init__(self, name="Poison", context=None, llm=None):
super().__init__(name, context, llm)
def _update_prompt_json(self, prompt_json: dict, profile: str, name: str, context: str, **kwargs):
prompt_json["OUTPUT_FORMAT"]["RESPONSE"] += "Or if you want to PASS, return PASS."
return prompt_json
"""
PROMPT_TEMPLATE: str = """
{
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__. Note that villager, seer, guard and witch are all in villager side, they have the same objective. Werewolves can collectively hunt ONE player at night."
,"HISTORY": "You have knowledge to the following conversation: __context__"
,"ACTION": "Choose one living player to __action__."
,"ATTENTION": "1. You can only __action__ a player who is alive this night! And you can not __action__ a player who is dead this night! 2. `HISTORY` is all the information you observed, DONT hallucinate other player actions!"
,"REFLECTION": "__reflection__"
,"STRATEGY": "__strategy__"
,"PAST_EXPERIENCES": "__experiences__"
,"OUTPUT_FORMAT":
{
"ROLE": "Your role, in this case, __profile__"
,"PLAYER_NAME": "Your name, in this case, __name__"
,"LIVING_PLAYERS": "List the players who is alive based on moderator's latest instruction. Return a json LIST datatype."
,"THOUGHTS": "Choose one living player from `LIVING_PLAYERS` to __action__ this night. Return the reason why you choose to __action__ this player. If you observe nothing at first night, DONT imagine unexisting player actions! If you find similar situation in `PAST_EXPERIENCES`, you may draw lessons from them to refine your strategy and take better actions. Give your step-by-step thought process, you should think no more than 3 steps. For example: My step-by-step thought process:..."
,"RESPONSE": "As a __profile__, you should choose one living player from `LIVING_PLAYERS` to __action__ this night according to the THOUGHTS you have just now. Return the player name ONLY."
}
}
"""
STRATEGY: str = """
Decide which player is most threatening to you or most needs your support, take your action correspondingly.
"""
name: str = "NightTimeWhispers"
def _construct_prompt_json(
self, role_profile: str, role_name: str, context: str, reflection: str, experiences: str, **kwargs
):
prompt_template = self.PROMPT_TEMPLATE
def replace_string(prompt_json: dict):
k: str
for k in prompt_json.keys():
if isinstance(prompt_json[k], dict):
prompt_json[k] = replace_string(prompt_json[k])
continue
prompt_json[k] = prompt_json[k].replace("__profile__", role_profile)
prompt_json[k] = prompt_json[k].replace("__name__", role_name)
prompt_json[k] = prompt_json[k].replace("__context__", context)
prompt_json[k] = prompt_json[k].replace("__action__", self.name)
prompt_json[k] = prompt_json[k].replace("__strategy__", self.STRATEGY)
prompt_json[k] = prompt_json[k].replace("__reflection__", reflection)
prompt_json[k] = prompt_json[k].replace("__experiences__", experiences)
return prompt_json
prompt_json: dict = json.loads(prompt_template)
prompt_json = replace_string(prompt_json)
prompt_json: dict = self._update_prompt_json(
prompt_json, role_profile, role_name, context, reflection, experiences, **kwargs
)
assert isinstance(prompt_json, dict)
prompt: str = json.dumps(prompt_json, indent=4, ensure_ascii=False)
return prompt
def _update_prompt_json(
self, prompt_json: dict, role_profile: str, role_name: str, context: str, reflection: str, experiences: str
) -> dict:
# one can modify the prompt_json dictionary here
return prompt_json
@retry(stop=stop_after_attempt(2), wait=wait_fixed(1))
async def run(self, context: str, profile: str, name: str, reflection: str = "", experiences: str = ""):
prompt = self._construct_prompt_json(
role_profile=profile, role_name=name, context=context, reflection=reflection, experiences=experiences
)
rsp = await self._aask(prompt)
rsp_json = log_and_parse_json(self.name, rsp)
return f"{self.name} " + rsp_json["RESPONSE"]
class Reflect(Action):
PROMPT_TEMPLATE: str = """
{
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__. Note that villager, seer, guard and witch are all in villager side, they have the same objective. Werewolves can collectively hunt ONE player at night."
,"HISTORY": "You have knowledge to the following conversation: __context__"
,"MODERATOR_INSTRUCTION": __latest_instruction__,
,"OUTPUT_FORMAT" (a json):
{
"ROLE": "Your role, in this case, __profile__"
,"PLAYER_NAME": "Your name, in this case, __name__"
"GAME_STATES": "You are about to follow `MODERATOR_INSTRUCTION`, but before taking any action, analyze each player, including the living and the dead, and summarize the game states.
For each player, your reflection should be a ONE-LINE json covering the following dimension, return a LIST of jsons (return an empty LIST for the first night):
[
{"TARGET": "the player you will analyze, if the player is yourself or your werewolf partner, indicate it" ,"STATUS": "living or dead, if dead, how was he/she possibly killed?", "CLAIMED_ROLE": "claims a role or not, if so, what role, any contradiction to others? If there is no claim, return 'None'", "SIDE_WITH": "sides with which players? If none, return 'None'", "ACCUSE": "accuses which players? If none, return 'None'"}
,{...}
,...
]"
,"REFLECTION": "Based on the whole `GAME_STATES`, return a json (return an empty string for the first night):
{
"Player1": "the true role (werewolf / special role / villager, living or dead) you infer about him/her, and why is this role? If the player is yourself or your werewolf partner, indicate it."
,...
,"Player7": "the true role (werewolf / special role / villager, living or dead) you infer about him/her, and why is this role? If the player is yourself or your werewolf partner, indicate it."
,"GAME_STATE_SUMMARIZATION": "summarize the current situation from your standpoint in one sentence, your summarization should catch the most important information from your reflection, such as conflicts, number of living werewolves, special roles, and villagers."
}"
}
}
"""
name: str = "Reflect"
@retry(stop=stop_after_attempt(2), wait=wait_fixed(1))
async def run(self, profile: str, name: str, context: str, latest_instruction: str):
prompt = (
self.PROMPT_TEMPLATE.replace("__context__", context)
.replace("__profile__", profile)
.replace("__name__", name)
.replace("__latest_instruction__", latest_instruction)
)
rsp = await self._aask(prompt)
rsp_json = log_and_parse_json(self.name, rsp)
return json.dumps(rsp_json["REFLECTION"])

View file

@ -0,0 +1,162 @@
import json
from typing import Optional
import chromadb
from pydantic import model_validator
from metagpt.actions import Action
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.environment.werewolf.const import RoleType
from metagpt.ext.werewolf.schema import RoleExperience
from metagpt.logs import logger
from metagpt.rag.engines.simple import SimpleEngine
from metagpt.rag.schema import ChromaIndexConfig, ChromaRetrieverConfig
from metagpt.utils.common import read_json_file, write_json_file
DEFAULT_COLLECTION_NAME = "role_reflection" # FIXME: some hard code for now
PERSIST_PATH = DEFAULT_WORKSPACE_ROOT.joinpath("werewolf_game/chroma")
PERSIST_PATH.mkdir(parents=True, exist_ok=True)
class AddNewExperiences(Action):
name: str = "AddNewExperience"
collection_name: str = DEFAULT_COLLECTION_NAME
delete_existing: bool = False
engine: Optional[SimpleEngine] = None
@model_validator(mode="after")
def validate_collection(self):
if self.engine:
return
if self.delete_existing:
try:
# implement engine `DELETE` method later
chromadb.PersistentClient(PERSIST_PATH.as_posix()).delete_collection(self.collection_name)
except Exception as exp:
logger.error(f"delete chroma collection: {self.collection_name} failed, exp: {exp}")
self.engine = SimpleEngine.from_objs(
retriever_configs=[
ChromaRetrieverConfig(
persist_path=PERSIST_PATH, collection_name=self.collection_name, metadata={"hnsw:space": "cosine"}
)
]
)
def run(self, experiences: list[RoleExperience]):
if not experiences:
return
for i, exp in enumerate(experiences):
exp.id = f"{exp.profile}-{exp.name}-step{i}-round_{exp.round_id}"
AddNewExperiences._record_experiences_local(experiences)
self.engine.add_objs(experiences)
def add_from_file(self, file_path):
experiences = read_json_file(file_path)
experiences = [RoleExperience.model_validate(item) for item in experiences]
experiences = [exp for exp in experiences if len(exp.reflection) > 2] # not "" or not '""'
self.engine.add_objs(experiences)
@staticmethod
def _record_experiences_local(experiences: list[RoleExperience]):
round_id = experiences[0].round_id
version = experiences[0].version
version = "test" if not version else version
experiences = [exp.model_dump() for exp in experiences]
experience_path = DEFAULT_WORKSPACE_ROOT.joinpath(f"werewolf_game/experiences/{version}")
experience_path.mkdir(parents=True, exist_ok=True)
save_path = f"{experience_path}/{round_id}.json"
write_json_file(save_path, experiences)
logger.info(f"experiences saved to {save_path}")
class RetrieveExperiences(Action):
name: str = "RetrieveExperiences"
collection_name: str = DEFAULT_COLLECTION_NAME
has_experiences: bool = True
engine: Optional[SimpleEngine] = None
topk: int = 10
@model_validator(mode="after")
def validate_collection(self):
if self.engine:
return
try:
self.engine = SimpleEngine.from_index(
index_config=ChromaIndexConfig(
persist_path=PERSIST_PATH, collection_name=self.collection_name, metadata={"hnsw:space": "cosine"}
),
retriever_configs=[
ChromaRetrieverConfig(
similarity_top_k=self.topk,
persist_path=PERSIST_PATH,
collection_name=self.collection_name,
metadata={"hnsw:space": "cosine"},
)
],
)
except Exception as exp:
logger.warning(f"No experience pool: {self.collection_name}, exp: {exp}")
def run(self, query: str, profile: str, excluded_version: str = "", verbose: bool = False) -> str:
"""_summary_
Args:
query (str): 用当前的reflection作为query去检索过去相似的reflection
profile (str): _description_
Returns:
_type_: _description_
"""
if not self.engine or len(query) <= 2: # not "" or not '""'
logger.warning("engine is None or query too short")
return ""
# ablation experiment logic
if profile == RoleType.WEREWOLF.value: # role werewolf as baseline, don't use experiences
logger.warning("Disable werewolves' experiences")
return ""
results = self.engine.retrieve(query)
logger.info(f"retrieve {profile}'s experiences")
experiences = [res.metadata["obj"] for res in results]
past_experiences = [] # currently use post-process to filter, and later add `filters` in rag
for exp in experiences:
if exp.profile == profile and exp.version != excluded_version:
past_experiences.append(exp)
if verbose and results:
logger.info("past_experiences: {}".format("\n\n".join(past_experiences)))
distances = results[0].score
logger.info(f"distances: {distances}")
template = """
{
"Situation __i__": "__situation__"
,"Moderator's instruction": "__instruction__"
,"Your action or speech during that time": "__response__"
,"Reality": "In fact, it turned out the true roles are __game_step__",
,"Outcome": "You __outcome__ in the end"
}
"""
past_experiences = [
(
template.replace("__i__", str(i))
.replace("__situation__", exp.reflection)
.replace("__instruction__", exp.instruction)
.replace("__response__", exp.response)
.replace("__game_step__", exp.game_setup.replace("0 | Game setup:\n", "").replace("\n", " "))
.replace("__outcome__", exp.outcome)
)
for i, exp in enumerate(past_experiences)
]
logger.info("past_experiences: {}".format("\n".join(past_experiences)))
logger.info("retrieval done")
return json.dumps(past_experiences)

View file

@ -0,0 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from metagpt.ext.werewolf.actions.common_actions import NighttimeWhispers
class Protect(NighttimeWhispers):
name: str = "Protect"

View file

@ -0,0 +1,39 @@
from metagpt.actions import Action
from metagpt.environment.werewolf.const import STEP_INSTRUCTIONS
class InstructSpeak(Action):
name: str = "InstructSpeak"
async def run(self, step_idx, living_players, werewolf_players, player_hunted, player_current_dead):
instruction_info = STEP_INSTRUCTIONS.get(
step_idx, {"content": "Unknown instruction.", "send_to": {}, "restricted_to": {}}
)
content = instruction_info["content"]
if "{living_players}" in content and "{werewolf_players}" in content:
content = content.format(
living_players=living_players, werewolf_players=werewolf_players, werewolf_num=len(werewolf_players)
)
if "{living_players}" in content:
content = content.format(living_players=living_players)
if "{werewolf_players}" in content:
content = content.format(werewolf_players=werewolf_players)
if "{player_hunted}" in content:
content = content.format(player_hunted=player_hunted)
if "{player_current_dead}" in content:
player_current_dead = "No one" if not player_current_dead else player_current_dead
content = content.format(player_current_dead=player_current_dead)
return content, instruction_info["send_to"], instruction_info["restricted_to"]
class ParseSpeak(Action):
name: str = "ParseSpeak"
async def run(self):
pass
class AnnounceGameResult(Action):
async def run(self, winner: str, win_reason: str):
return f"Game over! {win_reason}. The winner is the {winner}"

View file

@ -0,0 +1,5 @@
from metagpt.ext.werewolf.actions.common_actions import NighttimeWhispers
class Verify(NighttimeWhispers):
name: str = "Verify"

View file

@ -0,0 +1,17 @@
from metagpt.ext.werewolf.actions.common_actions import NighttimeWhispers, Speak
class Hunt(NighttimeWhispers):
name: str = "Hunt"
class Impersonate(Speak):
"""Action: werewolf impersonating a good guy in daytime speak"""
STRATEGY: str = """
Try continuously impersonating a role, such as Seer, Guard, Villager, etc., in order to mislead
other players, make them trust you, and thus hiding your werewolf identity. However, pay attention to what your werewolf partner said,
DONT claim the same role as your werewolf partner. Remmber NOT to reveal your real identity as a werewolf!
"""
name: str = "Impersonate"

View file

@ -0,0 +1,47 @@
from metagpt.environment.werewolf.const import RoleActionRes
from metagpt.ext.werewolf.actions.common_actions import NighttimeWhispers
class Save(NighttimeWhispers):
name: str = "Save"
def _update_prompt_json(
self, prompt_json: dict, role_profile: str, role_name: str, context: str, reflection: str, experiences: str
) -> dict:
del prompt_json["ACTION"]
del prompt_json["ATTENTION"]
prompt_json["OUTPUT_FORMAT"][
"THOUGHTS"
] = "It is night time. Return the thinking steps of your decision of whether to save the player JUST killed this night."
prompt_json["OUTPUT_FORMAT"][
"RESPONSE"
] = "Follow the Moderator's instruction, decide whether you want to save that person or not. Return SAVE or PASS."
return prompt_json
async def run(self, *args, **kwargs):
rsp = await super().run(*args, **kwargs)
action_name, rsp = rsp.split()
return rsp # 只需回复SAVE或PASS不需要带上action名
class Poison(NighttimeWhispers):
STRATEGY: str = """
Only poison a player if you are confident he/she is a werewolf. Don't poison a player randomly or at first night.
If someone claims to be the witch, poison him/her, because you are the only witch, he/she can only be a werewolf.
"""
name: str = "Poison"
def _update_prompt_json(
self, prompt_json: dict, role_profile: str, role_name: str, context: str, reflection: str, experiences: str
) -> dict:
prompt_json["OUTPUT_FORMAT"]["RESPONSE"] += "Or if you want to PASS, return PASS."
return prompt_json
async def run(self, *args, **kwargs):
rsp = await super().run(*args, **kwargs)
if RoleActionRes.PASS.value in rsp.lower():
action_name, rsp = rsp.split() # 带PASS只需回复PASS不需要带上action名否则是Poison PlayerX无需改动
return rsp

View file

@ -0,0 +1,13 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from metagpt.ext.werewolf.roles.base_player import BasePlayer
from metagpt.ext.werewolf.roles.guard import Guard
from metagpt.ext.werewolf.roles.seer import Seer
from metagpt.ext.werewolf.roles.villager import Villager
from metagpt.ext.werewolf.roles.werewolf import Werewolf
from metagpt.ext.werewolf.roles.witch import Witch
from metagpt.ext.werewolf.roles.moderator import Moderator
__all__ = ["BasePlayer", "Guard", "Moderator", "Seer", "Villager", "Witch", "Werewolf"]

View file

@ -0,0 +1,176 @@
import re
from pydantic import Field, SerializeAsAny, model_validator
from metagpt.actions.action import Action
from metagpt.environment.werewolf.const import RoleState, RoleType
from metagpt.ext.werewolf.actions import (
ACTIONS,
AddNewExperiences,
InstructSpeak,
NighttimeWhispers,
Reflect,
RetrieveExperiences,
Speak,
)
from metagpt.ext.werewolf.schema import RoleExperience, WwMessage
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.utils.common import any_to_str
class BasePlayer(Role):
name: str = "PlayerXYZ"
profile: str = "BasePlayer"
special_action_names: list[str] = []
use_reflection: bool = True
use_experience: bool = False
use_memory_selection: bool = False
new_experience_version: str = ""
status: RoleState = RoleState.ALIVE
special_actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
experiences: list[RoleExperience] = []
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 技能和监听配置
self._watch([InstructSpeak]) # 监听Moderator的指令以做行动
special_actions = [ACTIONS[action_name] for action_name in self.special_action_names]
capable_actions = [Speak] + special_actions
self.set_actions(capable_actions) # 给角色赋予行动技能
self.special_actions = special_actions
if not self.use_reflection and self.use_experience:
logger.warning("You must enable use_reflection before using experience")
self.use_experience = False
@model_validator(mode="after")
def check_addresses(self):
if not self.addresses:
self.addresses = {any_to_str(self), self.name, self.profile} if self.name else {any_to_str(self)}
return self
async def _observe(self, ignore_memory=False) -> int:
if self.status != RoleState.ALIVE:
# 死者不再参与游戏
return 0
news = []
if not news:
news = self.rc.msg_buffer.pop_all()
old_messages = [] if ignore_memory else self.rc.memory.get()
for m in news:
if len(m.restricted_to) and self.profile not in m.restricted_to and self.name not in m.restricted_to:
# if the msg is not send to the whole audience ("") nor this role (self.profile or self.name),
# then this role should not be able to receive it and record it into its memory
continue
self.rc.memory.add(m)
self.rc.news = [
n for n in news if (n.cause_by in self.rc.watch or self.profile in n.send_to) and n not in old_messages
]
# TODO to delete
# await super()._observe()
# # 只有发给全体的(""或发给自己的self.profile消息需要走下面的_react流程
# # 其他的收听到即可,不用做动作
# self.rc.news = [msg for msg in self.rc.news if msg.send_to in ["", self.profile]]
return len(self.rc.news)
async def _think(self):
news = self.rc.news[0]
assert news.cause_by == any_to_str(InstructSpeak) # 消息为来自Moderator的指令时才去做动作
if not news.restricted_to:
# 消息接收范围为全体角色的,做公开发言(发表投票观点也算发言)
self.rc.todo = Speak()
elif self.profile in news.restricted_to:
# FIXME: hard code to split, restricted为"Moderator"或"Moderator, 角色profile"
# Moderator加密发给自己的意味着要执行角色的特殊动作
self.rc.todo = self.special_actions[0]()
async def _act(self):
# todo为_think时确定的有两种情况Speak或Protect
todo = self.rc.todo
logger.info(f"{self._setting}: ready to {str(todo)}")
# 可以用这个函数获取该角色的全部记忆和最新的instruction
memories = self.get_all_memories()
latest_instruction = self.get_latest_instruction()
reflection = (
await Reflect().run(
profile=self.profile, name=self.name, context=memories, latest_instruction=latest_instruction
)
if self.use_reflection
else ""
)
experiences = (
RetrieveExperiences().run(
query=reflection, profile=self.profile, excluded_version=self.new_experience_version
)
if self.use_experience
else ""
)
# 根据自己定义的角色Action对应地去runrun的入参可能不同
if isinstance(todo, Speak):
rsp = await todo.run(
profile=self.profile,
name=self.name,
context=memories,
latest_instruction=latest_instruction,
reflection=reflection,
experiences=experiences,
)
restricted_to = set()
elif isinstance(todo, NighttimeWhispers):
rsp = await todo.run(
profile=self.profile, name=self.name, context=memories, reflection=reflection, experiences=experiences
)
restricted_to = {RoleType.MODERATOR.value, self.profile} # 给Moderator发送使用特殊技能的加密消息
msg = WwMessage(
content=rsp,
role=self.profile,
sent_from=self.name,
cause_by=type(todo),
send_to={},
restricted_to=restricted_to,
)
self.experiences.append(
RoleExperience(
name=self.name,
profile=self.profile,
reflection=reflection,
instruction=latest_instruction,
response=rsp,
version=self.new_experience_version,
)
)
logger.info(f"{self._setting}: {rsp}")
return msg
def get_all_memories(self) -> str:
memories = self.rc.memory.get()
time_stamp_pattern = r"[0-9]+ \| "
# NOTE: 除Moderator外其他角色使用memory只能用m.sent_from玩家名不能用m.role玩家角色因为他们不知道说话者的身份
memories = [f"{m.sent_from}: {re.sub(time_stamp_pattern, '', m.content)}" for m in memories] # regex去掉时间戳
memories = "\n".join(memories)
return memories
def get_latest_instruction(self) -> str:
return self.rc.important_memory[-1].content # 角色监听着Moderator的InstructSpeak是其重要记忆直接获取即可
def set_status(self, new_status: RoleState):
self.status = new_status
def record_experiences(self, round_id: str, outcome: str, game_setup: str):
experiences = [exp for exp in self.experiences if len(exp.reflection) > 2] # not "" or not '""'
for exp in experiences:
exp.round_id = round_id
exp.outcome = outcome
exp.game_setup = game_setup
AddNewExperiences().run(experiences)

View file

@ -0,0 +1,8 @@
from metagpt.environment.werewolf.const import RoleType
from metagpt.ext.werewolf.roles.base_player import BasePlayer
class Guard(BasePlayer):
name: str = RoleType.GUARD.value
profile: str = RoleType.GUARD.value
special_action_names: list[str] = ["Protect"]

View file

@ -0,0 +1,45 @@
from metagpt.environment.werewolf.const import RoleType
from metagpt.ext.werewolf.actions import Speak
from metagpt.ext.werewolf.roles import BasePlayer
from metagpt.ext.werewolf.schema import WwMessage
from metagpt.logs import logger
async def _act(self):
todo = self.rc.todo
memories = self.get_all_memories()
input_instruction = f"""
## As a reminder, you have access to the following game history:
{memories}
## You are {self.name}({self.profile})
## Guidance:
1. If you are performing a special action or exercising a vote,
end your response with "PlayerX", replace PlayerX with the actual player name, e.g., "..., kill/protect/poison/.../vote Player1".
2. If it is a daytime free speech, you can speak in whatever format.
Now, please speak:
"""
rsp = input(input_instruction) # wait for human input
msg_cause_by = type(todo)
msg_restricted_to = {} if isinstance(todo, Speak) else {RoleType.MODERATOR.value, self.profile}
msg = WwMessage(
content=rsp,
role=self.profile,
sent_from=self.name,
cause_by=msg_cause_by,
send_to={},
restricted_to=msg_restricted_to, # 给Moderator及自身阵营发送加密消息
)
logger.info(f"{self._setting}: {rsp}")
return msg
def prepare_human_player(player_class: BasePlayer):
# Dynamically define a human player class that inherits from a certain role class
HumanPlayer = type("HumanPlayer", (player_class,), {"_act": _act})
return HumanPlayer

View file

@ -0,0 +1,251 @@
import re
from datetime import datetime
from typing import Union
from metagpt.actions.add_requirement import UserRequirement
from metagpt.const import DEFAULT_WORKSPACE_ROOT, MESSAGE_ROUTE_TO_ALL
from metagpt.environment.werewolf.const import (
STEP_INSTRUCTIONS,
RoleActionRes,
RoleState,
RoleType,
)
from metagpt.environment.werewolf.env_space import EnvAction, EnvActionType
from metagpt.ext.werewolf.actions import Hunt, Poison, Protect, Save, Verify
from metagpt.ext.werewolf.actions.moderator_actions import (
AnnounceGameResult,
InstructSpeak,
ParseSpeak,
)
from metagpt.ext.werewolf.roles.base_player import BasePlayer
from metagpt.ext.werewolf.schema import WwMessage
from metagpt.logs import logger
from metagpt.utils.common import any_to_str
class Moderator(BasePlayer):
name: str = RoleType.MODERATOR.value
profile: str = RoleType.MODERATOR.value
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._watch([UserRequirement, InstructSpeak, ParseSpeak])
self.set_actions([InstructSpeak, ParseSpeak, AnnounceGameResult])
# game states
self.step_idx = 0
self.game_setup = ""
self.werewolf_players = []
self.winner = None
self.win_reason = None
self.witch_poison_left = 1
self.witch_antidote_left = 1
def update_player_status(self, player_names: list[str]):
if not player_names:
return
roles_in_env = self.rc.env.get_roles()
for role_setting, role in roles_in_env.items():
for player_name in player_names:
if player_name in role_setting:
role.set_status(new_status=RoleState.DEAD) # 更新为死亡
def _record_all_experiences(self):
logger.info(f"The winner of the game: {self.winner}, start to record roles' experiences")
roles_in_env = self.rc.env.get_roles()
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
for _, role in roles_in_env.items():
if role == self:
continue
if self.winner == "werewolf":
outcome = "won" if role.profile in RoleType.WEREWOLF.value else "lost"
else:
outcome = "won" if role.profile not in RoleType.WEREWOLF.value else "lost"
role.record_experiences(round_id=timestamp, outcome=outcome, game_setup=self.game_setup)
async def _parse_speak(self, memories):
latest_msg = memories[-1]
latest_msg_content = latest_msg.content
match = re.search(r"Player[0-9]+", latest_msg_content[-10:]) # FIXME: hard code truncation
target = match.group(0) if match else ""
# default return
msg_content = "Understood"
restricted_to = set()
msg_cause_by = latest_msg.cause_by
if msg_cause_by == any_to_str(Hunt):
self.rc.env.step(
EnvAction(
action_type=EnvActionType.WOLF_KILL, player_name=latest_msg.sent_from, target_player_name=target
)
)
elif msg_cause_by == any_to_str(Protect):
self.rc.env.step(
EnvAction(
action_type=EnvActionType.GUARD_PROTECT, player_name=latest_msg.sent_from, target_player_name=target
)
)
elif msg_cause_by == any_to_str(Verify):
if target in self.werewolf_players:
msg_content = f"{target} is a werewolf"
else:
msg_content = f"{target} is a good guy"
restricted_to = {RoleType.MODERATOR.value, RoleType.SEER.value}
elif msg_cause_by == any_to_str(Save):
if RoleActionRes.PASS.value in latest_msg_content.lower():
# the role ignore to response, answer `pass`
pass
elif not self.witch_antidote_left:
msg_content = "You have no antidote left and thus can not save the player"
restricted_to = {RoleType.MODERATOR.value, RoleType.WITCH.value}
else:
self.rc.env.step(
EnvAction(
action_type=EnvActionType.WITCH_SAVE,
player_name=latest_msg.sent_from,
target_player_name=target,
)
)
elif msg_cause_by == any_to_str(Poison):
if RoleActionRes.PASS.value in latest_msg_content.lower():
pass
elif not self.witch_poison_left:
msg_content = "You have no poison left and thus can not poison the player"
restricted_to = {RoleType.MODERATOR.value, RoleType.WITCH.value}
else:
self.rc.env.step(
EnvAction(
action_type=EnvActionType.WITCH_POISON,
player_name=latest_msg.sent_from,
target_player_name=target,
)
)
return msg_content, restricted_to
def _update_player_status(self, step_idx: int, player_current_dead: list[str]):
"""update dead player's status"""
if step_idx in [15, 18]:
self.update_player_status(player_current_dead)
def _record_game_history(self, step_idx: int):
if step_idx and step_idx % len(STEP_INSTRUCTIONS) == 0 or self.winner:
logger.info("a night and day cycle completed, examine all history")
logger.debug(f"all_memories: {self.get_all_memories()}")
with open(DEFAULT_WORKSPACE_ROOT / "werewolf_transcript.txt", "w") as f:
f.write(self.get_all_memories())
async def _observe(self, ignore_memory=False) -> int:
news = []
if not news:
news = self.rc.msg_buffer.pop_all()
old_messages = [] if ignore_memory else self.rc.memory.get()
for m in news:
if len(m.restricted_to) and self.profile not in m.restricted_to and self.name not in m.restricted_to:
# if the msg is not send to the whole audience ("") nor this role (self.profile or self.name),
# then this role should not be able to receive it and record it into its memory
continue
self.rc.memory.add(m)
# add `MESSAGE_ROUTE_TO_ALL in n.send_to` make it to run `ParseSpeak`
self.rc.news = [
n
for n in news
if (n.cause_by in self.rc.watch or self.profile in n.send_to or MESSAGE_ROUTE_TO_ALL in n.send_to)
and n not in old_messages
]
return len(self.rc.news)
async def _think(self):
if self.winner:
self.rc.todo = AnnounceGameResult()
return
latest_msg = self.rc.memory.get()[-1]
if latest_msg.role in ["User", "Human", self.profile]:
# 1. 上一轮消息是用户指令,解析用户指令,开始游戏
# 2.1. 上一轮消息是Moderator自己的指令继续发出指令一个事情可以分几条消息来说
# 2.2. 上一轮消息是Moderator自己的解析消息一个阶段结束发出新一个阶段的指令
self.rc.todo = InstructSpeak()
else:
# 上一轮消息是游戏角色的发言,解析角色的发言
self.rc.todo = ParseSpeak()
def _init_fields_from_obj(self, obs: dict[str, Union[int, str, list[str]]]):
self.game_setup = obs.get("game_setup", "")
self.step_idx = obs.get("step_idx", 0)
self.winner = obs.get("winner")
self.win_reason = obs.get("win_reason")
self.werewolf_players = obs.get("werewolf_players", [])
self.witch_poison_left = obs.get("witch_poison_left", 0)
self.witch_antidote_left = obs.get("witch_antidote_left", 0)
async def _act(self):
todo = self.rc.todo
logger.info(f"{self._setting} ready to {todo}")
memories = self.get_all_memories(mode="msg")
obs, _, _, _, _ = self.rc.env.step(action=EnvAction(action_type=EnvActionType.NONE))
living_players = obs["living_players"]
werewolf_players = obs["werewolf_players"]
player_hunted = obs["player_hunted"]
player_current_dead = obs["player_current_dead"]
self._init_fields_from_obj(obs)
# 若进行完一夜一日的循环,打印和记录一次完整发言历史
self._record_game_history(self.step_idx)
# 若一晚或一日周期结束,对当晚或当日的死者进行总结,并更新玩家状态
self._update_player_status(self.step_idx, player_current_dead)
if self.winner:
self._record_all_experiences()
# 根据_think的结果执行InstructSpeak还是ParseSpeak, 并将结果返回
if isinstance(todo, InstructSpeak):
msg_content, msg_to_send_to, msg_restricted_to = await InstructSpeak().run(
self.step_idx,
living_players=living_players,
werewolf_players=werewolf_players,
player_hunted=player_hunted,
player_current_dead=player_current_dead,
)
# msg_content = f"Step {self.step_idx}: {msg_content}" # HACK: 加一个unique的step_idx避免记忆的自动去重
msg = WwMessage(
content=msg_content,
role=self.profile,
sent_from=self.name,
cause_by=InstructSpeak,
send_to=msg_to_send_to,
restricted_to=msg_restricted_to,
)
logger.info(f"current step_idx: {self.step_idx}")
self.rc.env.step(EnvAction(action_type=EnvActionType.PROGRESS_STEP)) # to update step_idx
elif isinstance(todo, ParseSpeak):
msg_content, msg_restricted_to = await self._parse_speak(memories)
# msg_content = f"Step {self.step_idx}: {msg_content}" # HACK: 加一个unique的step_idx避免记忆的自动去重
msg = WwMessage(
content=msg_content,
role=self.profile,
sent_from=self.name,
cause_by=ParseSpeak,
send_to={},
restricted_to=msg_restricted_to,
)
elif isinstance(todo, AnnounceGameResult):
msg_content = await AnnounceGameResult().run(winner=self.winner, win_reason=self.win_reason)
msg = WwMessage(content=msg_content, role=self.profile, sent_from=self.name, cause_by=AnnounceGameResult)
logger.info(f"{self._setting}: {msg_content}")
return msg
def get_all_memories(self, mode="str") -> str:
memories = self.rc.memory.get()
if mode == "str":
memories = [f"{m.sent_from}({m.role}): {m.content}" for m in memories]
memories = "\n".join(memories)
return memories

View file

@ -0,0 +1,8 @@
from metagpt.environment.werewolf.const import RoleType
from metagpt.ext.werewolf.roles.base_player import BasePlayer
class Seer(BasePlayer):
name: str = RoleType.SEER.value
profile: str = RoleType.SEER.value
special_action_names: list[str] = ["Verify"]

View file

@ -0,0 +1,8 @@
from metagpt.environment.werewolf.const import RoleType
from metagpt.ext.werewolf.roles.base_player import BasePlayer
class Villager(BasePlayer):
name: str = RoleType.VILLAGER.value
profile: str = RoleType.VILLAGER.value
special_action_names: list[str] = []

View file

@ -0,0 +1,15 @@
from metagpt.environment.werewolf.const import RoleType
from metagpt.ext.werewolf.actions import Impersonate, Speak
from metagpt.ext.werewolf.roles.base_player import BasePlayer
class Werewolf(BasePlayer):
name: str = RoleType.WEREWOLF.value
profile: str = RoleType.WEREWOLF.value
special_action_names: list[str] = ["Hunt"]
async def _think(self):
"""狼人白天发言时需要伪装与其他角色不同因此需要重写_think"""
await super()._think()
if isinstance(self.rc.todo, Speak):
self.rc.todo = Impersonate()

View file

@ -0,0 +1,28 @@
from metagpt.environment.werewolf.const import RoleType
from metagpt.ext.werewolf.actions import InstructSpeak, Poison, Save, Speak
from metagpt.ext.werewolf.roles.base_player import BasePlayer
from metagpt.utils.common import any_to_str
class Witch(BasePlayer):
name: str = RoleType.WITCH.value
profile: str = RoleType.WITCH.value
special_action_names: list[str] = ["Save", "Poison"]
async def _think(self):
"""女巫涉及两个特殊技能因此在此需要改写_think进行路由"""
news = self.rc.news[0]
assert news.cause_by == any_to_str(InstructSpeak) # 消息为来自Moderator的指令时才去做动作
if not news.restricted_to:
# 消息接收范围为全体角色的,做公开发言(发表投票观点也算发言)
self.rc.todo = Speak()
elif self.profile in news.restricted_to:
# FIXME: hard code to split, restricted为"Moderator"或"Moderator,角色profile"
# Moderator加密发给自己的意味着要执行角色的特殊动作
# 这里用关键词进行动作的选择需要Moderator侧的指令进行配合
if "save" in news.content.lower():
self.rc.todo = Save()
elif "poison" in news.content.lower():
self.rc.todo = Poison()
else:
raise ValueError("Moderator's instructions must include save or poison keyword")

View file

@ -0,0 +1,33 @@
from typing import Any
from pydantic import BaseModel, Field, field_validator
from metagpt.schema import Message
from metagpt.utils.common import any_to_str_set
class RoleExperience(BaseModel):
id: str = ""
name: str = ""
profile: str
reflection: str
instruction: str = ""
response: str
outcome: str = ""
round_id: str = ""
game_setup: str = ""
version: str = ""
def rag_key(self) -> str:
"""For search"""
return self.reflection
class WwMessage(Message):
# Werewolf Message
restricted_to: set[str] = Field(default=set(), validate_default=True)
@field_validator("restricted_to", mode="before")
@classmethod
def check_restricted_to(cls, restricted_to: Any):
return any_to_str_set(restricted_to if restricted_to else set())

View file

@ -0,0 +1,28 @@
from typing import Any, Optional
from metagpt.actions.add_requirement import UserRequirement
from metagpt.context import Context
from metagpt.environment.werewolf.werewolf_env import WerewolfEnv
from metagpt.ext.werewolf.schema import WwMessage
from metagpt.team import Team
class WerewolfGame(Team):
"""Use the "software company paradigm" to hold a werewolf game"""
env: Optional[WerewolfEnv] = None
def __init__(self, context: Context = None, **data: Any):
super(Team, self).__init__(**data)
ctx = context or Context()
if not self.env:
self.env = WerewolfEnv(context=ctx)
else:
self.env.context = ctx # The `env` object is allocated by deserialization
def run_project(self, idea):
"""Run a project from user instruction."""
self.idea = idea
self.env.publish_message(
WwMessage(role="User", content=idea, cause_by=UserRequirement, restricted_to={"Moderator"})
)

View file

@ -161,6 +161,13 @@ class SimpleEngine(RetrieverQueryEngine):
"""Inplement tools.SearchInterface"""
return await self.aquery(content)
def retrieve(self, query: QueryType) -> list[NodeWithScore]:
query_bundle = QueryBundle(query) if isinstance(query, str) else query
nodes = super().retrieve(query_bundle)
self._try_reconstruct_obj(nodes)
return nodes
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Allow query to be str."""
query_bundle = QueryBundle(query) if isinstance(query, str) else query

View file

@ -48,7 +48,7 @@ class RAGIndexFactory(ConfigBasedFactory):
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)

View file

@ -69,7 +69,7 @@ class RetrieverFactory(ConfigBasedFactory):
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)

View file

@ -1,8 +1,9 @@
"""RAG schemas."""
from pathlib import Path
from typing import Any, ClassVar, Literal, Union
from typing import Any, ClassVar, Literal, Optional, Union
from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
@ -59,6 +60,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class ElasticsearchStoreConfig(BaseModel):
@ -144,6 +148,9 @@ class ChromaIndexConfig(VectorIndexConfig):
"""Config for chroma-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class BM25IndexConfig(BaseIndexConfig):

View file

@ -335,6 +335,11 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
self.llm.cost_manager = self.context.cost_manager
self.set_actions(self.actions) # reset actions to update llm and prefix
@property
def name(self):
"""Get the role name"""
return self._setting.name
def _get_prefix(self):
"""Get the role prefix"""
if self.desc:

View file

@ -722,7 +722,10 @@ def list_files(root: str | Path) -> List[Path]:
def parse_json_code_block(markdown_text: str) -> List[str]:
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL)
json_blocks = (
re.findall(r"```json(.*?)```", markdown_text, re.DOTALL) if "```json" in markdown_text else [markdown_text]
)
return [v.strip() for v in json_blocks]

View file

@ -2,33 +2,34 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of WerewolfExtEnv
from metagpt.environment.werewolf.werewolf_ext_env import RoleState, WerewolfExtEnv
from metagpt.environment.werewolf.const import RoleState, RoleType
from metagpt.environment.werewolf.werewolf_ext_env import WerewolfExtEnv
from metagpt.roles.role import Role
class Werewolf(Role):
profile: str = "Werewolf"
profile: str = RoleType.WEREWOLF.value
class Villager(Role):
profile: str = "Villager"
profile: str = RoleType.VILLAGER.value
class Witch(Role):
profile: str = "Witch"
profile: str = RoleType.WITCH.value
class Guard(Role):
profile: str = "Guard"
profile: str = RoleType.GUARD.value
def test_werewolf_ext_env():
players_state = {
"Player0": ("Werewolf", RoleState.ALIVE),
"Player1": ("Werewolf", RoleState.ALIVE),
"Player2": ("Villager", RoleState.ALIVE),
"Player3": ("Witch", RoleState.ALIVE),
"Player4": ("Guard", RoleState.ALIVE),
"Player0": (RoleType.WEREWOLF.value, RoleState.ALIVE),
"Player1": (RoleType.WEREWOLF.value, RoleState.ALIVE),
"Player2": (RoleType.VILLAGER.value, RoleState.ALIVE),
"Player3": (RoleType.WITCH.value, RoleState.ALIVE),
"Player4": (RoleType.GUARD.value, RoleState.ALIVE),
}
ext_env = WerewolfExtEnv(players_state=players_state, step_idx=4, special_role_players=["Player3", "Player4"])
@ -41,9 +42,9 @@ def test_werewolf_ext_env():
assert "Werewolves, please open your eyes" in curr_instr["content"]
# current step_idx = 5
ext_env.wolf_kill_someone(wolf=Role(name="Player10"), player_name="Player4")
ext_env.wolf_kill_someone(wolf=Werewolf(name="Player0"), player_name="Player4")
ext_env.wolf_kill_someone(wolf=Werewolf(name="Player1"), player_name="Player4")
ext_env.wolf_kill_someone(wolf_name="Player10", player_name="Player4")
ext_env.wolf_kill_someone(wolf_name="Player0", player_name="Player4")
ext_env.wolf_kill_someone(wolf_name="Player1", player_name="Player4")
assert ext_env.player_hunted == "Player4"
assert len(ext_env.living_players) == 5 # hunted but can be saved by witch
@ -52,11 +53,11 @@ def test_werewolf_ext_env():
# current step_idx = 18
assert ext_env.step_idx == 18
ext_env.vote_kill_someone(voteer=Werewolf(name="Player0"), player_name="Player2")
ext_env.vote_kill_someone(voteer=Werewolf(name="Player1"), player_name="Player3")
ext_env.vote_kill_someone(voteer=Villager(name="Player2"), player_name="Player3")
ext_env.vote_kill_someone(voteer=Witch(name="Player3"), player_name="Player4")
ext_env.vote_kill_someone(voteer=Guard(name="Player4"), player_name="Player2")
ext_env.vote_kill_someone(voter_name="Player0", player_name="Player2")
ext_env.vote_kill_someone(voter_name="Player1", player_name="Player3")
ext_env.vote_kill_someone(voter_name="Player2", player_name="Player3")
ext_env.vote_kill_someone(voter_name="Player3", player_name="Player4")
ext_env.vote_kill_someone(voter_name="Player4", player_name="Player2")
assert ext_env.player_current_dead == "Player2"
assert len(ext_env.living_players) == 4

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,164 @@
import json
import pytest
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.ext.werewolf.actions import AddNewExperiences, RetrieveExperiences
from metagpt.ext.werewolf.schema import RoleExperience
from metagpt.logs import logger
class TestExperiencesOperation:
collection_name = "test"
test_round_id = "test_01"
version = "test"
samples_to_add = [
RoleExperience(
profile="Witch",
reflection="The game is intense with two players claiming to be the Witch and one claiming to be the Seer. "
"Player4's behavior is suspicious.",
response="",
outcome="",
round_id=test_round_id,
version=version,
),
RoleExperience(
profile="Witch",
reflection="The game is in a critical state with only three players left, "
"and I need to make a wise decision to save Player7 or not.",
response="",
outcome="",
round_id=test_round_id,
version=version,
),
RoleExperience(
profile="Seer",
reflection="Player1, who is a werewolf, falsely claimed to be a Seer, and Player6, who might be a Witch, "
"sided with him. I, as the real Seer, am under suspicion.",
response="",
outcome="",
round_id=test_round_id,
version=version,
),
RoleExperience(
profile="TestRole",
reflection="Some test reflection1",
response="",
outcome="",
round_id=test_round_id,
version=version + "_01-10",
),
RoleExperience(
profile="TestRole",
reflection="Some test reflection2",
response="",
outcome="",
round_id=test_round_id,
version=version + "_11-20",
),
RoleExperience(
profile="TestRole",
reflection="Some test reflection3",
response="",
outcome="",
round_id=test_round_id,
version=version + "_21-30",
),
]
@pytest.mark.asyncio
async def test_add(self):
saved_file = DEFAULT_WORKSPACE_ROOT.joinpath(
f"werewolf_game/experiences/{self.version}/{self.test_round_id}.json"
)
if saved_file.exists():
saved_file.unlink()
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
action.run(self.samples_to_add)
# test insertion
inserted = action.engine.retriever._index._vector_store._collection.get()
assert len(inserted["documents"]) == len(self.samples_to_add)
# test if we record the samples correctly to local file
# & test if we could recover a embedding db from the file
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
action.add_from_file(saved_file)
inserted = action.engine.retriever._index._vector_store._collection.get()
assert len(inserted["documents"]) == len(self.samples_to_add)
@pytest.mark.asyncio
async def test_retrieve(self):
action = RetrieveExperiences(collection_name=self.collection_name)
query = "one player claimed to be Seer and the other Witch"
results = action.run(query, profile="Witch")
results = json.loads(results)
assert len(results) == 2, "Witch should have 2 experiences"
assert "The game is intense with two players" in results[0]
@pytest.mark.asyncio
async def test_retrieve_filtering(self):
action = RetrieveExperiences(collection_name=self.collection_name)
query = "some test query"
profile = "TestRole"
excluded_version = ""
results = action.run(query, profile=profile, excluded_version=excluded_version)
results = json.loads(results)
assert len(results) == 3
excluded_version = self.version + "_21-30"
results = action.run(query, profile=profile, excluded_version=excluded_version)
results = json.loads(results)
assert len(results) == 2
class TestActualRetrieve:
collection_name = "role_reflection"
@pytest.mark.asyncio
async def test_check_experience_pool(self):
logger.info("check experience pool")
action = RetrieveExperiences(collection_name=self.collection_name)
if action.engine:
all_experiences = action.engine.retriever._index._vector_store._collection.get()
logger.info(f"{len(all_experiences['metadatas'])=}")
@pytest.mark.asyncio
async def test_retrieve_werewolf_experience(self):
action = RetrieveExperiences(collection_name=self.collection_name)
query = "there are conflicts"
logger.info(f"test retrieval with {query=}")
action.run(query, "Werewolf")
@pytest.mark.asyncio
async def test_retrieve_villager_experience(self):
action = RetrieveExperiences(collection_name=self.collection_name)
query = "there are conflicts"
logger.info(f"test retrieval with {query=}")
results = action.run(query, "Seer")
assert "conflict" not in results # 相似局面应该需要包含conflict关键词
@pytest.mark.asyncio
async def test_retrieve_villager_experience_filtering(self):
action = RetrieveExperiences(collection_name=self.collection_name)
query = "there are conflicts"
excluded_version = "01-10"
logger.info(f"test retrieval with {excluded_version=}")
results_01_10 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True)
excluded_version = "11-20"
logger.info(f"test retrieval with {excluded_version=}")
results_11_20 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True)
assert results_01_10 == results_11_20

View file

@ -29,7 +29,7 @@ def test_add_role(env: Environment):
name="Alice", profile="product manager", goal="create a new product", constraints="limited resources"
)
env.add_role(role)
assert env.get_role(role.profile) == role
assert env.get_role(str(role._setting)) == role
def test_get_roles(env: Environment):