Пару дней назад коллега попросил сделать логгирующий сам себя итератор поверх enumerate. Я попробовал наследоваться напрямую и потерпел неудачу. Я абсолютно забыл как работает магический метод __new__. Поскольку я был занят, я пообещал себе разобраться с этой проблемой позже. А ларчик открывался очень просто. 18 строк кода и у меня появилась нужная функциональность.

Изначальная задача

Сначала стоит объяснить зачем нам вообще понадобился подобный итератор, почему нам не хватило обычного enumerate. Все дело в том, что у нас в проекте очень много задач, построенных по такому шаблону:

  1. Получить пачку объектов из базы
  2. Написать в лог сколько объектов получили
  3. С каждым полученным объектом сделать что-либо, отчитываясь о ходе работы после каждого Х-го объекта
  4. Написать в лог о завершении задачи

В питоне это выглядит как-то так:

iterable = get_bunch()
total = len(iterable)

print("total: {}".format(total))

for i, item in enumerate(iterable, start=1):
    try:
        func(item)
    except Exception as e:
        print("catch exception: {}".format(e))

    if not i % 100:
        print("done {} of {}".format(i, total))

print("Done!")

Вся разница между этими задачами в теле функции func, да в сообщениях в лог. Вся эта структура копипастилась раз за разом. Так что мы решили избавиться от этого сделав свой итератор, который бы прятал все лишнее.

Реализация

Ок. Давайте сделаем класс, который будет наследником enumerate. Как я говорил выше, нам придется переопределить метод __new__, так как enumerate делает это. Согласно документации, если __new__() возвращает истанс класса, тогда метод __init__() нового инстанса будет вызываться с теми же аргументами. Так что у меня получилась такая реализация:

class LogEnumerate(enumerate):
    def __new__(cls, iterable, start=1, *args, **kwargs):
        return super(LogEnumerate, cls).__new__(cls, iterable, start)
    def __init__(self, iterable, start=1, step=10,
                 start_message='', progress_message='', stop_message=''):
        self.progress_message = progress_message
        self.stop_message = stop_message
        self.step = step
        self.total = len(iterable)
        print(start_message.format(start_message))
    def __next__(self):
        try:
            i, item = super().__next__()
            if not i % self.step:
                print(self.progress_message.format(i, self.total))
            return item
        except StopIteration:
            print(self.stop_message)
            raise


c