Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from utils.math import generate_permute_matrix | |
| from utils.image import one_hot_mask | |
| from networks.layers.basic import seq_to_2d | |
| class AOTEngine(nn.Module): | |
| def __init__(self, | |
| aot_model, | |
| gpu_id=0, | |
| long_term_mem_gap=9999, | |
| short_term_mem_skip=1, | |
| max_len_long_term=9999): | |
| super().__init__() | |
| self.cfg = aot_model.cfg | |
| self.align_corners = aot_model.cfg.MODEL_ALIGN_CORNERS | |
| self.AOT = aot_model | |
| self.max_obj_num = aot_model.max_obj_num | |
| self.gpu_id = gpu_id | |
| self.long_term_mem_gap = long_term_mem_gap | |
| self.short_term_mem_skip = short_term_mem_skip | |
| self.max_len_long_term = max_len_long_term | |
| self.losses = None | |
| self.restart_engine() | |
| def forward(self, | |
| all_frames, | |
| all_masks, | |
| batch_size, | |
| obj_nums, | |
| step=0, | |
| tf_board=False, | |
| use_prev_pred=False, | |
| enable_prev_frame=False, | |
| use_prev_prob=False): # only used for training | |
| if self.losses is None: | |
| self._init_losses() | |
| self.freeze_id = True if use_prev_pred else False | |
| aux_weight = self.aux_weight * max(self.aux_step - step, | |
| 0.) / self.aux_step | |
| self.offline_encoder(all_frames, all_masks) | |
| self.add_reference_frame(frame_step=0, obj_nums=obj_nums) | |
| grad_state = torch.no_grad if aux_weight == 0 else torch.enable_grad | |
| with grad_state(): | |
| ref_aux_loss, ref_aux_mask = self.generate_loss_mask( | |
| self.offline_masks[self.frame_step], step) | |
| aux_losses = [ref_aux_loss] | |
| aux_masks = [ref_aux_mask] | |
| curr_losses, curr_masks = [], [] | |
| if enable_prev_frame: | |
| self.set_prev_frame(frame_step=1) | |
| with grad_state(): | |
| prev_aux_loss, prev_aux_mask = self.generate_loss_mask( | |
| self.offline_masks[self.frame_step], step) | |
| aux_losses.append(prev_aux_loss) | |
| aux_masks.append(prev_aux_mask) | |
| else: | |
| self.match_propogate_one_frame() | |
| curr_loss, curr_mask, curr_prob = self.generate_loss_mask( | |
| self.offline_masks[self.frame_step], step, return_prob=True) | |
| self.update_short_term_memory( | |
| curr_mask if not use_prev_prob else curr_prob, | |
| None if use_prev_pred else self.assign_identity( | |
| self.offline_one_hot_masks[self.frame_step])) | |
| curr_losses.append(curr_loss) | |
| curr_masks.append(curr_mask) | |
| self.match_propogate_one_frame() | |
| curr_loss, curr_mask, curr_prob = self.generate_loss_mask( | |
| self.offline_masks[self.frame_step], step, return_prob=True) | |
| curr_losses.append(curr_loss) | |
| curr_masks.append(curr_mask) | |
| for _ in range(self.total_offline_frame_num - 3): | |
| self.update_short_term_memory( | |
| curr_mask if not use_prev_prob else curr_prob, | |
| None if use_prev_pred else self.assign_identity( | |
| self.offline_one_hot_masks[self.frame_step])) | |
| self.match_propogate_one_frame() | |
| curr_loss, curr_mask, curr_prob = self.generate_loss_mask( | |
| self.offline_masks[self.frame_step], step, return_prob=True) | |
| curr_losses.append(curr_loss) | |
| curr_masks.append(curr_mask) | |
| aux_loss = torch.cat(aux_losses, dim=0).mean(dim=0) | |
| pred_loss = torch.cat(curr_losses, dim=0).mean(dim=0) | |
| loss = aux_weight * aux_loss + pred_loss | |
| all_pred_mask = aux_masks + curr_masks | |
| all_frame_loss = aux_losses + curr_losses | |
| boards = {'image': {}, 'scalar': {}} | |
| return loss, all_pred_mask, all_frame_loss, boards | |
| def _init_losses(self): | |
| cfg = self.cfg | |
| from networks.layers.loss import CrossEntropyLoss, SoftJaccordLoss | |
| bce_loss = CrossEntropyLoss( | |
| cfg.TRAIN_TOP_K_PERCENT_PIXELS, | |
| cfg.TRAIN_HARD_MINING_RATIO * cfg.TRAIN_TOTAL_STEPS) | |
| iou_loss = SoftJaccordLoss() | |
| losses = [bce_loss, iou_loss] | |
| loss_weights = [0.5, 0.5] | |
| self.losses = nn.ModuleList(losses) | |
| self.loss_weights = loss_weights | |
| self.aux_weight = cfg.TRAIN_AUX_LOSS_WEIGHT | |
| self.aux_step = cfg.TRAIN_TOTAL_STEPS * cfg.TRAIN_AUX_LOSS_RATIO + 1e-5 | |
| def encode_one_img_mask(self, img=None, mask=None, frame_step=-1): | |
| if frame_step == -1: | |
| frame_step = self.frame_step | |
| if self.enable_offline_enc: | |
| curr_enc_embs = self.offline_enc_embs[frame_step] | |
| elif img is None: | |
| curr_enc_embs = None | |
| else: | |
| curr_enc_embs = self.AOT.encode_image(img) | |
| if mask is not None: | |
| curr_one_hot_mask = one_hot_mask(mask, self.max_obj_num) | |
| elif self.enable_offline_enc: | |
| curr_one_hot_mask = self.offline_one_hot_masks[frame_step] | |
| else: | |
| curr_one_hot_mask = None | |
| return curr_enc_embs, curr_one_hot_mask | |
| def offline_encoder(self, all_frames, all_masks=None): | |
| self.enable_offline_enc = True | |
| self.offline_frames = all_frames.size(0) // self.batch_size | |
| # extract backbone features | |
| self.offline_enc_embs = self.split_frames( | |
| self.AOT.encode_image(all_frames), self.batch_size) | |
| self.total_offline_frame_num = len(self.offline_enc_embs) | |
| if all_masks is not None: | |
| # extract mask embeddings | |
| offline_one_hot_masks = one_hot_mask(all_masks, self.max_obj_num) | |
| self.offline_masks = list( | |
| torch.split(all_masks, self.batch_size, dim=0)) | |
| self.offline_one_hot_masks = list( | |
| torch.split(offline_one_hot_masks, self.batch_size, dim=0)) | |
| if self.input_size_2d is None: | |
| self.update_size(all_frames.size()[2:], | |
| self.offline_enc_embs[0][-1].size()[2:]) | |
| def assign_identity(self, one_hot_mask): | |
| if self.enable_id_shuffle: | |
| one_hot_mask = torch.einsum('bohw,bot->bthw', one_hot_mask, | |
| self.id_shuffle_matrix) | |
| id_emb = self.AOT.get_id_emb(one_hot_mask).view( | |
| self.batch_size, -1, self.enc_hw).permute(2, 0, 1) | |
| if self.training and self.freeze_id: | |
| id_emb = id_emb.detach() | |
| return id_emb | |
| def split_frames(self, xs, chunk_size): | |
| new_xs = [] | |
| for x in xs: | |
| all_x = list(torch.split(x, chunk_size, dim=0)) | |
| new_xs.append(all_x) | |
| return list(zip(*new_xs)) | |
| def add_reference_frame(self, | |
| img=None, | |
| mask=None, | |
| frame_step=-1, | |
| obj_nums=None, | |
| img_embs=None): | |
| if self.obj_nums is None and obj_nums is None: | |
| print('No objects for reference frame!') | |
| exit() | |
| elif obj_nums is not None: | |
| self.obj_nums = obj_nums | |
| if frame_step == -1: | |
| frame_step = self.frame_step | |
| if img_embs is None: | |
| curr_enc_embs, curr_one_hot_mask = self.encode_one_img_mask( | |
| img, mask, frame_step) | |
| else: | |
| _, curr_one_hot_mask = self.encode_one_img_mask( | |
| None, mask, frame_step) | |
| curr_enc_embs = img_embs | |
| if curr_enc_embs is None: | |
| print('No image for reference frame!') | |
| exit() | |
| if curr_one_hot_mask is None: | |
| print('No mask for reference frame!') | |
| exit() | |
| if self.input_size_2d is None: | |
| self.update_size(img.size()[2:], curr_enc_embs[-1].size()[2:]) | |
| self.curr_enc_embs = curr_enc_embs | |
| self.curr_one_hot_mask = curr_one_hot_mask | |
| if self.pos_emb is None: | |
| self.pos_emb = self.AOT.get_pos_emb(curr_enc_embs[-1]).expand( | |
| self.batch_size, -1, -1, | |
| -1).view(self.batch_size, -1, self.enc_hw).permute(2, 0, 1) | |
| curr_id_emb = self.assign_identity(curr_one_hot_mask) | |
| self.curr_id_embs = curr_id_emb | |
| # self matching and propagation | |
| self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs, | |
| None, | |
| None, | |
| curr_id_emb, | |
| pos_emb=self.pos_emb, | |
| size_2d=self.enc_size_2d) | |
| lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories = self.curr_lstt_output | |
| if self.long_term_memories is None: | |
| self.long_term_memories = lstt_long_memories | |
| else: | |
| self.update_long_term_memory(lstt_long_memories) | |
| self.last_mem_step = self.frame_step | |
| self.short_term_memories_list = [lstt_short_memories] | |
| self.short_term_memories = lstt_short_memories | |
| def set_prev_frame(self, img=None, mask=None, frame_step=1): | |
| self.frame_step = frame_step | |
| curr_enc_embs, curr_one_hot_mask = self.encode_one_img_mask( | |
| img, mask, frame_step) | |
| if curr_enc_embs is None: | |
| print('No image for previous frame!') | |
| exit() | |
| if curr_one_hot_mask is None: | |
| print('No mask for previous frame!') | |
| exit() | |
| self.curr_enc_embs = curr_enc_embs | |
| self.curr_one_hot_mask = curr_one_hot_mask | |
| curr_id_emb = self.assign_identity(curr_one_hot_mask) | |
| self.curr_id_embs = curr_id_emb | |
| # self matching and propagation | |
| self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs, | |
| None, | |
| None, | |
| curr_id_emb, | |
| pos_emb=self.pos_emb, | |
| size_2d=self.enc_size_2d) | |
| lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories = self.curr_lstt_output | |
| if self.long_term_memories is None: | |
| self.long_term_memories = lstt_long_memories | |
| else: | |
| self.update_long_term_memory(lstt_long_memories) | |
| self.last_mem_step = frame_step | |
| self.short_term_memories_list = [lstt_short_memories] | |
| self.short_term_memories = lstt_short_memories | |
| def update_long_term_memory(self, new_long_term_memories): | |
| TOKEN_NUM = new_long_term_memories[0][0].shape[0] | |
| if self.long_term_memories is None: | |
| self.long_term_memories = new_long_term_memories | |
| updated_long_term_memories = [] | |
| for new_long_term_memory, last_long_term_memory in zip( | |
| new_long_term_memories, self.long_term_memories): | |
| updated_e = [] | |
| for new_e, last_e in zip(new_long_term_memory, | |
| last_long_term_memory): | |
| if new_e is None or last_e is None: | |
| updated_e.append(None) | |
| else: | |
| if last_e.shape[0] >= self.max_len_long_term * TOKEN_NUM: | |
| last_e = last_e[:(self.max_len_long_term - 1) * TOKEN_NUM] | |
| updated_e.append(torch.cat([new_e, last_e], dim=0)) | |
| updated_long_term_memories.append(updated_e) | |
| self.long_term_memories = updated_long_term_memories | |
| def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False): | |
| if curr_id_emb is None: | |
| if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1: | |
| curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num) | |
| else: | |
| curr_one_hot_mask = curr_mask | |
| curr_id_emb = self.assign_identity(curr_one_hot_mask) | |
| lstt_curr_memories = self.curr_lstt_output[1] | |
| lstt_curr_memories_2d = [] | |
| for layer_idx in range(len(lstt_curr_memories)): | |
| curr_k, curr_v = lstt_curr_memories[layer_idx][ | |
| 0], lstt_curr_memories[layer_idx][1] | |
| curr_k, curr_v = self.AOT.LSTT.layers[layer_idx].fuse_key_value_id( | |
| curr_k, curr_v, curr_id_emb) | |
| lstt_curr_memories[layer_idx][0], lstt_curr_memories[layer_idx][ | |
| 1] = curr_k, curr_v | |
| lstt_curr_memories_2d.append([ | |
| seq_to_2d(lstt_curr_memories[layer_idx][0], self.enc_size_2d), | |
| seq_to_2d(lstt_curr_memories[layer_idx][1], self.enc_size_2d) | |
| ]) | |
| self.short_term_memories_list.append(lstt_curr_memories_2d) | |
| self.short_term_memories_list = self.short_term_memories_list[ | |
| -self.short_term_mem_skip:] | |
| self.short_term_memories = self.short_term_memories_list[0] | |
| if self.frame_step - self.last_mem_step >= self.long_term_mem_gap: | |
| # skip the update of long-term memory or not | |
| if not skip_long_term_update: | |
| self.update_long_term_memory(lstt_curr_memories) | |
| self.last_mem_step = self.frame_step | |
| def match_propogate_one_frame(self, img=None, img_embs=None): | |
| self.frame_step += 1 | |
| if img_embs is None: | |
| curr_enc_embs, _ = self.encode_one_img_mask( | |
| img, None, self.frame_step) | |
| else: | |
| curr_enc_embs = img_embs | |
| self.curr_enc_embs = curr_enc_embs | |
| self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs, | |
| self.long_term_memories, | |
| self.short_term_memories, | |
| None, | |
| pos_emb=self.pos_emb, | |
| size_2d=self.enc_size_2d) | |
| def decode_current_logits(self, output_size=None): | |
| curr_enc_embs = self.curr_enc_embs | |
| curr_lstt_embs = self.curr_lstt_output[0] | |
| pred_id_logits = self.AOT.decode_id_logits(curr_lstt_embs, | |
| curr_enc_embs) | |
| if self.enable_id_shuffle: # reverse shuffle | |
| pred_id_logits = torch.einsum('bohw,bto->bthw', pred_id_logits, | |
| self.id_shuffle_matrix) | |
| # remove unused identities | |
| for batch_idx, obj_num in enumerate(self.obj_nums): | |
| pred_id_logits[batch_idx, (obj_num+1):] = - \ | |
| 1e+10 if pred_id_logits.dtype == torch.float32 else -1e+4 | |
| self.pred_id_logits = pred_id_logits | |
| if output_size is not None: | |
| pred_id_logits = F.interpolate(pred_id_logits, | |
| size=output_size, | |
| mode="bilinear", | |
| align_corners=self.align_corners) | |
| return pred_id_logits | |
| def predict_current_mask(self, output_size=None, return_prob=False): | |
| if output_size is None: | |
| output_size = self.input_size_2d | |
| pred_id_logits = F.interpolate(self.pred_id_logits, | |
| size=output_size, | |
| mode="bilinear", | |
| align_corners=self.align_corners) | |
| pred_mask = torch.argmax(pred_id_logits, dim=1) | |
| if not return_prob: | |
| return pred_mask | |
| else: | |
| pred_prob = torch.softmax(pred_id_logits, dim=1) | |
| return pred_mask, pred_prob | |
| def calculate_current_loss(self, gt_mask, step): | |
| pred_id_logits = self.pred_id_logits | |
| pred_id_logits = F.interpolate(pred_id_logits, | |
| size=gt_mask.size()[-2:], | |
| mode="bilinear", | |
| align_corners=self.align_corners) | |
| label_list = [] | |
| logit_list = [] | |
| for batch_idx, obj_num in enumerate(self.obj_nums): | |
| now_label = gt_mask[batch_idx].long() | |
| now_logit = pred_id_logits[batch_idx, :(obj_num + 1)].unsqueeze(0) | |
| label_list.append(now_label.long()) | |
| logit_list.append(now_logit) | |
| total_loss = 0 | |
| for loss, loss_weight in zip(self.losses, self.loss_weights): | |
| total_loss = total_loss + loss_weight * \ | |
| loss(logit_list, label_list, step) | |
| return total_loss | |
| def generate_loss_mask(self, gt_mask, step, return_prob=False): | |
| self.decode_current_logits() | |
| loss = self.calculate_current_loss(gt_mask, step) | |
| if return_prob: | |
| mask, prob = self.predict_current_mask(return_prob=True) | |
| return loss, mask, prob | |
| else: | |
| mask = self.predict_current_mask() | |
| return loss, mask | |
| def keep_gt_mask(self, pred_mask, keep_prob=0.2): | |
| pred_mask = pred_mask.float() | |
| gt_mask = self.offline_masks[self.frame_step].float().squeeze(1) | |
| shape = [1 for _ in range(pred_mask.ndim)] | |
| shape[0] = self.batch_size | |
| random_tensor = keep_prob + torch.rand( | |
| shape, dtype=pred_mask.dtype, device=pred_mask.device) | |
| random_tensor.floor_() # binarize | |
| pred_mask = pred_mask * (1 - random_tensor) + gt_mask * random_tensor | |
| return pred_mask | |
| def restart_engine(self, batch_size=1, enable_id_shuffle=False): | |
| self.batch_size = batch_size | |
| self.frame_step = 0 | |
| self.last_mem_step = -1 | |
| self.enable_id_shuffle = enable_id_shuffle | |
| self.freeze_id = False | |
| self.obj_nums = None | |
| self.pos_emb = None | |
| self.enc_size_2d = None | |
| self.enc_hw = None | |
| self.input_size_2d = None | |
| self.long_term_memories = None | |
| self.short_term_memories_list = [] | |
| self.short_term_memories = None | |
| self.enable_offline_enc = False | |
| self.offline_enc_embs = None | |
| self.offline_one_hot_masks = None | |
| self.offline_frames = -1 | |
| self.total_offline_frame_num = 0 | |
| self.curr_enc_embs = None | |
| self.curr_memories = None | |
| self.curr_id_embs = None | |
| if enable_id_shuffle: | |
| self.id_shuffle_matrix = generate_permute_matrix( | |
| self.max_obj_num + 1, batch_size, gpu_id=self.gpu_id) | |
| else: | |
| self.id_shuffle_matrix = None | |
| def update_size(self, input_size, enc_size): | |
| self.input_size_2d = input_size | |
| self.enc_size_2d = enc_size | |
| self.enc_hw = self.enc_size_2d[0] * self.enc_size_2d[1] | |
| class AOTInferEngine(nn.Module): | |
| def __init__(self, | |
| aot_model, | |
| gpu_id=0, | |
| long_term_mem_gap=9999, | |
| short_term_mem_skip=1, | |
| max_aot_obj_num=None, | |
| max_len_long_term=9999,): | |
| super().__init__() | |
| self.cfg = aot_model.cfg | |
| self.AOT = aot_model | |
| if max_aot_obj_num is None or max_aot_obj_num > aot_model.max_obj_num: | |
| self.max_aot_obj_num = aot_model.max_obj_num | |
| else: | |
| self.max_aot_obj_num = max_aot_obj_num | |
| self.gpu_id = gpu_id | |
| self.long_term_mem_gap = long_term_mem_gap | |
| self.short_term_mem_skip = short_term_mem_skip | |
| self.max_len_long_term = max_len_long_term | |
| self.aot_engines = [] | |
| self.restart_engine() | |
| def restart_engine(self): | |
| del (self.aot_engines) | |
| self.aot_engines = [] | |
| self.obj_nums = None | |
| def separate_mask(self, mask, obj_nums): | |
| if mask is None: | |
| return [None] * len(self.aot_engines) | |
| if len(self.aot_engines) == 1: | |
| return [mask], [obj_nums] | |
| separated_obj_nums = [ | |
| self.max_aot_obj_num for _ in range(len(self.aot_engines)) | |
| ] | |
| if obj_nums % self.max_aot_obj_num > 0: | |
| separated_obj_nums[-1] = obj_nums % self.max_aot_obj_num | |
| if len(mask.size()) == 3 or mask.size()[0] == 1: | |
| separated_masks = [] | |
| for idx in range(len(self.aot_engines)): | |
| start_id = idx * self.max_aot_obj_num + 1 | |
| end_id = (idx + 1) * self.max_aot_obj_num | |
| fg_mask = ((mask >= start_id) & (mask <= end_id)).float() | |
| separated_mask = (fg_mask * mask - start_id + 1) * fg_mask | |
| separated_masks.append(separated_mask) | |
| return separated_masks, separated_obj_nums | |
| else: | |
| prob = mask | |
| separated_probs = [] | |
| for idx in range(len(self.aot_engines)): | |
| start_id = idx * self.max_aot_obj_num + 1 | |
| end_id = (idx + 1) * self.max_aot_obj_num | |
| fg_prob = prob[start_id:(end_id + 1)] | |
| bg_prob = 1. - torch.sum(fg_prob, dim=1, keepdim=True) | |
| separated_probs.append(torch.cat([bg_prob, fg_prob], dim=1)) | |
| return separated_probs, separated_obj_nums | |
| def min_logit_aggregation(self, all_logits): | |
| if len(all_logits) == 1: | |
| return all_logits[0] | |
| fg_logits = [] | |
| bg_logits = [] | |
| for logit in all_logits: | |
| bg_logits.append(logit[:, 0:1]) | |
| fg_logits.append(logit[:, 1:1 + self.max_aot_obj_num]) | |
| bg_logit, _ = torch.min(torch.cat(bg_logits, dim=1), | |
| dim=1, | |
| keepdim=True) | |
| merged_logit = torch.cat([bg_logit] + fg_logits, dim=1) | |
| return merged_logit | |
| def soft_logit_aggregation(self, all_logits): | |
| if len(all_logits) == 1: | |
| return all_logits[0] | |
| fg_probs = [] | |
| bg_probs = [] | |
| for logit in all_logits: | |
| prob = torch.softmax(logit, dim=1) | |
| bg_probs.append(prob[:, 0:1]) | |
| fg_probs.append(prob[:, 1:1 + self.max_aot_obj_num]) | |
| bg_prob = torch.prod(torch.cat(bg_probs, dim=1), dim=1, keepdim=True) | |
| merged_prob = torch.cat([bg_prob] + fg_probs, | |
| dim=1).clamp(1e-5, 1 - 1e-5) | |
| merged_logit = torch.logit(merged_prob) | |
| return merged_logit | |
| def add_reference_frame(self, img, mask, obj_nums, frame_step=-1): | |
| if isinstance(obj_nums, list): | |
| obj_nums = obj_nums[0] | |
| self.obj_nums = obj_nums | |
| aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1) | |
| while (aot_num > len(self.aot_engines)): | |
| new_engine = AOTEngine(self.AOT, self.gpu_id, | |
| self.long_term_mem_gap, | |
| self.short_term_mem_skip, | |
| self.max_len_long_term,) | |
| new_engine.eval() | |
| self.aot_engines.append(new_engine) | |
| separated_masks, separated_obj_nums = self.separate_mask( | |
| mask, obj_nums) | |
| img_embs = None | |
| for aot_engine, separated_mask, separated_obj_num in zip( | |
| self.aot_engines, separated_masks, separated_obj_nums): | |
| aot_engine.add_reference_frame(img, | |
| separated_mask, | |
| obj_nums=[separated_obj_num], | |
| frame_step=frame_step, | |
| img_embs=img_embs) | |
| if img_embs is None: # reuse image embeddings | |
| img_embs = aot_engine.curr_enc_embs | |
| self.update_size() | |
| def match_propogate_one_frame(self, img=None): | |
| img_embs = None | |
| for aot_engine in self.aot_engines: | |
| aot_engine.match_propogate_one_frame(img, img_embs=img_embs) | |
| if img_embs is None: # reuse image embeddings | |
| img_embs = aot_engine.curr_enc_embs | |
| def decode_current_logits(self, output_size=None): | |
| all_logits = [] | |
| for aot_engine in self.aot_engines: | |
| all_logits.append(aot_engine.decode_current_logits(output_size)) | |
| pred_id_logits = self.soft_logit_aggregation(all_logits) | |
| return pred_id_logits | |
| def update_memory(self, curr_mask, skip_long_term_update=False): | |
| _curr_mask = F.interpolate(curr_mask,self.input_size_2d) | |
| separated_masks, _ = self.separate_mask(_curr_mask, self.obj_nums) | |
| for aot_engine, separated_mask in zip(self.aot_engines, | |
| separated_masks): | |
| aot_engine.update_short_term_memory(separated_mask, | |
| skip_long_term_update=skip_long_term_update) | |
| def update_size(self): | |
| self.input_size_2d = self.aot_engines[0].input_size_2d | |
| self.enc_size_2d = self.aot_engines[0].enc_size_2d | |
| self.enc_hw = self.aot_engines[0].enc_hw | |