Spaces:
Runtime error
Runtime error
Update demo_gradio.py
Browse files- demo_gradio.py +6 -5
demo_gradio.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import spaces
|
| 2 |
import huggingface_hub
|
| 3 |
|
|
@@ -26,7 +27,7 @@ import glob
|
|
| 26 |
import torch
|
| 27 |
import cv2
|
| 28 |
import argparse
|
| 29 |
-
|
| 30 |
import DPT.util.io
|
| 31 |
|
| 32 |
from torchvision.transforms import Compose
|
|
@@ -38,7 +39,7 @@ from DPT.dpt.transforms import Resize, NormalizeImage, PrepareForNet
|
|
| 38 |
"""
|
| 39 |
Get ZeST Ready
|
| 40 |
"""
|
| 41 |
-
base_model_path = "
|
| 42 |
image_encoder_path = "models/image_encoder"
|
| 43 |
ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
|
| 44 |
controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
|
|
@@ -55,7 +56,7 @@ pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
|
|
| 55 |
add_watermarker=False,
|
| 56 |
).to(device)
|
| 57 |
pipe.unet = register_cross_attention_hook(pipe.unet)
|
| 58 |
-
|
| 59 |
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)
|
| 60 |
|
| 61 |
|
|
@@ -161,7 +162,7 @@ def greet(input_image, material_exemplar):
|
|
| 161 |
|
| 162 |
|
| 163 |
num_samples = 1
|
| 164 |
-
images = ip_model.generate(pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=
|
| 165 |
|
| 166 |
return images[0]
|
| 167 |
|
|
@@ -193,4 +194,4 @@ with gr.Blocks(css=css) as demo:
|
|
| 193 |
output_image = gr.Image(label="transfer result")
|
| 194 |
submit_btn.click(fn=greet, inputs=[input_image, input_image2], outputs=[output_image])
|
| 195 |
|
| 196 |
-
demo.queue().launch()
|
|
|
|
| 1 |
+
|
| 2 |
import spaces
|
| 3 |
import huggingface_hub
|
| 4 |
|
|
|
|
| 27 |
import torch
|
| 28 |
import cv2
|
| 29 |
import argparse
|
| 30 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 31 |
import DPT.util.io
|
| 32 |
|
| 33 |
from torchvision.transforms import Compose
|
|
|
|
| 39 |
"""
|
| 40 |
Get ZeST Ready
|
| 41 |
"""
|
| 42 |
+
base_model_path = "Lykon/dreamshaper-xl-lightning"
|
| 43 |
image_encoder_path = "models/image_encoder"
|
| 44 |
ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
|
| 45 |
controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
|
|
|
|
| 56 |
add_watermarker=False,
|
| 57 |
).to(device)
|
| 58 |
pipe.unet = register_cross_attention_hook(pipe.unet)
|
| 59 |
+
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
| 60 |
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)
|
| 61 |
|
| 62 |
|
|
|
|
| 162 |
|
| 163 |
|
| 164 |
num_samples = 1
|
| 165 |
+
images = ip_model.generate(guidance_scale=2, pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=4, seed=42)
|
| 166 |
|
| 167 |
return images[0]
|
| 168 |
|
|
|
|
| 194 |
output_image = gr.Image(label="transfer result")
|
| 195 |
submit_btn.click(fn=greet, inputs=[input_image, input_image2], outputs=[output_image])
|
| 196 |
|
| 197 |
+
demo.queue().launch()
|