Jamal的博客

Python-函数装饰器和闭包

装饰器是Python的一个重要特性,Python本身提供了一些装饰器property,classmethod等,第三方组件如Django等使用装饰器管理缓存和权限。

基础知识

装饰器是可调用对象(实现了call方法),其参数是另一个函数(被装饰的函数)。装饰器可能会处理被装饰的函数然后将其返回,或者是直接将其替换成另一个函数或者是可调用对象。
举个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
>>> def deco(func):
... def inner():
... print('running inner')
... return inner
...
>>> @deco
... def target():
... print('running target')
...
>>> target()
running inner
>>> target
<function deco.<locals>.inner at 0x1017ae268>
>>> deco(target)
<function deco.<locals>.inner at 0x1017ae1e0>

这两个函数实现的功能和最终的结果是一致的,但是这两个代码执行完毕后返回的函数不是target,而是target对inner的引用,同时我们发现使用@deco和deco(target)的返回值是一致的。
总结,装饰器有两个特性:

  • 能把被装饰的函数替换成其他函数
  • 装饰器在加载模块时立即执行

装饰器何时执行

前面我们说过,装饰器的一大特性是在加载模块是立即执行,也就是说在Python加载模块的时候就开始执行,如下代码所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
registry = []
def register(func):
print('running register({})'.format(func))
registry.append(func)
return func
@register
def f1():
print('running f1()'.format())
@register
def f2():
print('running f2()'.format())
def f3():
print('running f3()'.format())
def main():
print('running main()'.format())
print('register -> {}'.format(registry))
f1()
f2()
f3()
if __name__ == '__main__':
main()

执行结果:

1
2
3
4
5
6
7
running register(<function f1 at 0x101716d08>)
running register(<function f2 at 0x101716d90>)
running main()
register -> [<function f1 at 0x101716d08>, <function f2 at 0x101716d90>]
running f1()
running f2()
running f3()

执行结果中,在模块加载的时候register执行两次,然后开始执行main,模块加载完成后registry中有两个被装饰函数的引用,这两个函数只有main明确调用他们的时候才执行。如果在其他模块中导入这个模块,如下所示:

1
2
running register(<function f1 at 0x101716d08>)
running register(<function f2 at 0x101716d90>)

发现装饰器在导入模块的时候立即执行,而被装饰的函数一直要到调用的时候才执行。
这个示例中装饰器返回的函数和通过参数传入的相同,前面我们说过装饰器能把被装饰的函数替换成其他函数,而实际上,大多数装饰器会在内部定义一个函数,然后将其返回

使用装饰器改进策略模式

Python设计模式-策略模式(Strategy pattern)这篇文档中,我们提到最大的问题在于best_promo判断哪个折扣最大的时候promos列表中有全量的策略函数名称,这就可能导致我们在新增策略函数的之后忘记策略函数添加到promos列表中,并且系统可以正常运行,为系统引入了不易察觉的缺陷。下面我们使用装饰器解决这个问题,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
promos = []
def promotion(promo_func):
promos.append(promo_func)
return promo_func
@promotion
def fidelity_promo(order):
"""为积分1000以上的提供5%折扣"""
return order.total() * 0.05 if order.customer.fidelity >= 1000 else 0
@promotion
def bulk_item_promo(order):
"""单个商品为20个或以上时提供10%折扣"""
discount = 0
for item in order.cart:
if item.quantity >= 20:
discount += item.total() * 0.1
return discount
@promotion
def large_order_promo(order):
"""订单中的不同商品达到10个或以上,享受7%折扣"""
discount_item = {item.product for item in order.cart}
if len(discount_item) >= 10:
return order.total() * 0.07
return 0
def best_promo(order):
"""选择最佳折扣"""
return max(promo(order) for promo in promos)

使用装饰器装饰后,有以下优点:

  • promo函数无需使用特殊的名称
  • @promotion装饰器突出了被装饰的函数的作用,还便于临时禁用某个促销策略:只需要把装饰器注释掉
  • 促销折扣策略可以在其他模块中定义,只要有@promotion装饰即可

变量作用域和闭包

前面我们说到,在多数装饰器中我们会定义一个内部函数,然后将其返回,替换被装饰的函数。使用内部函数的代码几乎都要依靠闭包才能正常运行。想要理解闭包,就必须先理解变量作用域。
我们看一下下面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
>>> def f1(a):
... print(a)
... print(b)
...
>>> f1(3)
3
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 3, in f1
NameError: name 'b' is not defined
>>> b = 6
>>> f1(3)
3
6
>>> def f2(a):
... print(a)
... print(b)
... b = 9
...
>>> f2(3)
3
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 3, in f2
UnboundLocalError: local variable 'b' referenced before assignment

f1的报错,是由于在f1中没有给全局变量b赋值导致报错,在赋值了全局变量b之后就没问题了。但是在f2中,即使给全局变量b赋值了,依然报错了。
这是因为python在编译f2的时候,会先判断b在函数中是局部变量,因为在函数中给b赋值了,因此在后面调用的时候,f2会尝试获取局部变量b的值,但是发现b并没有绑定值。
如果想要解释器把b当成是全局变量,就需要使用global声明。
了解了作用域之后,下面就可以开始讨论闭包了。
闭包指延伸了作用域的函数,其中包含函数定义体中引用,但是不在定义体重定义的非全局变量,它能够访问定义体之外定义的非全局变量。
我们看下面的一个示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
>>> def make_avearger():
... series = []
... def averager(new_value):
... series.append(new_value)
... total = sum(series)
... return total/len(series)
... return averager
...
>>> avg = make_averager()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NameError: name 'make_averager' is not defined
>>> avg = make_avearger()
>>> avg(10)
10.0
>>> avg(11)
10.5
>>> avg(12)
11.0
>>> avg.__code__
<code object averager at 0x1017946f0, file "<stdin>", line 3>
>>> avg.__code__.co_varnames
('new_value', 'total')
>>> avg.__code__.co_freevars
('series',)
>>> avg.__closure__
(<cell at 0x1010aa0a8: list object at 0x1017b1fc8>,)
>>> avg.__closure__[0]
<cell at 0x1010aa0a8: list object at 0x1017b1fc8>
>>> avg.__closure__[0].cell_contents
[10, 11, 12]

make_avearger是一个高阶函数,因为在调用make_avearger时,他返回了一个averager函数对象(回顾高阶函数的定义)。每次调用averager的时候,他会把参数添加到series中,然后计算当前的平均值。
这里有一个点很关键,series是make_avearger函数的局部变量,但是在调用avg(10)的时候,make_avearger已经返回了,而他的本地作用域已经没有了,那么,avg是在哪里寻找series的呢?
其实在averager函数中,series有一个术语名称,叫自由变量(free variable),指未在本地作用域中绑定的变量,如图所示:

averager的闭包延伸到了本身的作用域之外,包含自有变量series的绑定,上面的代码中我们查看averager的code属性中保存局部变量和自由变量的名称,发现newvalue、total是函数的局部变量,series是函数的自由变量。
series的绑定在返回的avg函数的closure属性中,avg.closure中的各个元素对应于averager.code.co_freevars中的一个名称。这些元素是cell对象,有cell_contents属性,保存着真正的值,如上述的代码所示。
综上,闭包是一种函数,他会保留定义函数时存在的自由变量的绑定,这样调用函数时,虽然作用域不可用了,但仍然能使用那些绑定(注意,只有嵌套在其他函数中的函数才可能需要处理不在全局作用域中的外部变量)。
上一个示例中我们实现的make_averager效率不高,我们改进一下,只存储目前的总值和元素的个数,然后计算平均值
如代码所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
>>> def make_averager():
... count = 0
... total = 0
... def averager(new_value):
... count += 1
... total += new_value
... return total / count
... return averager
...
>>> avg = make_averager()
>>> avg(10)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 5, in averager
UnboundLocalError: local variable 'count' referenced before assignment

我们发现这个代码在执行的时候报错了,这是因为当count是数字或者是任何不可变类型时,count += 1语句的实际作用其实与count= count + 1一样。因此,我们在averager的定义体中为count赋值了,这会把count变成局部变量,total也有一样的问题。
更上面的代码没有这个问题,因为我们只是调用了他的append方法,而没有给series赋值,也就是说,我们在这里利用了list是可变对象的特性。
但是对数字,字符串,元组等不可变对象来说,就只能读取不能更新了。如果尝试重新绑定,例如count += 1的话,其实会隐式的创建局部变量count。这样,count就不是自由变量了,因此不会保存在闭包中。这时候就需要引入nolocal声明。他的作用是把变量标记为自由变量,即使在函数中为变量赋予新值了,也会变成自由变量,如果赋予了新值,那闭包中保存的绑定会更新。如下代码所示:

1
2
3
4
5
6
7
8
9
>>> def make_average():
... count = 0
... total = 0
... def averager(new_value):
... nonlocal count, total
... count += 1
... total += new_value
... return total / count
... return averager

实现一个简单的装饰器

我们实现一个简单的装饰器,他会在每次调用被装饰的函数时计时,然后把耗时、传入的参数和调用结果打印出来,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import time
def clock(func):
def clocked(*args):
t0 = time.perf_counter()
result = func(*args)
elapsed = time.perf_counter() - t0
name = func.__name__
arg_str = ', '.join(repr(arg) for arg in args)
print('[%0.8fs] %s(%s) -> %r' % (elapsed, name, arg_str, result))
return result
return clocked
@clock
def snooze(seconds):
time.sleep(seconds)
@clock
def factorial(n):
return 1 if n < 2 else n * factorial(n - 1)
if __name__ == '__main__':
print('*' * 40, 'Calling snooze(.123)')
snooze(.123)
print('*' * 40, 'Calling factorial(6)')
print('6! =', factorial(6))

clock的内部函数clocked接受任意个定位参数,其中result = func(*args)中func为自由变量。示例运行结果如下:

1
2
3
4
5
6
7
8
9
10
**************************************** Calling snooze(.123)
[0.12693679s] snooze(0.123) -> None
**************************************** Calling factorial(6)
[0.00000312s] factorial(1) -> 1
[0.00003997s] factorial(2) -> 2
[0.00005664s] factorial(3) -> 6
[0.00007261s] factorial(4) -> 24
[0.00008674s] factorial(5) -> 120
[0.00010699s] factorial(6) -> 720
6! = 720

在这个示例中,

1
2
@clock
def factorial(n):

等价于:

1
factorial = clock(factorial)

factorial会把func作为参数传递给clock,clock会返回clocked函数,python会把clocked负值给factorial。我们在另一个文件中导入factorial后查看其name属性,会得到结果是clocked。factorial现在保存的是clocked的引用。
在这个代码中,clock装饰器不支持关键字参数,而且遮盖了被装饰函数的namedoc属性,我们使用functools.wraps装饰器修改一下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import time
import functools
def clock(func):
@functools.wraps(func)
def clocked(*args, **kwargs):
t0 = time.time()
result = func(*args, **kwargs)
elapsed = time.time() - t0
name = func.__name__
arg_lst = []
if args:
arg_lst.append(', '.join(repr(arg) for arg in args))
if kwargs:
pairs = ['%s=%r' % (k, w) for k, w in sorted(kwargs.items())]
arg_lst.append(', '.join(pairs))
arg_str = ', '.join(arg_lst)
print('[%0.8fs] %s(%s) -> %r' % (elapsed, name, arg_str, result))
return result
return clocked