Govind K curt-park commited on
Commit
1d1c937
·
0 Parent(s):

Duplicate from curt-park/segment-anything-with-clip

Browse files

Co-authored-by: Jinwoo Park <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ flagged
Makefile ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ env:
2
+ conda create -n segment-anything python=3.9
3
+
4
+ setup:
5
+ pip install -r requirements.txt
6
+
7
+ run:
8
+ gradio app.py
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Segment Anything
3
+ emoji: 🐠
4
+ colorFrom: green
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.24.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: curt-park/segment-anything-with-clip
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
ViT-B-32.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af
3
+ size 353976522
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from random import randint
4
+ from typing import Any, Callable, Dict, List, Tuple
5
+
6
+ import clip
7
+ import cv2
8
+ import gradio as gr
9
+ import numpy as np
10
+ import PIL
11
+ import torch
12
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
13
+
14
+ CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
15
+ MODEL_TYPE = "default"
16
+ MAX_WIDTH = MAX_HEIGHT = 800
17
+ CLIP_WIDTH = CLIP_HEIGHT = 300
18
+ THRESHOLD = 0.05
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+
22
+ @lru_cache
23
+ def load_mask_generator() -> SamAutomaticMaskGenerator:
24
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device)
25
+ mask_generator = SamAutomaticMaskGenerator(sam)
26
+ return mask_generator
27
+
28
+
29
+ @lru_cache
30
+ def load_clip(
31
+ name: str = "ViT-B-32.pt",
32
+ ) -> Tuple[torch.nn.Module, Callable[[PIL.Image.Image], torch.Tensor]]:
33
+ model_path = os.path.join(".", name)
34
+ model, preprocess = clip.load(model_path, device=device)
35
+ return model.to(device), preprocess
36
+
37
+
38
+ def adjust_image_size(image: np.ndarray) -> np.ndarray:
39
+ height, width = image.shape[:2]
40
+ if height > width:
41
+ if height > MAX_HEIGHT:
42
+ height, width = MAX_HEIGHT, int(MAX_HEIGHT / height * width)
43
+ else:
44
+ if width > MAX_WIDTH:
45
+ height, width = int(MAX_WIDTH / width * height), MAX_WIDTH
46
+ image = cv2.resize(image, (width, height))
47
+ return image
48
+
49
+
50
+ @torch.no_grad()
51
+ def get_scores(crops: List[PIL.Image.Image], query: str) -> torch.Tensor:
52
+ model, preprocess = load_clip()
53
+ preprocessed = [preprocess(crop) for crop in crops]
54
+ preprocessed = torch.stack(preprocessed).to(device)
55
+ token = clip.tokenize(query).to(device)
56
+ img_features = model.encode_image(preprocessed)
57
+ txt_features = model.encode_text(token)
58
+ img_features /= img_features.norm(dim=-1, keepdim=True)
59
+ txt_features /= txt_features.norm(dim=-1, keepdim=True)
60
+ similarity = (100.0 * img_features @ txt_features.T).softmax(dim=0)
61
+ return similarity
62
+
63
+
64
+ def filter_masks(
65
+ image: np.ndarray,
66
+ masks: List[Dict[str, Any]],
67
+ predicted_iou_threshold: float,
68
+ stability_score_threshold: float,
69
+ query: str,
70
+ clip_threshold: float,
71
+ ) -> List[Dict[str, Any]]:
72
+ cropped_masks: List[PIL.Image.Image] = []
73
+ filtered_masks: List[Dict[str, Any]] = []
74
+
75
+ for mask in masks:
76
+ if (
77
+ mask["predicted_iou"] < predicted_iou_threshold
78
+ or mask["stability_score"] < stability_score_threshold
79
+ ):
80
+ continue
81
+
82
+ filtered_masks.append(mask)
83
+
84
+ x, y, w, h = mask["bbox"]
85
+ crop = image[y: y + h, x: x + w]
86
+ crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
87
+ crop = PIL.Image.fromarray(np.uint8(crop * 255)).convert("RGB")
88
+ crop.resize((CLIP_WIDTH, CLIP_HEIGHT))
89
+ cropped_masks.append(crop)
90
+
91
+ if query and filtered_masks:
92
+ scores = get_scores(cropped_masks, query)
93
+ filtered_masks = [
94
+ filtered_masks[i]
95
+ for i, score in enumerate(scores)
96
+ if score > clip_threshold
97
+ ]
98
+
99
+ return filtered_masks
100
+
101
+
102
+ def draw_masks(
103
+ image: np.ndarray, masks: List[np.ndarray], alpha: float = 0.7
104
+ ) -> np.ndarray:
105
+ for mask in masks:
106
+ color = [randint(127, 255) for _ in range(3)]
107
+
108
+ # draw mask overlay
109
+ colored_mask = np.expand_dims(mask["segmentation"], 0).repeat(3, axis=0)
110
+ colored_mask = np.moveaxis(colored_mask, 0, -1)
111
+ masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
112
+ image_overlay = masked.filled()
113
+ image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
114
+
115
+ # draw contour
116
+ contours, _ = cv2.findContours(
117
+ np.uint8(mask["segmentation"]), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
118
+ )
119
+ cv2.drawContours(image, contours, -1, (255, 0, 0), 2)
120
+ return image
121
+
122
+
123
+ def segment(
124
+ predicted_iou_threshold: float,
125
+ stability_score_threshold: float,
126
+ clip_threshold: float,
127
+ image_path: str,
128
+ query: str,
129
+ ) -> PIL.ImageFile.ImageFile:
130
+ mask_generator = load_mask_generator()
131
+ # reduce the size to save gpu memory
132
+ image = adjust_image_size(cv2.imread(image_path))
133
+ masks = mask_generator.generate(image)
134
+ masks = filter_masks(
135
+ image,
136
+ masks,
137
+ predicted_iou_threshold,
138
+ stability_score_threshold,
139
+ query,
140
+ clip_threshold,
141
+ )
142
+ image = draw_masks(image, masks)
143
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
144
+ image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
145
+ return image
146
+
147
+
148
+ demo = gr.Interface(
149
+ fn=segment,
150
+ inputs=[
151
+ gr.Slider(0, 1, value=0.9, label="predicted_iou_threshold"),
152
+ gr.Slider(0, 1, value=0.8, label="stability_score_threshold"),
153
+ gr.Slider(0, 1, value=0.05, label="clip_threshold"),
154
+ gr.Image(type="filepath"),
155
+ "text",
156
+ ],
157
+ outputs="image",
158
+ allow_flagging="never",
159
+ title="Segment Anything with CLIP",
160
+ examples=[
161
+ [
162
+ 0.9,
163
+ 0.8,
164
+ 0.15,
165
+ os.path.join(os.path.dirname(__file__), "examples/dog.jpg"),
166
+ "A dog only",
167
+ ],
168
+ [
169
+ 0.9,
170
+ 0.8,
171
+ 0.1,
172
+ os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
173
+ "A bridge on the water",
174
+ ],
175
+ [
176
+ 0.9,
177
+ 0.8,
178
+ 0.05,
179
+ os.path.join(os.path.dirname(__file__), "examples/food.jpg"),
180
+ "",
181
+ ],
182
+ [
183
+ 0.9,
184
+ 0.8,
185
+ 0.05,
186
+ os.path.join(os.path.dirname(__file__), "examples/horse.jpg"),
187
+ "horse",
188
+ ],
189
+ ],
190
+ )
191
+
192
+ if __name__ == "__main__":
193
+ demo.launch()
examples/city.jpg ADDED
examples/dog.jpg ADDED
examples/food.jpg ADDED
examples/horse.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==3.24.1
2
+ opencv-python==4.7.0.72
3
+ pycocotools==2.0.6
4
+ matplotlib==3.7.1
5
+ git+https://github.com/facebookresearch/segment-anything.git
6
+ git+https://github.com/openai/CLIP.git
sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879