Update modeling_quiet.py
Browse files- modeling_quiet.py +6 -6
modeling_quiet.py
CHANGED
|
@@ -1311,12 +1311,12 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1311 |
return self.model
|
| 1312 |
|
| 1313 |
def _init_weights(self, module):
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
-
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
|
| 1321 |
@torch.no_grad()
|
| 1322 |
def infer(
|
|
|
|
| 1311 |
return self.model
|
| 1312 |
|
| 1313 |
def _init_weights(self, module):
|
| 1314 |
+
if isinstance(module, nn.Linear):
|
| 1315 |
+
nn.init.xavier_uniform_(module.weight)
|
| 1316 |
+
if module.bias is not None:
|
| 1317 |
+
nn.init.constant_(module.bias, 0)
|
| 1318 |
+
elif isinstance(module, nn.Embedding):
|
| 1319 |
+
nn.init.xavier_uniform_(module.weight)
|
| 1320 |
|
| 1321 |
@torch.no_grad()
|
| 1322 |
def infer(
|