⚠ 转载请注明出处:作者:ZobinHuang,更新日期:July 11 2022
本作品由 ZobinHuang 采用 知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议 进行许可,在进行使用或分享前请查看权限要求。若发现侵权行为,会采取法律手段维护作者正当合法权益,谢谢配合。
目录
有特定需要的内容直接跳转到相关章节查看即可。
前言
在 DGL 官方给出的为数不多的开发者文档 dgl_ffi 中,里面说到他们在大部分 DGL 的代码中使用 python 予以相关逻辑的实现,在性能关键的地方就是用 C++ 予以实现。因此在代码中将经常遇到需要使用 Python 代码对使用 C++ 代码定义的函数进行调用的情况。这篇文章我们将分析 DGL 实现这层机制的底层原理。
DGL 所使用的
TVM FFI 机制
TVM FFI 机制本质上是使用了 Python 提供的 ctypes
模块来调用 C++ 代码提供的 ctypes
是 Python 内建的可以用于调用 ctypes
模块来加载动态链接库中的函数。在 ctypes
的官方文档中+是这么介绍 ctypes
的:
下面我们首先从 C++ 侧来看 TVM 如何定义具体的函数功能接口,然后再来看 TVM 如何在 Python 侧实现对这些函数的加载和调用。
C++ 底层数据结构支撑
在 C++ 侧,PackedFunc
这个类用于描述和存储作为 FFI 的 C++ 函数,它是 Python 和 C++ 互相调用的桥梁。这个类是在 include/tvm/runtime/packed_func.h
tvmsrc_packedfunc_h 中定义的。在阅读这个类的代码之前,我们首先对该类所使用到的关键数据结构进行分析,然后再回过头来分析 PackedFunc
。

PackedFunc
所使用的数据的封装,它们的功能如下所示:
TVMValue
: 对基本数据类型进行统一;TVMPODValue_
: 封装了对数据类型进行强制类型转换的运算符;TVMArgValue
: 用于封装传给PackedFunc
的 一个 参数;TVMArgs
: 用于封装传给PackedFunc
的 所有 参数;TVMRetValue
: 用于作为存放调用PackedFunc
返回值的容器;
下面我们分别对它们进行分析。
TVMValue
1 | /*! |
首先是最底层的 TVMValue
,其是在 tvm/include/tvm/runtime/c_runtime_api.h
tvmsrc_c_runtime_api_h 中定义的,是一个 union。它的功能主要是为了封装 C++ 和其它语言交互时所支持的几种类型的数据。
TVMPODValue_
1 | /*! |
然后是 TVMPODValue_
,其是在 include/tvm/runtime/packed_func.h
tvmsrc_packedfunc_h 中定义的。这个类的实现核心是「强制类型转换运算符」重载。在 C++中,类型的名字,包括类的名字本身也是一种运算符,即类型强制转换运算符。如上面摘抄的定义所示,该类重载了「强制类型转换运算符」,可以实现对 TVMValue
存储的数据的类型转换。
TVMArgValue
1 | /*! |
然后是 TVMArgValue
,其是在 include/tvm/runtime/packed_func.h
tvmsrc_packedfunc_h 中定义的。这个类继承自前面介绍的 TVMPODValue_
类,用作表示 PackedFunc
的一个参数,它和 TVMPODValue_
的区别是扩充了一些数据类型的支持,比如 string
、PackedFunc
、TypedPackedFunc
等,其中对后两个的支持是在 C++ 代码中能够调用 Python 函数的根本原因。这个类只使用所保存的 underlying data,而不会去做释放。
TVMArgs
1 | /*! \brief Arguments into TVM functions. */ |
然后是 TVMArgs
,它是在 include/tvm/runtime/packed_func.h
tvmsrc_packedfunc_h 中定义的。这个类主要是为了封装传给代表着 C++ FFI 的 PackedFunc
的所有参数,这个类也比较简单原始,主要基于 TVMValue
、参数类型编码、参数个数来实现。值得注意的是 TVMArgs
使用的是 TVMValue
数组来存储传入的所有参数,而对于迭代访问运算符 []
,其返回的又是 TVMArgValue
类型的值。因此上面介绍的 TVMArgValue
实际上并不存储传入的参数本身,而用于对传入的参数进行处理使用。
TVMRetValue
1 | class TVMRetValue : public TVMPODValue_ { |
最后是 TVMRetValue
,它是在 include/tvm/runtime/packed_func.h
tvmsrc_packedfunc_h 中定义的。这个类也是继承自 TVMPODValue_
类,主要作用是作为存放调用 PackedFunc
返回值的容器,它和 TVMArgValue
的区别是,它会管理所保存的 underlying data,会对其做释放。这个类主要由四部分构成:
- 构造和析构函数;
- 对强制类型转换运算符重载的扩展;
- 对赋值运算符的重载;
- 辅助函数,包括释放资源的
Clear
函数
TVMArgsSetter
1 | class TVMArgsSetter { |
另外,TVMArgsSetter
是一个用于给 TVMValue
对象赋值的辅助类,主要通过重载函数调用运算符来实现。
PackedFunc
的具体实现
1 | /*! |
有了前面所述的数据结构作为基础,再来看 PackedFunc
的实现,它是在 include/tvm/runtime/packed_func.h
tvmsrc_packedfunc_h 中定义的。PackedFunc
代表了用户自定义的一个 C++ 函数,最终将被编译和注册到动态链接库中,以供 Python 程序调用。PackedFunc
使用了一个储存函数指针的变量 data_
(p.s. 继承自 ObjectRef
),再通过重载函数调用运算符 ()
来调用这个函数指针所指向的函数。重载 ()
运算符的定义是在同个文件下实现的,定义如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21template <typename... Args>
inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
// 初始化两个数组存放,用于存放转化用户传入的参数后的若干 TVMValue
// 以及它们的 type_codes
TVMValue values[kArraySize];
int type_codes[kArraySize];
// 使用 TVMArgsSetter 将用户传入的参数初始化为 TVMValue,放入数组中
detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
TVMRetValue rv;
// 调用在 data_ 中存储的函数
(static_cast<PackedFuncObj*>(data_.get()))
->CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);
// 返回由 TVMRetValue rv 存储的函数返回值
return rv;
}
上面代码的相关注释以及解释了整个调用流程。我们可以发现在调用的时候,TVM 先把 转化为 PackedFuncObj
类实例,然后再调用其 CallPacked
函数以完成对函数的执行。
基于 PackedFunc
的函数注册机

在完成对 PackedFunc
的理解后,下面我们对代码是如何将各个 PackedFunc
实例注册到动态链接库 (以供 Python 程序调用) 进行分析。
TVM_REGISTER_GLOBAL
宏
注册过程中最常使用的宏是 TVM_REGISTER_GLOBAL
,调用这个宏将返回一个 Registry
对象,基于这个对象,我们可以调用诸如 set_body
等注册接口,将函数进行注册。下面是一个例子:
1
2
3
4
5// src/topi/nn.cc
TVM_REGISTER_GLOBAL("topi.nn.relu")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = relu<float>(args[0]);
});
值得注意的是,set_body
等注册接口可以接收普通函数,也可以接收 Lambda 表达式 cpp_lambda,如上所示的例子接收的就是 Lambda 表达式。
下面我们对 TVM_REGISTER_GLOBAL
宏定义进行分析。这个宏是在 include/tvm/runtime/registry.h
tvmsrc_registry_h 中定义的。如下所示:
1
2
3
4
5
6
7
8
// 辅助 macro
可以发现,TVM_REGISTER_GLOBAL
实际上就是基于它接收的参数 OpName
,在调用宏的位置定义了一个 static
的引用变量,引用到注册机内部 new
出来的一个新的 Registry
对象。上面的代码中使用到了 __COUNTER
编译器扩展宏,GCC 的文档里对该宏的描述如下所示:
将上面的宏进行展开,实际上就是在调用宏的地方,运行了这么一句代码:
1
2
3static TVM_ATTRIBUTE_UNUSED \
::tvm::runtime::Registry& __mk_##TVM[编译器扩展宏 __COUNTER__] =
::tvm::runtime::Registry::Register(OpName);
Registry
类
下面我们对 Registry
这个关键的类进行分析,它是在 include/tvm/runtime/registry.h
tvmsrc_registry_h 中被定义的,这个类的每一个实例代表着被注册的每一个函数。Registry
有一个友元结构体 Manager
,它用于存储全局所有的 Registry
实例。我们首先对它进行分析,下面是它在 src/runtime/registry.cc
tvmsrc_registry_cc 中的具体定义:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20struct Registry::Manager {
// map storing the functions.
// We deliberately used raw pointer.
// This is because PackedFunc can contain callbacks into the host language (Python) and the
// resource can become invalid because of indeterministic order of destruction and forking.
// The resources will only be recycled during program exit.
std::unordered_map<std::string, Registry*> fmap;
// mutex
std::mutex mutex;
Manager() {}
static Manager* Global() {
// We deliberately leak the Manager instance, to avoid leak sanitizers
// complaining about the entries in Manager::fmap being leaked at program
// exit.
static Manager* inst = new Manager();
return inst;
}
};
从上面定义的 Global
函数可以发现,友元结构体 Manager
实际上是一种单例的设计模式 cpp_single_instance,因为被声明为 static
的局部变量只会在第一次被访问的时候被初始化 cpp_new_static_local,上面 Global
函数中的 inst
就是这样一种变量。全局的其他函数可以通过结构体 Manager
的 Global
函数访问到全局唯一的一个 Manager
实例。
另外,我们还可以发现 Manager
使用的是 std::unordered_map
来存储注册的信息,注册的对象是 Registry
指针。
最后我们来到了 Register
这个类,这个类是在 include/tvm/runtime/registry.h
tvmsrc_registry_h 中被定义的。相关定义摘抄如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15class Registry {
public:
Registry& set_body(PackedFunc f);
Registry& set_body_typed(FLambda f);
Registry& set_body_method(R (T::*f)(Args...));
static Registry& Register(const std::string& name);
static const PackedFunc* Get(const std::string& name);
static std::vector ListNames();
protected:
std::string name_;
PackedFunc func_;
friend struct Manager;
};
下面进行分析:
-
首先
Registry
类首先提供了若干个set_body
注册接口,这些接口将会基于用户传入的PackedFunc
初始化Registry
实例,具体定义代码如下所示:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35// include/tvm/runtime/registry.h
class Registry {
public:
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
TVM_DLL Registry& set_body(PackedFunc f); // NOLINT(*)
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
template <typename TCallable,
typename = typename std::enable_if_t<
std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
!std::is_base_of<PackedFunc, TCallable>::value>>
Registry& set_body(TCallable f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
/* ... */
protected:
/*! \brief internal packed function */
PackedFunc func_;
/* ... */
}
// include/tvm/runtime/registry.cc
Registry& Registry::set_body(PackedFunc f) { // NOLINT(*)
func_ = f;
return *this;
}可以发现
set_body
等接口实际上就是把用户传入的函数存储到了Registry
的私有变量func_
中。 -
Registry
类还提供了用于创建Registry
对象的Register
静态接口,这个接口具体是在tvm/src/runtime/registry.cc
tvmsrc_registry_cc 中被定义的,具体代码如下所示:1
2
3
4
5
6
7
8
9
10
11
12Registry& Registry::Register(const std::string& name, bool can_override) { // NOLINT(*)
Manager* m = Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex);
if (m->fmap.count(name)) {
ICHECK(can_override) << "Global PackedFunc " << name << " is already registered";
}
Registry* r = new Registry();
r->name_ = name;
m->fmap[name] = r;
return *r;
}可以发现
Register
函数实际上就是实例化出了一个Registry
对象,并且将它存储在我们上面介绍的全局唯一的Manager
的std::unordered_map
中。这里穿插一句题外话,我们从上面
Registry
和Manager
的代码中可以发现,TVM prefer 使用静态类成员函数来实例化一个类对象,Registry
使用Register
函数;Manager
使用Global
函数,这样子写的好处是可以将类对象的初始化逻辑包括在类的相关定义中,给外界提供一个更加简洁的接口。 -
Registry
还提供了一个根据名称来获取注册函数的接口Get
,这个接口具体是在tvm/src/runtime/registry.cc
tvmsrc_registry_cc 中被定义的,具体代码如下所示:1
2
3
4
5
6
7const PackedFunc* Registry::Get(const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex);
auto it = m->fmap.find(name);
if (it == m->fmap.end()) return nullptr;
return &(it->second->func_);
}这个接口代码比较简单,不再过多赘述。
Python 的加载和调用细节

在理解了 PackedFunc
和 Registry
后,现在让我们来看 Python 程序是如何调用 C++ 函数接口的。
C++ 侧暴露的函数接口
我们上面看到了,所有的在 C++ 端定义的函数实际上都会存储在全局唯一的 Manager
下的 std::unordered_map
中存储的各个 Registry
对象实例中。基于这层结构,在 tvm/include/tvm/runtime/c_runtime_api.h
tvmsrc_c_runtime_api_h 中,TVM 定义了动态链接库侧暴露给 Python 程序的接口。代码摘抄如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
extern "C" {
/*!
* \brief List all the globally registered function name
* \param out_size The number of functions
* \param out_array The array of function names.
* \return 0 when success, nonzero when failure happens
*/
TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array);
/*!
* \brief Get a global function.
*
* \param name The name of the function.
* \param out the result function pointer, NULL if it does not exist.
*
* \note The function handle of global function is managed by TVM runtime,
* So TVMFuncFree is should not be called when it get deleted.
*/
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
/*!
* \brief Call a Packed TVM Function.
*
* \param func node handle of the function.
* \param arg_values The arguments
* \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
*
* \param ret_val The return value.
* \param ret_type_code the type code of return value.
*
* \return 0 when success, nonzero when failure happens
* \note TVM calls always exchanges with type bits=64, lanes=1
*
* \note API calls always exchanges with type bits=64, lanes=1
* If API call returns container handles (e.g. FunctionHandle)
* these handles should be managed by the front-end.
* The front-end need to call free function (e.g. TVMFuncFree)
* to free these handles.
*/
TVM_DLL int TVMFuncCall(TVMFunctionHandle func, TVMValue* arg_values, int* type_codes, int num_args,
TVMValue* ret_val, int* ret_type_code);
} // TVM_EXTERN_C
在上面的代码中,首先,对于动态链接库提供的 API,需要使用符合 C 语言编译和链接约定的 API,因为 Python 的 ctype
只和 C 兼容,而 C++ 编译器会对函数和变量名进行 name mangling,所以需要使用 __cplusplus
宏和 extern "C"
来得到符合 C
语言编译和链接约定的 API。
其次,我们观察到上面的三个接口都是用了 TVM_DLL
加以修饰。TVM_DLL
的定义如下所示:
1
对于 Linux 的动态链接文件 (.so) 来说,GCC 的 visibility
属性可以控制共享库文件导出的符号。当我们为某个函数、变量、模板或者 C++ 类声明加上 __attribute__((visibility("default")))
的编译属性,并且在编译该动态链接库的时候基于 -fvisibility=hidden
参数进行编译时,形成的动态链接文件的动态链接表中就只会包含附带该属性的函数、变量、模板或者 C++ 类 gcc_attribute_visibility。
因此,回到 TVM 的代码中,使用 TVM_DLL
宏进行修饰的函数接口将会被最终呈现在动态链接表中,可以被 Python 程序所使用。
Python 侧加载动态链接库
TVM 的 Python 代码从 python/tvm/__init__.py
tvmsrc_python_init_py 中开始真正执行:
1
from ._ffi.base import TVMError, __version__, _RUNTIME_ONLY
在运行到这行代码时,参考我的另一篇博客 Python 循环 Import (Circular Import) 陷阱原理,此时 Python 解释器将会执行 python/tvm/_ffi/__init__.py
tvmsrc_python_ffi_init_py:
1
2
3
4from . import _pyversion
from .base import register_error
from .registry import register_object, register_func, register_extension
from .registry import _init_api, get_global_func
上面的第一句,又会导致 python/tvm/_ffi/base.py
tvmsrc_python_ffi_base_py 中的下面代码被执行:
1
2
3
4
5
6
7
8
9
10
11
12
13def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
# The dll search path need to be added explicitly in
# windows after python 3.8
if sys.platform.startswith("win32") and sys.version_info >= (3, 8):
for path in libinfo.get_dll_directories():
os.add_dll_directory(path)
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
lib.TVMGetLastError.restype = ctypes.c_char_p
return lib, os.path.basename(lib_path[0])
_LIB, _LIB_NAME = _load_lib()
此时,Python 运行时将可以获得 _LIB
和 _LIB_NAME
两个变量。其中,_LIB
可以理解为动态链接库在 Python 侧的操作句柄,_LIB_NAME
则是 "libtvm.so"
这个字符串,代表着动态链接库的名字。后续在 Python 中,TVM 将通过 _LIB
这个桥梁不断地和 C++ 的部分进行交互。
Python 侧关联 C++ 侧的 PackedFunc
在有了 _LIB
这个 Python 和动态链接库之间的桥梁后,现在让我们来看 Python 侧是如何基于动态链接库暴露出来的接口拿到 TVM 使用 PackedFunc
封装的函数的。
Python 中获取 TVM C++ 使用 PackedFunc
封装的 API 的底层函数是 _get_global_func
,它是在 python/tvm/_ffi/_ctypes/packed_func.py
tvmsrc_python_ffi_packed_func_py 中定义的,具体定义如下:
1
2
3
4
5
6
7
8
9
10
11def _get_global_func(name, allow_missing=False):
handle = PackedFuncHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
if handle.value:
return _make_packed_func(handle, False)
if allow_missing:
return None
raise ValueError("Cannot find global function %s" % name)
从上面的代码可以发现,Python 侧的 _get_global_func
函数实际上调用了动态链接库暴露的 TVMFuncGetGlobal
接口。下面是 C++ 侧关于 TVMFuncGetGlobal
接口的实现,这部分代码位于 src/runtime/registry.cc
tvmsrc_registry_cc 中:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
API_BEGIN();
const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name);
if (fp != nullptr) {
tvm::runtime::TVMRetValue ret;
ret = *fp;
TVMValue val;
int type_code;
ret.MoveToCHost(&val, &type_code);
*out = val.v_handle;
} else {
*out = nullptr;
}
API_END();
}
可以发现 TVMFuncGetGlobal
接口实际上就是调用了我们上面介绍过的 Registry
的 Get
接口,来根据给出的 API 名称获得对应的 PackedFunc*
函数指针。如果无法基于给定的名称找到对应的函数,则返回空指针。
回到 Python 侧,在调用 _LIB.TVMFuncGetGlobal
获取函数指针后,函数指针被保存在了 handle
变量中。我们可以看到 Python 侧随后调用了 _make_packed_func
API 来在 Python 侧创建一个 Python 侧的函数封装对象。_make_packed_func
是在 python/tvm/_ffi/_ctypes/packed_func.py
tvmsrc_python_ffi_packed_func_py 中定义的,具体如下所示:
1
2
3
4
5
6def _make_packed_func(handle, is_global):
"""Make a packed function class"""
obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
obj.is_global = is_global
obj.handle = handle
return obj
可以看见 _make_packed_func
接口基于传入的 handle
参数,初始化了一个 PackedFunc
类实例,这是一个空实现的类,在 python/tvm/runtime/packed_func.py
tvmsrc_python_runtime_packed_func_py 中予以实现,其继承自 PackedFuncBase
类,后者是在 python/tvm/_ffi/_ctypes/packed_func.py
tvmsrc_python_ffi_ctypes_packed_func_py 中被定义的。
Python 侧调用 C++ 侧的 PackedFunc
当 Python 侧想要调用动态链接库里的函数时,其是通过调用 PackedFunc
类实现的,调用 PackedFunc
类间接调用了 PackedFuncBase
父类,后者中实现了一个 __cal__
方法以供调用,如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26class PackedFuncBase(object):
def __call__(self, *args):
"""Call the function with positional arguments
args : list
The positional arguments to the function call.
"""
temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
if (
_LIB.TVMFuncCall(
self.handle,
values,
tcodes,
ctypes.c_int(num_args),
ctypes.byref(ret_val),
ctypes.byref(ret_tcode),
)
!= 0
):
raise get_last_ffi_error()
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
从上面可以看出,__call__
函数调用了 C++ 侧的 TVMFuncCall
这个 API,把前面保存有 C++ PackedFunc
对象地址的 handle
以及相关的函数参数传了进去。在 C++ 侧,TVMFuncCall
是在 src/runtime/c_runtime_api.cc
tvmsrc_src_runtime_c_runtime_api_cc 中定义的,其完成了对函数的调用,以及函数返回值的处理,主体代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args,
TVMValue* ret_val, int* ret_type_code) {
API_BEGIN();
TVMRetValue rv;
(static_cast<const PackedFuncObj*>(func))
->CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
if (rv.type_code() != kTVMDataType) {
e->ret_str = *rv.ptr<std::string>();
} else {
e->ret_str = rv.operator std::string();
}
if (rv.type_code() == kTVMBytes) {
e->ret_bytes.data = e->ret_str.c_str();
e->ret_bytes.size = e->ret_str.length();
*ret_type_code = kTVMBytes;
ret_val->v_handle = &(e->ret_bytes);
} else {
*ret_type_code = kTVMStr;
ret_val->v_str = e->ret_str.c_str();
}
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
API_END();
}
将动态链接库定义的函数接口关联到各个 Python 模块
上面我们分析了 Python 程序如何检索和调用动态链接库中提供的函数接口,最后让我们来看动态链接库中的函数接口如何被绑定到各个 Python 模块上去的。这里的 "绑定" 指的是在完成动态链接库构建,以及 Python 和动态链接库之间实现可访问之后,使用 Python 模块的函数实现对底层动态链接库函数接口的封装。
首先让我们来看位于 python/tvm/_ffi/registry.py
tvmsrc_python_ffi_registry_py 中的函数 _init_api
和 _init_api_prefix
函数,这是实现模块绑定的两个关键 API。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
namespace : str
The namespace of the source registry
target_module_name : str
The target module name if different from namespace
"""
target_module_name = target_module_name if target_module_name else namespace
if namespace.startswith("tvm."):
_init_api_prefix(target_module_name, namespace[4:])
else:
_init_api_prefix(target_module_name, namespace)
def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]
for name in list_global_func_names():
if not name.startswith(prefix):
continue
fname = name[len(prefix) + 1 :]
target_module = module
if fname.find(".") != -1:
continue
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = "TVM PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff)
从函数的定义可以看出,_init_api
实际上是 _init_api_prefix
的一个 wrapper,我们直接对后者进行分析。
后者调用 sys.modules
列表获取指定 Python 模块名的 Python 模块,然后调用 list_global_func_names
函数获取所有的在动态链接库中的 API,然后基于传入的名称前缀 (prefix) 对这些 API 进行过滤,对于符合前缀的 API 名称,就使用我们上面介绍过的 _get_global_func
函数获得封装有动态链接库函数指针的 PackedFunc
实例 (p.s. 上面代码中使用的 get_global_func
是对 _get_global_func
的直接封装),然后调用 setattr
函数 runoob_python_setattr 将这个实例绑定到指定的模块下。
基于上面介绍的 _init_api
函数,TVM 在 Python 程序的一些模块中对 _init_api
进行了调用,就完成了这些模块和动态链接库中的函数接口之间的关联。代码中的几处示例如下所示:
1
2
3
4
5
6
7
8# python/tvm/runtime/_ffi_api.py
tvm._ffi._init_api("runtime", __name__)
# python/tvm/relay/op/op.py
tvm._ffi._init_api("relay.op", __name__)
# python/tvm/relay/backend/_backend.py
tvm._ffi._init_api("relay.backend", __name__)
Python 关联 C++ 函数接口例子
下面我们以 TVM 中求绝对值的函数 abs
为例,这个函数实现在 tir
模块,函数的功能很简单,不会造成额外的理解负担,我们只关注从 Python 调用是怎么映射到 C++ 中的,先看在 C++ 中 abs
函数的定义和注册:
1
2
3
4
5
6// src/tir/op/op.cc
// 函数定义
PrimExpr abs(PrimExpr x, Span span) { ... }
// 函数注册
TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);
再看 Python 端对动态链接库中该函数接口的封装过程:
1
2
3
4
5
6
7
8
9
10# python/tvm/tir/_ffi_api.py
# 把c++ tir中注册的函数以python PackedFunc
# 对象的形式关联到了_ffi_api这个模块
tvm._ffi._init_api("tir", __name__)
# python/tvm/tir/op.py
# 定义了abs的python函数,其实内部调用了前面
# 关联到_ffi_api这个模块的python PackedFunc对象
def abs(x, span=None):
return _ffi_api.abs(x, span)
完成封装后,用户就可以基于 tir
这个 Python 模块实现对动态链接库中的函数接口的调用:
1
2
3
4
5import tvm
from tvm import tir
rlt = tir.abs(-100)
print("abs(-100) = %d" % (rlt)
DGL FFI 机制
在理解了 TVM 的 FFI 机制后,我们来到 DGL 的代码,会发现基本上框架是一样的。举个例子来说,比如创建一个异构图的结构,DGL 在 src/graph/heterograph_capi.cc
dglsrc_src_graph_heterograph_capi_cc 中有如下的实现:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t nvtypes = args[0];
int64_t num_src = args[1];
int64_t num_dst = args[2];
IdArray row = args[3];
IdArray col = args[4];
List<Value> formats = args[5];
bool row_sorted = args[6];
bool col_sorted = args[7];
std::vector<SparseFormat> formats_vec;
for (Value val : formats) {
std::string fmt = val->data;
formats_vec.push_back(ParseSparseFormat(fmt));
}
const auto code = SparseFormatsToCode(formats_vec);
auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col,
row_sorted, col_sorted, code);
*rv = HeteroGraphRef(hgptr);
});
同样是通过调用宏 DGL_REGISTER_GLOBAL
注册的方式,DGL 将函数注册进了注册机中。
在 Python 侧,在 python/dgl/heterograph_index.py
dglsrc_python_dgl_heterograph_index_py 中,通过调用 _init_api
的方式,将动态链接库中的函数接口绑定到了 heterograph_index
模块中:
1
2# python/dgl/heterograph_index.py
_init_api("dgl.heterograph_index")
这样一来,在这个模块中的代码就可以直接使用 C++ 函数接口了:
1
2
3
4
5
6
7
8
9# python/dgl/heterograph_index.py
def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col,
formats, row_sorted=False, col_sorted=False):
if isinstance(formats, str):
formats = [formats]
return _CAPI_DGLHeteroCreateUnitGraphFromCOO(
int(num_ntypes), int(num_src), int(num_dst),
F.to_dgl_nd(row), F.to_dgl_nd(col),
formats, row_sorted, col_sorted)