ianpan commited on
Commit
dbdbb0d
·
verified ·
1 Parent(s): 9eaf668

Upload model

Browse files
Files changed (1) hide show
  1. modeling.py +4 -2
modeling.py CHANGED
@@ -167,10 +167,12 @@ class TotalClassifierModel(PreTrainedModel):
167
  # inner list - list of above for each slice
168
  # outer list - list of above for each batch element (studies)
169
  batch_list = []
 
170
  for i in range(probas.shape[0]):
171
  probas_i = probas[i]
 
 
172
  for each_slice in range(probas_i.shape[0]):
173
- list_for_batch = []
174
  for each_class in range(probas_i.shape[1]):
175
  list_for_batch.append(
176
  [
@@ -179,7 +181,7 @@ class TotalClassifierModel(PreTrainedModel):
179
  if probas_i[each_slice, each_class] >= threshold
180
  ]
181
  )
182
- batch_list.append(list_for_batch)
183
  return batch_list
184
 
185
  return probas
 
167
  # inner list - list of above for each slice
168
  # outer list - list of above for each batch element (studies)
169
  batch_list = []
170
+ # probas.shape = (batch_size, num_slices, num_classes)
171
  for i in range(probas.shape[0]):
172
  probas_i = probas[i]
173
+ # probas_i.shape = (num_slices, num_classes)
174
+ list_for_batch = []
175
  for each_slice in range(probas_i.shape[0]):
 
176
  for each_class in range(probas_i.shape[1]):
177
  list_for_batch.append(
178
  [
 
181
  if probas_i[each_slice, each_class] >= threshold
182
  ]
183
  )
184
+ batch_list.append(list_for_batch)
185
  return batch_list
186
 
187
  return probas