Spaces:
Runtime error
Runtime error
| import math | |
| def adjust_learning_rate(optimizer, | |
| base_lr, | |
| p, | |
| itr, | |
| max_itr, | |
| restart=1, | |
| warm_up_steps=1000, | |
| is_cosine_decay=False, | |
| min_lr=1e-5, | |
| encoder_lr_ratio=1.0, | |
| freeze_params=[]): | |
| if restart > 1: | |
| each_max_itr = int(math.ceil(float(max_itr) / restart)) | |
| itr = itr % each_max_itr | |
| warm_up_steps /= restart | |
| max_itr = each_max_itr | |
| if itr < warm_up_steps: | |
| now_lr = min_lr + (base_lr - min_lr) * itr / warm_up_steps | |
| else: | |
| itr = itr - warm_up_steps | |
| max_itr = max_itr - warm_up_steps | |
| if is_cosine_decay: | |
| now_lr = min_lr + (base_lr - min_lr) * (math.cos(math.pi * itr / | |
| (max_itr + 1)) + | |
| 1.) * 0.5 | |
| else: | |
| now_lr = min_lr + (base_lr - min_lr) * (1 - itr / (max_itr + 1))**p | |
| for param_group in optimizer.param_groups: | |
| if encoder_lr_ratio != 1.0 and "encoder." in param_group["name"]: | |
| param_group['lr'] = (now_lr - min_lr) * encoder_lr_ratio + min_lr | |
| else: | |
| param_group['lr'] = now_lr | |
| for freeze_param in freeze_params: | |
| if freeze_param in param_group["name"]: | |
| param_group['lr'] = 0 | |
| param_group['weight_decay'] = 0 | |
| break | |
| return now_lr | |
| def get_trainable_params(model, | |
| base_lr, | |
| weight_decay, | |
| use_frozen_bn=False, | |
| exclusive_wd_dict={}, | |
| no_wd_keys=[]): | |
| params = [] | |
| memo = set() | |
| total_param = 0 | |
| for key, value in model.named_parameters(): | |
| if value in memo: | |
| continue | |
| total_param += value.numel() | |
| if not value.requires_grad: | |
| continue | |
| memo.add(value) | |
| wd = weight_decay | |
| for exclusive_key in exclusive_wd_dict.keys(): | |
| if exclusive_key in key: | |
| wd = exclusive_wd_dict[exclusive_key] | |
| break | |
| if len(value.shape) == 1: # normalization layers | |
| if 'bias' in key: # bias requires no weight decay | |
| wd = 0. | |
| elif not use_frozen_bn: # if not use frozen BN, apply zero weight decay | |
| wd = 0. | |
| elif 'encoder.' not in key: # if use frozen BN, apply weight decay to all frozen BNs in the encoder | |
| wd = 0. | |
| else: | |
| for no_wd_key in no_wd_keys: | |
| if no_wd_key in key: | |
| wd = 0. | |
| break | |
| params += [{ | |
| "params": [value], | |
| "lr": base_lr, | |
| "weight_decay": wd, | |
| "name": key | |
| }] | |
| print('Total Param: {:.2f}M'.format(total_param / 1e6)) | |
| return params | |
| def freeze_params(module): | |
| for p in module.parameters(): | |
| p.requires_grad = False | |
| def calculate_params(state_dict): | |
| memo = set() | |
| total_param = 0 | |
| for key, value in state_dict.items(): | |
| if value in memo: | |
| continue | |
| memo.add(value) | |
| total_param += value.numel() | |
| print('Total Param: {:.2f}M'.format(total_param / 1e6)) | |