Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import json | |
| import datetime as datetime | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.distributed as dist | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| from dataloaders.train_datasets import DAVIS2017_Train, YOUTUBEVOS_Train, StaticTrain, TEST | |
| import dataloaders.video_transforms as tr | |
| from utils.meters import AverageMeter | |
| from utils.image import label2colormap, masked_image, save_image | |
| from utils.checkpoint import load_network_and_optimizer, load_network, save_network | |
| from utils.learning import adjust_learning_rate, get_trainable_params | |
| from utils.metric import pytorch_iou | |
| from utils.ema import ExponentialMovingAverage, get_param_buffer_for_ema | |
| from networks.models import build_vos_model | |
| from networks.engines import build_engine | |
| class Trainer(object): | |
| def __init__(self, rank, cfg, enable_amp=True): | |
| self.gpu = rank + cfg.DIST_START_GPU | |
| self.gpu_num = cfg.TRAIN_GPUS | |
| self.rank = rank | |
| self.cfg = cfg | |
| self.print_log("Exp {}:".format(cfg.EXP_NAME)) | |
| self.print_log(json.dumps(cfg.__dict__, indent=4, sort_keys=True)) | |
| print("Use GPU {} for training VOS.".format(self.gpu)) | |
| torch.cuda.set_device(self.gpu) | |
| torch.backends.cudnn.benchmark = True if cfg.DATA_RANDOMCROP[ | |
| 0] == cfg.DATA_RANDOMCROP[ | |
| 1] and 'swin' not in cfg.MODEL_ENCODER else False | |
| self.print_log('Build VOS model.') | |
| self.model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(self.gpu) | |
| self.model_encoder = self.model.encoder | |
| self.engine = build_engine( | |
| cfg.MODEL_ENGINE, | |
| 'train', | |
| aot_model=self.model, | |
| gpu_id=self.gpu, | |
| long_term_mem_gap=cfg.TRAIN_LONG_TERM_MEM_GAP) | |
| if cfg.MODEL_FREEZE_BACKBONE: | |
| for param in self.model_encoder.parameters(): | |
| param.requires_grad = False | |
| if cfg.DIST_ENABLE: | |
| dist.init_process_group(backend=cfg.DIST_BACKEND, | |
| init_method=cfg.DIST_URL, | |
| world_size=cfg.TRAIN_GPUS, | |
| rank=rank, | |
| timeout=datetime.timedelta(seconds=300)) | |
| self.model.encoder = nn.SyncBatchNorm.convert_sync_batchnorm( | |
| self.model.encoder).cuda(self.gpu) | |
| self.dist_engine = torch.nn.parallel.DistributedDataParallel( | |
| self.engine, | |
| device_ids=[self.gpu], | |
| output_device=self.gpu, | |
| find_unused_parameters=True, | |
| broadcast_buffers=False) | |
| else: | |
| self.dist_engine = self.engine | |
| self.use_frozen_bn = False | |
| if 'swin' in cfg.MODEL_ENCODER: | |
| self.print_log('Use LN in Encoder!') | |
| elif not cfg.MODEL_FREEZE_BN: | |
| if cfg.DIST_ENABLE: | |
| self.print_log('Use Sync BN in Encoder!') | |
| else: | |
| self.print_log('Use BN in Encoder!') | |
| else: | |
| self.use_frozen_bn = True | |
| self.print_log('Use Frozen BN in Encoder!') | |
| if self.rank == 0: | |
| try: | |
| total_steps = float(cfg.TRAIN_TOTAL_STEPS) | |
| ema_decay = 1. - 1. / (total_steps * cfg.TRAIN_EMA_RATIO) | |
| self.ema_params = get_param_buffer_for_ema( | |
| self.model, update_buffer=(not cfg.MODEL_FREEZE_BN)) | |
| self.ema = ExponentialMovingAverage(self.ema_params, | |
| decay=ema_decay) | |
| self.ema_dir = cfg.DIR_EMA_CKPT | |
| except Exception as inst: | |
| self.print_log(inst) | |
| self.print_log('Error: failed to create EMA model!') | |
| self.print_log('Build optimizer.') | |
| trainable_params = get_trainable_params( | |
| model=self.dist_engine, | |
| base_lr=cfg.TRAIN_LR, | |
| use_frozen_bn=self.use_frozen_bn, | |
| weight_decay=cfg.TRAIN_WEIGHT_DECAY, | |
| exclusive_wd_dict=cfg.TRAIN_WEIGHT_DECAY_EXCLUSIVE, | |
| no_wd_keys=cfg.TRAIN_WEIGHT_DECAY_EXEMPTION) | |
| if cfg.TRAIN_OPT == 'sgd': | |
| self.optimizer = optim.SGD(trainable_params, | |
| lr=cfg.TRAIN_LR, | |
| momentum=cfg.TRAIN_SGD_MOMENTUM, | |
| nesterov=True) | |
| else: | |
| self.optimizer = optim.AdamW(trainable_params, | |
| lr=cfg.TRAIN_LR, | |
| weight_decay=cfg.TRAIN_WEIGHT_DECAY) | |
| self.enable_amp = enable_amp | |
| if enable_amp: | |
| self.scaler = torch.cuda.amp.GradScaler() | |
| else: | |
| self.scaler = None | |
| self.prepare_dataset() | |
| self.process_pretrained_model() | |
| if cfg.TRAIN_TBLOG and self.rank == 0: | |
| from tensorboardX import SummaryWriter | |
| self.tblogger = SummaryWriter(cfg.DIR_TB_LOG) | |
| def process_pretrained_model(self): | |
| cfg = self.cfg | |
| self.step = cfg.TRAIN_START_STEP | |
| self.epoch = 0 | |
| if cfg.TRAIN_AUTO_RESUME: | |
| ckpts = os.listdir(cfg.DIR_CKPT) | |
| if len(ckpts) > 0: | |
| ckpts = list( | |
| map(lambda x: int(x.split('_')[-1].split('.')[0]), ckpts)) | |
| ckpt = np.sort(ckpts)[-1] | |
| cfg.TRAIN_RESUME = True | |
| cfg.TRAIN_RESUME_CKPT = ckpt | |
| cfg.TRAIN_RESUME_STEP = ckpt | |
| else: | |
| cfg.TRAIN_RESUME = False | |
| if cfg.TRAIN_RESUME: | |
| if self.rank == 0: | |
| try: | |
| try: | |
| ema_ckpt_dir = os.path.join( | |
| self.ema_dir, | |
| 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) | |
| ema_model, removed_dict = load_network( | |
| self.model, ema_ckpt_dir, self.gpu) | |
| except Exception as inst: | |
| self.print_log(inst) | |
| self.print_log('Try to use backup EMA checkpoint.') | |
| DIR_RESULT = './backup/{}/{}'.format( | |
| cfg.EXP_NAME, cfg.STAGE_NAME) | |
| DIR_EMA_CKPT = os.path.join(DIR_RESULT, 'ema_ckpt') | |
| ema_ckpt_dir = os.path.join( | |
| DIR_EMA_CKPT, | |
| 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) | |
| ema_model, removed_dict = load_network( | |
| self.model, ema_ckpt_dir, self.gpu) | |
| if len(removed_dict) > 0: | |
| self.print_log( | |
| 'Remove {} from EMA model.'.format(removed_dict)) | |
| ema_decay = self.ema.decay | |
| del (self.ema) | |
| ema_params = get_param_buffer_for_ema( | |
| ema_model, update_buffer=(not cfg.MODEL_FREEZE_BN)) | |
| self.ema = ExponentialMovingAverage(ema_params, | |
| decay=ema_decay) | |
| self.ema.num_updates = cfg.TRAIN_RESUME_CKPT | |
| except Exception as inst: | |
| self.print_log(inst) | |
| self.print_log('Error: EMA model not found!') | |
| try: | |
| resume_ckpt = os.path.join( | |
| cfg.DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) | |
| self.model, self.optimizer, removed_dict = load_network_and_optimizer( | |
| self.model, | |
| self.optimizer, | |
| resume_ckpt, | |
| self.gpu, | |
| scaler=self.scaler) | |
| except Exception as inst: | |
| self.print_log(inst) | |
| self.print_log('Try to use backup checkpoint.') | |
| DIR_RESULT = './backup/{}/{}'.format(cfg.EXP_NAME, | |
| cfg.STAGE_NAME) | |
| DIR_CKPT = os.path.join(DIR_RESULT, 'ckpt') | |
| resume_ckpt = os.path.join( | |
| DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) | |
| self.model, self.optimizer, removed_dict = load_network_and_optimizer( | |
| self.model, | |
| self.optimizer, | |
| resume_ckpt, | |
| self.gpu, | |
| scaler=self.scaler) | |
| if len(removed_dict) > 0: | |
| self.print_log( | |
| 'Remove {} from checkpoint.'.format(removed_dict)) | |
| self.step = cfg.TRAIN_RESUME_STEP | |
| if cfg.TRAIN_TOTAL_STEPS <= self.step: | |
| self.print_log("Your training has finished!") | |
| exit() | |
| self.epoch = int(np.ceil(self.step / len(self.train_loader))) | |
| self.print_log('Resume from step {}'.format(self.step)) | |
| elif cfg.PRETRAIN: | |
| if cfg.PRETRAIN_FULL: | |
| try: | |
| self.model, removed_dict = load_network( | |
| self.model, cfg.PRETRAIN_MODEL, self.gpu) | |
| except Exception as inst: | |
| self.print_log(inst) | |
| self.print_log('Try to use backup EMA checkpoint.') | |
| DIR_RESULT = './backup/{}/{}'.format( | |
| cfg.EXP_NAME, cfg.STAGE_NAME) | |
| DIR_EMA_CKPT = os.path.join(DIR_RESULT, 'ema_ckpt') | |
| PRETRAIN_MODEL = os.path.join( | |
| DIR_EMA_CKPT, | |
| cfg.PRETRAIN_MODEL.split('/')[-1]) | |
| self.model, removed_dict = load_network( | |
| self.model, PRETRAIN_MODEL, self.gpu) | |
| if len(removed_dict) > 0: | |
| self.print_log('Remove {} from pretrained model.'.format( | |
| removed_dict)) | |
| self.print_log('Load pretrained VOS model from {}.'.format( | |
| cfg.PRETRAIN_MODEL)) | |
| else: | |
| model_encoder, removed_dict = load_network( | |
| self.model_encoder, cfg.PRETRAIN_MODEL, self.gpu) | |
| if len(removed_dict) > 0: | |
| self.print_log('Remove {} from pretrained model.'.format( | |
| removed_dict)) | |
| self.print_log( | |
| 'Load pretrained backbone model from {}.'.format( | |
| cfg.PRETRAIN_MODEL)) | |
| def prepare_dataset(self): | |
| cfg = self.cfg | |
| self.enable_prev_frame = cfg.TRAIN_ENABLE_PREV_FRAME | |
| self.print_log('Process dataset...') | |
| if cfg.TRAIN_AUG_TYPE == 'v1': | |
| composed_transforms = transforms.Compose([ | |
| tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR, | |
| cfg.DATA_MAX_SCALE_FACTOR, | |
| cfg.DATA_SHORT_EDGE_LEN), | |
| tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP, | |
| max_obj_num=cfg.MODEL_MAX_OBJ_NUM), | |
| tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), | |
| tr.Resize(cfg.DATA_RANDOMCROP, use_padding=True), | |
| tr.ToTensor() | |
| ]) | |
| elif cfg.TRAIN_AUG_TYPE == 'v2': | |
| composed_transforms = transforms.Compose([ | |
| tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR, | |
| cfg.DATA_MAX_SCALE_FACTOR, | |
| cfg.DATA_SHORT_EDGE_LEN), | |
| tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP, | |
| max_obj_num=cfg.MODEL_MAX_OBJ_NUM), | |
| tr.RandomColorJitter(), | |
| tr.RandomGrayScale(), | |
| tr.RandomGaussianBlur(), | |
| tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), | |
| tr.Resize(cfg.DATA_RANDOMCROP, use_padding=True), | |
| tr.ToTensor() | |
| ]) | |
| else: | |
| assert NotImplementedError | |
| train_datasets = [] | |
| if 'static' in cfg.DATASETS: | |
| pretrain_vos_dataset = StaticTrain( | |
| cfg.DIR_STATIC, | |
| cfg.DATA_RANDOMCROP, | |
| seq_len=cfg.DATA_SEQ_LEN, | |
| merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB, | |
| max_obj_n=cfg.MODEL_MAX_OBJ_NUM, | |
| aug_type=cfg.TRAIN_AUG_TYPE) | |
| train_datasets.append(pretrain_vos_dataset) | |
| self.enable_prev_frame = False | |
| if 'davis2017' in cfg.DATASETS: | |
| train_davis_dataset = DAVIS2017_Train( | |
| root=cfg.DIR_DAVIS, | |
| full_resolution=cfg.TRAIN_DATASET_FULL_RESOLUTION, | |
| transform=composed_transforms, | |
| repeat_time=cfg.DATA_DAVIS_REPEAT, | |
| seq_len=cfg.DATA_SEQ_LEN, | |
| rand_gap=cfg.DATA_RANDOM_GAP_DAVIS, | |
| rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ, | |
| merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB, | |
| enable_prev_frame=self.enable_prev_frame, | |
| max_obj_n=cfg.MODEL_MAX_OBJ_NUM) | |
| train_datasets.append(train_davis_dataset) | |
| if 'youtubevos' in cfg.DATASETS: | |
| train_ytb_dataset = YOUTUBEVOS_Train( | |
| root=cfg.DIR_YTB, | |
| transform=composed_transforms, | |
| seq_len=cfg.DATA_SEQ_LEN, | |
| rand_gap=cfg.DATA_RANDOM_GAP_YTB, | |
| rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ, | |
| merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB, | |
| enable_prev_frame=self.enable_prev_frame, | |
| max_obj_n=cfg.MODEL_MAX_OBJ_NUM) | |
| train_datasets.append(train_ytb_dataset) | |
| if 'test' in cfg.DATASETS: | |
| test_dataset = TEST(transform=composed_transforms, | |
| seq_len=cfg.DATA_SEQ_LEN) | |
| train_datasets.append(test_dataset) | |
| if len(train_datasets) > 1: | |
| train_dataset = torch.utils.data.ConcatDataset(train_datasets) | |
| elif len(train_datasets) == 1: | |
| train_dataset = train_datasets[0] | |
| else: | |
| self.print_log('No dataset!') | |
| exit(0) | |
| self.train_sampler = torch.utils.data.distributed.DistributedSampler( | |
| train_dataset) if self.cfg.DIST_ENABLE else None | |
| self.train_loader = DataLoader(train_dataset, | |
| batch_size=int(cfg.TRAIN_BATCH_SIZE / | |
| cfg.TRAIN_GPUS), | |
| shuffle=False if self.cfg.DIST_ENABLE else True, | |
| num_workers=cfg.DATA_WORKERS, | |
| pin_memory=True, | |
| sampler=self.train_sampler, | |
| drop_last=True, | |
| prefetch_factor=4) | |
| self.print_log('Done!') | |
| def sequential_training(self): | |
| cfg = self.cfg | |
| if self.enable_prev_frame: | |
| frame_names = ['Ref', 'Prev'] | |
| else: | |
| frame_names = ['Ref(Prev)'] | |
| for i in range(cfg.DATA_SEQ_LEN - 1): | |
| frame_names.append('Curr{}'.format(i + 1)) | |
| seq_len = len(frame_names) | |
| running_losses = [] | |
| running_ious = [] | |
| for _ in range(seq_len): | |
| running_losses.append(AverageMeter()) | |
| running_ious.append(AverageMeter()) | |
| batch_time = AverageMeter() | |
| avg_obj = AverageMeter() | |
| optimizer = self.optimizer | |
| model = self.dist_engine | |
| train_sampler = self.train_sampler | |
| train_loader = self.train_loader | |
| step = self.step | |
| epoch = self.epoch | |
| max_itr = cfg.TRAIN_TOTAL_STEPS | |
| start_seq_training_step = int(cfg.TRAIN_SEQ_TRAINING_START_RATIO * | |
| max_itr) | |
| use_prev_prob = cfg.MODEL_USE_PREV_PROB | |
| self.print_log('Start training:') | |
| model.train() | |
| while step < cfg.TRAIN_TOTAL_STEPS: | |
| if self.cfg.DIST_ENABLE: | |
| train_sampler.set_epoch(epoch) | |
| epoch += 1 | |
| last_time = time.time() | |
| for frame_idx, sample in enumerate(train_loader): | |
| if step > cfg.TRAIN_TOTAL_STEPS: | |
| break | |
| if step % cfg.TRAIN_TBLOG_STEP == 0 and self.rank == 0 and cfg.TRAIN_TBLOG: | |
| tf_board = True | |
| else: | |
| tf_board = False | |
| if step >= start_seq_training_step: | |
| use_prev_pred = True | |
| freeze_params = cfg.TRAIN_SEQ_TRAINING_FREEZE_PARAMS | |
| else: | |
| use_prev_pred = False | |
| freeze_params = [] | |
| if step % cfg.TRAIN_LR_UPDATE_STEP == 0: | |
| now_lr = adjust_learning_rate( | |
| optimizer=optimizer, | |
| base_lr=cfg.TRAIN_LR, | |
| p=cfg.TRAIN_LR_POWER, | |
| itr=step, | |
| max_itr=max_itr, | |
| restart=cfg.TRAIN_LR_RESTART, | |
| warm_up_steps=cfg.TRAIN_LR_WARM_UP_RATIO * max_itr, | |
| is_cosine_decay=cfg.TRAIN_LR_COSINE_DECAY, | |
| min_lr=cfg.TRAIN_LR_MIN, | |
| encoder_lr_ratio=cfg.TRAIN_LR_ENCODER_RATIO, | |
| freeze_params=freeze_params) | |
| ref_imgs = sample['ref_img'] # batch_size * 3 * h * w | |
| prev_imgs = sample['prev_img'] | |
| curr_imgs = sample['curr_img'] | |
| ref_labels = sample['ref_label'] # batch_size * 1 * h * w | |
| prev_labels = sample['prev_label'] | |
| curr_labels = sample['curr_label'] | |
| obj_nums = sample['meta']['obj_num'] | |
| bs, _, h, w = curr_imgs[0].size() | |
| ref_imgs = ref_imgs.cuda(self.gpu, non_blocking=True) | |
| prev_imgs = prev_imgs.cuda(self.gpu, non_blocking=True) | |
| curr_imgs = [ | |
| curr_img.cuda(self.gpu, non_blocking=True) | |
| for curr_img in curr_imgs | |
| ] | |
| ref_labels = ref_labels.cuda(self.gpu, non_blocking=True) | |
| prev_labels = prev_labels.cuda(self.gpu, non_blocking=True) | |
| curr_labels = [ | |
| curr_label.cuda(self.gpu, non_blocking=True) | |
| for curr_label in curr_labels | |
| ] | |
| obj_nums = list(obj_nums) | |
| obj_nums = [int(obj_num) for obj_num in obj_nums] | |
| batch_size = ref_imgs.size(0) | |
| all_frames = torch.cat([ref_imgs, prev_imgs] + curr_imgs, | |
| dim=0) | |
| all_labels = torch.cat([ref_labels, prev_labels] + curr_labels, | |
| dim=0) | |
| self.engine.restart_engine(batch_size, True) | |
| optimizer.zero_grad(set_to_none=True) | |
| if self.enable_amp: | |
| with torch.cuda.amp.autocast(enabled=True): | |
| loss, all_pred, all_loss, boards = model( | |
| all_frames, | |
| all_labels, | |
| batch_size, | |
| use_prev_pred=use_prev_pred, | |
| obj_nums=obj_nums, | |
| step=step, | |
| tf_board=tf_board, | |
| enable_prev_frame=self.enable_prev_frame, | |
| use_prev_prob=use_prev_prob) | |
| loss = torch.mean(loss) | |
| start = time.time() | |
| self.scaler.scale(loss).backward() | |
| end = time.time() | |
| print(end-start) | |
| self.scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), | |
| cfg.TRAIN_CLIP_GRAD_NORM) | |
| self.scaler.step(optimizer) | |
| self.scaler.update() | |
| else: | |
| loss, all_pred, all_loss, boards = model( | |
| all_frames, | |
| all_labels, | |
| ref_imgs.size(0), | |
| use_prev_pred=use_prev_pred, | |
| obj_nums=obj_nums, | |
| step=step, | |
| tf_board=tf_board, | |
| enable_prev_frame=self.enable_prev_frame, | |
| use_prev_prob=use_prev_prob) | |
| loss = torch.mean(loss) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), | |
| cfg.TRAIN_CLIP_GRAD_NORM) | |
| loss.backward() | |
| optimizer.step() | |
| for idx in range(seq_len): | |
| now_pred = all_pred[idx].detach() | |
| now_label = all_labels[idx * bs:(idx + 1) * bs].detach() | |
| now_loss = torch.mean(all_loss[idx].detach()) | |
| now_iou = pytorch_iou(now_pred.unsqueeze(1), now_label, | |
| obj_nums) * 100 | |
| if self.cfg.DIST_ENABLE: | |
| dist.all_reduce(now_loss) | |
| dist.all_reduce(now_iou) | |
| now_loss /= self.gpu_num | |
| now_iou /= self.gpu_num | |
| if self.rank == 0: | |
| running_losses[idx].update(now_loss.item()) | |
| running_ious[idx].update(now_iou.item()) | |
| if self.rank == 0: | |
| self.ema.update(self.ema_params) | |
| avg_obj.update(sum(obj_nums) / float(len(obj_nums))) | |
| curr_time = time.time() | |
| batch_time.update(curr_time - last_time) | |
| last_time = curr_time | |
| if step % cfg.TRAIN_TBLOG_STEP == 0: | |
| all_f = [ref_imgs, prev_imgs] + curr_imgs | |
| self.process_log(ref_imgs, all_f[-2], all_f[-1], | |
| ref_labels, all_pred[-2], now_label, | |
| now_pred, boards, running_losses, | |
| running_ious, now_lr, step) | |
| if step % cfg.TRAIN_LOG_STEP == 0: | |
| strs = 'I:{}, LR:{:.5f}, T:{:.1f}({:.1f})s, Obj:{:.1f}({:.1f})'.format( | |
| step, now_lr, batch_time.val, | |
| batch_time.moving_avg, avg_obj.val, | |
| avg_obj.moving_avg) | |
| batch_time.reset() | |
| avg_obj.reset() | |
| for idx in range(seq_len): | |
| strs += ', {}: L {:.3f}({:.3f}) IoU {:.1f}({:.1f})%'.format( | |
| frame_names[idx], running_losses[idx].val, | |
| running_losses[idx].moving_avg, | |
| running_ious[idx].val, | |
| running_ious[idx].moving_avg) | |
| running_losses[idx].reset() | |
| running_ious[idx].reset() | |
| self.print_log(strs) | |
| step += 1 | |
| if step % cfg.TRAIN_SAVE_STEP == 0 and self.rank == 0: | |
| max_mem = torch.cuda.max_memory_allocated( | |
| device=self.gpu) / (1024.**3) | |
| ETA = str( | |
| datetime.timedelta( | |
| seconds=int(batch_time.moving_avg * | |
| (cfg.TRAIN_TOTAL_STEPS - step)))) | |
| self.print_log('ETA: {}, Max Mem: {:.2f}G.'.format( | |
| ETA, max_mem)) | |
| self.print_log('Save CKPT (Step {}).'.format(step)) | |
| save_network(self.model, | |
| optimizer, | |
| step, | |
| cfg.DIR_CKPT, | |
| cfg.TRAIN_MAX_KEEP_CKPT, | |
| backup_dir='./backup/{}/{}/ckpt'.format( | |
| cfg.EXP_NAME, cfg.STAGE_NAME), | |
| scaler=self.scaler) | |
| try: | |
| torch.cuda.empty_cache() | |
| # First save original parameters before replacing with EMA version | |
| self.ema.store(self.ema_params) | |
| # Copy EMA parameters to model | |
| self.ema.copy_to(self.ema_params) | |
| # Save EMA model | |
| save_network( | |
| self.model, | |
| optimizer, | |
| step, | |
| self.ema_dir, | |
| cfg.TRAIN_MAX_KEEP_CKPT, | |
| backup_dir='./backup/{}/{}/ema_ckpt'.format( | |
| cfg.EXP_NAME, cfg.STAGE_NAME), | |
| scaler=self.scaler) | |
| # Restore original parameters to resume training later | |
| self.ema.restore(self.ema_params) | |
| except Exception as inst: | |
| self.print_log(inst) | |
| self.print_log('Error: failed to save EMA model!') | |
| self.print_log('Stop training!') | |
| def print_log(self, string): | |
| if self.rank == 0: | |
| print(string) | |
| def process_log(self, ref_imgs, prev_imgs, curr_imgs, ref_labels, | |
| prev_labels, curr_labels, curr_pred, boards, | |
| running_losses, running_ious, now_lr, step): | |
| cfg = self.cfg | |
| mean = np.array([[[0.485]], [[0.456]], [[0.406]]]) | |
| sigma = np.array([[[0.229]], [[0.224]], [[0.225]]]) | |
| show_ref_img, show_prev_img, show_curr_img = [ | |
| img.cpu().numpy()[0] * sigma + mean | |
| for img in [ref_imgs, prev_imgs, curr_imgs] | |
| ] | |
| show_gt, show_prev_gt, show_ref_gt, show_preds_s = [ | |
| label.cpu()[0].squeeze(0).numpy() | |
| for label in [curr_labels, prev_labels, ref_labels, curr_pred] | |
| ] | |
| show_gtf, show_prev_gtf, show_ref_gtf, show_preds_sf = [ | |
| label2colormap(label).transpose((2, 0, 1)) | |
| for label in [show_gt, show_prev_gt, show_ref_gt, show_preds_s] | |
| ] | |
| if cfg.TRAIN_IMG_LOG or cfg.TRAIN_TBLOG: | |
| show_ref_img = masked_image(show_ref_img, show_ref_gtf, | |
| show_ref_gt) | |
| if cfg.TRAIN_IMG_LOG: | |
| save_image( | |
| show_ref_img, | |
| os.path.join(cfg.DIR_IMG_LOG, | |
| '%06d_ref_img.jpeg' % (step))) | |
| show_prev_img = masked_image(show_prev_img, show_prev_gtf, | |
| show_prev_gt) | |
| if cfg.TRAIN_IMG_LOG: | |
| save_image( | |
| show_prev_img, | |
| os.path.join(cfg.DIR_IMG_LOG, | |
| '%06d_prev_img.jpeg' % (step))) | |
| show_img_pred = masked_image(show_curr_img, show_preds_sf, | |
| show_preds_s) | |
| if cfg.TRAIN_IMG_LOG: | |
| save_image( | |
| show_img_pred, | |
| os.path.join(cfg.DIR_IMG_LOG, | |
| '%06d_prediction.jpeg' % (step))) | |
| show_curr_img = masked_image(show_curr_img, show_gtf, show_gt) | |
| if cfg.TRAIN_IMG_LOG: | |
| save_image( | |
| show_curr_img, | |
| os.path.join(cfg.DIR_IMG_LOG, | |
| '%06d_groundtruth.jpeg' % (step))) | |
| if cfg.TRAIN_TBLOG: | |
| for seq_step, running_loss, running_iou in zip( | |
| range(len(running_losses)), running_losses, | |
| running_ious): | |
| self.tblogger.add_scalar('S{}/Loss'.format(seq_step), | |
| running_loss.avg, step) | |
| self.tblogger.add_scalar('S{}/IoU'.format(seq_step), | |
| running_iou.avg, step) | |
| self.tblogger.add_scalar('LR', now_lr, step) | |
| self.tblogger.add_image('Ref/Image', show_ref_img, step) | |
| self.tblogger.add_image('Ref/GT', show_ref_gtf, step) | |
| self.tblogger.add_image('Prev/Image', show_prev_img, step) | |
| self.tblogger.add_image('Prev/GT', show_prev_gtf, step) | |
| self.tblogger.add_image('Curr/Image_GT', show_curr_img, step) | |
| self.tblogger.add_image('Curr/Image_Pred', show_img_pred, step) | |
| self.tblogger.add_image('Curr/Mask_GT', show_gtf, step) | |
| self.tblogger.add_image('Curr/Mask_Pred', show_preds_sf, step) | |
| for key in boards['image'].keys(): | |
| tmp = boards['image'][key].cpu().numpy() | |
| self.tblogger.add_image('S{}/' + key, tmp, step) | |
| for key in boards['scalar'].keys(): | |
| tmp = boards['scalar'][key].cpu().numpy() | |
| self.tblogger.add_scalar('S{}/' + key, tmp, step) | |
| self.tblogger.flush() | |
| del (boards) | |