Spaces:
Running
on
Zero
Running
on
Zero
| import os, sys | |
| import json | |
| import cv2 | |
| import math | |
| import shutil | |
| import numpy as np | |
| import random | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import torch | |
| import os.path as osp | |
| import time | |
| from moviepy.editor import VideoFileClip | |
| from torch.utils.data import Dataset | |
| # Import files from the local folder | |
| root_path = os.path.abspath('.') | |
| sys.path.append(root_path) | |
| from utils.img_utils import resize_with_antialiasing, numpy_to_pt | |
| from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian | |
| from data_loader.video_dataset import tokenize_captions | |
| # For the 2D dilation | |
| blur_kernel = bivariate_Gaussian(99, 10, 10, 0, grid = None, isotropic = True) | |
| def get_thisthat_sam(config, intput_dir, store_dir = None, flip = False, verbose=False): | |
| ''' | |
| Args: | |
| idx (int): The index to the folder we need to process | |
| ''' | |
| # Read file | |
| file_path = os.path.join(intput_dir, "data.txt") | |
| file1 = open(file_path, 'r') | |
| Lines = file1.readlines() | |
| # Initial the optical flow format we want | |
| thisthat_condition = np.zeros((config["video_seq_length"], config["conditioning_channels"], config["height"], config["width"]), dtype=np.float32) # The last image should be empty | |
| # Init the image | |
| sample_img = cv2.imread(os.path.join(intput_dir, "im_0.jpg")) | |
| org_height, org_width, _ = sample_img.shape | |
| # Prepare masking | |
| controlnet_image_index = [] | |
| coordinate_values = [] | |
| # Iterate all points in the txt file | |
| for idx in range(len(Lines)): | |
| # Read points | |
| frame_idx, horizontal, vertical = Lines[idx].split(' ') | |
| frame_idx, vertical, horizontal = int(frame_idx), int(float(vertical)), int(float(horizontal)) | |
| # Read the mask frame idx | |
| controlnet_image_index.append(frame_idx) | |
| coordinate_values.append((vertical, horizontal)) | |
| # Init the base image | |
| base_img = np.zeros((org_height, org_width, 3)).astype(np.float32) # Use the original image size | |
| base_img.fill(255) | |
| # Draw square around the target position | |
| dot_range = 10 # Diameter | |
| for i in range(-1*dot_range, dot_range+1): | |
| for j in range(-1*dot_range, dot_range+1): | |
| dil_vertical, dil_horizontal = vertical + i, horizontal + j | |
| if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]): | |
| if idx == 0: | |
| base_img[dil_vertical][dil_horizontal] = [0, 0, 255] # The first point should be red | |
| else: | |
| base_img[dil_vertical][dil_horizontal] = [0, 255, 0] # The second point should be green to distinguish the first point | |
| # Dilate | |
| if config["dilate"]: | |
| base_img = cv2.filter2D(base_img, -1, blur_kernel) | |
| ############################################################################################################################## | |
| ### The core pipeline of processing is: Dilate -> Resize -> Range Shift -> Transpose Shape -> Store | |
| # Resize frames Don't use negative and don't resize in [0,1] | |
| base_img = cv2.resize(base_img, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC) | |
| # Flip the image for aug if needed | |
| if flip: | |
| base_img = np.fliplr(base_img) | |
| # Channel Transform and Range Shift | |
| if config["conditioning_channels"] == 3: | |
| # Map to [0, 1] range | |
| if store_dir is not None and verbose: # For the first frame condition visualization | |
| cv2.imwrite(os.path.join(store_dir, "condition_TT"+str(idx)+".png"), base_img) | |
| base_img = base_img / 255.0 | |
| else: | |
| raise NotImplementedError() | |
| # ReOrganize shape | |
| base_img = base_img.transpose(2, 0, 1) # hwc -> chw | |
| # Check the min max value range | |
| # if verbose: | |
| # print("{} min, max range value is {} - {}".format(intput_dir, np.min(base_img), np.max(base_img))) | |
| # Write base img based on frame_idx | |
| thisthat_condition[frame_idx] = base_img # Only the first frame, the rest is 0 initialized | |
| ############################################################################################################################## | |
| if config["motion_bucket_id"] is None: | |
| # take the motion to stats collected before | |
| reflected_motion_bucket_id = 200 | |
| else: | |
| reflected_motion_bucket_id = config["motion_bucket_id"] | |
| # print("Motion Bucket ID is ", reflected_motion_bucket_id) | |
| return (thisthat_condition, reflected_motion_bucket_id, controlnet_image_index, coordinate_values) | |
| class Video_ThisThat_Dataset(Dataset): | |
| ''' | |
| Video Dataset to load sequential frames for training with needed pre-processing and process with optical flow | |
| ''' | |
| def __init__(self, config, device, normalize=True, tokenizer=None): | |
| # Attribute variables | |
| self.config = config | |
| self.device = device | |
| self.normalize = normalize | |
| self.tokenizer = tokenizer | |
| # Obtain values | |
| self.video_seq_length = config["video_seq_length"] | |
| self.height = config["height"] | |
| self.width = config["width"] | |
| # Process data | |
| self.video_lists = [] | |
| for dataset_path in config["dataset_path"]: | |
| for video_name in sorted(os.listdir(dataset_path)): | |
| if not os.path.exists(os.path.join(dataset_path, video_name, "data.txt")): | |
| continue | |
| self.video_lists.append(os.path.join(dataset_path, video_name)) | |
| print("length of the dataset is ", len(self.video_lists)) | |
| def __len__(self): | |
| return len(self.video_lists) | |
| def _extract_frame_bridge(self, idx, flip=False): | |
| ''' Extract the frame in video based on the needed fps from already extracted frame | |
| Args: | |
| idx (int): The index to the file in the directory | |
| flip (bool): Bool for whether we will flip | |
| Returns: | |
| video_frames (numpy): Extracted video frames in numpy format | |
| ''' | |
| # Init the the Video Reader | |
| # The naming of the Bridge dataset follow a pattern: im_x.jpg, so we need to | |
| video_frame_path = self.video_lists[idx] | |
| # Find needed file | |
| needed_img_path = [] | |
| for idx in range(self.video_seq_length): | |
| img_path = os.path.join(video_frame_path, "im_" + str(idx) + ".jpg") | |
| needed_img_path.append(img_path) | |
| # Read all img_path based on the order | |
| video_frames = [] | |
| for img_path in needed_img_path: | |
| if not os.path.exists(img_path): | |
| print("We don't have ", img_path) | |
| frame = cv2.imread(img_path) | |
| try: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
| except Exception: | |
| print("The exception place is ", img_path) | |
| # Resize frames | |
| frame = cv2.resize(frame, (self.width, self.height), interpolation = cv2.INTER_CUBIC) | |
| # Flip aug | |
| if flip: | |
| frame = np.fliplr(frame) | |
| # Collect frames | |
| video_frames.append(np.expand_dims(frame, axis=0)) # The frame is already RGB, there is no need to convert here. | |
| # Concatenate | |
| video_frames = np.concatenate(video_frames, axis=0) | |
| assert(len(video_frames) == self.video_seq_length) | |
| # Returns | |
| return video_frames | |
| def __getitem__(self, idx): | |
| ''' Get item by idx and pre-process by Resize and Normalize to [0, 1] | |
| Args: | |
| idx (int): The index to the file in the directory | |
| Returns: | |
| return_dict (dict): video_frames (torch.float32) [-1, 1] and controlnet_condition (torch.float32) [0, 1] | |
| ''' | |
| # Prepare the text if needed: | |
| if self.config["use_text"]: | |
| # Read the file | |
| file_path = os.path.join(self.video_lists[idx], "lang.txt") | |
| file = open(file_path, 'r') | |
| prompt = file.readlines()[0] # Only read the first line | |
| if self.config["mix_ambiguous"] and os.path.exists(os.path.join(self.video_lists[idx], "processed_text.txt")): | |
| # If we don't have this txt file, we skip | |
| ######################################################## Mix up prompt ######################################################## | |
| # Read the file | |
| file_path = os.path.join(self.video_lists[idx], "processed_text.txt") | |
| file = open(file_path, 'r') | |
| prompts = [line for line in file.readlines()] # Only read the first line | |
| # Get the componenet | |
| action = prompts[0][:-1] | |
| this = prompts[1][:-1] | |
| there = prompts[2][:-1] | |
| random_value = random.random() | |
| # If less than 0.4, we don't care, just use the most concrete one | |
| if random_value >= 0.4 and random_value < 0.6: | |
| # Mask pick object to "This" | |
| prompt = action + " this to " + there | |
| elif random_value >= 0.6 and random_value < 0.8: | |
| # Mask place position to "There" | |
| prompt = action + " " + this + " to there" | |
| elif random_value >= 0.8 and random_value < 1.0: | |
| # Just be like "this to there" | |
| prompt = action + " this to there" | |
| # print("New prompt is ", prompt) | |
| ################################################################################################################################################### | |
| # else: | |
| # print("We don't have llama processed prompt at ", self.video_lists[idx]) | |
| else: | |
| prompt = "" | |
| # Tokenize text prompt | |
| tokenized_prompt = tokenize_captions(prompt, self.tokenizer, self.config) | |
| # Dataset aug by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) | |
| flip = False | |
| if random.random() < self.config["flip_aug_prob"]: | |
| if self.config["use_text"]: | |
| if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) | |
| flip = True | |
| else: | |
| flip = True | |
| # Read frames for different dataset; Currently, we have WebVid / Bridge | |
| if self.config["dataset_name"] == "Bridge": | |
| video_frames_raw = self._extract_frame_bridge(idx, flip=flip) | |
| else: | |
| raise NotImplementedError("We don't support this dataset loader") | |
| # Scale [0, 255] -> [-1, 1] if needed | |
| if self.normalize: | |
| video_frames = video_frames_raw.astype(np.float32) / 127.5 - 1 # Be careful to cast to float32 | |
| # Transform to Pytorch Tensor in the range [-1, 1] | |
| video_frames = numpy_to_pt(video_frames) | |
| # Generate the pairs we need | |
| intput_dir = self.video_lists[idx] | |
| # Get the This That point information | |
| controlnet_condition, reflected_motion_bucket_id, controlnet_image_index, coordinate_values = get_thisthat_sam(self.config, intput_dir, flip=flip) | |
| controlnet_condition = torch.from_numpy(controlnet_condition) | |
| # Cast other value to tensor | |
| reflected_motion_bucket_id = torch.tensor(reflected_motion_bucket_id, dtype=torch.float32) | |
| controlnet_image_index = torch.tensor(controlnet_image_index, dtype=torch.int32) | |
| coordinate_values = torch.tensor(coordinate_values, dtype=torch.int32) | |
| # The tensor we returned is torch float32. We won't cast here for mixed precision training! | |
| return {"video_frames" : video_frames, | |
| "controlnet_condition" : controlnet_condition, | |
| "reflected_motion_bucket_id" : reflected_motion_bucket_id, | |
| "controlnet_image_index": controlnet_image_index, | |
| "prompt": tokenized_prompt, | |
| "coordinate_values": coordinate_values, # Useless now, but I still passed back | |
| } | |