"""Implement a callback to log images."""
# =============================================================================
# Copyright 2021 Henrique Morimitsu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple, Union
try:
from neptune.new.types import File as NeptuneFile
except ImportError:
NeptuneFile = None
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers.comet import CometLogger
from lightning.pytorch.loggers.neptune import NeptuneLogger
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.trainer.trainer import Trainer
import torch
import torch.nn.functional as F
from torchvision.utils import make_grid
try:
import wandb
except ImportError:
wandb = None
from ptlflow.models.base_model.base_model import BaseModel
from ptlflow.utils import flow_utils
from ptlflow.utils.utils import config_logging
config_logging()
[docs]
class LoggerCallback(Callback):
"""Callback to collect and log images during training and validation.
For each dataloader, num_images samples will be collected. The samples are collected by trying to retrieve from both
inputs and outputs tensors whose keys match the values provided in log_keys.
num_images samples are uniformly sampled from the whole dataloader.
"""
[docs]
def __init__(
self,
num_images: int = 5,
image_size: Tuple[int, int] = (200, 400),
log_keys: Sequence[str] = ("images", "flows", "occs", "mbs", "confs"),
epe_clip: float = 5.0,
) -> None:
"""Initialize LoggerCallback.
Parameters
----------
num_images : int, default 5
Number of images to log during one epoch.
image_size : Tuple[int, int], default (200, 400)
The size of the stored images.
log_keys : Sequence[str], default ('images', 'flows', 'occs', 'mbs', 'confs')
The keys to use to collect the images from the inputs and outputs of the model. If a key is not found, it is
ignored.
epe_clip : float, default 5.0
The maximum EPE value that is shown on EPE image. All EPE values above this will be clipped.
"""
super().__init__()
self.num_images = num_images
self.image_size = image_size
self.log_keys = log_keys
self.epe_clip = epe_clip
self.train_collect_img_idx = []
self.train_images = {}
self.val_dataloader_names = []
self.val_collect_image_idx = {}
self.val_images = {}
[docs]
def log_image(self, title: str, image: torch.Tensor, pl_module: BaseModel) -> None:
"""Log the image in all of the pl_module loggers.
Note, however, that not all loggers may be able to log images.
Parameters
----------
title : str
A title for the image.
image : torch.Tensor
The image to log. It must be a 3D tensor CHW (typically C=3).
pl_module : BaseModel
An instance of the optical flow model to get the logger from.
"""
image_npy = image.permute(1, 2, 0).numpy()
for logger in pl_module.loggers:
if isinstance(logger, CometLogger):
logger.experiment.log_image(image_npy, name=title)
elif isinstance(logger, NeptuneLogger):
logger.experiment[title].log(NeptuneFile.as_image(image))
elif isinstance(logger, TensorBoardLogger):
logger.experiment.add_image(title, image, pl_module.global_step)
elif isinstance(logger, WandbLogger) and wandb is not None:
title_wb = title.replace("/", "-")
image_wb = wandb.Image(image_npy)
logger.experiment.log({title_wb: image_wb})
[docs]
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: BaseModel,
outputs: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
batch_idx: int,
**kwargs,
) -> None:
"""Store one image to be logged, if the current batch_idx is in the log selection group.
Parameters
----------
trainer : Trainer
An instance of the PyTorch Lightning trainer.
pl_module : BaseModel
An instance of the optical flow model.
outputs : Dict[str, torch.Tensor]
The outputs of the current training batch.
batch : Dict[str, torch.Tensor]
The inputs of the current training batch.
batch_idx : int
The counter value of the current batch.
"""
if batch_idx in self.train_collect_img_idx:
self._append_images(
self.train_images, pl_module.last_inputs, pl_module.last_predictions
)
[docs]
def on_train_epoch_start(self, trainer: Trainer, pl_module: BaseModel) -> None:
"""Reset the training log params and accumulators.
Parameters
----------
trainer : Trainer
An instance of the PyTorch Lightning trainer.
pl_module : BaseModel
An instance of the optical flow model.
"""
self.train_images = {}
limit_batches = (
pl_module.args.limit_train_batches
if pl_module.args.limit_train_batches is not None
else 1.0
)
collect_idx = np.unique(
np.linspace(
0,
self._compute_max_range(
pl_module.train_dataloader_length, limit_batches
),
self.num_images,
dtype=np.int32,
)
)
self.train_collect_img_idx = collect_idx
[docs]
def on_train_epoch_end(
self, trainer: Trainer, pl_module: BaseModel, **kwargs
) -> None:
"""Log the images accumulated during the training.
Parameters
----------
trainer : Trainer
An instance of the PyTorch Lightning trainer.
pl_module : BaseModel
An instance of the optical flow model.
outputs : Any
Outputs of the training epoch.
"""
img_grid = self._make_image_grid(self.train_images)
self.log_image("train", img_grid, pl_module)
[docs]
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: BaseModel,
outputs: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
batch_idx: int,
dataloader_idx: int,
) -> None:
"""Store one image to be logged, if the current batch_idx is in the log selection group.
Parameters
----------
trainer : Trainer
An instance of the PyTorch Lightning trainer.
pl_module : BaseModel
An instance of the optical flow model.
outputs : Dict[str, torch.Tensor]
The outputs of the current validation batch.
batch : Dict[str, torch.Tensor]
The inputs of the current validation batch.
batch_idx : int
The counter value of the current batch.
dataloader_idx : int
The index number of the current dataloader.
"""
dl_name = self.val_dataloader_names[dataloader_idx]
if batch_idx in self.val_collect_image_idx[dl_name]:
self._append_images(
self.val_images[dl_name],
pl_module.last_inputs,
pl_module.last_predictions,
)
[docs]
def on_validation_epoch_start(self, trainer: Trainer, pl_module: BaseModel) -> None:
"""Reset the validation log params and accumulators.
Parameters
----------
trainer : Trainer
An instance of the PyTorch Lightning trainer.
pl_module : BaseModel
An instance of the optical flow model.
"""
self.val_dataloader_names = pl_module.val_dataloader_names
for dl_name in self.val_dataloader_names:
self.val_images[dl_name] = {}
limit_batches = (
pl_module.args.limit_val_batches
if pl_module.args.limit_val_batches is not None
else 1.0
)
for dname, dlen in zip(
pl_module.val_dataloader_names, pl_module.val_dataloader_lengths
):
collect_idx = np.unique(
np.linspace(
0,
self._compute_max_range(dlen, limit_batches),
self.num_images,
dtype=np.int32,
)
)
self.val_collect_image_idx[dname] = collect_idx
[docs]
def on_validation_epoch_end(self, trainer: Trainer, pl_module: BaseModel) -> None:
"""Log the images accumulated during the validation.
Parameters
----------
trainer : Trainer
An instance of the PyTorch Lightning trainer.
pl_module : BaseModel
An instance of the optical flow model.
"""
for dl_name, dl_images in self.val_images.items():
img_grid = self._make_image_grid(dl_images)
self.log_image(f"val/{dl_name}", img_grid, pl_module)
def _add_title(self, image: torch.Tensor, img_title: str) -> torch.Tensor:
"""Add a title to an image.
Parameters
----------
image : torch.Tensor
The image where the title will be added.
img_title : str
The title to be added.
Returns
-------
torch.Tensor
The input image with the title superposed on it.
"""
size = min(image.shape[1:3])
image = (255 * image.permute(1, 2, 0).numpy()).astype(np.uint8)
image = Image.fromarray(image)
this_dir = Path(__file__).resolve().parent
title_font = ImageFont.truetype(
str(this_dir / "RobotoMono-Regular.ttf"), size // 10
)
draw = ImageDraw.Draw(image)
bb = (
size // 25,
size // 25,
size // 25 + len(img_title) * size // 15,
size // 25 + size // 8,
)
draw.rectangle(bb, fill="black")
draw.text((size // 20, size // 30), img_title, (237, 230, 211), font=title_font)
image = np.array(image)
image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255
return image
def _append_images( # noqa: C901
self,
images: Dict[str, List[torch.Tensor]],
inputs: Dict[str, torch.Tensor],
preds: Dict[str, torch.Tensor],
) -> None:
"""Append samples to the images accumulator.
Parameters
----------
images : Dict[str, List[torch.Tensor]]
The accumulator where the samples will be appended to.
inputs : Dict[str, torch.Tensor]
The inputs of the model.
preds : Dict[str, torch.Tensor]
The outrputs of the model.
"""
for k in self.log_keys:
log_names = []
log_sources = []
if k in inputs or (k == "confs" and k in preds):
log_names.append(f"i_{k}")
log_sources.append(inputs)
if k in preds:
log_names.append(f"o_{k}")
log_sources.append(preds)
if k == "flows":
log_names.append(f"epe<{self.epe_clip:.1f}")
log_sources.append(None)
for name, source in zip(log_names, log_sources):
if images.get(name) is None:
images[name] = []
if name == "i_confs":
img = self._compute_confidence_gt(preds["flows"], inputs["flows"])
elif name.startswith("epe"):
epe = torch.norm(preds[k] - inputs[k], p=2, dim=2, keepdim=True)
img = torch.clamp(epe, 0, self.epe_clip) / self.epe_clip
if inputs.get("valids") is not None:
img[inputs["valids"] < 0.5] = 0
else:
img = source[k]
img = img[:1, 0].detach().cpu()
img = F.interpolate(img, self.image_size)
img = img[0]
if "images" in name:
img = img.flip([0]) # BGR to RGB
elif "flows" in name:
img = flow_utils.flow_to_rgb(img)
images[name].append(img)
def _compute_confidence_gt(
self, pred_flows: torch.Tensor, target_flows: torch.Tensor
) -> torch.Tensor:
"""Compute a confidence score for the flow predictions.
This score was proposed in https://arxiv.org/abs/2007.09319.
Parameters
----------
pred_flows : torch.Tensor
The predicted optical flow.
target_flows : torch.Tensor
The groundtruth optical flow.
Returns
-------
torch.Tensor
The confidence score for each pixel of the input.
"""
conf_gt = torch.exp(
-torch.pow(pred_flows - target_flows, 2).sum(dim=2, keepdim=True)
)
return conf_gt
def _compute_max_range(
self, dataloader_length: int, limit_batches: Union[float, int]
) -> int:
"""Find the maximum number of samples that will be drawn from a dataloader.
Parameters
----------
dataloader_length : int
Total size of the dataloader.
limit_batches : Union[float, int]
A value that may decrease the samples in the dataloader. See --limit_val_batches or --limit_train_batches from
PyTorch Lightning for more information.
Returns
-------
int
The maximum number of samples that will be drawn from the dataloader.
"""
if isinstance(limit_batches, int):
max_range = limit_batches - 1
else:
max_range = int(limit_batches * dataloader_length) - 1
return max_range
def _make_image_grid(
self, dl_images: Dict[str, List[torch.Tensor]]
) -> torch.Tensor:
"""Transform a bunch of images into a single one by adding them to a grid.
Parameters
----------
dl_images : Dict[str, List[torch.Tensor]]
Lists of images, each identified by a title name.
Returns
-------
torch.Tensor
A single 3D tensor image 3HW.
"""
imgs = []
for img_label, img_list in dl_images.items():
for j, im in enumerate(img_list):
if len(im.shape) == 2:
im = im[None]
if im.shape[0] == 1:
im = im.repeat(3, 1, 1)
if j == 0:
im = self._add_title(im, img_label)
imgs.append(im)
grid = None
if len(dl_images) > 0:
grid = make_grid(imgs, len(imgs) // len(dl_images))
return grid