Use weighted list reward functions
Browse filesseparate the calculation of advantage from the calculation of rewards
- src/retool_trainer.py +89 -12
src/retool_trainer.py
CHANGED
|
@@ -7,7 +7,7 @@ import datasets
|
|
| 7 |
import torch
|
| 8 |
import torch.utils.data
|
| 9 |
import transformers
|
| 10 |
-
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
|
| 11 |
from datasets import Dataset, IterableDataset
|
| 12 |
from packaging import version
|
| 13 |
from torch import nn
|
|
@@ -163,19 +163,96 @@ class ReToolTrainer(Trainer): # Change this line
|
|
| 163 |
return self._check_equivalence(predicted, ground_truth)
|
| 164 |
return False
|
| 165 |
|
| 166 |
-
def
|
| 167 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
-
#
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
#
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
return advantages
|
| 181 |
|
|
|
|
| 7 |
import torch
|
| 8 |
import torch.utils.data
|
| 9 |
import transformers
|
| 10 |
+
#from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
|
| 11 |
from datasets import Dataset, IterableDataset
|
| 12 |
from packaging import version
|
| 13 |
from torch import nn
|
|
|
|
| 163 |
return self._check_equivalence(predicted, ground_truth)
|
| 164 |
return False
|
| 165 |
|
| 166 |
+
def _compute_rewards(self, inputs, prompts, completions, completion_ids_list=None):
|
| 167 |
+
"""Calculate rewards for completions and combine them according to weights."""
|
| 168 |
+
device = self.device # Your device might be set differently
|
| 169 |
+
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
| 170 |
+
|
| 171 |
+
# Extract additional arguments from inputs if needed
|
| 172 |
+
reward_kwargs = {}
|
| 173 |
+
if isinstance(inputs, list) and len(inputs) > 0 and isinstance(inputs[0], dict):
|
| 174 |
+
keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
|
| 175 |
+
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
|
| 176 |
|
| 177 |
+
# Add correct_answers to kwargs if present (common in math reasoning tasks)
|
| 178 |
+
if "correct_answers" in reward_kwargs:
|
| 179 |
+
reward_kwargs["solution"] = reward_kwargs["correct_answers"] # Alias for compatibility
|
| 180 |
+
|
| 181 |
+
# Calculate rewards for each function with non-zero weight
|
| 182 |
+
for i, (reward_func, func_name) in enumerate(zip(self.reward_funcs, self.reward_func_names)):
|
| 183 |
+
# Skip computation if weight is zero
|
| 184 |
+
if abs(self.reward_weights[i].item()) < 1e-6:
|
| 185 |
+
rewards_per_func[:, i] = float('nan')
|
| 186 |
+
if self.verbose:
|
| 187 |
+
print(f"Skipping reward '{func_name}' (zero weight)")
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
# Calculate reward
|
| 191 |
+
try:
|
| 192 |
+
# Call the reward function with appropriate arguments
|
| 193 |
+
rewards = reward_func(
|
| 194 |
+
prompts=prompts,
|
| 195 |
+
completions=completions,
|
| 196 |
+
completion_ids=completion_ids_list if completion_ids_list is not None else None,
|
| 197 |
+
**reward_kwargs
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Convert None values to NaN and ensure it's a tensor
|
| 201 |
+
rewards = [r if r is not None else float('nan') for r in rewards]
|
| 202 |
+
rewards_per_func[:, i] = torch.tensor(rewards, dtype=torch.float32, device=device)
|
| 203 |
+
|
| 204 |
+
# Log reward statistics if verbose
|
| 205 |
+
if self.verbose:
|
| 206 |
+
valid_rewards = [r for r in rewards if not (r is None or (isinstance(r, float) and math.isnan(r)))]
|
| 207 |
+
if valid_rewards:
|
| 208 |
+
print(f"Reward '{func_name}': min={min(valid_rewards):.4f}, max={max(valid_rewards):.4f}, "
|
| 209 |
+
f"mean={sum(valid_rewards)/len(valid_rewards):.4f}")
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"Error in reward function '{func_name}': {e}")
|
| 212 |
+
rewards_per_func[:, i] = float('nan')
|
| 213 |
+
|
| 214 |
+
# Combine rewards using weights
|
| 215 |
+
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
|
| 216 |
|
| 217 |
+
# Convert to list for easier handling
|
| 218 |
+
final_rewards = rewards.cpu().tolist()
|
| 219 |
+
|
| 220 |
+
return final_rewards
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def compute_rewards_and_advantages(self, inputs, prompts, completions, completion_ids_list=None):
|
| 224 |
+
"""Calculate rewards and compute advantages based on those rewards."""
|
| 225 |
+
# First calculate rewards
|
| 226 |
+
rewards = self.compute_rewards(inputs, prompts, completions, completion_ids_list)
|
| 227 |
+
|
| 228 |
+
# Convert to tensor if not already
|
| 229 |
+
if not isinstance(rewards, torch.Tensor):
|
| 230 |
+
rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device)
|
| 231 |
+
|
| 232 |
+
# For now, simple advantage calculation
|
| 233 |
+
advantages = rewards.clone() # Simple case: advantages = rewards
|
| 234 |
+
|
| 235 |
+
# If later I want to implement GRPO-style advantage calculation:
|
| 236 |
+
if self.use_grouped_advantages:
|
| 237 |
+
# Reshape rewards into groups (assuming self.num_generations is set)
|
| 238 |
+
grouped_rewards = rewards.view(-1, self.num_generations)
|
| 239 |
+
|
| 240 |
+
# Calculate statistics per group
|
| 241 |
+
mean_grouped_rewards = grouped_rewards.mean(dim=1)
|
| 242 |
+
std_grouped_rewards = grouped_rewards.std(dim=1)
|
| 243 |
+
|
| 244 |
+
# Expand means and stds to match original shape
|
| 245 |
+
mean_expanded = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 246 |
+
std_expanded = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 247 |
+
|
| 248 |
+
# Compute advantages: reward - baseline
|
| 249 |
+
advantages = rewards - mean_expanded
|
| 250 |
+
|
| 251 |
+
# Optionally normalize advantages
|
| 252 |
+
if self.normalize_advantages:
|
| 253 |
+
# Avoid division by zero
|
| 254 |
+
std_expanded = torch.clamp(std_expanded, min=1e-8)
|
| 255 |
+
advantages = advantages / std_expanded
|
| 256 |
|
| 257 |
return advantages
|
| 258 |
|