ianpan commited on
Commit
9df570f
·
verified ·
1 Parent(s): 0b22878

Upload model

Browse files
Files changed (1) hide show
  1. modeling.py +19 -1
modeling.py CHANGED
@@ -115,7 +115,9 @@ class TotalClassifierModel(PreTrainedModel):
115
  mask: torch.Tensor | None = None,
116
  return_logits: bool = False,
117
  return_as_dict: bool = False,
 
118
  return_as_df: bool = False,
 
119
  ) -> torch.Tensor:
120
  if return_as_df:
121
  assert (
@@ -138,8 +140,9 @@ class TotalClassifierModel(PreTrainedModel):
138
  # return raw logits
139
  return logits
140
  probas = logits.sigmoid()
 
141
  if return_as_dict or return_as_df:
142
- # list_of_dictionaries
143
  batch_list = []
144
  for i in range(probas.shape[0]):
145
  dict_for_batch = {}
@@ -157,6 +160,21 @@ class TotalClassifierModel(PreTrainedModel):
157
  else:
158
  batch_list.append(dict_for_batch)
159
  return batch_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  return probas
161
 
162
  def normalize(self, x: torch.Tensor) -> torch.Tensor:
 
115
  mask: torch.Tensor | None = None,
116
  return_logits: bool = False,
117
  return_as_dict: bool = False,
118
+ return_as_list: bool = False,
119
  return_as_df: bool = False,
120
+ threshold: float = 0.5, # only used for return_as_list=True
121
  ) -> torch.Tensor:
122
  if return_as_df:
123
  assert (
 
140
  # return raw logits
141
  return logits
142
  probas = logits.sigmoid()
143
+
144
  if return_as_dict or return_as_df:
145
+ # list of dictionaries
146
  batch_list = []
147
  for i in range(probas.shape[0]):
148
  dict_for_batch = {}
 
160
  else:
161
  batch_list.append(dict_for_batch)
162
  return batch_list
163
+
164
+ if return_as_list:
165
+ # list of lists
166
+ batch_list = []
167
+ for i in range(probas.shape[0]):
168
+ probas_i = probas[i]
169
+ batch_list.append(
170
+ [
171
+ self.index2label[each_class]
172
+ for each_class in range(probas_i.shape[1])
173
+ if probas_i[:, each_class] >= threshold
174
+ ]
175
+ )
176
+ return batch_list
177
+
178
  return probas
179
 
180
  def normalize(self, x: torch.Tensor) -> torch.Tensor: