Update modeling_quiet.py
Browse files- modeling_quiet.py +40 -6
modeling_quiet.py
CHANGED
|
@@ -1246,6 +1246,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1246 |
|
| 1247 |
self.policy_loss_beta = 1e6
|
| 1248 |
self.embedding_scale = 1e2
|
|
|
|
|
|
|
|
|
|
| 1249 |
self.reinforce_temperature = 3
|
| 1250 |
self.base_loss_beta = 1
|
| 1251 |
self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
|
|
@@ -1626,16 +1629,20 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1626 |
sample_probs_history = []
|
| 1627 |
action_loglikelihoods_list = []
|
| 1628 |
|
|
|
|
|
|
|
|
|
|
| 1629 |
if self.use_end_thought_token or self.use_start_thought_token:
|
| 1630 |
if not self.use_reparam_for_thought_embeddings:
|
| 1631 |
-
start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale
|
| 1632 |
-
end_embedding = self.end_embedding[0].unsqueeze(0) * self.embedding_scale
|
| 1633 |
else:
|
| 1634 |
-
start_embedding = self.start_embedding * self.embedding_scale
|
| 1635 |
-
end_embedding = self.end_embedding * self.embedding_scale
|
| 1636 |
base_embeddings = self.model.embed_tokens.weight
|
| 1637 |
if self.train_only_thinking_embedding:
|
| 1638 |
base_embeddings = base_embeddings.detach()
|
|
|
|
| 1639 |
# # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1640 |
fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
|
| 1641 |
for ahead_idx in range(fwd_iters):
|
|
@@ -1900,9 +1907,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1900 |
contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
|
| 1901 |
contains_thought = contains_start or contains_end
|
| 1902 |
|
|
|
|
| 1903 |
if not contains_thought:
|
| 1904 |
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
| 1905 |
-
inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype))
|
| 1906 |
else:
|
| 1907 |
thought_id = self.start_token_id if contains_start else self.end_token_id
|
| 1908 |
cur_thought_embedding = start_embedding if contains_start else end_embedding
|
|
@@ -1915,7 +1923,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1915 |
sampled_end = inputs_embeds.clone().detach()
|
| 1916 |
else:
|
| 1917 |
inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
|
| 1918 |
-
|
| 1919 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
| 1920 |
|
| 1921 |
# Predict the usefulness of thinking at each token position
|
|
@@ -2127,6 +2135,32 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 2127 |
hidden_states=outputs.hidden_states,
|
| 2128 |
attentions=outputs.attentions,
|
| 2129 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2130 |
|
| 2131 |
|
| 2132 |
def prepare_inputs_for_generation(
|
|
|
|
| 1246 |
|
| 1247 |
self.policy_loss_beta = 1e6
|
| 1248 |
self.embedding_scale = 1e2
|
| 1249 |
+
self.temperature = nn.Parameter(torch.tensor(1.0))
|
| 1250 |
+
self.max_temperature = config.max_temperature
|
| 1251 |
+
self.complexity_factor = config.complexity_factor
|
| 1252 |
self.reinforce_temperature = 3
|
| 1253 |
self.base_loss_beta = 1
|
| 1254 |
self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
|
|
|
|
| 1629 |
sample_probs_history = []
|
| 1630 |
action_loglikelihoods_list = []
|
| 1631 |
|
| 1632 |
+
complexity_scores = self.compute_complexity_scores(input_ids, attention_mask)
|
| 1633 |
+
temperature = self.temperature * complexity_scores.unsqueeze(-1)
|
| 1634 |
+
|
| 1635 |
if self.use_end_thought_token or self.use_start_thought_token:
|
| 1636 |
if not self.use_reparam_for_thought_embeddings:
|
| 1637 |
+
start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
|
| 1638 |
+
end_embedding = self.end_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
|
| 1639 |
else:
|
| 1640 |
+
start_embedding = self.start_embedding * self.embedding_scale * temperature
|
| 1641 |
+
end_embedding = self.end_embedding * self.embedding_scale * temperature
|
| 1642 |
base_embeddings = self.model.embed_tokens.weight
|
| 1643 |
if self.train_only_thinking_embedding:
|
| 1644 |
base_embeddings = base_embeddings.detach()
|
| 1645 |
+
|
| 1646 |
# # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1647 |
fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
|
| 1648 |
for ahead_idx in range(fwd_iters):
|
|
|
|
| 1907 |
contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
|
| 1908 |
contains_thought = contains_start or contains_end
|
| 1909 |
|
| 1910 |
+
|
| 1911 |
if not contains_thought:
|
| 1912 |
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
| 1913 |
+
inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype) * temperature)
|
| 1914 |
else:
|
| 1915 |
thought_id = self.start_token_id if contains_start else self.end_token_id
|
| 1916 |
cur_thought_embedding = start_embedding if contains_start else end_embedding
|
|
|
|
| 1923 |
sampled_end = inputs_embeds.clone().detach()
|
| 1924 |
else:
|
| 1925 |
inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
|
| 1926 |
+
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
| 1927 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
| 1928 |
|
| 1929 |
# Predict the usefulness of thinking at each token position
|
|
|
|
| 2135 |
hidden_states=outputs.hidden_states,
|
| 2136 |
attentions=outputs.attentions,
|
| 2137 |
)
|
| 2138 |
+
|
| 2139 |
+
def compute_complexity_scores(self, input_ids, attention_mask):
|
| 2140 |
+
# Compute complexity scores based on input sequence characteristics
|
| 2141 |
+
# Example: Normalize sequence lengths and consider the presence of rare tokens
|
| 2142 |
+
seq_lengths = torch.sum(attention_mask, dim=-1)
|
| 2143 |
+
max_length = torch.max(seq_lengths)
|
| 2144 |
+
length_scores = seq_lengths / max_length
|
| 2145 |
+
|
| 2146 |
+
# Compute the proportion of rare tokens in each sequence
|
| 2147 |
+
rare_token_ids = self.get_rare_token_ids()
|
| 2148 |
+
rare_token_mask = torch.isin(input_ids, rare_token_ids)
|
| 2149 |
+
rare_token_counts = torch.sum(rare_token_mask, dim=-1)
|
| 2150 |
+
rare_token_scores = rare_token_counts / seq_lengths
|
| 2151 |
+
|
| 2152 |
+
# Combine length scores and rare token scores
|
| 2153 |
+
complexity_scores = self.complexity_factor * length_scores + (1 - self.complexity_factor) * rare_token_scores
|
| 2154 |
+
return complexity_scores
|
| 2155 |
+
|
| 2156 |
+
def get_rare_token_ids(self):
|
| 2157 |
+
# Get the IDs of rare tokens based on a predefined frequency threshold
|
| 2158 |
+
frequency_threshold = 1e-4
|
| 2159 |
+
token_counts = torch.bincount(self.model.embed_tokens.weight.argmax(dim=-1))
|
| 2160 |
+
total_tokens = torch.sum(token_counts)
|
| 2161 |
+
rare_token_mask = token_counts / total_tokens < frequency_threshold
|
| 2162 |
+
rare_token_ids = torch.nonzero(rare_token_mask).squeeze(-1)
|
| 2163 |
+
return rare_token_ids
|
| 2164 |
|
| 2165 |
|
| 2166 |
def prepare_inputs_for_generation(
|