Skip to content

Python Hook 设计模式

1. 什么是 Hook

Hook,中文常译为“钩子”,在 C/C++ 中通常对应“回调函数”的概念。钩子方法一般由抽象类或具体类声明并提供默认实现(通常为空实现),子类可根据需要重写或扩展。通过钩子,框架可以在预定时机调用用户自定义逻辑,而用户无需修改框架源码。

从功能角度看,这种机制称为“钩子”;从函数调用时机角度看,则可称为“回调”。在模板方法模式中,面向对象的多态性使得子类方法在运行时覆盖父类方法,子类通过实现钩子方法对父类行为进行约束,从而实现对父类执行流程的反向控制。

通俗理解,钩子是在已有方法上预留的扩展点:方法执行前或执行后,框架根据条件触发用户提供的钩子函数。用户提供数据与逻辑,框架负责流程执行。这样,框架的通用性得以大幅提升,而具体行为可在运行时才确定。

2. Hook 设计三要素

一个完整的 Hook 机制通常包含以下三个要素:

  1. Hook 函数或类:实现具体的自定义操作或功能。
  2. 注册(Register):只有经过注册的 Hook 才会被框架按顺序调用。
  3. 挂载点(Mount Point):由框架预先定义,用户无法修改,用于在特定时机触发 Hook。

3. 基础示例

Hook 是一种与语言无关的编程机制。下面通过一个 Python 示例帮助理解其工作原理。

python
import time


class LazyPerson:
    def __init__(self, name):
        self.name = name
        self.watch_tv_func = None
        self.have_dinner_func = None

    def get_up(self):
        print(f"{self.name} 起床了,时间:{time.time()}")

    def go_to_sleep(self):
        print(f"{self.name} 睡觉了,时间:{time.time()}")

    def register_tv_hook(self, watch_tv_func):
        self.watch_tv_func = watch_tv_func

    def register_dinner_hook(self, have_dinner_func):
        self.have_dinner_func = have_dinner_func

    def enjoy_a_lazy_day(self):
        self.get_up()
        time.sleep(1)

        # 若注册了看电视的 Hook,则执行;否则输出默认提示
        if self.watch_tv_func is not None:
            self.watch_tv_func(self.name)
        else:
            print("没有电视可看")
        time.sleep(1)

        # 若注册了吃饭的 Hook,则执行;否则输出默认提示
        if self.have_dinner_func is not None:
            self.have_dinner_func(self.name)
        else:
            print("晚餐没有东西吃")
        time.sleep(1)

        self.go_to_sleep()


def watch_daydayup(name):
    print(f"{name}:节目《天天向上》很有意思!")


def watch_happyfamily(name):
    print(f"{name}:节目《快乐大本营》有点无聊。")


def eat_meat(name):
    print(f"{name}:这肉真香!")


def eat_hamburger(name):
    print(f"{name}:汉堡也还不错。")


if __name__ == "__main__":
    tom = LazyPerson("Tom")
    jerry = LazyPerson("Jerry")

    tom.register_tv_hook(watch_daydayup)
    tom.register_dinner_hook(eat_meat)

    jerry.register_tv_hook(watch_happyfamily)
    jerry.register_dinner_hook(eat_hamburger)

    tom.enjoy_a_lazy_day()
    jerry.enjoy_a_lazy_day()

运行结果:

text
Tom 起床了,时间:1598599060.6962798
Tom:节目《天天向上》很有意思!
Tom:这肉真香!
Tom 睡觉了,时间:1598599069.701241
Jerry 起床了,时间:1598599069.7012656
Jerry:节目《快乐大本营》有点无聊。
Jerry:汉堡也还不错。
Jerry 睡觉了,时间:1598599078.7097971

4. 实际案例:MMCV 中的 Hook 机制

下面以 MMCV(OpenMMLab 的基础库)中的 Runner 类为例,说明 Hook 机制在训练框架中的实际应用。

4.1 Hook 基类

Hook 是所有 Hook 类的基类,定义了训练流程中的各个挂载点,例如 before_runbefore_epochafter_epochafter_run 等。用户通过继承 Hook 并实现感兴趣的方法,即可在指定时机插入自定义逻辑。

python
class Hook:
    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

    def before_train_epoch(self, runner):
        self.before_epoch(runner)

    def before_val_epoch(self, runner):
        self.before_epoch(runner)

    def after_train_epoch(self, runner):
        self.after_epoch(runner)

    def after_val_epoch(self, runner):
        self.after_epoch(runner)

    def before_train_iter(self, runner):
        self.before_iter(runner)

    def before_val_iter(self, runner):
        self.before_iter(runner)

    def after_train_iter(self, runner):
        self.after_iter(runner)

    def after_val_iter(self, runner):
        self.after_iter(runner)

    def every_n_epochs(self, runner, n):
        return (runner.epoch + 1) % n == 0 if n > 0 else False

    def every_n_inner_iters(self, runner, n):
        return (runner.inner_iter + 1) % n == 0 if n > 0 else False

    def every_n_iters(self, runner, n):
        return (runner.iter + 1) % n == 0 if n > 0 else False

    def end_of_epoch(self, runner):
        return runner.inner_iter + 1 == len(runner.data_loader)

Hook 基类将训练过程中可能需要额外操作的时间点进行了统一抽象,包括训练开始前后、每个 epoch 前后、每个 iteration 前后等,并进一步细分为 train 与 val 阶段(默认两阶段行为一致)。

4.2 Hook 的调用

Runner 内部维护一个有序的 Hook 列表 self._hooks,在每个挂载点通过 call_hook 方法依次调用所有 Hook 对应的方法:

python
def call_hook(self, fn_name):
    for hook in self._hooks:
        getattr(hook, fn_name)(self)

fn_name 为方法名字符串,利用 Python 内置函数 getattr 动态获取 Hook 对象中的方法引用。若用户未实现某方法,则默认调用基类中的空实现。

4.3 Hook 的实现示例:LrUpdaterHook

LrUpdaterHook 封装了学习率调整逻辑。它重写了 before_runbefore_train_epochbefore_train_iter 等方法,在训练流程的特定节点修改优化器学习率。由于 get_lr 方法未在基类中实现,用户需要继承该类并提供具体的学习率衰减策略。

python
class LrUpdaterHook(Hook):
    """MMCV 中的学习率调度 Hook。

    Args:
        by_epoch (bool): 是否按 epoch 调整学习率。
        warmup (str | None): Warmup 类型,可选 None、'constant'、'linear'、'exp'。
        warmup_iters (int): Warmup 持续的 iteration 或 epoch 数。
        warmup_ratio (float): Warmup 起始学习率比例,范围为 (0, 1]。
        warmup_by_epoch (bool): 为 True 时 warmup_iters 表示 epoch 数,否则表示 iteration 数。
    """

    def __init__(self,
                 by_epoch=True,
                 warmup=None,
                 warmup_iters=0,
                 warmup_ratio=0.1,
                 warmup_by_epoch=False):
        if warmup is not None:
            assert warmup in ['constant', 'linear', 'exp']
            assert warmup_iters > 0
            assert 0 < warmup_ratio <= 1.0

        self.by_epoch = by_epoch
        self.warmup = warmup
        self.warmup_iters = warmup_iters
        self.warmup_ratio = warmup_ratio
        self.warmup_by_epoch = warmup_by_epoch

        if self.warmup_by_epoch:
            self.warmup_epochs = self.warmup_iters
            self.warmup_iters = None
        else:
            self.warmup_epochs = None

        self.base_lr = []
        self.regular_lr = []

    def _set_lr(self, runner, lr_groups):
        if isinstance(runner.optimizer, dict):
            for key, optim in runner.optimizer.items():
                for param_group, lr in zip(optim.param_groups, lr_groups[key]):
                    param_group['lr'] = lr
        else:
            for param_group, lr in zip(runner.optimizer.param_groups, lr_groups):
                param_group['lr'] = lr

    def get_lr(self, runner, base_lr):
        raise NotImplementedError

    def get_regular_lr(self, runner):
        if isinstance(runner.optimizer, dict):
            lr_groups = {}
            for key in runner.optimizer.keys():
                lr_group = [
                    self.get_lr(runner, base_lr)
                    for base_lr in self.base_lr[key]
                ]
                lr_groups[key] = lr_group
            return lr_groups
        else:
            return [
                self.get_lr(runner, base_lr)
                for base_lr in self.base_lr
            ]

    def get_warmup_lr(self, cur_iters):
        if self.warmup == 'constant':
            warmup_lr = [lr * self.warmup_ratio for lr in self.regular_lr]
        elif self.warmup == 'linear':
            k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
            warmup_lr = [lr * (1 - k) for lr in self.regular_lr]
        elif self.warmup == 'exp':
            k = self.warmup_ratio ** (1 - cur_iters / self.warmup_iters)
            warmup_lr = [lr * k for lr in self.regular_lr]
        return warmup_lr

    def before_run(self, runner):
        if isinstance(runner.optimizer, dict):
            self.base_lr = {}
            for key, optim in runner.optimizer.items():
                for group in optim.param_groups:
                    group.setdefault('initial_lr', group['lr'])
                self.base_lr[key] = [group['initial_lr'] for group in optim.param_groups]
        else:
            for group in runner.optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
            self.base_lr = [group['initial_lr'] for group in runner.optimizer.param_groups]

    def before_train_epoch(self, runner):
        if self.warmup_iters is None:
            epoch_len = len(runner.data_loader)
            self.warmup_iters = self.warmup_epochs * epoch_len

        if not self.by_epoch:
            return

        self.regular_lr = self.get_regular_lr(runner)
        self._set_lr(runner, self.regular_lr)

    def before_train_iter(self, runner):
        cur_iter = runner.iter
        if not self.by_epoch:
            self.regular_lr = self.get_regular_lr(runner)
            if self.warmup is None or cur_iter >= self.warmup_iters:
                self._set_lr(runner, self.regular_lr)
            else:
                warmup_lr = self.get_warmup_lr(cur_iter)
                self._set_lr(runner, warmup_lr)
        else:
            if self.warmup is None or cur_iter > self.warmup_iters:
                return
            elif cur_iter == self.warmup_iters:
                self._set_lr(runner, self.regular_lr)
            else:
                warmup_lr = self.get_warmup_lr(cur_iter)
                self._set_lr(runner, warmup_lr)

LrUpdaterHook 的核心价值在于:它只负责在正确的时间点修改学习率,而具体的衰减策略由用户自行实现。这种解耦使得训练流程更易于定制。

4.4 Hook 的注册

Hook 的注册过程较为简单,只需按优先级插入有序列表即可。相同优先级的 Hook 按注册顺序执行。

python
def register_hook(self, hook, priority='NORMAL'):
    """将 Hook 注册到 Hook 列表中。

    Hook 会按优先级插入队列;数值越小优先级越高。
    相同优先级的 Hook 按注册顺序触发。

    Args:
        hook (Hook): 待注册的 Hook 对象。
        priority (int | str | Priority): Hook 优先级。
    """
    assert isinstance(hook, Hook)
    if hasattr(hook, 'priority'):
        raise ValueError('"priority" is a reserved attribute for hooks')

    priority = get_priority(priority)
    hook.priority = priority

    inserted = False
    for i in range(len(self._hooks) - 1, -1, -1):
        if priority >= self._hooks[i].priority:
            self._hooks.insert(i + 1, hook)
            inserted = True
            break

    if not inserted:
        self._hooks.insert(0, hook)

4.5 Runner 的训练流程

Runner.run 方法按照 workflow 循环执行训练与验证阶段,并在关键节点调用 Hook:

python
def run(self, data_loaders, workflow, max_epochs, **kwargs):
    """启动训练。

    Args:
        data_loaders (list[DataLoader]): 训练与验证的数据加载器。
        workflow (list[tuple]): 每个 epoch 的执行阶段与次数,
            例如 [('train', 2), ('val', 1)] 表示训练 2 个 epoch、验证 1 个 epoch。
        max_epochs (int): 总训练 epoch 数。
    """
    assert isinstance(data_loaders, list)
    assert mmcv.is_list_of(workflow, tuple)
    assert len(data_loaders) == len(workflow)

    self._max_epochs = max_epochs
    for i, flow in enumerate(workflow):
        mode, epochs = flow
        if mode == 'train':
            self._max_iters = self._max_epochs * len(data_loaders[i])
            break

    work_dir = self.work_dir if self.work_dir is not None else 'NONE'
    self.logger.info('Start running, host: %s, work_dir: %s',
                     get_host_info(), work_dir)
    self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)

    self.call_hook('before_run')

    while self.epoch < max_epochs:
        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if not hasattr(self, mode):
                raise ValueError(
                    f'runner has no method named "{mode}" to run an epoch')
            epoch_runner = getattr(self, mode)

            for _ in range(epochs):
                if mode == 'train' and self.epoch >= max_epochs:
                    break
                epoch_runner(data_loaders[i], **kwargs)

    time.sleep(1)
    self.call_hook('after_run')

train 方法定义了一个 epoch 的具体执行流程:

python
def train(self, data_loader, **kwargs):
    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    self._max_iters = self._max_epochs * len(self.data_loader)

    self.call_hook('before_train_epoch')
    time.sleep(2)

    for i, data_batch in enumerate(self.data_loader):
        self._inner_iter = i
        self.call_hook('before_train_iter')

        if self.batch_processor is None:
            outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
        else:
            outputs = self.batch_processor(
                self.model, data_batch, train_mode=True, **kwargs)

        if not isinstance(outputs, dict):
            raise TypeError(
                '"batch_processor()" or "model.train_step()" must return a dict')

        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])

        self.outputs = outputs
        self.call_hook('after_train_iter')
        self._iter += 1

    self.call_hook('after_train_epoch')
    self._epoch += 1

runtrain 方法在完整训练循环、每个 epoch、每次 iteration 的前后分别触发对应的 Hook,从而让用户能够在不侵入主流程的情况下实现日志记录、学习率调整、模型保存等自定义功能。

5. 设计建议

  • 每个 Hook 应只负责一个职责;若同一 Hook 需要在两个不同时机执行不同优先级的操作,建议拆分为两个 Hook。
  • 自定义 Hook 所需的数据来自 Runner 对象,而 Runner 中的数据由用户创建时传入,因此 Hook 只需要关注“执行者(runner)”与“执行时机(挂载点)”两个要素。

Maintained by Robin