YaohuiW commited on
Commit
73669b1
·
verified ·
1 Parent(s): 62c64ac

Update gradio_tabs/img_edit.py

Browse files
Files changed (1) hide show
  1. gradio_tabs/img_edit.py +27 -75
gradio_tabs/img_edit.py CHANGED
@@ -37,111 +37,69 @@ labels_v = [
37
  ]
38
 
39
 
40
- @torch.compiler.allow_in_graph
41
  def load_image(img, size):
42
  img = Image.open(img).convert('RGB')
43
  w, h = img.size
44
  img = img.resize((size, size))
45
  img = np.asarray(img)
46
- img = np.copy(img)
47
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
48
 
49
  return img / 255.0, w, h
50
 
51
 
52
- @torch.compiler.allow_in_graph
53
  def img_preprocessing(img_path, size):
54
- img, w, h = load_image(img_path, size) # [0, 1]
55
  img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
56
  imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
57
 
58
  return imgs_norm, w, h
59
 
60
 
61
- # Pre-compile resize transforms for better performance
62
- resize_transform_cache = {}
63
-
64
- def get_resize_transform(size):
65
- """Get cached resize transform - creates once, reuses many times"""
66
- if size not in resize_transform_cache:
67
- # Only create the transform if it doesn't exist in cache
68
- resize_transform_cache[size] = torchvision.transforms.Resize(
69
- size,
70
- interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
71
- antialias=True
72
- )
73
- return resize_transform_cache[size]
74
-
75
-
76
  def resize(img, size):
77
- """Use cached resize transform"""
78
- transform = get_resize_transform((size, size))
 
 
79
  return transform(img)
80
 
81
 
82
  def resize_back(img, w, h):
83
- """Use cached resize transform for back operation"""
84
- transform = get_resize_transform((h, w))
 
 
85
  return transform(img)
86
 
87
 
88
  def img_denorm(img):
89
- img = img.clamp(-1, 1)
90
  img = (img - img.min()) / (img.max() - img.min())
91
 
92
  return img
93
 
94
 
95
- def img_postprocessing(img, w, h):
96
 
97
- img = resize_back(img, w, h)
98
- img = img_denorm(img)
99
- img = img.squeeze(0).permute(1, 2, 0).contiguous() # contiguous() for fast transfer
100
- img_output = (img.cpu().numpy() * 255).astype(np.uint8)
101
 
102
- return img_output
 
 
103
 
104
 
105
  def img_edit(gen, device):
106
 
107
- @torch.compile
108
- def compiled_enc_img(image_tensor, selected_s):
109
- """Compiled version of just the model inference"""
110
- return gen.enc_img(image_tensor, labels_v, selected_s)
111
-
112
- @torch.compile
113
- def compiled_dec_img(z_s2r, alpha_r2s, feat_rgb):
114
- """Compiled version of just the model inference"""
115
- return gen.dec_img(z_s2r, alpha_r2s, feat_rgb)
116
-
117
-
118
- # Pre-warm the compiled model with dummy data to reduce first-run compilation time
119
- def _warmup_model():
120
- """Pre-warm the model compilation with representative shapes"""
121
- print("[img_edit] Pre-warming model compilation...")
122
- dummy_image = torch.randn(1, 3, 512, 512, device=device)
123
- dummy_selected_s = [0.0] * len(labels_v)
124
-
125
- try:
126
- with torch.inference_mode():
127
- z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(dummy_image, dummy_selected_s)
128
- _ = compiled_dec_img(z_s2r, alpha_r2s, feat_rgb)
129
- print("[img_edit] Model pre-warming completed successfully")
130
- except Exception as e:
131
- print(f"[img_edit] Model pre-warming failed (will compile on first use): {e}")
132
-
133
- # Pre-warm the model
134
- _warmup_model()
135
-
136
  @spaces.GPU
137
- @torch.inference_mode()
138
  def edit_img(image, *selected_s):
139
 
140
  image_tensor, w, h = img_preprocessing(image, 512)
141
  image_tensor = image_tensor.to(device)
142
 
143
- z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(image_tensor, selected_s)
144
- edited_image_tensor = compiled_dec_img(z_s2r, alpha_r2s, feat_rgb)
145
 
146
  # de-norm
147
  edited_image = img_postprocessing(edited_image_tensor, w, h)
@@ -178,10 +136,10 @@ def img_edit(gen, device):
178
  with gr.Row():
179
  with gr.Column(scale=1):
180
  with gr.Row(): # Buttons now within a single Row
181
- #edit_btn = gr.Button("Edit")
182
  clear_btn = gr.Button("Clear")
183
- #with gr.Row():
184
- # animate_btn = gr.Button("Generate")
185
 
186
 
187
 
@@ -192,7 +150,7 @@ def img_edit(gen, device):
192
  image_output = gr.Image(label="Output Image", type='numpy', interactive=False, width=512)
193
 
194
 
195
- with gr.Accordion("Control Panel - Using Sliders to Edit Image", open=True):
196
  with gr.Tab("Head"):
197
  with gr.Row():
198
  for k in labels_k[:3]:
@@ -223,18 +181,12 @@ def img_edit(gen, device):
223
  slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k)
224
  inputs_s.append(slider)
225
 
226
- for slider in inputs_s:
227
- slider.change(
228
  fn=edit_img,
229
  inputs=[image_input] + inputs_s,
230
  outputs=[image_output],
231
-
232
- show_progress='hidden',
233
-
234
- trigger_mode='always_last',
235
-
236
- # currently we have a latency around 450ms
237
- stream_every=0.5
238
  )
239
 
240
  clear_btn.click(
 
37
  ]
38
 
39
 
 
40
  def load_image(img, size):
41
  img = Image.open(img).convert('RGB')
42
  w, h = img.size
43
  img = img.resize((size, size))
44
  img = np.asarray(img)
 
45
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
46
 
47
  return img / 255.0, w, h
48
 
49
 
 
50
  def img_preprocessing(img_path, size):
51
+ img, w, h = load_image(img_path, size) # [0, 1]
52
  img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
53
  imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
54
 
55
  return imgs_norm, w, h
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def resize(img, size):
59
+ transform = torchvision.transforms.Compose([
60
+ torchvision.transforms.Resize((size,size), antialias=True),
61
+ ])
62
+
63
  return transform(img)
64
 
65
 
66
  def resize_back(img, w, h):
67
+ transform = torchvision.transforms.Compose([
68
+ torchvision.transforms.Resize((h, w), antialias=True),
69
+ ])
70
+
71
  return transform(img)
72
 
73
 
74
  def img_denorm(img):
75
+ img = img.clamp(-1, 1).cpu()
76
  img = (img - img.min()) / (img.max() - img.min())
77
 
78
  return img
79
 
80
 
81
+ def img_postprocessing(image, w, h):
82
 
83
+ image = resize_back(image, w, h)
84
+ image = image.permute(0, 2, 3, 1)
85
+ edited_image = img_denorm(image)
86
+ img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
87
 
88
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
89
+ imageio.imwrite(temp_file.name, img_output, quality=8)
90
+ return temp_file.name
91
 
92
 
93
  def img_edit(gen, device):
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  @spaces.GPU
96
+ @torch.no_grad()
97
  def edit_img(image, *selected_s):
98
 
99
  image_tensor, w, h = img_preprocessing(image, 512)
100
  image_tensor = image_tensor.to(device)
101
 
102
+ edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
 
103
 
104
  # de-norm
105
  edited_image = img_postprocessing(edited_image_tensor, w, h)
 
136
  with gr.Row():
137
  with gr.Column(scale=1):
138
  with gr.Row(): # Buttons now within a single Row
139
+ edit_btn = gr.Button("Edit")
140
  clear_btn = gr.Button("Clear")
141
+ with gr.Row():
142
+ animate_btn = gr.Button("Generate")
143
 
144
 
145
 
 
150
  image_output = gr.Image(label="Output Image", type='numpy', interactive=False, width=512)
151
 
152
 
153
+ with gr.Accordion("Control Panel", open=True):
154
  with gr.Tab("Head"):
155
  with gr.Row():
156
  for k in labels_k[:3]:
 
181
  slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k)
182
  inputs_s.append(slider)
183
 
184
+
185
+ edit_btn.click(
186
  fn=edit_img,
187
  inputs=[image_input] + inputs_s,
188
  outputs=[image_output],
189
+ show_progress=True
 
 
 
 
 
 
190
  )
191
 
192
  clear_btn.click(