"""Simple utilities to measure elapsed time."""
# =============================================================================
# Copyright 2021 Henrique Morimitsu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import logging
import time
from typing import Union, Tuple
try:
import torch
except NameError:
torch = None
[docs]
class Timer(object):
"""Utility to count the total elapsed time.
Every time toc() is called, the elapsed time since the last tic() is added to the total time. Call reset() to zero the
total time.
Attributes
----------
name : str
A string name that will be printed to identify this timer.
indent_level : int, default 0
The level of indentation when printing this timer. Useful for creating better prints. E.g., inner parts may have a
larger indentation.
Examples
--------
>>> t1 = Timer('parent_op', 0)
>>> t1.tic()
>>> ...
>>> t2 = Timer('inner_op', 1)
>>> t2.tic()
>>> ...
>>> t2.toc()
>>> t1.toc()
>>> print(t1)
parent_op: 2000.0 (2000.0) ms
>>> print(t2)
inner_op: 1000.0 (1000.0) ms
"""
[docs]
def __init__(self, name: str, indent_level: int = 0) -> None:
"""Initialize the Timer.
Parameters
----------
name : str
A string name that will be printed to identify this timer.
indent_level : int, optional
The level of indentation when printing this timer. Useful for creating better prints. E.g., inner parts may have a
larger indentation.
"""
self.name = name
self.indent_level = indent_level
self.reset()
self.num_tocs = 0
self.num_global_tocs = 0
[docs]
def reset(self) -> None:
"""Zero the total time counter."""
self.total_time = 0.0
[docs]
def tic(self) -> None:
"""Start to count the elapsed time."""
self.has_tic = True
if torch is not None and torch.cuda.is_available():
torch.cuda.synchronize()
self.start = time.perf_counter()
[docs]
def toc(self) -> None:
"""Count the elapsed time since the last tic() and add it to the total time."""
if torch is not None and torch.cuda.is_available():
torch.cuda.synchronize()
self.end = time.perf_counter()
assert self.has_tic, "toc called without tic"
if self.num_tocs > 0:
self.total_time += self.end - self.start
self.has_tic = False
self.num_tocs += 1
[docs]
def mean(self) -> float:
"""Return the average time (total time divided by the number of tocs).
Returns
-------
float
The average time in milliseconds.
"""
num_tocs = self.num_global_tocs if self.num_global_tocs > 0 else self.num_tocs
return self.total() / max(1, num_tocs - 1)
[docs]
def total(self) -> float:
"""Return the total time since the last reset().
Returns
-------
float
The total time in milliseconds.
"""
return self.total_time
def __repr__(self) -> str:
return f'{" "*self.indent_level}{self.name}: {1000 * self.total():.1f} ({1000 * self.mean():.1f}) ms'
def __str__(self) -> str:
return self.__repr__()
[docs]
class TimerManager(object):
"""Utility to handle multiple timers.
Timers can be accessed using a dict-like call (see Usage below). The timers can be either printed to the default output
or to a log file.
Attributes
----------
timers : dict[Union[str, tuple[str, int]], Timer]
The timers to be managed. The dict key can be either a single string representing the name of the timer, or a tuple
(name, indentation_level)
log_id : str, default 'timer'
A string representing the name of this manager.
log_path : str, default 'timer_log.txt'
Path to where the log file will be saved (if log is used).
logger : logging.Logger
A hander for the logger.
Examples
--------
>>> tm = TimerManager()
>>> tm['op1'].tic()
>>> ...
>>> tm['op1'].toc()
>>> # You may pass a tuple (str, int) as a key, which will be interpreted
>>> # as the (name, indent_level) for the timer (see Timer above):
>>> tm[('op2', 1)].tic()
>>> ...
>>> tm['op2'].toc()
>>> print(tm) # Prints all timers to default output
op1: 2000.0 (2000.0) ms
op2: 1000.0 (1000.0) ms
>>> tm.write_to_log('Some header message (optional)') # Write timers to a log file
See Also
--------
Timer : The timers that are managed.
"""
[docs]
def __init__(self, log_id: str = "timer", log_path: str = "timer_log.txt") -> None:
"""Initialize the TimerManager.
Parameters
----------
log_id : str, default 'timer'
A string representing the name of this manager.
log_path : str, default 'timer_log.txt'
Path to where the log file will be saved (if log is used).
"""
self.timers = {}
self.log_id = log_id
self.log_path = log_path
self.logger = None
self.num_global_tocs = 0
def global_toc(self):
self.num_global_tocs += 1
for t in self.timers.values():
t.num_global_tocs = self.num_global_tocs
[docs]
def clear(self) -> None:
"""Remove all timers."""
self.num_global_tocs = 0
self.timers = {}
[docs]
def reset(self) -> None:
"""Restart the total time counter of all timers."""
self.num_global_tocs = 0
for _, t in self.timers.items():
t.reset()
[docs]
def write_to_log(self, header: str = "") -> None:
"""Write the timers to the log file.
Parameters
----------
header : str
An optional string to be added to the top of the log file.
"""
if self.logger is None:
self._init_logger()
if len(header) > 0:
self.logger.info(header)
self.logger.info(self.__repr__())
def _init_logger(self) -> None:
self.logger = logging.getLogger(self.log_id)
self.logger.setLevel(logging.INFO)
fh = logging.FileHandler(self.log_path, mode="w")
fh.setLevel(logging.INFO)
self.logger.addHandler(fh)
def __getitem__(self, key: Union[str, Tuple[str, int]]) -> None:
indent_level = 0
if isinstance(key, tuple) or isinstance(key, list):
indent_level = key[1]
key = key[0]
if self.timers.get(key) is None:
self.timers[key] = Timer(key, indent_level)
return self.timers[key]
def __repr__(self) -> str:
ret = ""
for _, t in self.timers.items():
ret += t.__repr__() + "\n"
return ret
def __str__(self) -> str:
return self.__repr__()