Source code for ptlflow.utils.flow_metrics

"""Handle and compute metrics for optical flow and related estimations.

This handler is designed according to the torchmetrics specifications. Besides accuracy metrics for optical flow, it can also
compute basic metrics for occlusion, motion boundary and flow confidence estimations.
"""

# =============================================================================
# Copyright 2021 Henrique Morimitsu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from typing import Dict

from torchmetrics import Metric
import torch


[docs] class FlowMetrics(Metric): """Handler for optical flow and related metrics. Attributes ---------- average_mode : str, default 'epoch_mean' How the final metric is averaged. It can be either 'epoch_mean' or 'ema' (exponential moving average). ema_decay : float, default 0.99 The decay to be applied if average_mode is 'ema'. prefix : str, optional A prefix string that will be attached to the metric names. """ full_state_update = True
[docs] def __init__( self, dist_sync_on_step: bool = False, prefix: str = "", average_mode: str = "epoch_mean", ema_decay: float = 0.99, f1_mode: str = "macro", ) -> None: """Initialize FlowMetrics. Parameters ---------- dist_sync_on_step : bool, default False Used by torchmetrics to sync metrics between multiple processes. prefix : str, optional A prefix string that will be attached to the metric names. average_mode : str, default 'epoch_mean' How the final metric is averaged. It can be either 'epoch_mean' or 'ema' (exponential moving average). ema_decay : float, default 0.99 The decay to be applied if average_mode is 'ema'. f1_mode : float, default 'macro' How to calculate the f1-score. Accepts one of these options {binary, macro, weighted}. If binary, then the f1-score is calculated only for the positive pixels. If macro, then the f1-score is the average of positive and negative scores. If weighted, then the average is weighted according to the number of positive/negative samples. """ super().__init__(dist_sync_on_step=dist_sync_on_step) assert average_mode in ["epoch_mean", "ema"] self.average_mode = average_mode self.prefix = prefix self.ema_decay = ema_decay self.f1_mode = f1_mode self.ema_max_count = min(100, int(1.0 / (1.0 - ema_decay))) self.add_state("epe", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state( "epe_non_occ", default=torch.tensor(0).float(), dist_reduce_fx="sum" ) self.add_state("epe_occ", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state("px1", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state( "px1_non_occ", default=torch.tensor(0).float(), dist_reduce_fx="sum" ) self.add_state("px1_occ", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state("px3", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state( "px3_non_occ", default=torch.tensor(0).float(), dist_reduce_fx="sum" ) self.add_state("px3_occ", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state("px5", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state( "px5_non_occ", default=torch.tensor(0).float(), dist_reduce_fx="sum" ) self.add_state("px5_occ", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state("flall", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state( "flall_non_occ", default=torch.tensor(0).float(), dist_reduce_fx="sum" ) self.add_state( "flall_occ", default=torch.tensor(0).float(), dist_reduce_fx="sum" ) self.add_state("wauc", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state( "wauc_non_occ", default=torch.tensor(0).float(), dist_reduce_fx="sum" ) self.add_state( "wauc_occ", default=torch.tensor(0).float(), dist_reduce_fx="sum" ) self.add_state("occ_f1", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state("mb_f1", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state("conf_f1", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state( "sample_count", default=torch.tensor(0).float(), dist_reduce_fx="sum" ) self.add_state( "step_count", default=torch.tensor(0).float(), dist_reduce_fx="sum" ) self.include_occlusion = False self.used_keys = []
[docs] def update( self, preds: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor] ) -> None: """Compute and update one step of the metrics. Parameters ---------- preds : dict[str, torch.Tensor] The predictions of the optical flow model. targets : dict[str, torch.Tensor] The groundtruth of the predictions. """ if self.average_mode == "epoch_mean": prev_weight = 1.0 next_weight = 1.0 else: prev_weight = self.ema_decay next_weight = 1.0 - self.ema_decay batch_size = self._get_batch_size(targets["flows"]) flow_pred = self._fix_shape(preds["flows"], batch_size) flow_target = self._fix_shape(targets["flows"], batch_size) valid_target = targets.get("valids") if valid_target is not None: valid_target = self._fix_shape(valid_target, batch_size) else: valid_target = torch.ones_like(flow_target[:, :1]) valid_target = valid_target[:, 0] occlusion_target = targets.get("occs") if occlusion_target is not None: occlusion_target = self._fix_shape(occlusion_target, batch_size) if len(flow_target.shape) == 5: epe = torch.norm(flow_pred[:, None] - flow_target, p=2, dim=2) epe, min_idx = epe.min(dim=1) target_norm = torch.norm(flow_target, p=2, dim=2) target_norm = target_norm.gather(1, min_idx[:, None])[:, 0] else: epe = torch.norm(flow_pred - flow_target, p=2, dim=1) target_norm = torch.norm(flow_target, p=2, dim=1) px1_mask = (epe < 1).float() px3_mask = (epe < 3).float() px5_mask = (epe < 5).float() flall_mask = ((epe > 3) & (epe > (0.05 * target_norm))).float() * 100 self.used_keys = [ ("epe", "epe", "valid_target"), ("px1", "px1_mask", "valid_target"), ("px3", "px3_mask", "valid_target"), ("px5", "px5_mask", "valid_target"), ("flall", "flall_mask", "valid_target"), ("wauc", "epe", "valid_target"), ] if occlusion_target is not None: valid_occ = occlusion_target[:, 0] * valid_target valid_non_occ = (1 - occlusion_target[:, 0]) * valid_target self.used_keys.extend( [ ("epe_occ", "epe", "valid_occ"), ("epe_non_occ", "epe", "valid_non_occ"), ("px1_occ", "px1_mask", "valid_occ"), ("px1_non_occ", "px1_mask", "valid_non_occ"), ("px3_occ", "px3_mask", "valid_occ"), ("px3_non_occ", "px3_mask", "valid_non_occ"), ("px5_occ", "px5_mask", "valid_occ"), ("px5_non_occ", "px5_mask", "valid_non_occ"), ("flall_occ", "flall_mask", "valid_occ"), ("flall_non_occ", "flall_mask", "valid_non_occ"), ("wauc_occ", "epe", "valid_occ"), ("wauc_non_occ", "epe", "valid_non_occ"), ] ) self.include_occlusion = True if preds.get("occs") is not None: occlusion_pred = self._fix_shape(preds["occs"], batch_size) occ_f1 = self._f1_score( occlusion_pred, occlusion_target, mode=self.f1_mode ) self.used_keys.extend([("occ_f1", "occ_f1", "valid_target")]) if preds.get("mbs") is not None and targets.get("mbs") is not None: mb_pred = self._fix_shape(preds["mbs"], batch_size) mb_target = self._fix_shape(targets["mbs"], batch_size) mb_f1 = self._f1_score(mb_pred, mb_target, mode=self.f1_mode) self.used_keys.extend([("mb_f1", "mb_f1", "valid_target")]) if preds.get("confs") is not None: conf_target = torch.exp( -torch.pow(flow_target - flow_pred, 2).sum(dim=1, keepdim=True) ) conf_pred = self._fix_shape(preds["confs"], batch_size) conf_f1 = self._f1_score(conf_pred, conf_target, mode=self.f1_mode) self.used_keys.extend([("conf_f1", "conf_f1", "valid_target")]) for v1, v2, v3 in self.used_keys: if "wauc" not in v1: setattr( self, v1, prev_weight * getattr(self, v1) + next_weight * self._compute_total(locals()[v2], locals()[v3]), ) self.wauc = prev_weight * self.wauc + next_weight * self._compute_total_wauc( epe, valid_target ) if occlusion_target is not None: self.wauc_occ = ( prev_weight * self.wauc_occ + next_weight * self._compute_total_wauc(epe, valid_occ) ) self.wauc_non_occ = ( prev_weight * self.wauc_non_occ + next_weight * self._compute_total_wauc(epe, valid_non_occ) ) self.sample_count += batch_size self.step_count += 1
[docs] def calculate_metrics(self) -> Dict[str, torch.Tensor]: """Compute and return the average of all metrics. On Pytorch-Lightning < 1.2, compute() automatically calls reset(). Sometimes this is not desirable, so the metrics are calculated here in this other function, which can be called externally. Returns ------- Dict[str, torch.Tensor] The average of the metrics. """ if self.average_mode == "epoch_mean": divider = self.sample_count else: divider = 1.0 if self.step_count < self.ema_max_count: divider -= self.ema_decay**self.step_count metrics = {} for k in self.used_keys: metrics[self.prefix + k[0]] = getattr(self, k[0]) / divider return metrics
[docs] def compute(self) -> Dict[str, torch.Tensor]: """Compute and return the average of all metrics. Called internally by torchmetrics. Returns ------- Dict[str, torch.Tensor] The average of the metrics. """ return self.calculate_metrics()
def _compute_total( self, tensor: torch.Tensor, valid_mask: torch.Tensor ) -> torch.Tensor: tensor = tensor * valid_mask tensor = tensor.view(tensor.shape[0], -1) valid_sum = valid_mask.reshape(valid_mask.shape[0], -1).sum(dim=1) valid_sum = torch.clamp(valid_sum, 1) tensor = tensor.sum(dim=1) / valid_sum if self.average_mode == "epoch_mean": tensor = tensor.sum() else: tensor = tensor.mean() return tensor def _f1_score( self, pred: torch.Tensor, target: torch.Tensor, mode: str = "macro" ) -> torch.Tensor: f1_pos = self._single_f1_score(pred, target) if mode == "binary": return f1_pos else: f1_neg = self._single_f1_score(1 - pred, 1 - target) if mode == "macro": return (f1_pos + f1_neg) / 2.0 else: # weighted target_pos = (target > 0.5).float() target_pos = target_pos.view( target_pos.shape[0], target_pos.shape[1], -1 ) n_pos = target_pos.sum(dim=2)[:, :, None, None] w_pos = n_pos / target_pos.shape[2] target_neg = (target <= 0.5).float() target_neg = target_neg.view( target_neg.shape[0], target_neg.shape[1], -1 ) n_neg = target_neg.sum(dim=2)[:, :, None, None] w_neg = n_neg / target_neg.shape[2] f1_weighted = w_pos * f1_pos + w_neg * f1_neg return f1_weighted def _single_f1_score( self, pred: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: pred_bin = (pred > 0.5).float() target_bin = (target > 0.5).float() dims = pred_bin.shape pred_bin = pred_bin.view(*dims[:-2], -1) target_bin = target_bin.view(*dims[:-2], -1) tp = (pred_bin * target_bin).sum(dim=-1) fp = ((1 - pred_bin) * target_bin).sum(dim=-1) fn = (pred_bin * (1 - target_bin)).sum(dim=-1) eps = torch.finfo(pred.dtype).eps precision = tp / (tp + fp + eps) recall = tp / (tp + fn + eps) f1 = 2 * precision * recall / (precision + recall + eps) return f1[:, :, None] def _fix_shape(self, tensor: torch.Tensor, batch_size: int) -> torch.Tensor: if len(tensor.shape) == 2: tensor = tensor[None, None] elif len(tensor.shape) == 3: if tensor.shape[0] == batch_size: tensor = tensor[:, None] else: tensor = tensor[None] elif len(tensor.shape) == 5: tensor = tensor.view( tensor.shape[0] * tensor.shape[1], tensor.shape[2], tensor.shape[3], tensor.shape[4], ) elif len(tensor.shape) == 6: tensor = tensor.view( tensor.shape[0] * tensor.shape[1], tensor.shape[2], tensor.shape[3], tensor.shape[4], tensor.shape[5], ) return tensor def _get_batch_size(self, flow_tensor: torch.Tensor) -> int: if len(flow_tensor.shape) < 4: return 1 elif len(flow_tensor.shape) == 4: return flow_tensor.shape[0] elif len(flow_tensor.shape) == 5: return flow_tensor.shape[0] * flow_tensor.shape[1] elif len(flow_tensor.shape) == 6: return flow_tensor.shape[0] def _compute_total_wauc( self, epe: torch.Tensor, valid_mask: torch.Tensor ) -> torch.Tensor: # Code adapted from https://github.com/cv-stuttgart/springwebsite/blob/main/springeval/management/commands/evaluation.py # MIT License epe = epe.clone() epe[valid_mask < 0.5] = 100 epe = epe.view(epe.shape[0], -1) N = valid_mask.reshape(valid_mask.shape[0], -1).sum(dim=1) wauc = torch.zeros(epe.shape[0], dtype=epe.dtype, device=epe.device) sum_wi = 0 for i in range(1, 101): wi = 1 - ((i - 1) / 100.0) deltai = i / 20.0 err = (epe <= deltai).sum(dim=1) wauc += wi * err sum_wi += wi wauc = 100 * wauc / (N * sum_wi + 1e-8) if self.average_mode == "epoch_mean": wauc = wauc.sum() else: wauc = wauc.mean() return wauc