Update modeling_quiet.py
Browse files- modeling_quiet.py +13 -0
modeling_quiet.py
CHANGED
|
@@ -1283,6 +1283,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1283 |
self.talk_head = nn.ModuleList([nn.Sequential(
|
| 1284 |
nn.Linear(talk_input_dim, talk_output_dim, bias=False)
|
| 1285 |
)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1286 |
|
| 1287 |
# Initialize weights and apply final processing
|
| 1288 |
self.post_init()
|
|
@@ -1304,6 +1309,14 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1304 |
|
| 1305 |
def get_decoder(self):
|
| 1306 |
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1307 |
|
| 1308 |
@torch.no_grad()
|
| 1309 |
def infer(
|
|
|
|
| 1283 |
self.talk_head = nn.ModuleList([nn.Sequential(
|
| 1284 |
nn.Linear(talk_input_dim, talk_output_dim, bias=False)
|
| 1285 |
)])
|
| 1286 |
+
|
| 1287 |
+
self.apply(self._init_weights)
|
| 1288 |
+
|
| 1289 |
+
# Add dropout regularization
|
| 1290 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1291 |
|
| 1292 |
# Initialize weights and apply final processing
|
| 1293 |
self.post_init()
|
|
|
|
| 1309 |
|
| 1310 |
def get_decoder(self):
|
| 1311 |
return self.model
|
| 1312 |
+
|
| 1313 |
+
def _init_weights(self, module):
|
| 1314 |
+
if isinstance(module, nn.Linear):
|
| 1315 |
+
nn.init.xavier_uniform_(module.weight)
|
| 1316 |
+
if module.bias is not None:
|
| 1317 |
+
nn.init.constant_(module.bias, 0)
|
| 1318 |
+
elif isinstance(module, nn.Embedding):
|
| 1319 |
+
nn.init.xavier_uniform_(module.weight)
|
| 1320 |
|
| 1321 |
@torch.no_grad()
|
| 1322 |
def infer(
|