Upload ConstBERT
Browse files- modeling.py +4 -0
modeling.py
CHANGED
|
@@ -178,6 +178,8 @@ class ConstBERT(BertPreTrainedModel):
|
|
| 178 |
return D
|
| 179 |
|
| 180 |
def encode_query(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
|
|
|
|
|
|
|
| 181 |
if bsize:
|
| 182 |
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
|
| 183 |
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
|
|
@@ -187,6 +189,8 @@ class ConstBERT(BertPreTrainedModel):
|
|
| 187 |
return self.query(input_ids, attention_mask)
|
| 188 |
|
| 189 |
def encode_document(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
|
|
|
|
|
|
|
| 190 |
assert keep_dims in [True, False, 'flatten']
|
| 191 |
|
| 192 |
if bsize:
|
|
|
|
| 178 |
return D
|
| 179 |
|
| 180 |
def encode_query(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
|
| 181 |
+
if type(queries) == str:
|
| 182 |
+
queries = [queries]
|
| 183 |
if bsize:
|
| 184 |
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
|
| 185 |
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
|
|
|
|
| 189 |
return self.query(input_ids, attention_mask)
|
| 190 |
|
| 191 |
def encode_document(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
|
| 192 |
+
if type(docs) == str:
|
| 193 |
+
docs = [docs]
|
| 194 |
assert keep_dims in [True, False, 'flatten']
|
| 195 |
|
| 196 |
if bsize:
|