Python Hook 设计模式
1. 什么是 Hook
Hook,中文常译为“钩子”,在 C/C++ 中通常对应“回调函数”的概念。钩子方法一般由抽象类或具体类声明并提供默认实现(通常为空实现),子类可根据需要重写或扩展。通过钩子,框架可以在预定时机调用用户自定义逻辑,而用户无需修改框架源码。
从功能角度看,这种机制称为“钩子”;从函数调用时机角度看,则可称为“回调”。在模板方法模式中,面向对象的多态性使得子类方法在运行时覆盖父类方法,子类通过实现钩子方法对父类行为进行约束,从而实现对父类执行流程的反向控制。
通俗理解,钩子是在已有方法上预留的扩展点:方法执行前或执行后,框架根据条件触发用户提供的钩子函数。用户提供数据与逻辑,框架负责流程执行。这样,框架的通用性得以大幅提升,而具体行为可在运行时才确定。
2. Hook 设计三要素
一个完整的 Hook 机制通常包含以下三个要素:
- Hook 函数或类:实现具体的自定义操作或功能。
- 注册(Register):只有经过注册的 Hook 才会被框架按顺序调用。
- 挂载点(Mount Point):由框架预先定义,用户无法修改,用于在特定时机触发 Hook。
3. 基础示例
Hook 是一种与语言无关的编程机制。下面通过一个 Python 示例帮助理解其工作原理。
import time
class LazyPerson:
def __init__(self, name):
self.name = name
self.watch_tv_func = None
self.have_dinner_func = None
def get_up(self):
print(f"{self.name} 起床了,时间:{time.time()}")
def go_to_sleep(self):
print(f"{self.name} 睡觉了,时间:{time.time()}")
def register_tv_hook(self, watch_tv_func):
self.watch_tv_func = watch_tv_func
def register_dinner_hook(self, have_dinner_func):
self.have_dinner_func = have_dinner_func
def enjoy_a_lazy_day(self):
self.get_up()
time.sleep(1)
# 若注册了看电视的 Hook,则执行;否则输出默认提示
if self.watch_tv_func is not None:
self.watch_tv_func(self.name)
else:
print("没有电视可看")
time.sleep(1)
# 若注册了吃饭的 Hook,则执行;否则输出默认提示
if self.have_dinner_func is not None:
self.have_dinner_func(self.name)
else:
print("晚餐没有东西吃")
time.sleep(1)
self.go_to_sleep()
def watch_daydayup(name):
print(f"{name}:节目《天天向上》很有意思!")
def watch_happyfamily(name):
print(f"{name}:节目《快乐大本营》有点无聊。")
def eat_meat(name):
print(f"{name}:这肉真香!")
def eat_hamburger(name):
print(f"{name}:汉堡也还不错。")
if __name__ == "__main__":
tom = LazyPerson("Tom")
jerry = LazyPerson("Jerry")
tom.register_tv_hook(watch_daydayup)
tom.register_dinner_hook(eat_meat)
jerry.register_tv_hook(watch_happyfamily)
jerry.register_dinner_hook(eat_hamburger)
tom.enjoy_a_lazy_day()
jerry.enjoy_a_lazy_day()运行结果:
Tom 起床了,时间:1598599060.6962798
Tom:节目《天天向上》很有意思!
Tom:这肉真香!
Tom 睡觉了,时间:1598599069.701241
Jerry 起床了,时间:1598599069.7012656
Jerry:节目《快乐大本营》有点无聊。
Jerry:汉堡也还不错。
Jerry 睡觉了,时间:1598599078.70979714. 实际案例:MMCV 中的 Hook 机制
下面以 MMCV(OpenMMLab 的基础库)中的 Runner 类为例,说明 Hook 机制在训练框架中的实际应用。
4.1 Hook 基类
Hook 是所有 Hook 类的基类,定义了训练流程中的各个挂载点,例如 before_run、before_epoch、after_epoch、after_run 等。用户通过继承 Hook 并实现感兴趣的方法,即可在指定时机插入自定义逻辑。
class Hook:
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
def before_train_epoch(self, runner):
self.before_epoch(runner)
def before_val_epoch(self, runner):
self.before_epoch(runner)
def after_train_epoch(self, runner):
self.after_epoch(runner)
def after_val_epoch(self, runner):
self.after_epoch(runner)
def before_train_iter(self, runner):
self.before_iter(runner)
def before_val_iter(self, runner):
self.before_iter(runner)
def after_train_iter(self, runner):
self.after_iter(runner)
def after_val_iter(self, runner):
self.after_iter(runner)
def every_n_epochs(self, runner, n):
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n):
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)Hook 基类将训练过程中可能需要额外操作的时间点进行了统一抽象,包括训练开始前后、每个 epoch 前后、每个 iteration 前后等,并进一步细分为 train 与 val 阶段(默认两阶段行为一致)。
4.2 Hook 的调用
Runner 内部维护一个有序的 Hook 列表 self._hooks,在每个挂载点通过 call_hook 方法依次调用所有 Hook 对应的方法:
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)fn_name 为方法名字符串,利用 Python 内置函数 getattr 动态获取 Hook 对象中的方法引用。若用户未实现某方法,则默认调用基类中的空实现。
4.3 Hook 的实现示例:LrUpdaterHook
LrUpdaterHook 封装了学习率调整逻辑。它重写了 before_run、before_train_epoch、before_train_iter 等方法,在训练流程的特定节点修改优化器学习率。由于 get_lr 方法未在基类中实现,用户需要继承该类并提供具体的学习率衰减策略。
class LrUpdaterHook(Hook):
"""MMCV 中的学习率调度 Hook。
Args:
by_epoch (bool): 是否按 epoch 调整学习率。
warmup (str | None): Warmup 类型,可选 None、'constant'、'linear'、'exp'。
warmup_iters (int): Warmup 持续的 iteration 或 epoch 数。
warmup_ratio (float): Warmup 起始学习率比例,范围为 (0, 1]。
warmup_by_epoch (bool): 为 True 时 warmup_iters 表示 epoch 数,否则表示 iteration 数。
"""
def __init__(self,
by_epoch=True,
warmup=None,
warmup_iters=0,
warmup_ratio=0.1,
warmup_by_epoch=False):
if warmup is not None:
assert warmup in ['constant', 'linear', 'exp']
assert warmup_iters > 0
assert 0 < warmup_ratio <= 1.0
self.by_epoch = by_epoch
self.warmup = warmup
self.warmup_iters = warmup_iters
self.warmup_ratio = warmup_ratio
self.warmup_by_epoch = warmup_by_epoch
if self.warmup_by_epoch:
self.warmup_epochs = self.warmup_iters
self.warmup_iters = None
else:
self.warmup_epochs = None
self.base_lr = []
self.regular_lr = []
def _set_lr(self, runner, lr_groups):
if isinstance(runner.optimizer, dict):
for key, optim in runner.optimizer.items():
for param_group, lr in zip(optim.param_groups, lr_groups[key]):
param_group['lr'] = lr
else:
for param_group, lr in zip(runner.optimizer.param_groups, lr_groups):
param_group['lr'] = lr
def get_lr(self, runner, base_lr):
raise NotImplementedError
def get_regular_lr(self, runner):
if isinstance(runner.optimizer, dict):
lr_groups = {}
for key in runner.optimizer.keys():
lr_group = [
self.get_lr(runner, base_lr)
for base_lr in self.base_lr[key]
]
lr_groups[key] = lr_group
return lr_groups
else:
return [
self.get_lr(runner, base_lr)
for base_lr in self.base_lr
]
def get_warmup_lr(self, cur_iters):
if self.warmup == 'constant':
warmup_lr = [lr * self.warmup_ratio for lr in self.regular_lr]
elif self.warmup == 'linear':
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
warmup_lr = [lr * (1 - k) for lr in self.regular_lr]
elif self.warmup == 'exp':
k = self.warmup_ratio ** (1 - cur_iters / self.warmup_iters)
warmup_lr = [lr * k for lr in self.regular_lr]
return warmup_lr
def before_run(self, runner):
if isinstance(runner.optimizer, dict):
self.base_lr = {}
for key, optim in runner.optimizer.items():
for group in optim.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lr[key] = [group['initial_lr'] for group in optim.param_groups]
else:
for group in runner.optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lr = [group['initial_lr'] for group in runner.optimizer.param_groups]
def before_train_epoch(self, runner):
if self.warmup_iters is None:
epoch_len = len(runner.data_loader)
self.warmup_iters = self.warmup_epochs * epoch_len
if not self.by_epoch:
return
self.regular_lr = self.get_regular_lr(runner)
self._set_lr(runner, self.regular_lr)
def before_train_iter(self, runner):
cur_iter = runner.iter
if not self.by_epoch:
self.regular_lr = self.get_regular_lr(runner)
if self.warmup is None or cur_iter >= self.warmup_iters:
self._set_lr(runner, self.regular_lr)
else:
warmup_lr = self.get_warmup_lr(cur_iter)
self._set_lr(runner, warmup_lr)
else:
if self.warmup is None or cur_iter > self.warmup_iters:
return
elif cur_iter == self.warmup_iters:
self._set_lr(runner, self.regular_lr)
else:
warmup_lr = self.get_warmup_lr(cur_iter)
self._set_lr(runner, warmup_lr)LrUpdaterHook 的核心价值在于:它只负责在正确的时间点修改学习率,而具体的衰减策略由用户自行实现。这种解耦使得训练流程更易于定制。
4.4 Hook 的注册
Hook 的注册过程较为简单,只需按优先级插入有序列表即可。相同优先级的 Hook 按注册顺序执行。
def register_hook(self, hook, priority='NORMAL'):
"""将 Hook 注册到 Hook 列表中。
Hook 会按优先级插入队列;数值越小优先级越高。
相同优先级的 Hook 按注册顺序触发。
Args:
hook (Hook): 待注册的 Hook 对象。
priority (int | str | Priority): Hook 优先级。
"""
assert isinstance(hook, Hook)
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)4.5 Runner 的训练流程
Runner.run 方法按照 workflow 循环执行训练与验证阶段,并在关键节点调用 Hook:
def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""启动训练。
Args:
data_loaders (list[DataLoader]): 训练与验证的数据加载器。
workflow (list[tuple]): 每个 epoch 的执行阶段与次数,
例如 [('train', 2), ('val', 1)] 表示训练 2 个 epoch、验证 1 个 epoch。
max_epochs (int): 总训练 epoch 数。
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an epoch')
epoch_runner = getattr(self, mode)
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1)
self.call_hook('after_run')train 方法定义了一个 epoch 的具体执行流程:
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2)
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
if self.batch_processor is None:
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError(
'"batch_processor()" or "model.train_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1run 与 train 方法在完整训练循环、每个 epoch、每次 iteration 的前后分别触发对应的 Hook,从而让用户能够在不侵入主流程的情况下实现日志记录、学习率调整、模型保存等自定义功能。
5. 设计建议
- 每个 Hook 应只负责一个职责;若同一 Hook 需要在两个不同时机执行不同优先级的操作,建议拆分为两个 Hook。
- 自定义 Hook 所需的数据来自
Runner对象,而Runner中的数据由用户创建时传入,因此 Hook 只需要关注“执行者(runner)”与“执行时机(挂载点)”两个要素。