Upload model
Browse files- modeling.py +19 -1
modeling.py
CHANGED
|
@@ -115,7 +115,9 @@ class TotalClassifierModel(PreTrainedModel):
|
|
| 115 |
mask: torch.Tensor | None = None,
|
| 116 |
return_logits: bool = False,
|
| 117 |
return_as_dict: bool = False,
|
|
|
|
| 118 |
return_as_df: bool = False,
|
|
|
|
| 119 |
) -> torch.Tensor:
|
| 120 |
if return_as_df:
|
| 121 |
assert (
|
|
@@ -138,8 +140,9 @@ class TotalClassifierModel(PreTrainedModel):
|
|
| 138 |
# return raw logits
|
| 139 |
return logits
|
| 140 |
probas = logits.sigmoid()
|
|
|
|
| 141 |
if return_as_dict or return_as_df:
|
| 142 |
-
#
|
| 143 |
batch_list = []
|
| 144 |
for i in range(probas.shape[0]):
|
| 145 |
dict_for_batch = {}
|
|
@@ -157,6 +160,21 @@ class TotalClassifierModel(PreTrainedModel):
|
|
| 157 |
else:
|
| 158 |
batch_list.append(dict_for_batch)
|
| 159 |
return batch_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
return probas
|
| 161 |
|
| 162 |
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 115 |
mask: torch.Tensor | None = None,
|
| 116 |
return_logits: bool = False,
|
| 117 |
return_as_dict: bool = False,
|
| 118 |
+
return_as_list: bool = False,
|
| 119 |
return_as_df: bool = False,
|
| 120 |
+
threshold: float = 0.5, # only used for return_as_list=True
|
| 121 |
) -> torch.Tensor:
|
| 122 |
if return_as_df:
|
| 123 |
assert (
|
|
|
|
| 140 |
# return raw logits
|
| 141 |
return logits
|
| 142 |
probas = logits.sigmoid()
|
| 143 |
+
|
| 144 |
if return_as_dict or return_as_df:
|
| 145 |
+
# list of dictionaries
|
| 146 |
batch_list = []
|
| 147 |
for i in range(probas.shape[0]):
|
| 148 |
dict_for_batch = {}
|
|
|
|
| 160 |
else:
|
| 161 |
batch_list.append(dict_for_batch)
|
| 162 |
return batch_list
|
| 163 |
+
|
| 164 |
+
if return_as_list:
|
| 165 |
+
# list of lists
|
| 166 |
+
batch_list = []
|
| 167 |
+
for i in range(probas.shape[0]):
|
| 168 |
+
probas_i = probas[i]
|
| 169 |
+
batch_list.append(
|
| 170 |
+
[
|
| 171 |
+
self.index2label[each_class]
|
| 172 |
+
for each_class in range(probas_i.shape[1])
|
| 173 |
+
if probas_i[:, each_class] >= threshold
|
| 174 |
+
]
|
| 175 |
+
)
|
| 176 |
+
return batch_list
|
| 177 |
+
|
| 178 |
return probas
|
| 179 |
|
| 180 |
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|