Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -26,8 +26,9 @@ class Model:
|
|
| 26 |
self.task_name = ""
|
| 27 |
self.vq_model = self.load_vq()
|
| 28 |
self.t5_model = self.load_t5()
|
| 29 |
-
self.gpt_model_edge = self.load_gpt(condition_type='edge')
|
| 30 |
-
self.gpt_model_depth = self.load_gpt(condition_type='depth')
|
|
|
|
| 31 |
self.preprocessor = Preprocessor()
|
| 32 |
|
| 33 |
def to(self, device):
|
|
@@ -45,7 +46,7 @@ class Model:
|
|
| 45 |
return vq_model
|
| 46 |
|
| 47 |
def load_gpt(self, condition_type='edge'):
|
| 48 |
-
gpt_ckpt = models[condition_type]
|
| 49 |
# precision = torch.bfloat16
|
| 50 |
precision = torch.float32
|
| 51 |
latent_size = 512 // 16
|
|
@@ -56,12 +57,19 @@ class Model:
|
|
| 56 |
condition_type=condition_type,
|
| 57 |
adapter_size='base',
|
| 58 |
).to(device='cpu', dtype=precision)
|
| 59 |
-
model_weight = load_file(gpt_ckpt)
|
| 60 |
-
gpt_model.load_state_dict(model_weight, strict=False)
|
| 61 |
-
gpt_model.eval()
|
| 62 |
-
print("gpt model is loaded")
|
| 63 |
return gpt_model
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
def load_t5(self):
|
| 66 |
# precision = torch.bfloat16
|
| 67 |
precision = torch.float32
|
|
@@ -92,7 +100,8 @@ class Model:
|
|
| 92 |
preprocessor_name: str,
|
| 93 |
) -> list[PIL.Image.Image]:
|
| 94 |
self.t5_model.model.to('cuda').to(torch.bfloat16)
|
| 95 |
-
self.
|
|
|
|
| 96 |
self.vq_model.to('cuda')
|
| 97 |
if isinstance(image, np.ndarray):
|
| 98 |
image = Image.fromarray(image)
|
|
@@ -114,10 +123,10 @@ class Model:
|
|
| 114 |
condition_img = condition_img.resize((512,512))
|
| 115 |
W, H = condition_img.size
|
| 116 |
|
| 117 |
-
condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(
|
| 118 |
condition_img = condition_img.to(self.device)
|
| 119 |
condition_img = 2*(condition_img/255 - 0.5)
|
| 120 |
-
prompts = [prompt] *
|
| 121 |
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
|
| 122 |
|
| 123 |
print(f"processing left-padding...")
|
|
@@ -137,7 +146,7 @@ class Model:
|
|
| 137 |
t1 = time.time()
|
| 138 |
print(caption_embs.device)
|
| 139 |
index_sample = generate(
|
| 140 |
-
self.
|
| 141 |
c_indices,
|
| 142 |
(H // 16) * (W // 16),
|
| 143 |
c_emb_masks,
|
|
|
|
| 26 |
self.task_name = ""
|
| 27 |
self.vq_model = self.load_vq()
|
| 28 |
self.t5_model = self.load_t5()
|
| 29 |
+
# self.gpt_model_edge = self.load_gpt(condition_type='edge')
|
| 30 |
+
# self.gpt_model_depth = self.load_gpt(condition_type='depth')
|
| 31 |
+
self.gpt_model = self.load_gpt()
|
| 32 |
self.preprocessor = Preprocessor()
|
| 33 |
|
| 34 |
def to(self, device):
|
|
|
|
| 46 |
return vq_model
|
| 47 |
|
| 48 |
def load_gpt(self, condition_type='edge'):
|
| 49 |
+
# gpt_ckpt = models[condition_type]
|
| 50 |
# precision = torch.bfloat16
|
| 51 |
precision = torch.float32
|
| 52 |
latent_size = 512 // 16
|
|
|
|
| 57 |
condition_type=condition_type,
|
| 58 |
adapter_size='base',
|
| 59 |
).to(device='cpu', dtype=precision)
|
| 60 |
+
# model_weight = load_file(gpt_ckpt)
|
| 61 |
+
# gpt_model.load_state_dict(model_weight, strict=False)
|
| 62 |
+
# gpt_model.eval()
|
| 63 |
+
# print("gpt model is loaded")
|
| 64 |
return gpt_model
|
| 65 |
|
| 66 |
+
def load_gpt_weight(self, condition_type='edge'):
|
| 67 |
+
gpt_ckpt = models[condition_type]
|
| 68 |
+
model_weight = load_file(gpt_ckpt)
|
| 69 |
+
self.gpt_model.load_state_dict(model_weight, strict=False)
|
| 70 |
+
self.gpt_model.eval()
|
| 71 |
+
# print("gpt model is loaded")
|
| 72 |
+
|
| 73 |
def load_t5(self):
|
| 74 |
# precision = torch.bfloat16
|
| 75 |
precision = torch.float32
|
|
|
|
| 100 |
preprocessor_name: str,
|
| 101 |
) -> list[PIL.Image.Image]:
|
| 102 |
self.t5_model.model.to('cuda').to(torch.bfloat16)
|
| 103 |
+
self.load_gpt_weight('edge')
|
| 104 |
+
self.gpt_model.to('cuda').to(torch.bfloat16)
|
| 105 |
self.vq_model.to('cuda')
|
| 106 |
if isinstance(image, np.ndarray):
|
| 107 |
image = Image.fromarray(image)
|
|
|
|
| 123 |
condition_img = condition_img.resize((512,512))
|
| 124 |
W, H = condition_img.size
|
| 125 |
|
| 126 |
+
condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(3,1,1,1)
|
| 127 |
condition_img = condition_img.to(self.device)
|
| 128 |
condition_img = 2*(condition_img/255 - 0.5)
|
| 129 |
+
prompts = [prompt] * 3
|
| 130 |
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
|
| 131 |
|
| 132 |
print(f"processing left-padding...")
|
|
|
|
| 146 |
t1 = time.time()
|
| 147 |
print(caption_embs.device)
|
| 148 |
index_sample = generate(
|
| 149 |
+
self.gpt_model,
|
| 150 |
c_indices,
|
| 151 |
(H // 16) * (W // 16),
|
| 152 |
c_emb_masks,
|