上手Cython,这一篇就足够了

Page content

上手Cython,这一篇就足够了

写在前面

An Introduction to Just Enough Cython to be Useful | Peter Baumgartner

本文部分翻译自上文。

本文以Day 1 - Advent of Code 2021为切入问题。问题内容是计算一个整形列表的深度,即前一个元素小于后一个的次数。纯python代码如下:

# solution.py 
from typing import List

def count_increases(depths: List[int]) -> int:
    current_depth = depths[0]
    increase_counter: int = 0
    for depth in depths[1:]:
        if depth > current_depth:
            increase_counter += 1
        current_depth = depth
    return increase_counter

对Cython的一个简单解释:

  1. 如果你愿意添加编译步骤,你的python代码速度会提升2-3倍
  2. 如果做到了1,并且在你的代码里面设置函数和变量的类型,代码速度会提示10x倍
  3. 如果你同时做到1和2,并且花时间想你的代码和一些计算机科学的问题,你可以提升50x或者更多的加速效果

开始

安装cython

  1. pip install cython
  2. 复制一份你的代码,并重命名为.pyx后缀
  3. 创建编译配置文件 setup.py
  4. 编译

对于步骤3:

# setup.py

from distutils.core import setup
from Cython.Build import cythonize

setup(
	ext_modules=cythonize(
    	"solution.pyx", compiler_directives={"language_level": "3"}
    )
)

创建好上述文件之后,使用如下命令进行编译:

python setup.py build_ext --inplace

运行之后会在目录下生成 .so文件,此时我们可以直接在别的python脚本中 import solution来导入这个文件,solution.count_increase(input_depths)来调用函数。

以上过程加速了python代码2.22x倍。

windows环境中编译后生成了test.cp39-win_amd64.pyd文件。测试发现,只要保证这个文存在,就可以导入编写好的模块。test.c test.py可以不携带。


使用C类型

Cython中所有可用的类型:Language Basics — Cython 3.0.0a10 documentation

# solution.pyx

cpdef int count_increases_cy(list depths):
    cdef int increase_counter, current_depth, depth
    current_depth = depths[0]
    increase_counter = 0
    for depth in depths[1:]:
        if depth > current_depth:
            increase_counter += 1
        current_depth = depth
    return increase_counter

cdef用于使用python语法定义c函数,给函数创建一个python包装器,可以被python调用。cdef定义的函数,只能被当前c文件内的函数调用,不能被外部python调用。亦可用于定义局部变量。

int,申明函数的返回类型,注意这块定义类型的顺序是和c语言一样的,类型在前,变量名/函数名在后。而python中通常是反过来的。

函数第一行定义了要用到的变量。虽然这块也声明了depth,但是输入的类型depths是列表,所以具体到列表里面是什么类型是不知道的,只知道是python对象的集合。

文章称上面代码加速了11倍。

我这里用如下代码测试:

 import test
import timeit
from typing import List

def count_increases(depths: List[int]) -> int:
    current_depth = depths[0]
    increase_counter: int = 0
    for depth in depths[1:]:
        if depth > current_depth:
            increase_counter += 1
        current_depth = depth
    return increase_counter

if __name__=="__main__":
    a = timeit.timeit("test.count_increases_cy([1,2,2,3,4,5,3,2])", setup="from __main__ import test") 
    b = timeit.timeit("count_increases([1,2,2,3,4,5,3,2])", setup="from __main__ import count_increases") 
    print(a, b)
    print(b/a)
    
# >python imss.py
# 0.17307470000000003 0.5409483
# 3.125519212224548

Cython性能调优

编译cython的时候,调用cythonize时,可以选择参数annotate = True,编译过程中会生成一个可视化的HTML报告,黄色高亮的部分是代码和python交互的部分,即比较慢的部分。

示例代码中,由于传入的列表是python对象,所以影响代码速度的原因出在了传入的列表这块。

image-20220316095445949


最终的优化

numpy中的Arrays特点可以解决速度慢的问题,arrays的元素定长且同类型。Cython中直接使用也很方便。在使用numpy的时候,也可以使用memoryview的东西,它是numpy的内存视图,这样在访问元素的时候,就不需要创建矩阵的副本。

修改setup.py文件导入numpy

# setup.py

from distutils.core import setup
from Cython.Build import cythonize
import numpy

setup(
    ext_modules=cythonize(
        "solution_a_cy.pyx", compiler_directives={"language_level": "3"}, annotate=True
    ),
    include_dirs=[numpy.get_include()],
)

下面对代码进行修改:

# solution.pyx

cpdef int count_increases_cy_array(int[:] depths):
    cdef int increase_counter, current_depth, depth, length, i
    length = depths.shape[0]
    current_depth = depths[0]
    increase_counter = 0
    for i in range(1, length):
        if depths[i] > current_depth:
            increase_counter += 1
        current_depth = depths[i]
    return increase_counter

改动说明:

第一个区别是传入的参数类型是int[:],它是数组的memoryview的语法。其实本质传入的是一个numpy数组。创建了一些额外的变量,length和i用于数组循环和索引。

import test
import timeit
import numpy
from typing import List

def count_increases(depths: List[int]) -> int:
    current_depth = depths[0]
    increase_counter: int = 0
    for depth in depths[1:]:
        if depth > current_depth:
            increase_counter += 1
        current_depth = depth
    return increase_counter

if __name__=="__main__":
    foo = """\
from __main__ import test
import numpy
test.count_increases_cy_array(numpy.array([1,2,2,3,4,5,3,2]))
"""
    # 注意上面的缩进
    a = timeit.timeit(stmt=foo) 
    b = timeit.timeit("count_increases([1,2,2,3,4,5,3,2])", setup="from __main__ import count_increases") 
    print(a, b)
    print(a/b)
    
# 3.2162735 0.5433743
# 5.91907548811197

总结

并不是所有的代码都需要进行这样的优化,一来我们需要在可维护性和性能之间做平衡,二来实际项目只需要对调用最频繁的部分进行这种优化。

一般只在以下这些情况使用cython:

  1. 已经充分评估代码,确定了较慢的部分
  2. 这部分代码主要使用的是python内置类型和numpy数组
  3. 代码简单,不需要大量思考如何cython化
  4. 代码未来不会经常被改动