Source code for summary_metrics

"""Create a reduced version of the validation metrics table.

The reduced version is created by selecting a subset of columns.

It can also create a plot between two chosen metrics.

Tha parsing of this script is tightly connected to how the results are output by validate.py.
"""

# =============================================================================
# Copyright 2021 Henrique Morimitsu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import argparse
from pathlib import Path

import pandas as pd
import plotly.express as px


def _init_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--metrics_path",
        type=str,
        default=str(Path("docs/source/results/metrics_all.csv")),
        help=("Path to the csv file containing the validation metrics."),
    )
    parser.add_argument(
        "--chosen_metrics",
        type=str,
        nargs="+",
        default=("epe", "outlier"),
        help=(
            "Names of which metrics to keep in the summarized results. The chosen names must be at the end of the column "
            "name of the csv file. If exactly two metrics are chosen, a plot between the two will also be generated."
        ),
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=str(Path("outputs/metrics")),
        help=("Path to the directory where the outputs will be saved."),
    )
    parser.add_argument(
        "--sort_by",
        type=str,
        default="model",
        help=(
            "Name of the column to use to sort the outputs table. The name must match exactly a column name from the"
            "metrics csv file."
        ),
    )
    parser.add_argument(
        "--plot_ignore_models",
        type=str,
        nargs="*",
        default=None,
        help=("Name of the models that should not be included in the plot."),
    )
    parser.add_argument(
        "--drop_checkpoints",
        type=str,
        nargs="*",
        default=None,
        help=(
            "Name of checkpoints to not be included in the final outputs. The names must be substrings of the values in "
            "the file from --metrics_path."
        ),
    )

    return parser


[docs] def load_summarized_table(args: argparse.Namespace) -> pd.DataFrame: """Load the DataFrame and keep only columns according to the selected metrics. Parameters ---------- args : argparse.Namespace Arguments to control the loading. Returns ------- pd.DataFrame The summarized DataFrame. """ df = pd.read_csv(args.metrics_path) keep_cols = list(df.columns)[:2] for col in df.columns[2:]: for cmet in args.chosen_metrics: if col.endswith(cmet): keep_cols.append(col) summ_df = df[keep_cols] summ_df = summ_df.sort_values(args.sort_by) summ_df = summ_df.round(3) return summ_df
[docs] def save_plots(args: argparse.Namespace, df: pd.DataFrame) -> None: """Generate and save the plot to disk. Parameters ---------- args : argparse.Namespace Arguments to control the plot. df : pd.DataFrame A DataFrame with the validation metrics. """ if args.plot_ignore_models is not None: for name in args.plot_ignore_models: df = df[df[df.columns[0]] != name] metric_pairs = {} for col in df.columns[2:]: for cmet in args.chosen_metrics: if col.endswith(cmet): dataset_name = "_".join(col.split("-")[:2]) if metric_pairs.get(dataset_name) is None: metric_pairs[dataset_name] = {} metric_pairs[dataset_name][cmet] = col ckpt_groups = ["all", "chairs", "kitti", "sintel", "things", "others"] assert ckpt_groups[-1] == "others" # others must be the last element not_others = None for cgroup in ckpt_groups: group_df = df if cgroup == "all": color = group_df.columns[0] symbol = group_df.columns[1] elif cgroup == "others": group_df = df[~not_others] color = group_df.columns[0] symbol = group_df.columns[1] else: belong_to_group = df[df.columns[1]].str.contains(cgroup) if not_others is None: not_others = belong_to_group else: not_others = not_others | belong_to_group group_df = df[belong_to_group] color = group_df.columns[0] symbol = group_df.columns[0] for dataset_name, col_pair_dict in metric_pairs.items(): col1, col2 = col_pair_dict.values() fig = px.scatter( group_df, x=col1, y=col2, color=color, symbol=symbol, title=f"{dataset_name} - {args.chosen_metrics[0]} x {args.chosen_metrics[1]} - checkpoint: {cgroup}", ) fig.update_traces( marker={"size": 20, "line": {"width": 2, "color": "DarkSlateGrey"}}, selector={"mode": "markers"}, ) fig.update_layout(title_font_size=30) file_name = f"{dataset_name}_{args.chosen_metrics[0]}_{args.chosen_metrics[1]}_{cgroup}" if args.drop_checkpoints is not None and len(args.drop_checkpoints) > 0: file_name += f'-drop_{"_".join(args.drop_checkpoints)}' fig.write_html(args.output_dir / (file_name + ".html"))
[docs] def summarize(args: argparse.Namespace) -> None: """Summarize the results and save them to the disk. Parameters ---------- args : argparse.Namespace Arguments required to control the process. """ args.output_dir = Path(args.output_dir) df = load_summarized_table(args) df = _shorten_columns_names(df, len(args.chosen_metrics) > 1) if args.drop_checkpoints is not None and len(args.drop_checkpoints) > 0: ignore_idx = [ i for i in df.index if any(c in df.loc[i, "checkpoint"] for c in args.drop_checkpoints) ] df = df.drop(ignore_idx) args.output_dir.mkdir(parents=True, exist_ok=True) file_name = f'summarized_metrics-{"_".join(args.chosen_metrics)}' if args.drop_checkpoints is not None and len(args.drop_checkpoints) > 0: file_name += f'-drop_{"_".join(args.drop_checkpoints)}' df.to_csv(args.output_dir / f"{file_name}.csv", index=False) with open(args.output_dir / f"{file_name}.md", "w") as f: df.to_markdown(f) if len(args.chosen_metrics) == 2: save_plots(args, df)
def _shorten_columns_names(df: pd.DataFrame, keep_metric_name: bool) -> pd.DataFrame: change_dict = {} for col in df.columns: tokens = col.split("-") if len(tokens) > 1: metric_name = tokens[-1].split("/")[1] new_col_name = f"{tokens[0]}-{tokens[1]}" if keep_metric_name: new_col_name += f"-{metric_name}" change_dict[col] = new_col_name df = df.rename(columns=change_dict) return df if __name__ == "__main__": parser = _init_parser() args = parser.parse_args() summarize(args) print(f"Results saved to {str(args.output_dir)}.")