Source code for ptlflow.data.datasets

"""Handle common datasets used in optical flow estimation."""

# =============================================================================
# Copyright 2021 Henrique Morimitsu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import logging
import math
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from ptlflow.utils import flow_utils
from ptlflow.utils.utils import config_logging

config_logging()

THIS_DIR = Path(__file__).resolve().parent


[docs] class BaseFlowDataset(Dataset): """Manage optical flow dataset loading. This class can be used as the parent for any concrete dataset. It is structured to be able to read most types of inputs used in optical flow estimation. Classes inheriting from this one should implement the __init__() method and properly load the input paths from the chosen dataset. This should be done by populating the lists defined in the attributes below. Attributes ---------- img_paths : list[list[str]] Paths of the images. Each element of the main list is a list of paths. Typically, the inner list will have two elements, corresponding to the paths of two consecutive images, which will be used to estimate the optical flow. More than two paths can also be added in case the model is able to use more images for estimating the flow. flow_paths : list[list[str]] Similar structure to img_paths. However, the inner list must have exactly one element less than img_paths. For example, if an entry of img_paths is composed of two paths, then an entry of flow_list should be a list with a single path, corresponding to the optical flow from the first image to the second. occ_paths : list[list[str]] Paths to the occlusion masks, follows the same structure as flow_paths. It can be left empty if not available. mb_paths : list[list[str]] Paths to the motion boundary masks, follows the same structure as flow_paths. It can be left empty if not available. flow_b_paths : list[list[str]] The same as flow_paths, but it corresponds to the backward flow. This list must be in the same order as flow_paths. For example, flow_b_paths[i] must be backward flow of flow_paths[i]. It can be left empty if backard flows are not available. occ_b_paths : list[list[str]] Backward occlusion mask paths, read occ_paths and flow_b_paths above. mb_b_paths : list[list[str]] Backward motion boundary mask paths, read mb_paths and flow_b_paths above. metadata : list[Any] Some metadata for each input. It can include anything. A good recommendation would be to put a dict with the metadata. """
[docs] def __init__( self, dataset_name: str, split_name: str = "", transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 10000.0, get_valid_mask: bool = True, get_occlusion_mask: bool = True, get_motion_boundary_mask: bool = True, get_backward: bool = True, get_meta: bool = True, ) -> None: """Initialize BaseFlowDataset. Parameters ---------- dataset_name : str A string representing the dataset name. It is just used to be stored as metadata, so it can have any value. split_name : str, optional A string representing the split of the data. It is just used to be stored as metadata, so it can have any value. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 10000.0 Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. get_occlusion_mask : bool, default True Whether to get occlusion masks. get_motion_boundary_mask : bool, default True Whether to get motion boundary masks. get_backward : bool, default True Whether to get the occluded version of the inputs. get_meta : bool, default True Whether to get metadata. """ self.dataset_name = dataset_name self.split_name = split_name self.transform = transform self.max_flow = max_flow self.get_valid_mask = get_valid_mask self.get_occlusion_mask = get_occlusion_mask self.get_motion_boundary_mask = get_motion_boundary_mask self.get_backward = get_backward self.get_meta = get_meta self.img_paths = [] self.flow_paths = [] self.occ_paths = [] self.mb_paths = [] self.flow_b_paths = [] self.occ_b_paths = [] self.mb_b_paths = [] self.metadata = []
def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 """Retrieve and return one input. Parameters ---------- index : int The index of the entry on the input lists. Returns ------- Dict[str, torch.Tensor] The retrieved input. This dict may contain the following keys, depending on the initialization choices: ['images', 'flows', 'mbs', 'occs', 'valids', 'flows_b', 'mbs_b', 'occs_b', 'valids_b', 'meta']. Except for 'meta', all the values are 4D tensors with shape NCHW. Notice that N does not correspond to the batch size, but rather to the number of images of a given key. For example, typically 'images' will have N=2, and 'flows' will have N=1, and so on. Therefore, a batch of these inputs will be a 5D tensor BNCHW. """ inputs = {} inputs["images"] = [cv2.imread(str(path)) for path in self.img_paths[index]] if index < len(self.flow_paths): inputs["flows"], valids = self._get_flows_and_valids(self.flow_paths[index]) if self.get_valid_mask: inputs["valids"] = valids if self.get_occlusion_mask: if index < len(self.occ_paths): inputs["occs"] = [ cv2.imread(str(path), 0)[:, :, None] for path in self.occ_paths[index] ] elif self.dataset_name.startswith("KITTI"): noc_paths = [ str(p).replace("flow_occ", "flow_noc") for p in self.flow_paths[index] ] _, valids_noc = self._get_flows_and_valids(noc_paths) inputs["occs"] = [valids[i] - valids_noc[i] for i in range(len(valids))] if self.get_motion_boundary_mask and index < len(self.mb_paths): inputs["mbs"] = [ cv2.imread(str(path), 0)[:, :, None] for path in self.mb_paths[index] ] if self.get_backward: if index < len(self.flow_b_paths): inputs["flows_b"], valids_b = self._get_flows_and_valids( self.flow_b_paths[index] ) if self.get_valid_mask: inputs["valids_b"] = valids_b if self.get_occlusion_mask and index < len(self.occ_b_paths): inputs["occs_b"] = [ cv2.imread(str(path), 0)[:, :, None] for path in self.occ_b_paths[index] ] if self.get_motion_boundary_mask and index < len(self.mb_b_paths): inputs["mbs_b"] = [ cv2.imread(str(path), 0)[:, :, None] for path in self.mb_b_paths[index] ] if self.transform is not None: inputs = self.transform(inputs) if self.get_meta: inputs["meta"] = { "dataset_name": self.dataset_name, "split_name": self.split_name, } if index < len(self.metadata): inputs["meta"].update(self.metadata[index]) return inputs def __len__(self) -> int: return len(self.img_paths) def _get_flows_and_valids( self, flow_paths: Sequence[str] ) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]: flows = [] valids = [] for path in flow_paths: flow = flow_utils.flow_read(path) nan_mask = np.isnan(flow) flow[nan_mask] = self.max_flow + 1 if self.get_valid_mask: valid = (np.abs(flow) < self.max_flow).astype(np.uint8) * 255 valid = np.minimum(valid[:, :, 0], valid[:, :, 1]) valids.append(valid[:, :, None]) flow[nan_mask] = 0 flow = np.clip(flow, -self.max_flow, self.max_flow) flows.append(flow) return flows, valids def _log_status(self) -> None: if self.__len__() == 0: logging.warning( "No samples were found for %s dataset. Be sure to update the dataset path in datasets.yml, " "or provide the path by the argument --[dataset_name]_root_dir.", self.dataset_name, ) else: logging.info( "Loading %d samples from %s dataset.", self.__len__(), self.dataset_name ) def _extend_paths_list( self, paths_list: List[Union[str, Path]], sequence_length: int, sequence_position: str, ): if sequence_position == "first": begin_pad = 0 end_pad = sequence_length - 2 elif sequence_position == "middle": begin_pad = sequence_length // 2 end_pad = int(math.ceil(sequence_length / 2.0)) - 2 elif sequence_position == "last": begin_pad = sequence_length - 2 end_pad = 0 else: raise ValueError( f"Invalid sequence_position. Must be one of ('first', 'middle', 'last'). Received: {sequence_position}" ) for _ in range(begin_pad): paths_list.insert(0, paths_list[0]) for _ in range(end_pad): paths_list.append(paths_list[-1]) return paths_list
[docs] class AutoFlowDataset(BaseFlowDataset): """Handle the AutoFlow dataset."""
[docs] def __init__( self, root_dir: str, split: str = "train", transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 10000.0, get_valid_mask: bool = True, get_meta: bool = True, ) -> None: """Initialize AutoFlowDataset. Parameters ---------- root_dir : str path to the root directory of the AutoFlow dataset. split : str, default 'train' Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval'}. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 10000.0 Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. get_meta : bool, default True Whether to get metadata. """ super().__init__( dataset_name="AutoFlow", split_name=split, transform=transform, max_flow=max_flow, get_valid_mask=get_valid_mask, get_occlusion_mask=False, get_motion_boundary_mask=False, get_backward=False, get_meta=get_meta, ) self.root_dir = root_dir self.split_file = THIS_DIR / "AutoFlow_val.txt" # Read data from disk parts_dirs = [f"static_40k_png_{i+1}_of_4" for i in range(4)] sample_paths = [] for pdir in parts_dirs: sample_paths.extend( [p for p in (Path(root_dir) / pdir).glob("*") if p.is_dir()] ) with open(self.split_file, "r") as f: val_names = f.read().strip().splitlines() if split == "trainval": remove_names = [] elif split == "train": remove_names = val_names elif split == "val": remove_names = [p.stem for p in sample_paths if p.stem not in val_names] # Keep only data from the correct split self.img_paths = [ [p / "im0.png", p / "im1.png"] for p in sample_paths if p.stem not in remove_names ] self.flow_paths = [ [p / "forward.flo"] for p in sample_paths if p.stem not in remove_names ] assert len(self.img_paths) == len( self.flow_paths ), f"{len(self.img_paths)} vs {len(self.flow_paths)}" self.metadata = [ { "image_paths": [str(p) for p in paths], "is_val": paths[0].stem in val_names, "misc": "", "is_seq_start": True, } for paths in self.img_paths ] self._log_status()
[docs] class FlyingChairsDataset(BaseFlowDataset): """Handle the FlyingChairs dataset."""
[docs] def __init__( self, root_dir: str, split: str = "train", transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 10000.0, get_valid_mask: bool = True, get_meta: bool = True, ) -> None: """Initialize FlyingChairsDataset. Parameters ---------- root_dir : str path to the root directory of the FlyingChairs dataset. split : str, default 'train' Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval'}. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 10000.0 Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. get_meta : bool, default True Whether to get metadata. """ super().__init__( dataset_name="FlyingChairs", split_name=split, transform=transform, max_flow=max_flow, get_valid_mask=get_valid_mask, get_occlusion_mask=False, get_motion_boundary_mask=False, get_backward=False, get_meta=get_meta, ) self.root_dir = root_dir self.split_file = THIS_DIR / "FlyingChairs_val.txt" # Read data from disk img1_paths = sorted((Path(self.root_dir) / "data").glob("*img1.ppm")) img2_paths = sorted((Path(self.root_dir) / "data").glob("*img2.ppm")) flow_paths = sorted((Path(self.root_dir) / "data").glob("*flow.flo")) # Sanity check assert len(img1_paths) == len( img2_paths ), f"{len(img1_paths)} vs {len(img2_paths)}" assert len(img1_paths) == len( flow_paths ), f"{len(img1_paths)} vs {len(flow_paths)}" with open(self.split_file, "r") as f: val_names = f.read().strip().splitlines() if split == "trainval": remove_names = [] elif split == "train": remove_names = val_names elif split == "val": remove_names = [ p.stem.split("_")[0] for p in img1_paths if p.stem.split("_")[0] not in val_names ] # Keep only data from the correct split self.img_paths = [ [img1_paths[i], img2_paths[i]] for i in range(len(img1_paths)) if img1_paths[i].stem.split("_")[0] not in remove_names ] self.flow_paths = [ [flow_paths[i]] for i in range(len(flow_paths)) if flow_paths[i].stem.split("_")[0] not in remove_names ] assert len(self.img_paths) == len( self.flow_paths ), f"{len(self.img_paths)} vs {len(self.flow_paths)}" self.metadata = [ { "image_paths": [str(p) for p in paths], "is_val": paths[0].stem in val_names, "misc": "", "is_seq_start": True, } for paths in self.img_paths ] self._log_status()
[docs] class FlyingChairs2Dataset(BaseFlowDataset): """Handle the FlyingChairs 2 dataset."""
[docs] def __init__( self, root_dir: str, split: str = "train", add_reverse: bool = True, transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 1000.0, get_valid_mask: bool = True, get_occlusion_mask: bool = True, get_motion_boundary_mask: bool = True, get_backward: bool = True, get_meta: bool = True, ) -> None: """Initialize FlyingChairs2Dataset. Parameters ---------- root_dir : str path to the root directory of the FlyingChairs2 dataset. split : str, default 'train' Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval'}. add_reverse : bool, default True If True, double the number of samples by appending the backward samples as additional samples. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 10000.0 Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. get_occlusion_mask : bool, default True Whether to get occlusion masks. get_motion_boundary_mask : bool, default True Whether to get motion boundary masks. get_backward : bool, default True Whether to get the occluded version of the inputs. get_meta : bool, default True Whether to get metadata. """ super().__init__( dataset_name="FlyingChairs2", split_name=split, transform=transform, max_flow=max_flow, get_valid_mask=get_valid_mask, get_occlusion_mask=get_occlusion_mask, get_motion_boundary_mask=get_motion_boundary_mask, get_backward=get_backward, get_meta=get_meta, ) self.root_dir = root_dir self.add_reverse = add_reverse if split == "train": dir_names = ["train"] elif split == "val": dir_names = ["val"] else: dir_names = ["train", "val"] for dname in dir_names: # Read data from disk img1_paths = sorted((Path(self.root_dir) / dname).glob("*img_0.png")) img2_paths = sorted((Path(self.root_dir) / dname).glob("*img_1.png")) self.img_paths.extend( [[img1_paths[i], img2_paths[i]] for i in range(len(img1_paths))] ) self.flow_paths.extend( [ [x] for x in sorted((Path(self.root_dir) / dname).glob("*flow_01.flo")) ] ) self.occ_paths.extend( [[x] for x in sorted((Path(self.root_dir) / dname).glob("*occ_01.png"))] ) self.mb_paths.extend( [[x] for x in sorted((Path(self.root_dir) / dname).glob("*mb_01.png"))] ) if self.get_backward: self.flow_b_paths.extend( [ [x] for x in sorted( (Path(self.root_dir) / dname).glob("*flow_10.flo") ) ] ) self.occ_b_paths.extend( [ [x] for x in sorted( (Path(self.root_dir) / dname).glob("*occ_10.png") ) ] ) self.mb_b_paths.extend( [ [x] for x in sorted( (Path(self.root_dir) / dname).glob("*mb_10.png") ) ] ) if self.add_reverse: self.img_paths.extend( [[img2_paths[i], img1_paths[i]] for i in range(len(img1_paths))] ) self.flow_paths.extend( [ [x] for x in sorted( (Path(self.root_dir) / dname).glob("*flow_10.flo") ) ] ) self.occ_paths.extend( [ [x] for x in sorted( (Path(self.root_dir) / dname).glob("*occ_10.png") ) ] ) self.mb_paths.extend( [ [x] for x in sorted( (Path(self.root_dir) / dname).glob("*mb_10.png") ) ] ) if self.get_backward: self.flow_b_paths.extend( [ [x] for x in sorted( (Path(self.root_dir) / dname).glob("*flow_01.flo") ) ] ) self.occ_b_paths.extend( [ [x] for x in sorted( (Path(self.root_dir) / dname).glob("*occ_01.png") ) ] ) self.mb_b_paths.extend( [ [x] for x in sorted( (Path(self.root_dir) / dname).glob("*mb_01.png") ) ] ) self.metadata = [ { "image_paths": [str(p) for p in paths], "is_val": False, "misc": "", "is_seq_start": True, } for paths in self.img_paths ] # Sanity check assert len(img1_paths) == len( img2_paths ), f"{len(img1_paths)} vs {len(img2_paths)}" assert len(self.img_paths) == len( self.flow_paths ), f"{len(self.img_paths)} vs {len(self.flow_paths)}" assert len(self.img_paths) == len( self.occ_paths ), f"{len(self.img_paths)} vs {len(self.occ_paths)}" assert len(self.img_paths) == len( self.mb_paths ), f"{len(self.img_paths)} vs {len(self.mb_paths)}" if self.get_backward: assert len(self.img_paths) == len( self.flow_b_paths ), f"{len(self.img_paths)} vs {len(self.flow_b_paths)}" assert len(self.img_paths) == len( self.occ_b_paths ), f"{len(self.img_paths)} vs {len(self.occ_b_paths)}" assert len(self.img_paths) == len( self.mb_b_paths ), f"{len(self.img_paths)} vs {len(self.mb_b_paths)}" self._log_status()
[docs] class FlyingThings3DDataset(BaseFlowDataset): """Handle the FlyingThings3D dataset. Note that this only works for the complete FlyingThings3D dataset. For the subset version, use FlyingThings3DSubsetDataset. """
[docs] def __init__( # noqa: C901 self, root_dir: str, split: str = "train", pass_names: Union[str, List[str]] = "clean", side_names: Union[str, List[str]] = "left", add_reverse: bool = True, transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 1000.0, get_valid_mask: bool = True, get_occlusion_mask: bool = True, get_motion_boundary_mask: bool = True, get_backward: bool = True, get_meta: bool = True, sequence_length: int = 2, sequence_position: str = "first", ) -> None: """Initialize FlyingThings3DDataset. Parameters ---------- root_dir : str path to the root directory of the FlyingThings3D dataset. split : str, default 'train' Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval'}. pass_names : Union[str, List[str]], default 'clean' Which passes should be loaded. It can be one of {'clean', 'final', ['clean', 'final']}. side_names : Union[str, List[str]], default 'left' Samples from which side view should be loaded. It can be one of {'left', 'right', ['left', 'right']}. add_reverse : bool, default True If True, double the number of samples by appending the backward samples as additional samples. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 10000.0 Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. get_occlusion_mask : bool, default True Whether to get occlusion masks. get_motion_boundary_mask : bool, default True Whether to get motion boundary masks. get_backward : bool, default True Whether to get the backward version of the inputs. get_meta : bool, default True Whether to get metadata. sequence_length : int, default 2 How many consecutive images are loaded per sample. More than two images can be used for model which exploit more temporal information. sequence_position : str, default "first" Only used when sequence_length > 2. Determines the position where the main image frame will be in the sequence. It can one of three values: - "first": the main frame will be the first one of the sequence, - "middle": the main frame will be in the middle of the sequence (at position sequence_length // 2), - "last": the main frame will be the penultimate in the sequence. """ super().__init__( dataset_name="FlyingThings3D", split_name=split, transform=transform, max_flow=max_flow, get_valid_mask=get_valid_mask, get_occlusion_mask=get_occlusion_mask, get_motion_boundary_mask=get_motion_boundary_mask, get_backward=get_backward, get_meta=get_meta, ) self.root_dir = root_dir self.add_reverse = add_reverse self.pass_names = pass_names self.sequence_length = sequence_length self.sequence_position = sequence_position if isinstance(self.pass_names, str): self.pass_names = [self.pass_names] self.side_names = side_names if isinstance(self.side_names, str): self.side_names = [self.side_names] if split == "val": split_dir_names = ["TEST"] elif split == "train": split_dir_names = ["TRAIN"] else: split_dir_names = ["TRAIN", "TEST"] pass_dirs = [f"frames_{p}pass" for p in self.pass_names] directions = [("into_future", "into_past")] reverts = [False] if self.add_reverse: directions.append(("into_past", "into_future")) reverts.append(True) # Read paths from disk for passd in pass_dirs: for split in split_dir_names: split_path = Path(self.root_dir) / passd / split for letter_path in split_path.glob("*"): for seq_path in letter_path.glob("*"): for direcs, rev in zip(directions, reverts): for side in self.side_names: image_paths = sorted( (seq_path / side).glob("*.png"), reverse=rev ) image_paths = self._extend_paths_list( image_paths, sequence_length, sequence_position ) flow_paths = sorted( ( Path( str(seq_path).replace(passd, "optical_flow") ) / direcs[0] / side ).glob("*.pfm"), reverse=rev, ) flow_paths = self._extend_paths_list( flow_paths, sequence_length, sequence_position ) occ_paths = [] if (Path(self.root_dir) / "occlusions").exists(): occ_paths = sorted( ( Path( str(seq_path).replace( passd, "occlusions" ) ) / direcs[0] / side ).glob("*.png"), reverse=rev, ) occ_paths = self._extend_paths_list( occ_paths, sequence_length, sequence_position ) mb_paths = [] if (Path(self.root_dir) / "motion_boundaries").exists(): mb_paths = sorted( ( Path( str(seq_path).replace( passd, "motion_boundaries" ) ) / direcs[0] / side ).glob("*.png"), reverse=rev, ) mb_paths = self._extend_paths_list( mb_paths, sequence_length, sequence_position ) flow_b_paths = [] occ_b_paths = [] mb_b_paths = [] if self.get_backward: flow_b_paths = sorted( ( Path( str(seq_path).replace( passd, "optical_flow" ) ) / direcs[1] / side ).glob("*.pfm"), reverse=rev, ) flow_b_paths = self._extend_paths_list( flow_b_paths, sequence_length, sequence_position ) if (Path(self.root_dir) / "occlusions").exists(): occ_b_paths = sorted( ( Path( str(seq_path).replace( passd, "occlusions" ) ) / direcs[1] / side ).glob("*.png"), reverse=rev, ) occ_b_paths = self._extend_paths_list( occ_b_paths, sequence_length, sequence_position, ) if ( Path(self.root_dir) / "motion_boundaries" ).exists(): mb_b_paths = sorted( ( Path( str(seq_path).replace( passd, "motion_boundaries" ) ) / direcs[1] / side ).glob("*.png"), reverse=rev, ) mb_b_paths = self._extend_paths_list( mb_b_paths, sequence_length, sequence_position, ) for i in range( len(image_paths) - self.sequence_length + 1 ): self.img_paths.append( image_paths[i : i + self.sequence_length] ) if len(flow_paths) > 0: self.flow_paths.append( flow_paths[i : i + self.sequence_length - 1] ) if len(occ_paths) > 0: self.occ_paths.append( occ_paths[i : i + self.sequence_length - 1] ) if len(mb_paths) > 0: self.mb_paths.append( mb_paths[i : i + self.sequence_length - 1] ) self.metadata.append( { "image_paths": [ str(p) for p in image_paths[ i : i + self.sequence_length ] ], "is_val": False, "misc": "", "is_seq_start": i == 0, } ) if self.get_backward: if len(flow_b_paths) > 0: self.flow_b_paths.append( flow_b_paths[ i + 1 : i + self.sequence_length ] ) if len(occ_b_paths) > 0: self.occ_b_paths.append( occ_b_paths[ i + 1 : i + self.sequence_length ] ) if len(mb_b_paths) > 0: self.mb_b_paths.append( mb_b_paths[ i + 1 : i + self.sequence_length ] ) assert len(self.img_paths) == len( self.flow_paths ), f"{len(self.img_paths)} vs {len(self.flow_paths)}" assert len(self.occ_paths) == 0 or len(self.img_paths) == len( self.occ_paths ), f"{len(self.img_paths)} vs {len(self.occ_paths)}" assert len(self.mb_paths) == 0 or len(self.img_paths) == len( self.mb_paths ), f"{len(self.img_paths)} vs {len(self.mb_paths)}" if self.get_backward: assert len(self.img_paths) == len( self.flow_b_paths ), f"{len(self.img_paths)} vs {len(self.flow_b_paths)}" assert len(self.occ_b_paths) == 0 or len(self.img_paths) == len( self.occ_b_paths ), f"{len(self.img_paths)} vs {len(self.occ_b_paths)}" assert len(self.mb_b_paths) == 0 or len(self.img_paths) == len( self.mb_b_paths ), f"{len(self.img_paths)} vs {len(self.mb_b_paths)}" self._log_status()
[docs] class FlyingThings3DSubsetDataset(BaseFlowDataset): """Handle the FlyingThings3D subset dataset. Note that this only works for the FlyingThings3D subset dataset. For the complete version, use FlyingThings3DDataset. """
[docs] def __init__( # noqa: C901 self, root_dir: str, split: str = "train", pass_names: Union[str, List[str]] = "clean", side_names: Union[str, List[str]] = "left", add_reverse: bool = True, transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 1000.0, get_valid_mask: bool = True, get_occlusion_mask: bool = True, get_motion_boundary_mask: bool = True, get_backward: bool = True, get_meta: bool = True, sequence_length: int = 2, sequence_position: str = "first", ) -> None: """Initialize FlyingThings3DSubsetDataset. Parameters ---------- root_dir : str path to the root directory of the FlyingThings3D dataset. split : str, default 'train' Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval'}. pass_names : Union[str, List[str]], default 'clean' Which passes should be loaded. It can be one of {'clean', 'final', ['clean', 'final']}. side_names : Union[str, List[str]], default 'left' Samples from which side view should be loaded. It can be one of {'left', 'right', ['left', 'right']}. add_reverse : bool, default True If True, double the number of samples by appending the backward samples as additional samples. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 10000.0 Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. get_occlusion_mask : bool, default True Whether to get occlusion masks. get_motion_boundary_mask : bool, default True Whether to get motion boundary masks. get_backward : bool, default True Whether to get the occluded version of the inputs. get_meta : bool, default True Whether to get metadata. sequence_length : int, default 2 How many consecutive images are loaded per sample. More than two images can be used for model which exploit more temporal information. sequence_position : str, default "first" Only used when sequence_length > 2. Determines the position where the main image frame will be in the sequence. It can one of three values: - "first": the main frame will be the first one of the sequence, - "middle": the main frame will be in the middle of the sequence (at position sequence_length // 2), - "last": the main frame will be the penultimate in the sequence. """ super().__init__( dataset_name="FlyingThings3DSubset", split_name=split, transform=transform, max_flow=max_flow, get_valid_mask=get_valid_mask, get_occlusion_mask=get_occlusion_mask, get_motion_boundary_mask=get_motion_boundary_mask, get_backward=get_backward, get_meta=get_meta, ) self.root_dir = root_dir self.add_reverse = add_reverse self.pass_names = pass_names self.sequence_length = sequence_length self.sequence_position = sequence_position if isinstance(self.pass_names, str): self.pass_names = [self.pass_names] self.side_names = side_names if isinstance(self.side_names, str): self.side_names = [self.side_names] if split == "train" or split == "val": split_dir_names = [split] else: split_dir_names = ["train", "val"] directions = [("into_future", "into_past")] reverts = [False] if self.add_reverse: directions.append(("into_past", "into_future")) reverts.append(True) # Read paths from disk for split in split_dir_names: for pass_name in self.pass_names: for side in self.side_names: for direcs, rev in zip(directions, reverts): flow_dir = ( Path(self.root_dir) / split / "flow" / side / direcs[0] ) flow_paths = sorted(flow_dir.glob("*.flo"), reverse=rev) # Create groups to separate different sequences flow_groups_paths = [[flow_paths[0]]] prev_idx = int(flow_paths[0].stem) for path in flow_paths[1:]: idx = int(path.stem) if (idx - 1) == prev_idx: flow_groups_paths[-1].append(path) else: flow_groups_paths.append([path]) prev_idx = idx for flow_group in flow_groups_paths: flow_group = self._extend_paths_list( flow_group, sequence_length, sequence_position ) for i in range(len(flow_group) - self.sequence_length + 2): flow_paths = flow_group[ i : i + self.sequence_length - 1 ] self.flow_paths.append(flow_paths) img_dir = ( Path(self.root_dir) / split / f"image_{pass_name}" / side ) img_paths = [ img_dir / (fp.stem + ".png") for fp in flow_paths ] if rev: idx = int(img_paths[0].stem) - 1 else: idx = int(img_paths[-1].stem) + 1 img_paths.append(img_dir / f"{idx:07d}.png") self.img_paths.append(img_paths) if ( Path(self.root_dir) / split / "flow_occlusions" ).exists(): occ_paths = [ Path( str(fp) .replace("flow", "flow_occlusions") .replace(".flo", ".png") ) for fp in flow_paths ] self.occ_paths.append(occ_paths) if ( Path(self.root_dir) / split / "motion_boundaries" ).exists(): mb_paths = [ Path( str(fp) .replace("flow", "motion_boundaries") .replace(".flo", ".png") ) for fp in flow_paths ] self.mb_paths.append(mb_paths) self.metadata.append( { "image_paths": [str(p) for p in img_paths], "is_val": False, "misc": "", "is_seq_start": i == 0, } ) if self.get_backward: flow_dir = ( Path(self.root_dir) / split / "flow" / side / direcs[1] ) flow_paths = sorted(flow_dir.glob("*.flo"), reverse=rev) # Create groups to separate different sequences flow_groups_paths = [[flow_paths[0]]] prev_idx = int(flow_paths[0].stem) for path in flow_paths[1:]: idx = int(path.stem) if (idx - 1) == prev_idx: flow_groups_paths[-1].append(path) else: flow_groups_paths.append([path]) prev_idx = idx for flow_group in flow_groups_paths: flow_group = self._extend_paths_list( flow_group, sequence_length, sequence_position ) for i in range( len(flow_group) - self.sequence_length + 2 ): flow_paths = flow_group[ i : i + self.sequence_length - 1 ] self.flow_b_paths.append(flow_paths) if ( Path(self.root_dir) / split / "flow_occlusions" ).exists(): occ_paths = [ Path( str(fp) .replace("flow", "flow_occlusions") .replace(".flo", ".png") ) for fp in flow_paths ] self.occ_b_paths.append(occ_paths) if ( Path(self.root_dir) / split / "motion_boundaries" ).exists(): mb_paths = [ Path( str(fp) .replace("flow", "motion_boundaries") .replace(".flo", ".png") ) for fp in flow_paths ] self.mb_b_paths.append(mb_paths) assert len(self.img_paths) == len( self.flow_paths ), f"{len(self.img_paths)} vs {len(self.flow_paths)}" assert len(self.occ_paths) == 0 or len(self.img_paths) == len( self.occ_paths ), f"{len(self.img_paths)} vs {len(self.occ_paths)}" assert len(self.mb_paths) == 0 or len(self.img_paths) == len( self.mb_paths ), f"{len(self.img_paths)} vs {len(self.mb_paths)}" if self.get_backward: assert len(self.img_paths) == len( self.flow_b_paths ), f"{len(self.img_paths)} vs {len(self.flow_b_paths)}" assert len(self.occ_b_paths) == 0 or len(self.img_paths) == len( self.occ_b_paths ), f"{len(self.img_paths)} vs {len(self.occ_b_paths)}" assert len(self.mb_b_paths) == 0 or len(self.img_paths) == len( self.mb_b_paths ), f"{len(self.img_paths)} vs {len(self.mb_b_paths)}" self._log_status()
[docs] class Hd1kDataset(BaseFlowDataset): """Handle the HD1K dataset."""
[docs] def __init__( # noqa: C901 self, root_dir: str, split: str = "train", transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 512.0, get_valid_mask: bool = True, get_meta: bool = True, sequence_length: int = 2, sequence_position: str = "first", ) -> None: """Initialize Hd1kDataset. Parameters ---------- root_dir : str path to the root directory of the HD1K dataset. split : str, default 'train' Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval', 'test'}. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 512.0 Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. get_meta : bool, default True Whether to get metadata. sequence_length : int, default 2 How many consecutive images are loaded per sample. More than two images can be used for model which exploit more temporal information. sequence_position : str, default "first" Only used when sequence_length > 2. Determines the position where the main image frame will be in the sequence. It can one of three values: - "first": the main frame will be the first one of the sequence, - "middle": the main frame will be in the middle of the sequence (at position sequence_length // 2), - "last": the main frame will be the penultimate in the sequence. """ super().__init__( dataset_name="HD1K", split_name=split, transform=transform, max_flow=max_flow, get_valid_mask=get_valid_mask, get_occlusion_mask=False, get_motion_boundary_mask=False, get_backward=False, get_meta=get_meta, ) self.root_dir = root_dir self.split = split self.sequence_length = sequence_length self.sequence_position = sequence_position if split == "test": split_dir = "hd1k_challenge" else: split_dir = "hd1k_input" img_paths = sorted((Path(root_dir) / split_dir / "image_2").glob("*.png")) img_names = [p.stem for p in img_paths] # Group paths by sequence img_names_grouped = {} for n in img_names: seq_name = n.split("_")[0] if img_names_grouped.get(seq_name) is None: img_names_grouped[seq_name] = [] img_names_grouped[seq_name].append(n) val_names = [] split_file = THIS_DIR / "Hd1k_val.txt" with open(split_file, "r") as f: val_names = f.read().strip().splitlines() # Remove names that do not belong to the chosen split for seq_name, seq_img_names in img_names_grouped.items(): if split == "train": img_names_grouped[seq_name] = [ n for n in seq_img_names if n not in val_names ] elif split == "val": img_names_grouped[seq_name] = [ n for n in seq_img_names if n in val_names ] for seq_img_names in img_names_grouped.values(): seq_img_names = self._extend_paths_list( seq_img_names, sequence_length, sequence_position ) for i in range(len(seq_img_names) - self.sequence_length + 1): self.img_paths.append( [ Path(root_dir) / split_dir / "image_2" / (n + ".png") for n in seq_img_names[i : i + self.sequence_length] ] ) if split != "test": self.flow_paths.append( [ Path(root_dir) / "hd1k_flow_gt" / "flow_occ" / (n + ".png") for n in seq_img_names[i : i + self.sequence_length - 1] ] ) self.metadata.append( { "image_paths": [str(p) for p in self.img_paths[-1]], "is_val": (seq_img_names[i] in val_names), "misc": "", "is_seq_start": True, } ) if split != "test": assert len(self.img_paths) == len( self.flow_paths ), f"{len(self.img_paths)} vs {len(self.flow_paths)}" self._log_status()
[docs] class KittiDataset(BaseFlowDataset): """Handle the KITTI dataset."""
[docs] def __init__( # noqa: C901 self, root_dir_2012: Optional[str] = None, root_dir_2015: Optional[str] = None, split: str = "train", versions: Union[str, List[str]] = "2015", transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 512.0, get_valid_mask: bool = True, get_occlusion_mask: bool = False, get_meta: bool = True, ) -> None: """Initialize KittiDataset. Parameters ---------- root_dir_2012 : str, optional. Path to the root directory of the KITTI 2012 dataset, if available. root_dir_2015 : str, optional. Path to the root directory of the KITTI 2015 dataset, if available. split : str, default 'train' Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval', 'test'}. versions : Union[str, List[str]], default '2015' Which version should be loaded. It can be one of {'2012', '2015', ['2012', '2015']}. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 512.0 Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. get_occlusion_mask : bool, default True Whether to get occlusion masks. get_meta : bool, default True Whether to get metadata. """ if isinstance(versions, str): versions = [versions] super().__init__( dataset_name=f'KITTI_{"_".join(versions)}', split_name=split, transform=transform, max_flow=max_flow, get_valid_mask=get_valid_mask, get_occlusion_mask=get_occlusion_mask, get_motion_boundary_mask=False, get_backward=False, get_meta=get_meta, ) self.root_dir = {"2012": root_dir_2012, "2015": root_dir_2015} self.versions = versions self.split = split if split == "test": split_dir = "testing" else: split_dir = "training" for ver in versions: if self.root_dir[ver] is None: continue if ver == "2012": image_dir = "colored_0" else: image_dir = "image_2" img1_paths = sorted( (Path(self.root_dir[ver]) / split_dir / image_dir).glob("*_10.png") ) img2_paths = sorted( (Path(self.root_dir[ver]) / split_dir / image_dir).glob("*_11.png") ) assert len(img1_paths) == len( img2_paths ), f"{len(img1_paths)} vs {len(img2_paths)}" flow_paths = [] if split != "test": flow_paths = sorted( (Path(self.root_dir[ver]) / split_dir / "flow_occ").glob("*_10.png") ) assert len(img1_paths) == len( flow_paths ), f"{len(img1_paths)} vs {len(flow_paths)}" split_file = THIS_DIR / f"Kitti{ver}_val.txt" with open(split_file, "r") as f: val_names = f.read().strip().splitlines() if split == "trainval" or split == "test": remove_names = [] elif split == "train": remove_names = val_names elif split == "val": remove_names = [p.stem for p in img1_paths if p.stem not in val_names] self.img_paths.extend( [ [img1_paths[i], img2_paths[i]] for i in range(len(img1_paths)) if img1_paths[i].stem not in remove_names ] ) if split != "test": self.flow_paths.extend( [ [flow_paths[i]] for i in range(len(flow_paths)) if flow_paths[i].stem not in remove_names ] ) self.metadata.extend( [ { "image_paths": [str(img1_paths[i]), str(img2_paths[i])], "is_val": img1_paths[i].stem in val_names, "misc": ver, "is_seq_start": True, } for i in range(len(img1_paths)) if img1_paths[i].stem not in remove_names ] ) if split != "test": assert len(self.img_paths) == len( self.flow_paths ), f"{len(self.img_paths)} vs {len(self.flow_paths)}" self._log_status()
[docs] class SintelDataset(BaseFlowDataset): """Handle the MPI Sintel dataset."""
[docs] def __init__( # noqa: C901 self, root_dir: str, split: str = "train", pass_names: Union[str, List[str]] = "clean", transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 10000.0, get_valid_mask: bool = True, get_occlusion_mask: bool = True, get_meta: bool = True, sequence_length: int = 2, sequence_position: str = "first", ) -> None: """Initialize SintelDataset. Parameters ---------- root_dir : str path to the root directory of the MPI Sintel dataset. split : str, default 'train' Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval', 'test'}. pass_names : Union[str, List[str]], default 'clean' Which passes should be loaded. It can be one of {'clean', 'final', ['clean', 'final']}. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 10000.0 Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. get_occlusion_mask : bool, default True Whether to get occlusion masks. get_meta : bool, default True Whether to get metadata. sequence_length : int, default 2 How many consecutive images are loaded per sample. More than two images can be used for model which exploit more temporal information. sequence_position : str, default "first" Only used when sequence_length > 2. Determines the position where the main image frame will be in the sequence. It can one of three values: - "first": the main frame will be the first one of the sequence, - "middle": the main frame will be in the middle of the sequence (at position sequence_length // 2), - "last": the main frame will be the penultimate in the sequence. """ if isinstance(pass_names, str): pass_names = [pass_names] super().__init__( dataset_name=f'Sintel_{"_".join(pass_names)}', split_name=split, transform=transform, max_flow=max_flow, get_valid_mask=get_valid_mask, get_occlusion_mask=get_occlusion_mask, get_motion_boundary_mask=False, get_backward=False, get_meta=get_meta, ) self.root_dir = root_dir self.split = split self.pass_names = pass_names self.sequence_length = sequence_length self.sequence_position = sequence_position # Get sequence names for the given split if split == "test": split_dir = "test" else: split_dir = "training" split_file = THIS_DIR / "Sintel_val.txt" with open(split_file, "r") as f: val_seqs = f.read().strip().splitlines() sequence_names = sorted( [p.stem for p in (Path(root_dir) / split_dir / "clean").glob("*")] ) if split == "train" or split == "val": if split == "train": sequence_names = [s for s in sequence_names if s not in val_seqs] else: sequence_names = val_seqs # Read paths from disk for passd in pass_names: for seq_name in sequence_names: image_paths = sorted( (Path(self.root_dir) / split_dir / passd / seq_name).glob("*.png") ) image_paths = self._extend_paths_list( image_paths, sequence_length, sequence_position ) flow_paths = [] occ_paths = [] if split != "test": flow_paths = sorted( (Path(self.root_dir) / split_dir / "flow" / seq_name).glob( "*.flo" ) ) flow_paths = self._extend_paths_list( flow_paths, sequence_length, sequence_position ) assert len(image_paths) - 1 == len( flow_paths ), f"{passd}, {seq_name}: {len(image_paths)-1} vs {len(flow_paths)}" if (Path(self.root_dir) / split_dir / "occlusions").exists(): occ_paths = sorted( ( Path(self.root_dir) / split_dir / "occlusions" / seq_name ).glob("*.png") ) occ_paths = self._extend_paths_list( occ_paths, sequence_length, sequence_position ) assert len(occ_paths) == len(flow_paths) for i in range(len(image_paths) - self.sequence_length + 1): self.img_paths.append(image_paths[i : i + self.sequence_length]) if len(flow_paths) > 0: self.flow_paths.append( flow_paths[i : i + self.sequence_length - 1] ) if len(occ_paths) > 0: self.occ_paths.append( occ_paths[i : i + self.sequence_length - 1] ) self.metadata.append( { "image_paths": [ str(p) for p in image_paths[i : i + self.sequence_length] ], "is_val": seq_name in val_seqs, "misc": seq_name, "is_seq_start": i == 0, } ) # Sanity check if split != "test": assert len(self.img_paths) == len( self.flow_paths ), f"{len(self.img_paths)} vs {len(self.flow_paths)}" if len(self.occ_paths) > 0: assert len(self.img_paths) == len( self.occ_paths ), f"{len(self.img_paths)} vs {len(self.occ_paths)}" self._log_status()
[docs] class SpringDataset(BaseFlowDataset): """Handle the Spring dataset."""
[docs] def __init__( # noqa: C901 self, root_dir: str, split: str = "train", side_names: Union[str, List[str]] = "left", add_reverse: bool = True, transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 10000.0, get_valid_mask: bool = True, get_backward: bool = False, get_meta: bool = True, sequence_length: int = 2, sequence_position: str = "first", reverse_only: bool = False, ) -> None: """Initialize SintelDataset. Parameters ---------- root_dir : str path to the root directory of the MPI Sintel dataset. split : str, default 'train' Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval', 'test'}. side_names : Union[str, List[str]], default 'left' Samples from which side view should be loaded. It can be one of {'left', 'right', ['left', 'right']}. add_reverse : bool, default True If True, double the number of samples by appending the backward samples as additional samples. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 10000.0 Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. get_backward : bool, default True Whether to get the backward version of the inputs. get_meta : bool, default True Whether to get metadata. sequence_length : int, default 2 How many consecutive images are loaded per sample. More than two images can be used for model which exploit more temporal information. sequence_position : str, default "first" Only used when sequence_length > 2. Determines the position where the main image frame will be in the sequence. It can one of three values: - "first": the main frame will be the first one of the sequence, - "middle": the main frame will be in the middle of the sequence (at position sequence_length // 2), - "last": the main frame will be the penultimate in the sequence. reverse_only : bool, default False If True, only uses the backward samples, discarding the forward ones. """ if isinstance(side_names, str): side_names = [side_names] super().__init__( dataset_name="Spring", split_name=split, transform=transform, max_flow=max_flow, get_valid_mask=get_valid_mask, get_occlusion_mask=False, get_motion_boundary_mask=False, get_backward=get_backward, get_meta=get_meta, ) self.root_dir = root_dir self.split = split self.side_names = side_names self.sequence_length = sequence_length self.sequence_position = sequence_position # Get sequence names for the given split if split == "test": split_dir = "test" else: split_dir = "train" sequence_names = sorted( [p.stem for p in (Path(root_dir) / split_dir).glob("*")] ) if reverse_only: directions = [("BW", "FW")] else: directions = [("FW", "BW")] if add_reverse: directions.append(("BW", "FW")) # Read paths from disk for seq_name in sequence_names: for side in side_names: for direcs in directions: rev = direcs[0] == "BW" image_paths = sorted( ( Path(self.root_dir) / split_dir / seq_name / f"frame_{side}" ).glob("*.png"), reverse=rev, ) image_paths = self._extend_paths_list( image_paths, sequence_length, sequence_position ) flow_paths = [] flow_b_paths = [] if split != "test": flow_paths = sorted( ( Path(self.root_dir) / split_dir / seq_name / f"flow_{direcs[0]}_{side}" ).glob("*.flo5"), reverse=rev, ) flow_paths = self._extend_paths_list( flow_paths, sequence_length, sequence_position ) assert len(image_paths) - 1 == len( flow_paths ), f"{seq_name}, {side}: {len(image_paths)-1} vs {len(flow_paths)}" if self.get_backward: flow_b_paths = sorted( ( Path(self.root_dir) / split_dir / seq_name / f"flow_{direcs[1]}_{side}" ).glob("*.flo5"), reverse=rev, ) flow_b_paths = self._extend_paths_list( flow_b_paths, sequence_length, sequence_position ) assert len(image_paths) - 1 == len( flow_paths ), f"{seq_name}, {side}: {len(image_paths)-1} vs {len(flow_paths)}" for i in range(len(image_paths) - self.sequence_length + 1): self.img_paths.append(image_paths[i : i + self.sequence_length]) if len(flow_paths) > 0: self.flow_paths.append( flow_paths[i : i + self.sequence_length - 1] ) if self.get_backward and len(flow_b_paths) > 0: self.flow_b_paths.append( flow_b_paths[i : i + self.sequence_length - 1] ) self.metadata.append( { "image_paths": [ str(p) for p in image_paths[i : i + self.sequence_length] ], "is_val": False, "misc": seq_name, "is_seq_start": i == 0, } ) # Sanity check if split != "test": assert len(self.img_paths) == len( self.flow_paths ), f"{len(self.img_paths)} vs {len(self.flow_paths)}" self._log_status()
[docs] class MiddleburyDataset(BaseFlowDataset): """Handle the Middlebury dataset."""
[docs] def __init__( # noqa: C901 self, root_dir: str, split: str = "train", transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 10000.0, get_valid_mask: bool = True, get_meta: bool = True, ) -> None: """Initialize MiddleburyDataset. Parameters ---------- root_dir : str path to the root directory of the Middlebury dataset. split : str, default 'train' Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval', 'test'}. pass_names : Union[str, List[str]], default 'clean' Which passes should be loaded. It can be one of {'clean', 'final', ['clean', 'final']}. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 10000.0 Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. get_occlusion_mask : bool, default True Whether to get occlusion masks. get_meta : bool, default True Whether to get metadata. """ super().__init__( dataset_name="Middlebury", split_name=split, transform=transform, max_flow=max_flow, get_valid_mask=get_valid_mask, get_occlusion_mask=False, get_motion_boundary_mask=False, get_backward=False, get_meta=get_meta, ) self.root_dir = root_dir self.split = split self.sequence_length = 2 # Get sequence names for the given split if split == "test": split_dir = "eval" else: split_dir = "other" sequence_names = sorted( [p.stem for p in (Path(root_dir) / f"{split_dir}-gt-flow").glob("*")] ) # Read paths from disk for seq_name in sequence_names: image_paths = sorted( (Path(self.root_dir) / f"{split_dir}-data" / seq_name).glob("*.png") ) flow_paths = [] if split != "test": flow_paths = sorted( (Path(self.root_dir) / f"{split_dir}-gt-flow" / seq_name).glob( "*.flo" ) ) assert len(image_paths) - 1 == len( flow_paths ), f"{seq_name}: {len(image_paths)-1} vs {len(flow_paths)}" for i in range(len(image_paths) - self.sequence_length + 1): self.img_paths.append(image_paths[i : i + self.sequence_length]) if len(flow_paths) > 0: self.flow_paths.append(flow_paths[i : i + self.sequence_length - 1]) self.metadata.append( { "image_paths": [ str(p) for p in image_paths[i : i + self.sequence_length] ], "is_val": False, "misc": seq_name, "is_seq_start": True, } ) # Sanity check if split != "test": assert len(self.img_paths) == len( self.flow_paths ), f"{len(self.img_paths)} vs {len(self.flow_paths)}" self._log_status()
[docs] class MonkaaDataset(BaseFlowDataset): """Handle the Monkaa dataset."""
[docs] def __init__( # noqa: C901 self, root_dir: str, pass_names: Union[str, List[str]] = "clean", side_names: Union[str, List[str]] = "left", add_reverse: bool = True, transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 1000.0, get_valid_mask: bool = True, get_backward: bool = True, get_meta: bool = True, sequence_length: int = 2, sequence_position: str = "first", ) -> None: """Initialize MonkaaDataset. Parameters ---------- root_dir : str path to the root directory of the Monkaa dataset. pass_names : Union[str, List[str]], default 'clean' Which passes should be loaded. It can be one of {'clean', 'final', ['clean', 'final']}. side_names : Union[str, List[str]], default 'left' Samples from which side view should be loaded. It can be one of {'left', 'right', ['left', 'right']}. add_reverse : bool, default True If True, double the number of samples by appending the backward samples as additional samples. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 10000.0 Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. get_backward : bool, default True Whether to get the occluded version of the inputs. get_meta : bool, default True Whether to get metadata. sequence_length : int, default 2 How many consecutive images are loaded per sample. More than two images can be used for model which exploit more temporal information. sequence_position : str, default "first" Only used when sequence_length > 2. Determines the position where the main image frame will be in the sequence. It can one of three values: - "first": the main frame will be the first one of the sequence, - "middle": the main frame will be in the middle of the sequence (at position sequence_length // 2), - "last": the main frame will be the penultimate in the sequence. """ super().__init__( dataset_name="Monkaa", split_name="trainval", transform=transform, max_flow=max_flow, get_valid_mask=get_valid_mask, get_occlusion_mask=False, get_motion_boundary_mask=False, get_backward=get_backward, get_semantic_segmentation_labels=False, get_meta=get_meta, ) self.root_dir = root_dir self.add_reverse = add_reverse self.pass_names = pass_names self.sequence_length = sequence_length self.sequence_position = sequence_position if isinstance(self.pass_names, str): self.pass_names = [self.pass_names] self.side_names = side_names if isinstance(self.side_names, str): self.side_names = [self.side_names] pass_dirs = [f"frames_{p}pass" for p in self.pass_names] directions = [("into_future", "into_past")] reverts = [False] if self.add_reverse: directions.append(("into_past", "into_future")) reverts.append(True) # Read paths from disk for passd in pass_dirs: pass_path = Path(self.root_dir) / passd for seq_path in pass_path.glob("*"): for direcs, rev in zip(directions, reverts): for side in self.side_names: image_paths = sorted( (seq_path / side).glob("*.png"), reverse=rev ) image_paths = self._extend_paths_list( image_paths, sequence_length, sequence_position ) flow_paths = sorted( ( Path(str(seq_path).replace(passd, "optical_flow")) / direcs[0] / side ).glob("*.pfm"), reverse=rev, ) flow_paths = self._extend_paths_list( flow_paths, sequence_length, sequence_position ) flow_b_paths = [] if self.get_backward: flow_b_paths = sorted( ( Path(str(seq_path).replace(passd, "optical_flow")) / direcs[1] / side ).glob("*.pfm"), reverse=rev, ) flow_b_paths = self._extend_paths_list( flow_b_paths, sequence_length, sequence_position ) for i in range(len(image_paths) - self.sequence_length + 1): self.img_paths.append( image_paths[i : i + self.sequence_length] ) if len(flow_paths) > 0: self.flow_paths.append( flow_paths[i : i + self.sequence_length - 1] ) self.metadata.append( { "image_paths": [ str(p) for p in image_paths[ i : i + self.sequence_length ] ], "is_val": False, "misc": "", "is_seq_start": i == 0, } ) if self.get_backward: if len(flow_b_paths) > 0: self.flow_b_paths.append( flow_b_paths[i + 1 : i + self.sequence_length] ) assert len(self.img_paths) == len( self.flow_paths ), f"{len(self.img_paths)} vs {len(self.flow_paths)}" assert len(self.occ_paths) == 0 or len(self.img_paths) == len( self.occ_paths ), f"{len(self.img_paths)} vs {len(self.occ_paths)}" assert len(self.mb_paths) == 0 or len(self.img_paths) == len( self.mb_paths ), f"{len(self.img_paths)} vs {len(self.mb_paths)}" if self.get_backward: assert len(self.img_paths) == len( self.flow_b_paths ), f"{len(self.img_paths)} vs {len(self.flow_b_paths)}" assert len(self.occ_b_paths) == 0 or len(self.img_paths) == len( self.occ_b_paths ), f"{len(self.img_paths)} vs {len(self.occ_b_paths)}" assert len(self.mb_b_paths) == 0 or len(self.img_paths) == len( self.mb_b_paths ), f"{len(self.img_paths)} vs {len(self.mb_b_paths)}" self._log_status()