Improve interactivity by allowing direct bounding box drawing and tool selection
#1
by
Amordia
- opened
app.py
CHANGED
|
@@ -129,11 +129,9 @@ def viz_pred_mask(img,
|
|
| 129 |
else:
|
| 130 |
cv2.circle(out,(col, row), marker_size, (255,0,0), -1)
|
| 131 |
|
| 132 |
-
if bbox_coords
|
| 133 |
-
for
|
| 134 |
-
cv2.rectangle(out,
|
| 135 |
-
if len(bbox_coords) % 2 == 1:
|
| 136 |
-
cv2.circle(out, tuple(bbox_coords[-1]), marker_size, (255,165,0), -1)
|
| 137 |
|
| 138 |
return out.astype(np.uint8)
|
| 139 |
|
|
@@ -242,7 +240,8 @@ def refresh_predictions(predictor, input_img, output_img, click_coords, click_la
|
|
| 242 |
def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask,
|
| 243 |
click_coords, click_labels, bbox_coords,
|
| 244 |
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
| 245 |
-
output_img, binary_checkbox, multimask_mode, autopredict_checkbox,
|
|
|
|
| 246 |
"""
|
| 247 |
Record user click and update the prediction
|
| 248 |
"""
|
|
@@ -255,6 +254,16 @@ def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask,
|
|
| 255 |
else:
|
| 256 |
raise TypeError("Invalid brush label: {brush_label}")
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
# Only make new prediction if not waiting for additional bounding box click
|
| 259 |
if (len(bbox_coords) % 2 == 0) and autopredict_checkbox:
|
| 260 |
|
|
@@ -402,15 +411,20 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
|
|
| 402 |
)
|
| 403 |
|
| 404 |
with gr.Tab("Clicks/Boxes") as click_tab:
|
|
|
|
| 405 |
click_img = gr.Image(
|
| 406 |
label="Input",
|
| 407 |
type='numpy',
|
| 408 |
value=default_example,
|
|
|
|
|
|
|
| 409 |
show_download_button=True,
|
| 410 |
container=True,
|
| 411 |
height=display_height
|
| 412 |
)
|
| 413 |
with gr.Row():
|
|
|
|
|
|
|
| 414 |
undo_click_button = gr.Button("Undo Last Click")
|
| 415 |
clear_click_button = gr.Button("Clear Clicks/Boxes", variant="stop")
|
| 416 |
|
|
@@ -546,7 +560,8 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
|
|
| 546 |
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
|
| 547 |
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
| 548 |
output_img, binary_checkbox, multimask_mode, autopredict_checkbox
|
| 549 |
-
|
|
|
|
| 550 |
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
|
| 551 |
click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
|
| 552 |
api_name = "get_select_coords"
|
|
|
|
| 129 |
else:
|
| 130 |
cv2.circle(out,(col, row), marker_size, (255,0,0), -1)
|
| 131 |
|
| 132 |
+
if bbox_coords:
|
| 133 |
+
for bbox in bbox_coords:
|
| 134 |
+
cv2.rectangle(out, bbox[0], bbox[1], (255,165,0), 2)
|
|
|
|
|
|
|
| 135 |
|
| 136 |
return out.astype(np.uint8)
|
| 137 |
|
|
|
|
| 240 |
def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask,
|
| 241 |
click_coords, click_labels, bbox_coords,
|
| 242 |
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
| 243 |
+
output_img, binary_checkbox, multimask_mode, autopredict_checkbox,
|
| 244 |
+
tool_select, evt: gr.SelectData):
|
| 245 |
"""
|
| 246 |
Record user click and update the prediction
|
| 247 |
"""
|
|
|
|
| 254 |
else:
|
| 255 |
raise TypeError("Invalid brush label: {brush_label}")
|
| 256 |
|
| 257 |
+
if tool_select == "Rectangle" and evt.bounding_box:
|
| 258 |
+
# User drew a rectangle
|
| 259 |
+
bbox_coords.append((evt.bounding_box[0], evt.bounding_box[1]))
|
| 260 |
+
elif tool_select == "Point" and evt.index is not None:
|
| 261 |
+
# User clicked a point
|
| 262 |
+
click_coords.append(evt.index)
|
| 263 |
+
click_labels.append(1 if brush_label == 'Positive (green)' else 0)
|
| 264 |
+
else:
|
| 265 |
+
gr.Error("Invalid selection")
|
| 266 |
+
|
| 267 |
# Only make new prediction if not waiting for additional bounding box click
|
| 268 |
if (len(bbox_coords) % 2 == 0) and autopredict_checkbox:
|
| 269 |
|
|
|
|
| 411 |
)
|
| 412 |
|
| 413 |
with gr.Tab("Clicks/Boxes") as click_tab:
|
| 414 |
+
# Update click_img to be interactive and use the selected tool
|
| 415 |
click_img = gr.Image(
|
| 416 |
label="Input",
|
| 417 |
type='numpy',
|
| 418 |
value=default_example,
|
| 419 |
+
tool="select", # Use 'select' to capture both points and rectangles
|
| 420 |
+
interactive=True,
|
| 421 |
show_download_button=True,
|
| 422 |
container=True,
|
| 423 |
height=display_height
|
| 424 |
)
|
| 425 |
with gr.Row():
|
| 426 |
+
# Add a tool selection radio button
|
| 427 |
+
tool_select = gr.Radio(["Point", "Rectangle"], label="Tool", value="Point")
|
| 428 |
undo_click_button = gr.Button("Undo Last Click")
|
| 429 |
clear_click_button = gr.Button("Clear Clicks/Boxes", variant="stop")
|
| 430 |
|
|
|
|
| 560 |
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
|
| 561 |
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
| 562 |
output_img, binary_checkbox, multimask_mode, autopredict_checkbox
|
| 563 |
+
tool_select # Add tool_select here
|
| 564 |
+
],
|
| 565 |
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
|
| 566 |
click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
|
| 567 |
api_name = "get_select_coords"
|