Source code for ptlflow.utils.flowpy_torch

"""

This code is a port to PyTorch of the flow to RGB convertion from flowpy.

https://gitlab-research.centralesupelec.fr/2018seznecm/flowpy

"""

# =============================================================================
# Copyright 2021 Henrique Morimitsu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import math
from collections import namedtuple
from typing import Optional, Tuple, Union

import torch

from .external.flowpy import make_colorwheel

DEFAULT_TRANSITIONS = (15, 6, 4, 11, 13, 6)


[docs] def flow_to_rgb( flow: torch.Tensor, flow_max_radius: Optional[Union[float, torch.Tensor]] = None, background: str = "bright", custom_colorwheel: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Create a RGB representation of an optical flow. Parameters ---------- flow : torch.Tensor Flow with at least 3 dimensions in the ...CHW (..., Channels, Height, Width) layout, where ... represents any number of dimensions. flow[..., 0, h, w] should be the x-displacement flow[..., 1, h, w] should be the y-displacement flow_max_radius : float or torch.Tensor, optional Set the radius that gives the maximum color intensity, useful for comparing different flows. Default: The normalization is based on the input flow maximum radius per batch element. background : str, default 'bright' States if zero-valued flow should look 'bright' or 'dark'. custom_colorwheel : torch.Tensor Use a custom colorwheel for specific hue transition lengths. By default, the default transition lengths are used. Returns ------- torch.Tensor The RGB representation of the flow. RGB values are float in the [0, 1] interval. The output shape is (..., 3, H, W). Raises ------ ValueError If the background choice is invalid. See Also -------- ptlflow.utils.external.flowpy.make_colorwheel : How the colorwheel can be generated. """ valid_backgrounds = ("bright", "dark") if background not in valid_backgrounds: raise ValueError( f"background should be one the following: {valid_backgrounds}, not {background}" ) wheel = make_colorwheel() if custom_colorwheel is None else custom_colorwheel wheel = torch.from_numpy(wheel).to(dtype=flow.dtype, device=flow.device) / 255 orig_shape = flow.shape if len(orig_shape) == 3: flow = flow[None] elif len(orig_shape) > 4: flow = flow.view(-1, 2, flow.shape[-2], flow.shape[-1]) complex_flow = flow[:, 0] + 1j * flow[:, 1] complex_flow, nan_mask = _replace_nans(complex_flow) radius, angle = torch.abs(complex_flow), torch.angle(complex_flow) if flow_max_radius is None: flow_max_radius = radius.view(radius.shape[0], -1).max(dim=1)[0] else: flow_max_radius = ( torch.zeros(radius.shape[0]).to(dtype=flow.dtype, device=flow.device) + flow_max_radius ) flow_max_radius = torch.clamp(flow_max_radius[:, None, None], 1) radius /= flow_max_radius ncols = len(wheel) # Map the angles from (-pi, pi] to [0, 2pi) to [0, ncols - 1) angle[angle < 0] += 2 * math.pi angle = angle * ((ncols - 1) / (2 * math.pi)) # Make the wheel cyclic for interpolation wheel = torch.cat((wheel, wheel[:1]), dim=0) # Interpolate the hues angle_fractional = torch.frac(angle) angle_floor = torch.floor(angle) angle_ceil = torch.ceil(angle) angle_fractional = angle_fractional.view((angle_fractional.shape) + (1,)) float_hue = ( wheel[angle_floor.long()] * (1 - angle_fractional) + wheel[angle_ceil.long()] * angle_fractional ) ColorizationArgs = namedtuple( "ColorizationArgs", ["move_hue_valid_radius", "move_hue_oversized_radius", "invalid_color"], ) def _move_hue_on_v_axis(hues: torch.Tensor, factors: torch.Tensor) -> torch.Tensor: return hues * torch.unsqueeze(factors, -1) def _move_hue_on_s_axis(hues: torch.Tensor, factors: torch.Tensor) -> torch.Tensor: return 1.0 - torch.unsqueeze(factors, -1) * (1.0 - hues) if background == "dark": parameters = ColorizationArgs( _move_hue_on_v_axis, _move_hue_on_s_axis, torch.zeros(3).to(dtype=flow.dtype, device=flow.device) + 1, ) else: parameters = ColorizationArgs( _move_hue_on_s_axis, _move_hue_on_v_axis, torch.zeros(3).to(dtype=flow.dtype, device=flow.device), ) colors = parameters.move_hue_valid_radius(float_hue, radius) oversized_radius_mask = radius > 1 colors[oversized_radius_mask] = parameters.move_hue_oversized_radius( float_hue[oversized_radius_mask], 1 / radius[oversized_radius_mask] ) colors[nan_mask] = parameters.invalid_color output_shape = tuple( 3 if i == len(orig_shape) - 3 else orig_shape[i] for i in range(len(orig_shape)) ) colors = colors.permute(0, 3, 1, 2).contiguous() colors = colors.view(output_shape) return colors
def _replace_nans( array: torch.Tensor, value: int = 0 ) -> Tuple[torch.Tensor, torch.Tensor]: nan_mask = torch.isnan(array) array[nan_mask] = value return array, nan_mask