Spaces:
Runtime error
Runtime error
commit app
Browse files- .gitattributes +1 -0
- .gitignore +3 -0
- app.py +363 -0
- groundingdino_swint_ogc.pth +3 -0
- images/demo1.jpg +3 -0
- images/demo2.jpg +3 -0
- images/demo4.jpg +3 -0
- ram_swin_large_14m.pth +3 -0
- requirements.txt +27 -0
- sam_vit_h_4b8939.pth +3 -0
- tag2text_swin_14m.pth +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
images/*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
.vscode/
|
| 3 |
+
gradio_cached_examples/
|
app.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# setup Grouded-Segment-Anything
|
| 4 |
+
os.system("python -m pip install -e 'Grounded-Segment-Anything/segment_anything'")
|
| 5 |
+
os.system("python -m pip install -e 'Grounded-Segment-Anything/GroundingDINO'")
|
| 6 |
+
os.system("pip install --upgrade diffusers[torch]")
|
| 7 |
+
os.system("pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel")
|
| 8 |
+
|
| 9 |
+
# setup recognize-anything
|
| 10 |
+
os.system("python -m pip install -e 'recognize-anything'")
|
| 11 |
+
|
| 12 |
+
import random # noqa: E402
|
| 13 |
+
|
| 14 |
+
import cv2 # noqa: E402
|
| 15 |
+
import groundingdino.datasets.transforms as T # noqa: E402
|
| 16 |
+
import numpy as np # noqa: E402
|
| 17 |
+
import torch # noqa: E402
|
| 18 |
+
import torchvision # noqa: E402
|
| 19 |
+
import torchvision.transforms as TS # noqa: E402
|
| 20 |
+
from groundingdino.models import build_model # noqa: E402
|
| 21 |
+
from groundingdino.util.slconfig import SLConfig # noqa: E402
|
| 22 |
+
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # noqa: E402
|
| 23 |
+
from PIL import Image, ImageDraw, ImageFont # noqa: E402
|
| 24 |
+
from ram import inference_ram # noqa: E402
|
| 25 |
+
from ram import inference_tag2text # noqa: E402
|
| 26 |
+
from ram.models import ram # noqa: E402
|
| 27 |
+
from ram.models import tag2text_caption # noqa: E402
|
| 28 |
+
from segment_anything import SamPredictor, build_sam # noqa: E402
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# args
|
| 32 |
+
config_file = "Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
| 33 |
+
ram_checkpoint = "./ram_swin_large_14m.pth"
|
| 34 |
+
tag2text_checkpoint = "./tag2text_swin_14m.pth"
|
| 35 |
+
grounded_checkpoint = "./groundingdino_swint_ogc.pth"
|
| 36 |
+
sam_checkpoint = "./sam_vit_h_4b8939.pth"
|
| 37 |
+
box_threshold = 0.25
|
| 38 |
+
text_threshold = 0.2
|
| 39 |
+
iou_threshold = 0.5
|
| 40 |
+
device = "cpu"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_model(model_config_path, model_checkpoint_path, device):
|
| 44 |
+
args = SLConfig.fromfile(model_config_path)
|
| 45 |
+
args.device = device
|
| 46 |
+
model = build_model(args)
|
| 47 |
+
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
| 48 |
+
load_res = model.load_state_dict(
|
| 49 |
+
clean_state_dict(checkpoint["model"]), strict=False)
|
| 50 |
+
print(load_res)
|
| 51 |
+
_ = model.eval()
|
| 52 |
+
return model
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_grounding_output(model, image, caption, box_threshold, text_threshold, device="cpu"):
|
| 56 |
+
caption = caption.lower()
|
| 57 |
+
caption = caption.strip()
|
| 58 |
+
if not caption.endswith("."):
|
| 59 |
+
caption = caption + "."
|
| 60 |
+
model = model.to(device)
|
| 61 |
+
image = image.to(device)
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
outputs = model(image[None], captions=[caption])
|
| 64 |
+
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
| 65 |
+
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
| 66 |
+
logits.shape[0]
|
| 67 |
+
|
| 68 |
+
# filter output
|
| 69 |
+
logits_filt = logits.clone()
|
| 70 |
+
boxes_filt = boxes.clone()
|
| 71 |
+
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
| 72 |
+
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
| 73 |
+
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
| 74 |
+
logits_filt.shape[0]
|
| 75 |
+
|
| 76 |
+
# get phrase
|
| 77 |
+
tokenlizer = model.tokenizer
|
| 78 |
+
tokenized = tokenlizer(caption)
|
| 79 |
+
# build pred
|
| 80 |
+
pred_phrases = []
|
| 81 |
+
scores = []
|
| 82 |
+
for logit, box in zip(logits_filt, boxes_filt):
|
| 83 |
+
pred_phrase = get_phrases_from_posmap(
|
| 84 |
+
logit > text_threshold, tokenized, tokenlizer)
|
| 85 |
+
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
| 86 |
+
scores.append(logit.max().item())
|
| 87 |
+
|
| 88 |
+
return boxes_filt, torch.Tensor(scores), pred_phrases
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def draw_mask(mask, draw, random_color=False):
|
| 92 |
+
if random_color:
|
| 93 |
+
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 153)
|
| 94 |
+
else:
|
| 95 |
+
color = (30, 144, 255, 153)
|
| 96 |
+
|
| 97 |
+
nonzero_coords = np.transpose(np.nonzero(mask))
|
| 98 |
+
|
| 99 |
+
for coord in nonzero_coords:
|
| 100 |
+
draw.point(coord[::-1], fill=color)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def draw_box(box, draw, label):
|
| 104 |
+
# random color
|
| 105 |
+
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
| 106 |
+
line_width = min(5, max(25, 0.006*max(draw.im.size)))
|
| 107 |
+
draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline=color, width=line_width)
|
| 108 |
+
|
| 109 |
+
if label:
|
| 110 |
+
font_path = os.path.join(
|
| 111 |
+
cv2.__path__[0], 'qt', 'fonts', 'DejaVuSans.ttf')
|
| 112 |
+
font_size = min(15, max(75, 0.02*max(draw.im.size)))
|
| 113 |
+
font = ImageFont.truetype(font_path, size=font_size)
|
| 114 |
+
if hasattr(font, "getbbox"):
|
| 115 |
+
bbox = draw.textbbox((box[0], box[1]), str(label), font)
|
| 116 |
+
else:
|
| 117 |
+
w, h = draw.textsize(str(label), font)
|
| 118 |
+
bbox = (box[0], box[1], w + box[0], box[1] + h)
|
| 119 |
+
draw.rectangle(bbox, fill=color)
|
| 120 |
+
draw.text((box[0], box[1]), str(label), fill="white", font=font)
|
| 121 |
+
|
| 122 |
+
draw.text((box[0], box[1]), label, font=font)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grounding_dino_model, sam_model):
|
| 126 |
+
raw_image = raw_image.convert("RGB")
|
| 127 |
+
|
| 128 |
+
# run tagging model
|
| 129 |
+
normalize = TS.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 130 |
+
transform = TS.Compose([
|
| 131 |
+
TS.Resize((384, 384)),
|
| 132 |
+
TS.ToTensor(),
|
| 133 |
+
normalize
|
| 134 |
+
])
|
| 135 |
+
|
| 136 |
+
image = raw_image.resize((384, 384))
|
| 137 |
+
image = transform(image).unsqueeze(0).to(device)
|
| 138 |
+
|
| 139 |
+
# Currently ", " is better for detecting single tags
|
| 140 |
+
# while ". " is a little worse in some case
|
| 141 |
+
if tagging_model_type == "RAM":
|
| 142 |
+
res = inference_ram(image, tagging_model)
|
| 143 |
+
tags = res[0].strip(' ').replace(' ', ' ').replace(' |', ',')
|
| 144 |
+
tags_chinese = res[1].strip(' ').replace(' ', ' ').replace(' |', ',')
|
| 145 |
+
print("Tags: ", tags)
|
| 146 |
+
print("图像标签: ", tags_chinese)
|
| 147 |
+
else:
|
| 148 |
+
res = inference_tag2text(image, tagging_model, specified_tags)
|
| 149 |
+
tags = res[0].strip(' ').replace(' ', ' ').replace(' |', ',')
|
| 150 |
+
caption = res[2]
|
| 151 |
+
print(f"Tags: {tags}")
|
| 152 |
+
print(f"Caption: {caption}")
|
| 153 |
+
|
| 154 |
+
# run groundingDINO
|
| 155 |
+
transform = T.Compose([
|
| 156 |
+
T.RandomResize([800], max_size=1333),
|
| 157 |
+
T.ToTensor(),
|
| 158 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 159 |
+
])
|
| 160 |
+
|
| 161 |
+
image, _ = transform(raw_image, None) # 3, h, w
|
| 162 |
+
|
| 163 |
+
boxes_filt, scores, pred_phrases = get_grounding_output(
|
| 164 |
+
grounding_dino_model, image, tags, box_threshold, text_threshold, device=device
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# run SAM
|
| 168 |
+
image = np.asarray(raw_image)
|
| 169 |
+
sam_model.set_image(image)
|
| 170 |
+
|
| 171 |
+
size = raw_image.size
|
| 172 |
+
H, W = size[1], size[0]
|
| 173 |
+
for i in range(boxes_filt.size(0)):
|
| 174 |
+
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
| 175 |
+
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
| 176 |
+
boxes_filt[i][2:] += boxes_filt[i][:2]
|
| 177 |
+
|
| 178 |
+
boxes_filt = boxes_filt.cpu()
|
| 179 |
+
# use NMS to handle overlapped boxes
|
| 180 |
+
nms_idx = torchvision.ops.nms(
|
| 181 |
+
boxes_filt, scores, iou_threshold).numpy().tolist()
|
| 182 |
+
boxes_filt = boxes_filt[nms_idx]
|
| 183 |
+
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
| 184 |
+
|
| 185 |
+
transformed_boxes = sam_model.transform.apply_boxes_torch(
|
| 186 |
+
boxes_filt, image.shape[:2]).to(device)
|
| 187 |
+
|
| 188 |
+
masks, _, _ = sam_model.predict_torch(
|
| 189 |
+
point_coords=None,
|
| 190 |
+
point_labels=None,
|
| 191 |
+
boxes=transformed_boxes.to(device),
|
| 192 |
+
multimask_output=False,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# draw output image
|
| 196 |
+
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
|
| 197 |
+
|
| 198 |
+
mask_draw = ImageDraw.Draw(mask_image)
|
| 199 |
+
for mask in masks:
|
| 200 |
+
draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True)
|
| 201 |
+
|
| 202 |
+
image_draw = ImageDraw.Draw(raw_image)
|
| 203 |
+
|
| 204 |
+
for box, label in zip(boxes_filt, pred_phrases):
|
| 205 |
+
draw_box(box, image_draw, label)
|
| 206 |
+
|
| 207 |
+
out_image = raw_image.convert('RGBA')
|
| 208 |
+
out_image.alpha_composite(mask_image)
|
| 209 |
+
|
| 210 |
+
# return
|
| 211 |
+
if tagging_model_type == "RAM":
|
| 212 |
+
return tags, tags_chinese, out_image
|
| 213 |
+
else:
|
| 214 |
+
return tags, caption, out_image
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
import gradio as gr
|
| 219 |
+
|
| 220 |
+
# load RAM
|
| 221 |
+
ram_model = ram(pretrained=ram_checkpoint, image_size=384, vit='swin_l')
|
| 222 |
+
ram_model.eval()
|
| 223 |
+
ram_model = ram_model.to(device)
|
| 224 |
+
|
| 225 |
+
# load Tag2Text
|
| 226 |
+
delete_tag_index = [] # filter out attributes and action categories which are difficult to grounding
|
| 227 |
+
for i in range(3012, 3429):
|
| 228 |
+
delete_tag_index.append(i)
|
| 229 |
+
|
| 230 |
+
tag2text_model = tag2text_caption(pretrained=tag2text_checkpoint,
|
| 231 |
+
image_size=384,
|
| 232 |
+
vit='swin_b',
|
| 233 |
+
delete_tag_index=delete_tag_index)
|
| 234 |
+
tag2text_model.threshold = 0.64 # we reduce the threshold to obtain more tags
|
| 235 |
+
tag2text_model.eval()
|
| 236 |
+
tag2text_model = tag2text_model.to(device)
|
| 237 |
+
|
| 238 |
+
# load groundingDINO
|
| 239 |
+
grounding_dino_model = load_model(config_file, grounded_checkpoint, device=device)
|
| 240 |
+
|
| 241 |
+
# load SAM
|
| 242 |
+
sam_model = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
|
| 243 |
+
|
| 244 |
+
# build GUI
|
| 245 |
+
def build_gui():
|
| 246 |
+
|
| 247 |
+
description = """
|
| 248 |
+
<center><strong><font size='10'>Recognize Anything Model + Grounded-SAM</font></strong></center>
|
| 249 |
+
<br>
|
| 250 |
+
Welcome to the RAM/Tag2Text + Grounded-SAM demo! <br><br>
|
| 251 |
+
<li>
|
| 252 |
+
<b>Recognize Anything Model + Grounded-SAM:</b> Upload your image to get the <b>English and Chinese tags</b> (by RAM) and <b>masks and boxes</b> (by Grounded-SAM)!
|
| 253 |
+
</li>
|
| 254 |
+
<li>
|
| 255 |
+
<b>Tag2Text Model + Grounded-SAM:</b> Upload your image to get the <b>tags and caption</b> (by Tag2Text) and <b>masks and boxes</b> (by Grounded-SAM)!
|
| 256 |
+
(Optional: Specify tags to get the corresponding caption.)
|
| 257 |
+
</li>
|
| 258 |
+
""" # noqa
|
| 259 |
+
|
| 260 |
+
article = """
|
| 261 |
+
<p style='text-align: center'>
|
| 262 |
+
RAM and Tag2Text are trained on open-source datasets, and we are persisting in refining and iterating upon it.<br/>
|
| 263 |
+
Grounded-SAM is a combination of Grounding DINO and SAM aming to detect and segment anything with text inputs.<br/>
|
| 264 |
+
<a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a>
|
| 265 |
+
|
|
| 266 |
+
<a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a>
|
| 267 |
+
|
|
| 268 |
+
<a href='https://github.com/IDEA-Research/Grounded-Segment-Anything' target='_blank'>Grounded-Segment-Anything</a>
|
| 269 |
+
</p>
|
| 270 |
+
""" # noqa
|
| 271 |
+
|
| 272 |
+
def inference_with_ram(img):
|
| 273 |
+
return inference(img, None, "RAM", ram_model, grounding_dino_model, sam_model)
|
| 274 |
+
|
| 275 |
+
def inference_with_t2t(img, input_tags):
|
| 276 |
+
return inference(img, input_tags, "Tag2Text", tag2text_model, grounding_dino_model, sam_model)
|
| 277 |
+
|
| 278 |
+
with gr.Blocks(title="Recognize Anything Model") as demo:
|
| 279 |
+
###############
|
| 280 |
+
# components
|
| 281 |
+
###############
|
| 282 |
+
gr.HTML(description)
|
| 283 |
+
|
| 284 |
+
with gr.Tab(label="Recognize Anything Model"):
|
| 285 |
+
with gr.Row():
|
| 286 |
+
with gr.Column():
|
| 287 |
+
ram_in_img = gr.Image(type="pil")
|
| 288 |
+
with gr.Row():
|
| 289 |
+
ram_btn_run = gr.Button(value="Run")
|
| 290 |
+
ram_btn_clear = gr.Button(value="Clear")
|
| 291 |
+
with gr.Column():
|
| 292 |
+
ram_out_img = gr.Image(type="pil")
|
| 293 |
+
ram_out_tag = gr.Textbox(label="Tags")
|
| 294 |
+
ram_out_biaoqian = gr.Textbox(label="标签")
|
| 295 |
+
gr.Examples(
|
| 296 |
+
examples=[
|
| 297 |
+
["images/demo1.jpg"],
|
| 298 |
+
["images/demo2.jpg"],
|
| 299 |
+
["images/demo4.jpg"],
|
| 300 |
+
],
|
| 301 |
+
fn=inference_with_ram,
|
| 302 |
+
inputs=[ram_in_img],
|
| 303 |
+
outputs=[ram_out_tag, ram_out_biaoqian, ram_out_img],
|
| 304 |
+
cache_examples=True
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
with gr.Tab(label="Tag2Text Model"):
|
| 308 |
+
with gr.Row():
|
| 309 |
+
with gr.Column():
|
| 310 |
+
t2t_in_img = gr.Image(type="pil")
|
| 311 |
+
t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
|
| 312 |
+
with gr.Row():
|
| 313 |
+
t2t_btn_run = gr.Button(value="Run")
|
| 314 |
+
t2t_btn_clear = gr.Button(value="Clear")
|
| 315 |
+
with gr.Column():
|
| 316 |
+
t2t_out_img = gr.Image(type="pil")
|
| 317 |
+
t2t_out_tag = gr.Textbox(label="Tags")
|
| 318 |
+
t2t_out_cap = gr.Textbox(label="Caption")
|
| 319 |
+
gr.Examples(
|
| 320 |
+
examples=[
|
| 321 |
+
["images/demo4.jpg", ""],
|
| 322 |
+
["images/demo4.jpg", "power line"],
|
| 323 |
+
["images/demo4.jpg", "track, train"],
|
| 324 |
+
],
|
| 325 |
+
fn=inference_with_t2t,
|
| 326 |
+
inputs=[t2t_in_img, t2t_in_tag],
|
| 327 |
+
outputs=[t2t_out_tag, t2t_out_cap, t2t_out_img],
|
| 328 |
+
cache_examples=True
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
gr.HTML(article)
|
| 332 |
+
|
| 333 |
+
###############
|
| 334 |
+
# events
|
| 335 |
+
###############
|
| 336 |
+
# run inference
|
| 337 |
+
ram_btn_run.click(
|
| 338 |
+
fn=inference_with_ram,
|
| 339 |
+
inputs=[ram_in_img],
|
| 340 |
+
outputs=[ram_out_tag, ram_out_biaoqian, ram_out_img]
|
| 341 |
+
)
|
| 342 |
+
t2t_btn_run.click(
|
| 343 |
+
fn=inference_with_t2t,
|
| 344 |
+
inputs=[t2t_in_img, t2t_in_tag],
|
| 345 |
+
outputs=[t2t_out_tag, t2t_out_cap, t2t_out_img]
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# clear all
|
| 349 |
+
def clear_all():
|
| 350 |
+
return [gr.update(value=None)] * 4 + [gr.update(value="")] * 5
|
| 351 |
+
|
| 352 |
+
ram_btn_clear.click(fn=clear_all, inputs=[], outputs=[
|
| 353 |
+
ram_in_img, ram_out_img, t2t_in_img, t2t_out_img,
|
| 354 |
+
ram_out_tag, ram_out_biaoqian, t2t_in_tag, t2t_out_tag, t2t_out_cap
|
| 355 |
+
])
|
| 356 |
+
t2t_btn_clear.click(fn=clear_all, inputs=[], outputs=[
|
| 357 |
+
ram_in_img, t2t_in_img, t2t_in_img, t2t_out_img,
|
| 358 |
+
ram_out_tag, ram_out_biaoqian, t2t_in_tag, t2t_out_tag, t2t_out_cap
|
| 359 |
+
])
|
| 360 |
+
|
| 361 |
+
return demo
|
| 362 |
+
|
| 363 |
+
build_gui().launch(enable_queue=True, share=True)
|
groundingdino_swint_ogc.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b3ca2563c77c69f651d7bd133e97139c186df06231157a64c507099c52bc799
|
| 3 |
+
size 693997677
|
images/demo1.jpg
ADDED
|
Git LFS Details
|
images/demo2.jpg
ADDED
|
Git LFS Details
|
images/demo4.jpg
ADDED
|
Git LFS Details
|
ram_swin_large_14m.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:15c729c793af28b9d107c69f85836a1356d76ea830d4714699fb62e55fcc08ed
|
| 3 |
+
size 5625634877
|
requirements.txt
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
timm==0.4.12
|
| 2 |
+
transformers==4.15.0
|
| 3 |
+
fairscale==0.4.4
|
| 4 |
+
pycocoevalcap
|
| 5 |
+
torch
|
| 6 |
+
torchvision
|
| 7 |
+
Pillow
|
| 8 |
+
scipy
|
| 9 |
+
git+https://github.com/openai/CLIP.git
|
| 10 |
+
git+https://github.com/IDEA-Research/Grounded-Segment-Anything.git
|
| 11 |
+
git+https://github.com/xinyu1205/recognize-anything.git
|
| 12 |
+
addict
|
| 13 |
+
diffusers
|
| 14 |
+
gradio
|
| 15 |
+
huggingface_hub
|
| 16 |
+
matplotlib
|
| 17 |
+
numpy
|
| 18 |
+
onnxruntime
|
| 19 |
+
opencv_python
|
| 20 |
+
pycocotools
|
| 21 |
+
PyYAML
|
| 22 |
+
requests
|
| 23 |
+
setuptools
|
| 24 |
+
supervision
|
| 25 |
+
termcolor
|
| 26 |
+
yapf
|
| 27 |
+
nltk
|
sam_vit_h_4b8939.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
| 3 |
+
size 2564550879
|
tag2text_swin_14m.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4ce96f0ce98f940a6680d567f66a38ccc9ca8c4e638e5f5c5c2e881a0e3502ac
|
| 3 |
+
size 4478705095
|