JeffreyXiang commited on
Commit
93b0f21
·
1 Parent(s): 36c86a2
app.py CHANGED
@@ -35,7 +35,6 @@ def end_session(req: gr.Request):
35
  shutil.rmtree(user_dir)
36
 
37
 
38
- @spaces.GPU()
39
  def preprocess_image(image: Image.Image) -> Image.Image:
40
  """
41
  Preprocess the input image.
 
35
  shutil.rmtree(user_dir)
36
 
37
 
 
38
  def preprocess_image(image: Image.Image) -> Image.Image:
39
  """
40
  Preprocess the input image.
trellis2/pipelines/trellis2_image_to_3d.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import *
 
2
  import torch
3
  import torch.nn as nn
4
  import numpy as np
@@ -113,6 +114,16 @@ class Trellis2ImageTo3DPipeline(Pipeline):
113
  super().to(device)
114
  self.image_cond_model.to(device)
115
  self.rembg_model.to(device)
 
 
 
 
 
 
 
 
 
 
116
 
117
  def preprocess_image(self, input: Image.Image) -> Image.Image:
118
  """
@@ -131,12 +142,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
131
  if has_alpha:
132
  output = input
133
  else:
134
- input = input.convert('RGB')
135
- if self.low_vram:
136
- self.rembg_model.to(self.device)
137
- output = self.rembg_model(input)
138
- if self.low_vram:
139
- self.rembg_model.cpu()
140
  output_np = np.array(output)
141
  alpha = output_np[:, :, 3]
142
  bbox = np.argwhere(alpha > 0.8 * 255)
 
1
  from typing import *
2
+ import spaces
3
  import torch
4
  import torch.nn as nn
5
  import numpy as np
 
114
  super().to(device)
115
  self.image_cond_model.to(device)
116
  self.rembg_model.to(device)
117
+
118
+ @spaces.GPU()
119
+ def remove_background(self, input: Image.Image) -> Image.Image:
120
+ input = input.convert('RGB')
121
+ if self.low_vram:
122
+ self.rembg_model.to(self.device)
123
+ output = self.rembg_model(input)
124
+ if self.low_vram:
125
+ self.rembg_model.cpu()
126
+ return output
127
 
128
  def preprocess_image(self, input: Image.Image) -> Image.Image:
129
  """
 
142
  if has_alpha:
143
  output = input
144
  else:
145
+ output = self.remove_background(input)
 
 
 
 
 
146
  output_np = np.array(output)
147
  alpha = output_np[:, :, 3]
148
  bbox = np.argwhere(alpha > 0.8 * 255)