diff --git a/config/locomotion.py b/config/locomotion.py deleted file mode 100644 index 4410bb1..0000000 --- a/config/locomotion.py +++ /dev/null @@ -1,70 +0,0 @@ -import socket - -from diffuser.utils import watch - -#------------------------ base ------------------------# - -## automatically make experiment names for planning -## by labelling folders with these args - -diffusion_args_to_watch = [ - ('prefix', ''), - ('horizon', 'H'), - ('n_diffusion_steps', 'T'), -] - -base = { - 'diffusion': { - ## model - 'model': 'models.TemporalUnet', - 'diffusion': 'models.GaussianDiffusion', - 'horizon': 32, - 'n_diffusion_steps': 100, - 'action_weight': 10, - 'loss_weights': None, - 'loss_discount': 1, - 'predict_epsilon': False, - 'dim_mults': (1, 4, 8), - 'renderer': 'utils.MuJoCoRenderer', - - ## dataset - 'loader': 'datasets.SequenceDataset', - 'normalizer': 'LimitsNormalizer', - 'preprocess_fns': [], - 'clip_denoised': True, - 'use_padding': True, - 'max_path_length': 1000, - - ## serialization - 'logbase': 'logs', - 'prefix': 'diffusion/', - 'exp_name': watch(diffusion_args_to_watch), - - ## training - 'n_steps_per_epoch': 10000, - 'loss_type': 'l2', - 'n_train_steps': 1e6, - 'batch_size': 32, - 'learning_rate': 2e-4, - 'gradient_accumulate_every': 2, - 'ema_decay': 0.995, - 'save_freq': 1000, - 'sample_freq': 1000, - 'n_saves': 5, - 'save_parallel': False, - 'n_reference': 8, - 'n_samples': 2, - 'bucket': None, - 'device': 'cuda', - }, -} - -#------------------------ overrides ------------------------# - -## put environment-specific overrides here - -halfcheetah_medium_expert_v2 = { - 'diffusion': { - 'horizon': 16, - }, -} diff --git a/config/maze2d.py b/config/maze2d.py index a06ac7f..0a8d22a 100644 --- a/config/maze2d.py +++ b/config/maze2d.py @@ -34,11 +34,11 @@ base = { 'model': 'models.TemporalUnet', 'diffusion': 'models.GaussianDiffusion', 'horizon': 256, - 'n_diffusion_steps': 256, + 'n_diffusion_steps': 512, 'action_weight': 1, 'loss_weights': None, 'loss_discount': 1, - 'predict_epsilon': False, + 'predict_epsilon': True, 'dim_mults': (1, 4, 8), 'renderer': 'utils.Maze2dRenderer', @@ -57,14 +57,14 @@ base = { 'exp_name': watch(diffusion_args_to_watch), ## training - 'n_steps_per_epoch': 10000, - 'loss_type': 'l2', - 'n_train_steps': 2e6, - 'batch_size': 32, - 'learning_rate': 2e-4, - 'gradient_accumulate_every': 2, + 'n_steps_per_epoch': 60000, + 'loss_type': 'spline', + 'n_train_steps': 6e4, + 'batch_size': 1, + 'learning_rate': 5e-6, + 'gradient_accumulate_every': 8, 'ema_decay': 0.995, - 'save_freq': 1000, + 'save_freq': 2000, 'sample_freq': 1000, 'n_saves': 50, 'save_parallel': False, @@ -89,7 +89,6 @@ base = { 'prefix': 'plans/release', 'exp_name': watch(plan_args_to_watch), 'suffix': '0', - 'conditional': False, ## loading @@ -122,10 +121,10 @@ maze2d_umaze_v1 = { maze2d_large_v1 = { 'diffusion': { 'horizon': 384, - 'n_diffusion_steps': 256, + 'n_diffusion_steps': 16, }, 'plan': { 'horizon': 384, - 'n_diffusion_steps': 256, + 'n_diffusion_steps': 16, }, } diff --git a/diffuser/datasets/buffer.py b/diffuser/datasets/buffer.py index 1ad2106..5991f01 100644 --- a/diffuser/datasets/buffer.py +++ b/diffuser/datasets/buffer.py @@ -9,7 +9,7 @@ class ReplayBuffer: def __init__(self, max_n_episodes, max_path_length, termination_penalty): self._dict = { - 'path_lengths': np.zeros(max_n_episodes, dtype=np.int), + 'path_lengths': np.zeros(max_n_episodes, dtype=np.int_), } self._count = 0 self.max_n_episodes = max_n_episodes diff --git a/diffuser/datasets/sequence.py b/diffuser/datasets/sequence.py index 356c540..73c1b04 100644 --- a/diffuser/datasets/sequence.py +++ b/diffuser/datasets/sequence.py @@ -83,6 +83,7 @@ class SequenceDataset(torch.utils.data.Dataset): actions = self.fields.normed_actions[path_ind, start:end] conditions = self.get_conditions(observations) + trajectories = np.concatenate([actions, observations], axis=-1) batch = Batch(trajectories, conditions) return batch diff --git a/diffuser/models/diffusion.py b/diffuser/models/diffusion.py index fae4cfd..461680a 100644 --- a/diffuser/models/diffusion.py +++ b/diffuser/models/diffusion.py @@ -2,6 +2,7 @@ import numpy as np import torch from torch import nn import pdb +import matplotlib.pyplot as plt import diffuser.utils as utils from .helpers import ( @@ -9,6 +10,7 @@ from .helpers import ( extract, apply_conditioning, Losses, + catmull_rom_spline_with_rotation, ) class GaussianDiffusion(nn.Module): @@ -26,6 +28,7 @@ class GaussianDiffusion(nn.Module): betas = cosine_beta_schedule(n_timesteps) alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) + print(f"Alphas Cumprod: {alphas_cumprod}") alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) self.n_timesteps = int(n_timesteps) @@ -73,7 +76,7 @@ class GaussianDiffusion(nn.Module): ''' self.action_weight = action_weight - dim_weights = torch.ones(self.transition_dim, dtype=torch.float32) + dim_weights = torch.ones(self.transition_dim, dtype=torch.float64) ## set loss coefficients for dimensions of observation if weights_dict is None: weights_dict = {} @@ -97,18 +100,16 @@ class GaussianDiffusion(nn.Module): otherwise, model predicts x0 directly ''' if self.predict_epsilon: - return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) + return noise else: return noise def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t[:, :, self.action_dim:] ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped @@ -129,7 +130,7 @@ class GaussianDiffusion(nn.Module): def p_sample(self, x, cond, t): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t) - noise = torch.randn_like(x) + noise = torch.randn_like(x[:, :, self.action_dim:]) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @@ -139,22 +140,59 @@ class GaussianDiffusion(nn.Module): device = self.betas.device batch_size = shape[0] - x = torch.randn(shape, device=device) - x = apply_conditioning(x, cond, self.action_dim) + # x = torch.randn(shape, device=device, dtype=torch.float64) + # Extract known indices and values + known_indices = np.array(list(cond.keys()), dtype=int) + + # candidate_no x batch_size x dim + known_values = np.stack([c.cpu().numpy() for c in cond.values()], axis=0) + known_values = np.moveaxis(known_values, 0, 1) + + # Sort the timepoints + sorted_indices = np.argsort(known_indices) + known_indices = known_indices[sorted_indices] + known_values = known_values[:, sorted_indices] + + # Build the structured spline guess + catmull_spline_trajectory = np.array([ + catmull_rom_spline_with_rotation(known_values[b, :, :-1], known_indices, shape[1]) + for b in range(batch_size) + ]) + catmull_spline_trajectory = torch.tensor( + catmull_spline_trajectory, + dtype=torch.float64, + device=device + ) + + + if self.predict_epsilon: + x = torch.randn((shape[0], shape[1], self.observation_dim), device=device, dtype=torch.float64) + cond_residual = {k: torch.zeros_like(v)[:, :-1] for k, v in cond.items()} + is_cond = torch.zeros((shape[0], shape[1], 1), device=device, dtype=torch.float64) + is_cond[:, known_indices, :] = 1.0 if return_diffusion: diffusion = [x] - progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent() + # progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent() for i in reversed(range(0, self.n_timesteps)): + if self.predict_epsilon: + x = torch.cat([catmull_spline_trajectory, is_cond, x], dim=-1) + timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long) - x = self.p_sample(x, cond, timesteps) - x = apply_conditioning(x, cond, self.action_dim) + x = self.p_sample(x, cond_residual, timesteps) + + x = apply_conditioning(x, cond_residual, 0) - progress.update({'t': i}) + if return_diffusion: diffusion.append(x) - if return_diffusion: diffusion.append(x) + x = catmull_spline_trajectory + x - progress.close() + + + # Normalize the quaternions + # x[:, :, 3:7] = x[:, :, 3:7] / torch.norm(x[:, :, 3:7], dim=-1, keepdim=True) + + # progress.close() if return_diffusion: return x, torch.stack(diffusion, dim=1) @@ -167,7 +205,7 @@ class GaussianDiffusion(nn.Module): conditions : [ (time, state), ... ] ''' device = self.betas.device - batch_size = len(cond[0]) + batch_size = len(next(iter(cond.values()))) horizon = horizon or self.horizon shape = (batch_size, horizon, self.transition_dim) @@ -175,38 +213,106 @@ class GaussianDiffusion(nn.Module): #------------------------------------------ training ------------------------------------------# - def q_sample(self, x_start, t, noise=None): + def q_sample(self, x_start, t, spline=None, noise=None): + x_start_noise = x_start[:, : , :-1] + x_start_is_cond = x_start[:, :, [-1]] + + if spline is None: + spline = torch.randn_like(x_start_noise) if noise is None: - noise = torch.randn_like(x_start) + noise = torch.randn_like(x_start_noise) - sample = ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) + alpha = extract(self.sqrt_alphas_cumprod, t, x_start.shape) + oneminusalpha = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + + # Weighted combination of x_0 and the spline + out = alpha * x_start_noise + oneminusalpha * noise + + # Concatenate the binary feature and the spline as the conditioning + out = torch.cat([spline, x_start_is_cond, out], dim=-1) - return sample + return out def p_losses(self, x_start, cond, t): - noise = torch.randn_like(x_start) + batch_size, horizon, _ = x_start.shape + # Extract known indices and values + known_indices = np.array(list(cond.keys()), dtype=int) + + # candidate_no x batch_size x dim + known_values = np.stack([c.cpu().numpy() for c in cond.values()], axis=0) + known_values = np.moveaxis(known_values, 0, 1) + + # Sort the timepoints + sorted_indices = np.argsort(known_indices) + known_indices = known_indices[sorted_indices] + known_values = known_values[:, sorted_indices] + + # Build your structured guess + catmull_spline_trajectory = np.array([ + catmull_rom_spline_with_rotation(known_values[b, :, :-1], known_indices, horizon) + for b in range(batch_size) + ]) + catmull_spline_trajectory = torch.tensor( + catmull_spline_trajectory, + dtype=torch.float64, + device=x_start.device + ) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) + # Plot the quaternions + # plt.plot(x_start[0, :, 3].cpu().numpy()) + # plt.plot(catmull_spline_trajectory[0, :, 3].cpu().numpy()) + # plt.legend(["x_start", "catmull_spline"]) + # plt.show() + # raise Exception - x_recon = self.model(x_noisy, cond, t) - x_recon = apply_conditioning(x_recon, cond, self.action_dim) - assert noise.shape == x_recon.shape + if not self.predict_epsilon: + # Forward diffuse with the structured trajectory + x_noisy = self.q_sample( + x_start, + t, + spline=catmull_spline_trajectory, + ) + x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) - if self.predict_epsilon: - loss, info = self.loss_fn(x_recon, noise) + # Reverse pass guess + x_recon = self.model(x_noisy, cond, t) + x_recon = apply_conditioning(x_recon, cond, self.action_dim) + + # Then x_recon is the predicted x_0, compare to the true x_0 + loss, info = self.loss_fn(x_recon, x_start, cond) else: - loss, info = self.loss_fn(x_recon, x_start) + residual = x_start.clone() + + residual[:, :, :-1] -= catmull_spline_trajectory + + + cond_residual = {k: torch.zeros_like(v)[:, :-1] for k, v in cond.items()} + + x_noisy = self.q_sample( + residual, + t, + spline=catmull_spline_trajectory, + ) + x_noisy = apply_conditioning(x_noisy, cond_residual, self.action_dim) + + # Reverse pass guess + x_recon = self.model(x_noisy, cond, t) + x_recon = apply_conditioning(x_recon, cond_residual, 0) + + x_recon = x_recon + catmull_spline_trajectory + + loss, info = self.loss_fn(x_recon, x_start[:, :, :-1], cond) return loss, info def loss(self, x, cond): batch_size = len(x) t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() + # t = torch.randint(1, 2, (batch_size,), device=x.device).long() + # x = x.double() + # cond = {k: v.double() for k, v in cond.items()} + # print(f"Time: {t.item()}") return self.p_losses(x, cond, t) def forward(self, cond, *args, **kwargs): diff --git a/diffuser/models/helpers.py b/diffuser/models/helpers.py index d39f35d..9f43ef8 100644 --- a/diffuser/models/helpers.py +++ b/diffuser/models/helpers.py @@ -1,11 +1,11 @@ import math +import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import einops from einops.layers.torch import Rearrange -import pdb +from pytorch3d.transforms import quaternion_to_matrix, quaternion_to_axis_angle import diffuser.utils as utils @@ -30,7 +30,7 @@ class SinusoidalPosEmb(nn.Module): class Downsample1d(nn.Module): def __init__(self, dim): super().__init__() - self.conv = nn.Conv1d(dim, dim, 3, 2, 1) + self.conv = nn.Conv1d(dim, dim, 3, 2, 1).to(torch.float64) def forward(self, x): return self.conv(x) @@ -38,7 +38,7 @@ class Downsample1d(nn.Module): class Upsample1d(nn.Module): def __init__(self, dim): super().__init__() - self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1).to(torch.float64) def forward(self, x): return self.conv(x) @@ -52,9 +52,9 @@ class Conv1dBlock(nn.Module): super().__init__() self.block = nn.Sequential( - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2).to(torch.float64), Rearrange('batch channels horizon -> batch channels 1 horizon'), - nn.GroupNorm(n_groups, out_channels), + nn.GroupNorm(n_groups, out_channels).to(torch.float64), Rearrange('batch channels 1 horizon -> batch channels horizon'), nn.Mish(), ) @@ -72,7 +72,7 @@ def extract(a, t, x_shape): out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) -def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32): +def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float64): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ @@ -157,9 +157,979 @@ class ValueL2(ValueLoss): def _loss(self, pred, targ): return F.mse_loss(pred, targ, reduction='none') +class GeodesicL2Loss(nn.Module): + def __init__(self, *args): + super().__init__() + pass + + def _loss(self, pred, targ): + # Compute L2 loss for the first three dimensions + l2_loss = F.mse_loss(pred[..., :3], targ[..., :3], reduction='mean') + + # Normalize to unit quaternions for the last four dimensions + pred_quat = pred[..., 3:] / pred[..., 3:].norm(dim=-1, keepdim=True) + targ_quat = targ[..., 3:] / targ[..., 3:].norm(dim=-1, keepdim=True) + + assert not torch.isnan(pred_quat).any(), "Pred Quat has NaNs" + assert not torch.isnan(targ_quat).any(), "Targ Quat has NaNs" + + # Compute dot product for the quaternions + dot_product = torch.sum(pred_quat * targ_quat, dim=-1) + dot_product = torch.clamp(torch.abs(dot_product), -1.0, 1.0) + + # Compute geodesic loss for the quaternions + geodesic_loss = 2 * torch.acos(dot_product).mean() + + assert not torch.isnan(geodesic_loss).any(), "Geodesic Loss has NaNs" + assert not torch.isnan(l2_loss).any(), "L2 Loss has NaNs" + + return l2_loss + geodesic_loss, l2_loss, geodesic_loss + + def forward(self, pred, targ): + loss, l2, geodesic = self._loss(pred, targ) + + info = { + 'l2': l2.item(), + 'geodesic': geodesic.item(), + } + + return loss, info + +class RotationTranslationLoss(nn.Module): + def __init__(self, *args): + super().__init__() + pass + + def _loss(self, pred, targ, cond=None): + + # Make sure the dtype is float64 + pred = pred.to(torch.float64) + targ = targ.to(torch.float64) + + eps = 1e-8 + + pred_trans = pred[..., :3] + pred_quat = pred[..., 3:7] + targ_trans = targ[..., :3] + targ_quat = targ[..., 3:7] + + l2_loss = F.mse_loss(pred_trans, targ_trans, reduction='mean') + + # Calculate the geodesic loss + pred_n = pred_quat.norm(dim=-1, keepdim=True).clamp(min=eps) + targ_n = targ_quat.norm(dim=-1, keepdim=True).clamp(min=eps) + + pred_quat_norm = pred_quat / pred_n + targ_quat_norm = targ_quat / targ_n + + + dot_product = torch.sum(pred_quat_norm * targ_quat_norm, dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps) + quaternion_dist = 1 - (dot_product ** 2).mean() + + # Calculate the rotation error + pred_rot = quaternion_to_matrix(pred_quat_norm).reshape(-1, 3, 3) + targ_rot = quaternion_to_matrix(targ_quat_norm).reshape(-1, 3, 3) + + r2r1 = pred_rot @ targ_rot.permute(0, 2, 1) + trace = torch.diagonal(r2r1, dim1=-2, dim2=-1).sum(-1) + trace = torch.clamp((trace - 1) / 2, -1.0 + eps, 1.0 - eps) + geodesic_loss = torch.acos(trace).mean() + + # Add a smoothness and acceleration term to the positions and quaternions + alpha = 1.0 + smoothness_loss = F.mse_loss(pred[:, 1:, :7].reshape(-1, 7), pred[:, :-1, :7].reshape(-1, 7), reduction='mean') + acceleration_loss = F.mse_loss(pred[:, 2:, :7].reshape(-1, 7), 2 * pred[:, 1:-1, :7].reshape(-1, 7) - pred[:, :-2, :7].reshape(-1, 7), reduction='mean') + + l2_multiplier = 10.0 + + loss = l2_multiplier * l2_loss + quaternion_dist + geodesic_loss + alpha * (smoothness_loss + acceleration_loss) + + dtw = DynamicTimeWarpingLoss() + dtw_loss, _ = dtw.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3)) + + hausdorff = HausdorffDistanceLoss() + hausdorff_loss, _ = hausdorff.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3)) + + frec = FrechetDistanceLoss() + frechet_loss, _ = frec.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3)) + + chamfer = ChamferDistanceLoss() + chamfer_loss, _ = chamfer.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3)) + + return loss, l2_loss, geodesic_loss, quaternion_dist, dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss + + + def forward(self, pred, targ, cond=None): + loss, err_t, err_geo, err_r, err_dtw, err_hausdorff, err_frechet, err_chamfer = self._loss(pred, targ, cond) + + info = { + 'rot. error': err_r.item(), + 'geodesic error': err_geo.item(), + 'trans. error': err_t.item(), + 'dtw': err_dtw.item(), + 'hausdorff': err_hausdorff.item(), + 'frechet': err_frechet.item(), + 'chamfer': err_chamfer.item(), + } + + return loss, info + +class SplineLoss(nn.Module): + def __init__(self, *args): + super().__init__() + self.scales = json.load(open('scene_scale.json')) + + def compute_spline_coeffs(self, trans): + p0 = trans[:, :-3, :] + p1 = trans[:, 1:-2, :] + p2 = trans[:, 2:-1, :] + p3 = trans[:, 3:, :] + + # Tangent approximations + m1 = 0.5 * (-p0 + p2) + m2 = 0.5 * (-p1 + p3) + + # Cubic spline coefficients for each dimension + a = (2 * p1 - 2 * p2 + m1 + m2) + b = (-3 * p1 + 3 * p2 - 2 * m1 - m2) + c = (m1) + d = (p1) + + return torch.stack([a, b, c, d], dim=-1) + + def q_normalize(self, q): + return q / q.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-12) + + def q_conjugate(self, q): + w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3] + return torch.stack([w, -x, -y, -z], dim=-1) + + def q_multiply(self, q1, q2): + """ + q1*q2. + """ + w1, x1, y1, z1 = q1.unbind(-1) + w2, x2, y2, z2 = q2.unbind(-1) + w = w1*w2 - x1*x2 - y1*y2 - z1*z2 + x = w1*x2 + x1*w2 + y1*z2 - z1*y2 + y = w1*y2 - x1*z2 + y1*w2 + z1*x2 + z = w1*z2 + x1*y2 - y1*x2 + z1*w2 + return torch.stack([w, x, y, z], dim=-1) + + def q_inverse(self, q): + return self.q_conjugate(self.q_normalize(q)) + + def q_log(self, q): + """ + Quaternion logarithm for a unit quaternion + Only returns the imaginary part + """ + q = self.q_normalize(q) + w = q[..., 0] + xyz = q[..., 1:] # shape [..., 3] + mag_v = xyz.norm(p=2, dim=-1) + eps = 1e-12 + angle = torch.acos(w.clamp(-1.0 + eps, 1.0 - eps)) + + # We do a safe-guard against zero for sin(angle) + small_mask = (mag_v < 1e-12) | (angle < 1e-12) + # Where small_mask is True => near identity => log(q) ~ 0 + log_val = torch.zeros_like(xyz) + + # Normal case + scale = angle / mag_v.clamp(min=1e-12) + normal_case = scale.unsqueeze(-1) * xyz + + log_val = torch.where( + small_mask.unsqueeze(-1), + torch.zeros_like(xyz), + normal_case + ) + return log_val + + def q_exp(self, v): + """ + Quaternion exponential + """ + norm_v = v.norm(p=2, dim=-1) + small_mask = norm_v < 1e-12 + + w = torch.cos(norm_v) + sin_v = torch.sin(norm_v) + scale = torch.where( + small_mask, + torch.zeros_like(norm_v), # if zero, sin(0)/0 => 0 + sin_v / norm_v.clamp(min=1e-12) + ) + xyz = scale.unsqueeze(-1) * v + + # For small angles, we approximate cos(norm_v) ~ 1, sin(norm_v)/norm_v ~ 1 + w = torch.where( + small_mask, + torch.ones_like(w), + w + ) + return torch.cat([w.unsqueeze(-1), xyz], dim=-1) + + def q_slerp(self, q1, q2, t): + """ + Spherical linear interpolation from q1 to q2 at t in [0,1]. + Both q1, q2 assumed normalized. + q1, q2, t can be 1D or broadcastable shapes, but typically 1D. + """ + q1 = self.q_normalize(q1) + q2 = self.q_normalize(q2) + dot = (q1 * q2).sum(dim=-1, keepdim=True) # the dot product + + eps = 1e-12 + dot = dot.clamp(-1.0 + eps, 1.0 - eps) + + flip_mask = dot < 0.0 + if flip_mask.any(): + q2 = torch.where(flip_mask, -q2, q2) + dot = torch.where(flip_mask, -dot, dot) + + # If they're very close, do a simple linear interpolation + close_mask = dot.squeeze(-1) > 0.9995 + # Using an epsilon to avoid potential issues close to 1.0 + + # Branch 1: Very close + # linear LERP + lerp_val = (1.0 - t) * q1 + t * q2 + lerp_val = self.q_normalize(lerp_val) + + # Branch 2: Standard SLERP + theta_0 = torch.acos(dot) + sin_theta_0 = torch.sin(theta_0) + theta = theta_0 * t + s1 = torch.sin(theta_0 - theta) / sin_theta_0.clamp(min=1e-12) + s2 = torch.sin(theta) / sin_theta_0.clamp(min=1e-12) + slerp_val = s1 * q1 + s2 * q2 + slerp_val = self.q_normalize(slerp_val) + + # Combine + return torch.where( + close_mask.unsqueeze(-1), + lerp_val, + slerp_val + ) + + def compute_uniform_tangent(self, q_im1, q_i, q_ip1): + """ + Computes a 'Catmull–Rom-like' tangent T_i for quaternion q_i, + given neighbors q_im1, q_i, q_ip1. + + T_i = q_i * exp( -0.25 * [ log(q_i^-1 q_ip1) + log(q_i^-1 q_im1) ] ) + """ + q_im1 = self.q_normalize(q_im1) + q_i = self.q_normalize(q_i) + q_ip1 = self.q_normalize(q_ip1) + + inv_qi = self.q_inverse(q_i) + r1 = self.q_multiply(inv_qi, q_ip1) + r2 = self.q_multiply(inv_qi, q_im1) + + lr1 = self.q_log(r1) + lr2 = self.q_log(r2) + + m = -0.25 * (lr1 + lr2) + exp_m = self.q_exp(m) + return self.q_multiply(q_i, exp_m) + + def compute_all_uniform_tangents(self, quats): + """ + Vectorized version that computes tangents T_i for all keyframe quaternions at once. + quats shape: [N,4], N >= 2 + Returns shape [N,4]. + """ + q_im1 = torch.cat([quats[[0]], quats[:-1]], dim=0) # q_im1[0] = q0 + q_ip1 = torch.cat([quats[1:], quats[[-1]]], dim=0) # q_ip1[N-1]= q_{N-1} + + return self.compute_uniform_tangent(q_im1, quats, q_ip1) + + def squad(self, q0, a, b, q1, t): + """ + Shoemake's "squad" interpolation for quaternion splines: + squad(q0, a, b, q1; t) = slerp( slerp(q0, q1; t), + slerp(a, b; t), + 2t(1-t) ) + where a, b are tangential control quaternions for q0, q1. + """ + s1 = self.q_slerp(q0, q1, t) + s2 = self.q_slerp(a, b, t) + alpha = 2.0*t*(1.0 - t) + return self.q_slerp(s1, s2, alpha) + + def uniform_cr_spline(self, quats, num_samples_per_segment=10): + """ + Given a list of keyframe quaternions quats (each a torch 1D tensor [4]), + compute a "Uniform Catmull–Rom–like" quaternion spline through them. + + Returns: + A list (Python list) of interpolated quaternions (torch tensors), + including all segment endpoints. + + Each interior qi gets a tangent T_i using neighbors q_{i-1}, q_i, q_{i+1}. + For boundary tangents, we replicate the end quaternions. + """ + n = quats.shape[0] + if n < 2: + return quats.unsqueeze(0) # not enough quats to interpolate + + # Precompute tangents + tangents = self.compute_all_uniform_tangents(quats) + + # Interpolate each segment [qi, q_{i+1}] + q0 = quats[:-1].unsqueeze(1) + q1 = quats[1:].unsqueeze(1) + a = tangents[:-1].unsqueeze(1) + b = tangents[1:].unsqueeze(1) + + t_vals = torch.linspace(0.0, 1.0, num_samples_per_segment, device=quats.device, dtype=quats.dtype) + t_vals = t_vals.view(1, -1, 1) + + out = self.squad(q0, a, b, q1, t_vals) + return out + + + def forward(self, pred, targ, cond=None, scene_id=None, norm_params=None): + loss, err_t, err_smooth, err_geo, err_r, err_dtw, err_hausdorff, err_frechet, err_chamfer = self._loss(pred, targ, cond, scene_id, norm_params) + + info = { + 'trans. error': err_t.item(), + 'smoothness error': err_smooth.item(), + # 'dtw': err_dtw.item(), + # 'hausdorff': err_hausdorff.item(), + # 'frechet': err_frechet.item(), + # 'chamfer': err_chamfer.item(), + 'quat. dist.': err_r.item(), + 'geodesic dist.': err_geo.item(), + } + + return loss, info + + def _loss(self, pred, targ, cond=None, scene_id=None, norm_params=None): + def poly_eval(coeffs, x): + """ + Evaluates a polynomial (with highest-degree term first) at points x. + coeffs: 2D tensor of shape [num_polynomials, degree + 1], highest-degree term first. + x: 1D tensor of points at which to evaluate the polynomial. + Returns: + 2D tensor of shape [num_polynomials, len(x)], containing p(x). + """ + x_powers = torch.stack([x**i for i in range(coeffs.shape[-1] - 1, -1, -1)], dim=-1) + x_powers = x_powers.to(torch.float64).to(coeffs.device) + y = torch.matmul(coeffs, x_powers.T) + return y + + # Make sure the dtype is float64 + pred = pred.to(torch.float64) + targ = targ.to(torch.float64) + + # Rescale the translations + if scene_id is not None and norm_params is not None: + scene_id = scene_id.item() + scene_scale = self.scales[str(scene_id)] + scene_scale = norm_params['scale'][0] * scene_scale + pred[..., :3] = pred[..., :3] * scene_scale + targ[..., :3] = targ[..., :3] * scene_scale + # print(pred[..., :3].max(), targ[..., :3].max()) + + # We only consider interpolated points for loss calculation + candidate_idxs = sorted(cond.keys()) + pred = pred[:, candidate_idxs[0] : candidate_idxs[-1] + 1, :] + targ = targ[:, candidate_idxs[0] : candidate_idxs[-1] + 1, :] + + pred_trans = pred[..., :3] + pred_quat = pred[..., 3:7] + targ_trans = targ[..., :3] + targ_quat = targ[..., 3:7] + + pred_coeffs = self.compute_spline_coeffs(pred_trans) + targ_coeffs = self.compute_spline_coeffs(targ_trans) + + n_points = 2000 + + # Distribute sample points among intervals + dists = torch.norm(targ_trans[:, 1:, :] - targ_trans[:, :-1, :], dim=-1).reshape(-1) + dists_c = torch.zeros(len(candidate_idxs) - 1, device=pred.device) + for i in range(len(candidate_idxs) - 1): + dists_c[i] = dists[candidate_idxs[i]:candidate_idxs[i+1]].sum() + + weights_c = dists_c / dists_c.sum() + scaled_c = weights_c * n_points + points_c = torch.floor(scaled_c).int() + + while points_c.sum() < n_points: + idx = torch.argmax(scaled_c - points_c) + points_c[idx] += 1 + + # Calculate the spline loss + sample_points = 50 + x = torch.linspace(0, 1, sample_points, device=pred.device) + pred_spline = poly_eval(pred_coeffs, x).permute(0, 1, 3, 2).reshape(-1, sample_points, 3) + targ_spline = poly_eval(targ_coeffs, x).permute(0, 1, 3, 2).reshape(-1, sample_points, 3) + + indexes = [] + start_idx = candidate_idxs[0] + for c, (idx_i0, idx_i1) in enumerate(zip(candidate_idxs[:-1], candidate_idxs[1:])): + p = points_c[c] + total_dist = dists_c[c] + dist_arr = dists[idx_i0 - start_idx : idx_i1 - start_idx] + + step_distances = (dist_arr / sample_points).repeat_interleave(sample_points) + cumul_distances = step_distances.cumsum(dim=0) + + dist_per_pick = total_dist / p + pick_targets = torch.arange(1, p + 1, device=dists.device) * dist_per_pick + + pick_idxs = torch.searchsorted(cumul_distances, pick_targets, right=True) + pick_idxs = torch.clamp(pick_idxs, max=len(cumul_distances) - 1) + + + indexes_1d = torch.zeros_like(step_distances) + indexes_1d[pick_idxs] = 1 + + indexes_2d = indexes_1d.view(len(dist_arr), sample_points) + + indexes.append(indexes_2d) + + indexes = torch.cat(indexes)[1: -1] # The first and last candidates don't have spline representations + + indexes_trans = torch.stack([indexes for _ in range(3)], dim=-1) + indexes_quat = torch.stack([indexes for _ in range(4)], dim=-1) + + indexes_trans = indexes_trans.to(torch.bool) + indexes_quat = indexes_quat.to(torch.bool) + + pred_trans_selected_values = pred_spline[indexes_trans] + targ_trans_selected_values = targ_spline[indexes_trans] + + pred_trans_selected_values = pred_trans_selected_values.reshape(-1, 3) + targ_trans_selected_values = targ_trans_selected_values.reshape(-1, 3) + + # Calculate the loss for quaternions + pred_quat = pred_quat / pred_quat.norm(dim=-1, keepdim=True).clamp(min=1e-8) + targ_quat = targ_quat / targ_quat.norm(dim=-1, keepdim=True).clamp(min=1e-8) + + targ_quat_spline = self.uniform_cr_spline(targ_quat.reshape(-1, 4), num_samples_per_segment=sample_points) + pred_quat_spline = self.uniform_cr_spline(pred_quat.reshape(-1, 4), num_samples_per_segment=sample_points) + + + targ_quat_spline = targ_quat_spline[1:-1] + pred_quat_spline = pred_quat_spline[1:-1] + + + pred_quat_selected_values = pred_quat_spline[indexes_quat] + targ_quat_selected_values = targ_quat_spline[indexes_quat] + + pred_quat_selected_values = pred_quat_selected_values.reshape(-1, 4) + targ_quat_selected_values = targ_quat_selected_values.reshape(-1, 4) + + # Calculate the geodesic loss + pred_rot = quaternion_to_matrix(pred_quat_selected_values).reshape(-1, 3, 3) + targ_rot = quaternion_to_matrix(targ_quat_selected_values).reshape(-1, 3, 3) + + eps = 1e-12 + r2r1 = pred_rot @ targ_rot.permute(0, 2, 1) + trace = torch.diagonal(r2r1, dim1=-2, dim2=-1).sum(-1) + trace = torch.clamp((trace - 1) / 2, -1.0 + eps, 1.0 - eps) + geodesic_loss = torch.acos(trace).mean() + + # Calculate the rotation error + dot_product = torch.sum(pred_quat_selected_values * targ_quat_selected_values, dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps) + quaternion_dist = 1 - (dot_product ** 2).mean() + + # Calculate the L2 loss + l2_loss = F.mse_loss(pred_trans_selected_values, targ_trans_selected_values, reduction='mean') + + # Calculate the smoothness loss for translation and quaternion + smoothness_multiplier = 10 ** 2 # Empirically determined multiplier for smoothness loss + weight_acceleration = 0.1 + weight_jerk = 0.05 + + pos_acc = pred_trans_selected_values[2:, :] - 2 * pred_trans_selected_values[1:-1, :] + pred_trans_selected_values[:-2, :] + pos_jerk = pred_trans_selected_values[3:, :] - 3 * pred_trans_selected_values[2:-1, :] + 3 * pred_trans_selected_values[1:-2, :] - pred_trans_selected_values[:-3, :] + + pos_acceleration_loss = torch.mean(pos_acc ** 2) + pos_jerk_loss = torch.mean(pos_jerk ** 2) + + q0 = pred_quat_selected_values[:-1, :] + q1 = pred_quat_selected_values[1:, :] + sign = torch.where((q0 * q1).sum(dim=-1) < 0, -1.0, 1.0) + q1 = sign.unsqueeze(-1) * q1 + + dq = self.q_multiply(q1, self.q_inverse(q0)) + theta = 2 * torch.acos(torch.clamp(dq[..., 0], -1.0 + 1e-8, 1.0 - 1e-8)) + + rot_acc = theta[2:] - 2*theta[1:-1] + theta[:-2] + rot_jerk = theta[3:] - 3*theta[2:-1] + 3*theta[1:-2] - theta[:-3] + + rot_acceleration_loss = torch.mean(rot_acc ** 2) + rot_jerk_loss = torch.mean(rot_jerk ** 2) + + alpha_rot = 0.1 # <-- tune this (e.g. 0.1 … 10) + + + acceleration_loss = pos_acceleration_loss + alpha_rot * rot_acceleration_loss + jerk_loss = pos_jerk_loss + alpha_rot * rot_jerk_loss + + smoothness_loss = ( + weight_acceleration * acceleration_loss + + weight_jerk * jerk_loss + ) * smoothness_multiplier + + + # Calculate the spline loss + l2_multiplier = 10.0 + spline_loss = l2_multiplier * (l2_loss + smoothness_loss) + geodesic_loss + quaternion_dist + + dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss = None, None, None, None + + # Uncomment these lines if you want to use the other losses + ''' + dtw = DynamicTimeWarpingLoss() + dtw_loss, _ = dtw.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3)) + + hausdorff = HausdorffDistanceLoss() + hausdorff_loss, _ = hausdorff.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3)) + + frec = FrechetDistanceLoss() + frechet_loss, _ = frec.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3)) + + chamfer = ChamferDistanceLoss() + chamfer_loss, _ = chamfer.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3)) + ''' + + return spline_loss, l2_multiplier * l2_loss, l2_multiplier * smoothness_loss, geodesic_loss, quaternion_dist, dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss + + +class DynamicTimeWarpingLoss(nn.Module): + def __init__(self): + super().__init__() + + def _dtw_distance(self, seq1: torch.Tensor, seq2: torch.Tensor) -> torch.Tensor: + """ + Computes the DTW distance between two 2D tensors (T x D), + where T is sequence length and D is feature dimension. + """ + # seq1, seq2 shapes: (time_steps, feature_dim) + n, m = seq1.size(0), seq2.size(0) + + # Cost matrix (pairwise distances between all elements) + cost = torch.zeros(n, m, device=seq1.device, dtype=seq1.dtype) + for i in range(n): + for j in range(m): + cost[i, j] = torch.norm(seq1[i] - seq2[j], p=2) + + # Accumulated cost matrix + dist = torch.full((n + 1, m + 1), float('inf'), + device=seq1.device, dtype=seq1.dtype) + dist[0, 0] = 0.0 + + # Populate the DP table + for i in range(1, n + 1): + for j in range(1, m + 1): + dist[i, j] = cost[i - 1, j - 1] + torch.min( + torch.min( + dist[i - 1, j], # Insertion + dist[i, j - 1], # Deletion + ), + dist[i - 1, j - 1]# Match + ) + + return dist[n, m] + + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """ + Compute the average DTW loss over a batch of sequences. + + pred, targ shapes: (batch_size, T, D) + """ + # Ensure shapes match in batch dimension + assert pred.size(0) == targ.size(0), "Batch sizes must match." + + # Compute DTW distance per sample in the batch + distances = [] + for b in range(pred.size(0)): + seq1 = pred[b] + seq2 = targ[b] + dtw_val = self._dtw_distance(seq1, seq2) + distances.append(dtw_val) + + # Stack and take mean to get scalar loss + dtw_loss = torch.stack(distances).mean() + return dtw_loss + + def forward(self, pred: torch.Tensor, targ: torch.Tensor): + """ + Returns a tuple: (loss, info_dict), + where loss is a scalar tensor and info_dict is a dictionary + of extra information (e.g., loss components). + """ + loss = self._loss(pred, targ) + + info = { + 'dtw': loss.item() + } + + return loss, info + +class HausdorffDistanceLoss(nn.Module): + def __init__(self): + super().__init__() + + def _hausdorff_distance(self, set1: torch.Tensor, set2: torch.Tensor) -> torch.Tensor: + """ + Computes the Hausdorff distance between two 2D tensors (N x D), + where N is the number of points and D is the feature dimension. + + The Hausdorff distance H(A,B) between two sets A and B is defined as: + H(A, B) = max( h(A, B), h(B, A) ), + where + h(A, B) = max_{a in A} min_{b in B} d(a, b). + + Here, d(a, b) is the Euclidean distance between points a and b. + """ + # set1, set2 shapes: (num_points, feature_dim) + n, m = set1.size(0), set2.size(0) + + # Compute pairwise distances + cost = torch.zeros(n, m, device=set1.device, dtype=set1.dtype) + for i in range(n): + for j in range(m): + cost[i, j] = torch.norm(set1[i] - set2[j], p=2) + + # Forward direction: for each point in set1, find distance to closest point in set2 + forward_min = cost.min(dim=1)[0] # Shape (n,) + forward_hausdorff = forward_min.max() # max over n + + # Backward direction: for each point in set2, find distance to closest point in set1 + backward_min = cost.min(dim=0)[0] # Shape (m,) + backward_hausdorff = backward_min.max() # max over m + + # Hausdorff distance is the max of the two + hausdorff_dist = torch.max(forward_hausdorff, backward_hausdorff) + return hausdorff_dist + + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """ + Compute the average Hausdorff distance over a batch of point sets. + + pred, targ shapes: (batch_size, N, D) + """ + # Ensure shapes match in batch dimension + assert pred.size(0) == targ.size(0), "Batch sizes must match." + + distances = [] + for b in range(pred.size(0)): + set1 = pred[b] + set2 = targ[b] + h_dist = self._hausdorff_distance(set1, set2) + distances.append(h_dist) + + # Stack and take mean to get scalar loss + hausdorff_loss = torch.stack(distances).mean() + return hausdorff_loss + + def forward(self, pred: torch.Tensor, targ: torch.Tensor): + """ + Returns a tuple: (loss, info_dict), + where loss is a scalar tensor and info_dict is a dictionary + of extra information (e.g., distance components). + """ + loss = self._loss(pred, targ) + + info = { + 'hausdorff': loss.item() + } + + return loss, info + +class FrechetDistanceLoss(nn.Module): + def __init__(self): + super().__init__() + + def _frechet_distance(self, seq1: torch.Tensor, seq2: torch.Tensor) -> torch.Tensor: + """ + Computes the (discrete) Fréchet distance between two 2D tensors (T x D), + where T is the sequence length and D is the feature dimension. + + The Fréchet distance between two curves in discrete form can be computed + by filling in a DP table “ca” where: + + ca[i, j] = max( d(seq1[i], seq2[j]), + min(ca[i-1, j], ca[i, j-1], ca[i-1, j-1]) ) + + with boundary conditions handled appropriately. + Here, d(seq1[i], seq2[j]) is the Euclidean distance. + """ + n, m = seq1.size(0), seq2.size(0) + + # Cost matrix (pairwise distances between all elements) + cost = torch.zeros(n, m, device=seq1.device, dtype=seq1.dtype) + for i in range(n): + for j in range(m): + cost[i, j] = torch.norm(seq1[i] - seq2[j], p=2) + + # DP matrix for the Fréchet distance + ca = torch.full((n, m), float('inf'), device=seq1.device, dtype=seq1.dtype) + ca[0, 0] = cost[0, 0] + + # Initialize first row + for i in range(1, n): + ca[i, 0] = torch.max(ca[i - 1, 0], cost[i, 0]) + + # Initialize first column + for j in range(1, m): + ca[0, j] = torch.max(ca[0, j - 1], cost[0, j]) + + # Populate the DP table + for i in range(1, n): + for j in range(1, m): + ca[i, j] = torch.max( + cost[i, j], + torch.min( + torch.min( + ca[i - 1, j], + ca[i, j - 1], + ), + ca[i - 1, j - 1] + ) + ) + + return ca[n - 1, m - 1] + + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """ + Compute the average Fréchet distance over a batch of sequences. + + pred, targ shapes: (batch_size, T, D) + """ + # Ensure shapes match in batch dimension + assert pred.size(0) == targ.size(0), "Batch sizes must match." + + distances = [] + for b in range(pred.size(0)): + seq1 = pred[b] + seq2 = targ[b] + fd_val = self._frechet_distance(seq1, seq2) + distances.append(fd_val) + + # Stack and take mean to get scalar loss + frechet_loss = torch.stack(distances).mean() + return frechet_loss + + def forward(self, pred: torch.Tensor, targ: torch.Tensor): + """ + Returns a tuple: (loss, info_dict), + where loss is a scalar tensor and info_dict is a dictionary + of extra information (e.g., distance components). + """ + loss = self._loss(pred, targ) + info = { + 'frechet': loss.item() + } + return loss, info + +class ChamferDistanceLoss(nn.Module): + def __init__(self): + super().__init__() + + def _chamfer_distance(self, set1: torch.Tensor, set2: torch.Tensor) -> torch.Tensor: + """ + Computes the symmetrical Chamfer distance between + two 2D tensors (N x D), where N is the number of points + and D is the feature dimension. + + The Chamfer distance between two point sets A and B is often defined as: + + d_chamfer(A, B) = 1/|A| ∑_{a ∈ A} min_{b ∈ B} ‖a - b‖₂ + + 1/|B| ∑_{b ∈ B} min_{a ∈ A} ‖b - a‖₂, + + where ‖·‖₂ is the Euclidean distance. + """ + # set1, set2 shapes: (num_points, feature_dim) + n, m = set1.size(0), set2.size(0) + + cost = torch.zeros(n, m, device=set1.device, dtype=set1.dtype) + for i in range(n): + for j in range(m): + cost[i, j] = torch.norm(set1[i] - set2[j], p=2) + + # For each point in set1, find distance to the closest point in set2 + forward_min = cost.min(dim=1)[0] # shape: (n,) + forward_mean = forward_min.mean() + + # For each point in set2, find distance to the closest point in set1 + backward_min = cost.min(dim=0)[0] # shape: (m,) + backward_mean = backward_min.mean() + + chamfer_dist = forward_mean + backward_mean + return chamfer_dist + + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """ + Compute the average Chamfer distance over a batch of point sets. + + pred, targ shapes: (batch_size, N, D) + """ + # Ensure shapes match in batch dimension + assert pred.size(0) == targ.size(0), "Batch sizes must match." + + distances = [] + for b in range(pred.size(0)): + set1 = pred[b] + set2 = targ[b] + distance_val = self._chamfer_distance(set1, set2) + distances.append(distance_val) + + # Combine into a single scalar + chamfer_loss = torch.stack(distances).mean() + return chamfer_loss + + def forward(self, pred: torch.Tensor, targ: torch.Tensor): + """ + Returns a tuple: (loss, info_dict), + where 'loss' is a scalar tensor and 'info_dict' is a dictionary + of extra information (e.g., distance components). + """ + loss = self._loss(pred, targ) + info = { + 'chamfer': loss.item() + } + return loss, info + + +def slerp(q1, q2, t): + """Spherical linear interpolation between two quaternions.""" + q1 = q1 / np.linalg.norm(q1) + q2 = q2 / np.linalg.norm(q2) + dot = np.dot(q1, q2) + + if dot < 0.0: + q2 = -q2 + dot = -dot + # If dot is very close to 1, use linear interpolation + + if dot > 0.9995: + result = q1 + t * (q2 - q1) + result = result / np.linalg.norm(result) + return result + + theta_0 = np.arccos(dot) + theta = theta_0 * t + + q3 = q2 - q1 * dot + q3 = q3 / np.linalg.norm(q3) + return q1 * np.cos(theta) + q3 * np.sin(theta) + +def catmull_rom_spline_with_rotation(control_points, timepoints, horizon): + """Compute Catmull-Rom spline for both position and quaternion rotation.""" + spline_points = [] + # Extrapolate the initial points + if timepoints[0] != 0: + for t in range(timepoints[0]): + x = control_points[0][0] + y = control_points[0][1] + z = control_points[0][2] + q = control_points[0][3:7] + spline_points.append(np.concatenate([np.array([x, y, z]), q])) + + #Linear interpolate between 0th and 1th control points + for t in np.linspace(0, 1, timepoints[1] - timepoints[0] + 1): + x = control_points[0][0] + t * (control_points[1][0] - control_points[0][0]) + y = control_points[0][1] + t * (control_points[1][1] - control_points[0][1]) + z = control_points[0][2] + t * (control_points[1][2] - control_points[0][2]) + q = slerp(control_points[0][3:7], control_points[1][3:7], t) + spline_points.append(np.concatenate([np.array([x, y, z]), q])) + + + # Iterate over the control points + for i in range(1, len(control_points) - 2): + P0 = control_points[i-1][:3] + P1 = control_points[i][:3] + P2 = control_points[i+1][:3] + P3 = control_points[i+2][:3] + Q0 = control_points[i-1][3:7] + Q1 = control_points[i][3:7] + Q2 = control_points[i+1][3:7] + Q3 = control_points[i+2][3:7] + + # Interpolate position (using Catmull-Rom spline) + for idx, t in enumerate(np.linspace(0, 1, timepoints[i+1] - timepoints[i] + 1)): + if idx == 0: + continue + + x = 0.5 * ((2 * P1[0]) + (-P0[0] + P2[0]) * t + + (2 * P0[0] - 5 * P1[0] + 4 * P2[0] - P3[0]) * t**2 + + (-P0[0] + 3 * P1[0] - 3 * P2[0] + P3[0]) * t**3) + y = 0.5 * ((2 * P1[1]) + (-P0[1] + P2[1]) * t + + (2 * P0[1] - 5 * P1[1] + 4 * P2[1] - P3[1]) * t**2 + + (-P0[1] + 3 * P1[1] - 3 * P2[1] + P3[1]) * t**3) + z = 0.5 * ((2 * P1[2]) + (-P0[2] + P2[2]) * t + + (2 * P0[2] - 5 * P1[2] + 4 * P2[2] - P3[2]) * t**2 + + (-P0[2] + 3 * P1[2] - 3 * P2[2] + P3[2]) * t**3) + q = slerp(Q1, Q2, t) + spline_points.append(np.concatenate([np.array([x, y, z]), q])) + + #Linear interpolate between 2nd last and last control points + for idx, t in enumerate(np.linspace(0, 1, timepoints[-1] - timepoints[-2] + 1)): + if idx == 0: + continue + x = control_points[-2][0] + t * (control_points[-1][0] - control_points[-2][0]) + y = control_points[-2][1] + t * (control_points[-1][1] - control_points[-2][1]) + z = control_points[-2][2] + t * (control_points[-1][2] - control_points[-2][2]) + q = slerp(control_points[-2][3:7], control_points[-1][3:7], t) + spline_points.append(np.concatenate([np.array([x, y, z]), q])) + + # Extrapolate the rest of the points + if timepoints[-1] != horizon: + for t in range(timepoints[-1] + 1, horizon): + x = control_points[-1][0] + y = control_points[-1][1] + z = control_points[-1][2] + q = control_points[-1][3:7] + spline_points.append(np.concatenate([np.array([x, y, z]), q])) + + stacked_spline_points = np.stack(spline_points, axis=0) + + if control_points.shape[1] != 7: + stacked_spline_points = np.concatenate([stacked_spline_points, np.zeros((stacked_spline_points.shape[0], 1))], axis=1) + + + return stacked_spline_points + +def catmull_rom_loss(trajectories, conditions, loss_fc): + ''' + loss for catmull-rom interpolation + ''' + batch_size, horizon, transition = trajectories.shape + + # Extract known indices and values + known_indices = np.array(list(conditions.keys()), dtype=int) + + # candidate_no x batch_size x dim + known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0) + known_values = np.moveaxis(known_values, 0, 1) + + # Sort the timepoints + sorted_indices = np.argsort(known_indices) + known_indices = known_indices[sorted_indices] + known_values = known_values[:, sorted_indices] + spline_points = np.array([catmull_rom_spline_with_rotation(known_values[b], known_indices, horizon) for b in range(batch_size)]) + + # Convert to tensor and move to the same device as trajectories + spline_points = torch.tensor(spline_points, dtype=torch.float64, device=trajectories.device) + assert spline_points.shape == trajectories.shape, f"Shape mismatch: {spline_points.shape} != {trajectories.shape}" + return loss_fc(spline_points, trajectories) + Losses = { 'l1': WeightedL1, 'l2': WeightedL2, 'value_l1': ValueL1, 'value_l2': ValueL2, + 'geodesic_l2': GeodesicL2Loss, + 'rotation_translation': RotationTranslationLoss, + 'spline': SplineLoss, } diff --git a/diffuser/models/temporal.py b/diffuser/models/temporal.py index e0b9e5c..0f7854a 100644 --- a/diffuser/models/temporal.py +++ b/diffuser/models/temporal.py @@ -17,18 +17,18 @@ class ResidualTemporalBlock(nn.Module): super().__init__() self.blocks = nn.ModuleList([ - Conv1dBlock(inp_channels, out_channels, kernel_size), - Conv1dBlock(out_channels, out_channels, kernel_size), + Conv1dBlock(inp_channels, out_channels, kernel_size).to(dtype=torch.float64), + Conv1dBlock(out_channels, out_channels, kernel_size).to(dtype=torch.float64), ]) self.time_mlp = nn.Sequential( nn.Mish(), - nn.Linear(embed_dim, out_channels), + nn.Linear(embed_dim, out_channels).to(dtype=torch.float64), Rearrange('batch t -> batch t 1'), - ) + ).to(dtype=torch.float64) - self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \ - if inp_channels != out_channels else nn.Identity() + self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1).to(dtype=torch.float64) \ + if inp_channels != out_channels else nn.Identity().to(dtype=torch.float64) def forward(self, x, t): ''' @@ -37,7 +37,8 @@ class ResidualTemporalBlock(nn.Module): returns: out : [ batch_size x out_channels x horizon ] ''' - out = self.blocks[0](x) + self.time_mlp(t) + + out = self.blocks[0](x) + self.time_mlp(t.double()) out = self.blocks[1](out) return out + self.residual_conv(x) @@ -49,11 +50,11 @@ class TemporalUnet(nn.Module): transition_dim, cond_dim, dim=32, - dim_mults=(1, 2, 4, 8), + dim_mults=(1, 2, 4), ): super().__init__() - dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] + dims = [(transition_dim + cond_dim), *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) print(f'[ models/temporal ] Channel dimensions: {in_out}') @@ -100,7 +101,7 @@ class TemporalUnet(nn.Module): self.final_conv = nn.Sequential( Conv1dBlock(dim, dim, kernel_size=5), - nn.Conv1d(dim, transition_dim, 1), + nn.Conv1d(dim, transition_dim, 1).to(dtype=torch.float64), ) def forward(self, x, cond, time): @@ -129,7 +130,6 @@ class TemporalUnet(nn.Module): x = upsample(x) x = self.final_conv(x) - x = einops.rearrange(x, 'b t h -> b h t') return x diff --git a/diffuser/utils/arrays.py b/diffuser/utils/arrays.py index c3a9d24..96a7093 100644 --- a/diffuser/utils/arrays.py +++ b/diffuser/utils/arrays.py @@ -54,7 +54,7 @@ def batchify(batch): 1) converting np arrays to torch tensors and 2) and ensuring that everything has a batch dimension ''' - fn = lambda x: to_torch(x[None]) + fn = lambda x: to_torch(x[None], dtype=torch.float64) batched_vals = [] for field in batch._fields: diff --git a/diffuser/utils/serialization.py b/diffuser/utils/serialization.py index 6cc9db9..039eb64 100644 --- a/diffuser/utils/serialization.py +++ b/diffuser/utils/serialization.py @@ -19,7 +19,7 @@ def mkdir(savepath): return False def get_latest_epoch(loadpath): - states = glob.glob1(os.path.join(*loadpath), 'state_*') + states = glob.glob1(os.path.join(loadpath), 'state_*') latest_epoch = -1 for state in states: epoch = int(state.replace('state_', '').replace('.pt', '')) diff --git a/diffuser/utils/training.py b/diffuser/utils/training.py index be3556e..c21e0f0 100644 --- a/diffuser/utils/training.py +++ b/diffuser/utils/training.py @@ -4,16 +4,24 @@ import numpy as np import torch import einops import pdb +from tqdm import tqdm +import wandb +from pytorch3d.transforms import axis_angle_to_quaternion from .arrays import batch_to_device, to_np, to_device, apply_dict from .timer import Timer from .cloud import sync_logs +from ..models.helpers import catmull_rom_spline_with_rotation def cycle(dl): while True: for data in dl: yield data +def assert_no_nan_weights(model): + for name, param in model.named_parameters(): + assert not torch.isnan(param).any(), f"NaN detected in parameter: {name}" + class EMA(): ''' empirical moving average @@ -71,13 +79,35 @@ class Trainer(object): self.gradient_accumulate_every = gradient_accumulate_every self.dataset = dataset - self.dataloader = cycle(torch.utils.data.DataLoader( - self.dataset, batch_size=train_batch_size, num_workers=1, shuffle=True, pin_memory=True + dataset_size = len(self.dataset) + + # Read the indices from the .txt file + with open(os.path.join(results_folder, 'train_indices.txt'), 'r') as f: + self.train_indices = f.read() + self.train_indices = [int(i) for i in self.train_indices.split('\n') if i] + + with open(os.path.join(results_folder, 'val_indices.txt'), 'r') as f: + self.val_indices = f.read() + self.val_indices = [int(i) for i in self.val_indices.split('\n') if i] + + + self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices) + self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices) + self.train_dataloader = cycle(torch.utils.data.DataLoader( + self.train_dataset, batch_size=train_batch_size, num_workers=1, pin_memory=True, shuffle=False + )) + + self.val_dataloader = cycle(torch.utils.data.DataLoader( + self.val_dataset, batch_size=train_batch_size, num_workers=1, pin_memory=True, shuffle=False )) + self.dataloader_vis = cycle(torch.utils.data.DataLoader( self.dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True )) self.renderer = renderer + + + self.optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=train_lr) self.logdir = results_folder @@ -88,6 +118,8 @@ class Trainer(object): self.reset_parameters() self.step = 0 + self.log_to_wandb = False + def reset_parameters(self): self.ema_model.load_state_dict(self.model.state_dict()) @@ -102,36 +134,129 @@ class Trainer(object): #-----------------------------------------------------------------------------# def train(self, n_train_steps): - + # Save the indices as .txt files + with open(os.path.join(self.logdir, 'train_indices.txt'), 'w') as f: + for idx in self.train_indices: + f.write(f"{idx}\n") + with open(os.path.join(self.logdir, 'val_indices.txt'), 'w') as f: + for idx in self.val_indices: + f.write(f"{idx}\n") + timer = Timer() - for step in range(n_train_steps): + torch.autograd.set_detect_anomaly(True) + + # Setup wandb + if self.log_to_wandb: + wandb.init( + project='trajectory-generation', + config={'lr': self.optimizer.param_groups[0]['lr'], 'batch_size': self.batch_size, 'gradient_accumulate_every': self.gradient_accumulate_every}, + ) + + for step in tqdm(range(n_train_steps)): + + mean_train_loss = 0.0 for i in range(self.gradient_accumulate_every): - batch = next(self.dataloader) + batch = next(self.train_dataloader) batch = batch_to_device(batch) - - loss, infos = self.model.loss(*batch) + + loss, infos = self.model.loss(x=batch.trajectories, cond=batch.conditions) loss = loss / self.gradient_accumulate_every + mean_train_loss += loss.item() loss.backward() + if self.log_to_wandb: + wandb.log({ + 'step': self.step, + 'train/loss': mean_train_loss + }) + + # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() self.optimizer.zero_grad() + assert_no_nan_weights(self.model) + if self.step % self.update_ema_every == 0: self.step_ema() if self.step % self.save_freq == 0: - label = self.step // self.label_freq * self.label_freq + label = self.step + print(f'Saving model at step {self.step}...') self.save(label) if self.step % self.log_freq == 0: - infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()]) - print(f'{self.step}: {loss:8.4f} | {infos_str} | t: {timer():8.4f}') + val_losses = [] + lin_int_losses = [] + + val_infos_list = [] + lin_int_infos_list = [] + + catmull_losses = [] + catmull_infos_list = [] + + for _ in range(len(self.val_indices)): + val_batch = next(self.val_dataloader) + val_batch = batch_to_device(val_batch) + + traj = self.model.forward(val_batch.conditions, horizon=val_batch.trajectories.shape[1]) + val_loss, val_infos = self.model.loss_fn(traj, val_batch.trajectories, cond=val_batch.conditions) + + val_losses.append(val_loss.item()) + val_infos_list.append({key: val for key, val in val_infos.items()}) + + + (lin_int_loss, lin_int_infos), lin_int_traj = self.linear_interpolation_loss( + val_batch.trajectories, val_batch.conditions, self.model.loss_fn + ) + lin_int_losses.append(lin_int_loss.item()) + lin_int_infos_list.append({key: val for key, val in lin_int_infos.items()}) + + (catmull_loss, catmull_infos), catmull_traj = self.catmull_rom_loss( + val_batch.trajectories, val_batch.conditions, self.model.loss_fn + ) + + catmull_losses.append(catmull_loss.item()) + catmull_infos_list.append(catmull_infos) + + avg_val_loss = np.mean(val_losses) + avg_lin_int_loss = np.mean(lin_int_losses) + + val_infos = {key: np.mean([info[key] for info in val_infos_list]) for key in val_infos_list[0].keys()} + lin_int_infos = {key: np.mean([info[key] for info in lin_int_infos_list]) for key in lin_int_infos_list[0].keys()} - if self.step == 0 and self.sample_freq: - self.render_reference(self.n_reference) + avg_catmull_loss = np.mean(catmull_losses) + catmull_infos = {key: np.mean([info[key] for info in catmull_infos_list]) for key in catmull_infos_list[0].keys()} - if self.sample_freq and self.step % self.sample_freq == 0: - self.render_samples(n_samples=self.n_samples) + val_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in val_infos.items()]) + lin_int_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in lin_int_infos.items()]) + catmull_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in catmull_infos.items()]) + + + infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()]) + print("Learning Rate: ", self.optimizer.param_groups[0]['lr']) + print(f'Step {self.step}: {loss * self.gradient_accumulate_every:8.4f} | {infos_str} | t: {timer():8.4f}') + print(f'Validation - {self.step}: {avg_val_loss:8.4f} | {val_infos_str} | t: {timer():8.4f}') + print(f'Linear Interpolation Loss - {self.step}: {avg_lin_int_loss:8.4f} | {lin_int_infos_str} | t: {timer():8.4f}') + print(f'Catmull Rom Loss - {self.step}: {avg_catmull_loss:8.4f} | {catmull_infos_str} | t: {timer():8.4f}') + print() + + if self.log_to_wandb: + wandb.log({ + 'step': self.step, + 'val/loss': avg_val_loss, + 'val/linear_interp/loss': avg_lin_int_loss, + 'val/linear_interp/quaternion dist.': lin_int_infos['quat. dist.'], + 'val/linear_interp/euclidean dist.': lin_int_infos['trans. error'], + 'val/linear_interp/geodesic loss': lin_int_infos['geodesic dist.'], + 'val/catmull_rom/loss': avg_catmull_loss, + 'val/catmull_rom/quaternion dist.': catmull_infos['quat. dist.'], + 'val/catmull_rom/euclidean dist.': catmull_infos['trans. error'], + 'val/catmull_rom/geodesic loss': catmull_infos['geodesic dist.'], + 'val/quaternion dist.': val_infos['quat. dist.'], + 'val/euclidean dist.': val_infos['trans. error'], + 'val/geodesic loss': val_infos['geodesic dist.'], + }) self.step += 1 @@ -186,15 +311,6 @@ class Trainer(object): normed_observations = trajectories[:, :, self.dataset.action_dim:] observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations') - # from diffusion.datasets.preprocessing import blocks_cumsum_quat - # # observations = conditions + blocks_cumsum_quat(deltas) - # observations = conditions + deltas.cumsum(axis=1) - - #### @TODO: remove block-stacking specific stuff - # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka - # observations = blocks_add_kuka(observations) - #### - savepath = os.path.join(self.logdir, f'_sample-reference.png') self.renderer.composite(savepath, observations) @@ -225,9 +341,6 @@ class Trainer(object): # [ 1 x 1 x observation_dim ] normed_conditions = to_np(batch.conditions[0])[:,None] - # from diffusion.datasets.preprocessing import blocks_cumsum_quat - # observations = conditions + blocks_cumsum_quat(deltas) - # observations = conditions + deltas.cumsum(axis=1) ## [ n_samples x (horizon + 1) x observation_dim ] normed_observations = np.concatenate([ @@ -238,10 +351,70 @@ class Trainer(object): ## [ n_samples x (horizon + 1) x observation_dim ] observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations') - #### @TODO: remove block-stacking specific stuff - # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka - # observations = blocks_add_kuka(observations) - #### - savepath = os.path.join(self.logdir, f'sample-{self.step}-{i}.png') self.renderer.composite(savepath, observations) + + def linear_interpolation_loss(self, trajectories, conditions, loss_fc, scene_id=None, norm_params=None): + batch_size, horizon, transition = trajectories.shape + + # Extract known indices and values + known_indices = np.array(list(conditions.keys()), dtype=int) + # candidate_no x batch_size x dim + known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0) + known_values = np.moveaxis(known_values, 0, 1) + + # Create time steps for interpolation + time_steps = np.linspace(0, horizon, num=horizon) + + # Perform interpolation across all dimensions at once + linear_int_arr = np.array([[ + np.interp(time_steps, known_indices, known_values[b, :, dim]) + for dim in range(transition)] + for b in range(batch_size)] + ).T # Transpose to match shape (horizon, transition) + + # Convert to tensor and move to the same device as trajectories + linear_int_arr = np.transpose(linear_int_arr, axes=[2, 0, 1]) + linear_int_tensor = torch.tensor(linear_int_arr, dtype=torch.float64, device=trajectories.device) + + return loss_fc(linear_int_tensor, trajectories, cond=conditions, scene_id=scene_id, norm_params=norm_params), linear_int_tensor + + + def catmull_rom_loss(self, trajectories, conditions, loss_fc, scene_id=None, norm_params=None): + ''' + loss for catmull-rom interpolation + ''' + + batch_size, horizon, transition = trajectories.shape + + # Extract known indices and values + known_indices = np.array(list(conditions.keys()), dtype=int) + # candidate_no x batch_size x dim + known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0) + known_values = np.moveaxis(known_values, 0, 1) + + # Sort the timepoints + sorted_indices = np.argsort(known_indices) + known_indices = known_indices[sorted_indices] + known_values = known_values[:, sorted_indices] + + spline_points = np.array([catmull_rom_spline_with_rotation(known_values[b], known_indices, horizon) for b in range(batch_size)]) + + # Convert to tensor and move to the same device as trajectories + spline_points = torch.tensor(spline_points, dtype=torch.float64, device=trajectories.device) + + assert spline_points.shape == trajectories.shape, f"Shape mismatch: {spline_points.shape} != {trajectories.shape}" + + return loss_fc(spline_points, trajectories, cond=conditions, scene_id=scene_id, norm_params=norm_params), spline_points + + + + + + + + + + + + diff --git a/scripts/train.py b/scripts/train.py index 2c5f299..6728d6f 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -108,6 +108,7 @@ utils.report_parameters(model) print('Testing forward...', end=' ', flush=True) batch = utils.batchify(dataset[0]) + loss, _ = diffusion.loss(*batch) loss.backward() print('✓')