源码解析:借鉴 TVM 的 Python 和 C++ 的调用机制

⚠ 转载请注明出处:作者:ZobinHuang,更新日期:July 11 2022


知识共享许可协议

    本作品ZobinHuang 采用 知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议 进行许可,在进行使用或分享前请查看权限要求。若发现侵权行为,会采取法律手段维护作者正当合法权益,谢谢配合。


目录

有特定需要的内容直接跳转到相关章节查看即可。

正在加载目录...

前言

    在 DGL 官方给出的为数不多的开发者文档 dgl_ffi 中,里面说到他们在大部分 DGL 的代码中使用 python 予以相关逻辑的实现,在性能关键的地方就是用 C++ 予以实现。因此在代码中将经常遇到需要使用 Python 代码对使用 C++ 代码定义的函数进行调用的情况。这篇文章我们将分析 DGL 实现这层机制的底层原理。

    DGL 所使用的 Foreign Function Interface 机制的实现实际上参考了 TVM 项目中 FFI 功能模块 tvm_official 的实现方法 dgl_ffi。TVM 是一个用于优化 ML 模型编译以及实现通用化模型部署的编译器栈,这里我们说到的 FFI 只是它其中的一个功能模块。下面我们先对 TVM 的 FFI 机制的实现进行分析,然后回到 DGL 来看其具体实现。

TVM FFI 机制

    TVM FFI 机制本质上是使用了 Python 提供的 ctypes 模块来调用 C++ 代码提供的 ctypes 是 Python 内建的可以用于调用 C/C++ 动态链接库 函数的功能模块。这意味着我们首先需要把使用 C++ 编写的函数到一个动态链接库文件中,然后再在 Python 中使用 ctypes 模块来加载动态链接库中的函数。在 ctypes 的官方文档中+是这么介绍 ctypes 的:

ctypes is a foreign function library for Python. It provides C compatible data types, and allows calling functions in DLLs or shared libraries. It can be used to wrap these libraries in pure Python.

    下面我们首先从 C++ 侧来看 TVM 如何定义具体的函数功能接口,然后再来看 TVM 如何在 Python 侧实现对这些函数的加载和调用。

C++ 底层数据结构支撑

    在 C++ 侧,PackedFunc 这个类用于描述和存储作为 FFI 的 C++ 函数,它是 Python 和 C++ 互相调用的桥梁。这个类是在 include/tvm/runtime/packed_func.h tvmsrc_packedfunc_h 中定义的。在阅读这个类的代码之前,我们首先对该类所使用到的关键数据结构进行分析,然后再回过头来分析 PackedFunc

    img_data_structure 展示了我们下面要进行分析的几个数据结构之间的关系,这几个数据结构本质上是对 PackedFunc 所使用的数据的封装,它们的功能如下所示:

  • TVMValue: 对基本数据类型进行统一;
  • TVMPODValue_: 封装了对数据类型进行强制类型转换的运算符;
  • TVMArgValue: 用于封装传给 PackedFunc一个 参数;
  • TVMArgs: 用于封装传给 PackedFunc所有 参数;
  • TVMRetValue: 用于作为存放调用 PackedFunc 返回值的容器;

    下面我们分别对它们进行分析。

TVMValue

1
2
3
4
5
6
7
8
9
10
11
12
/*!
* \brief Union type of values
* being passed through API and function calls.
*/
typedef union {
int64_t v_int64;
double v_float64;
void* v_handle;
const char* v_str;
DLDataType v_type;
DLDevice v_device;
} TVMValue;

    首先是最底层的 TVMValue,其是在 tvm/include/tvm/runtime/c_runtime_api.h tvmsrc_c_runtime_api_h 中定义的,是一个 union。它的功能主要是为了封装 C++ 和其它语言交互时所支持的几种类型的数据。

TVMPODValue_

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/*!
* \brief Internal base class to
* handle conversion to POD values.
*/
class TVMPODValue_ {
public:
operator double() const { return value_.v_float64; }
operator int64_t() const { return value_.v_int64; }
operator void*() const { return value_.v_handle; }
template <typename T>
T* ptr() const { return static_cast<T*>(value_.v_handle); }

protected:
TVMValue value_;
int type_code_;
};

    然后是 TVMPODValue_,其是在 include/tvm/runtime/packed_func.h tvmsrc_packedfunc_h 中定义的。这个类的实现核心是「强制类型转换运算符」重载。在 C++中,类型的名字,包括类的名字本身也是一种运算符,即类型强制转换运算符。如上面摘抄的定义所示,该类重载了「强制类型转换运算符」,可以实现对 TVMValue 存储的数据的类型转换。

TVMArgValue

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/*!
* \brief A single argument value to PackedFunc.
* Containing both type_code and TVMValue
*
* Provides utilities to do type cast into other types.
*/
class TVMArgValue : public TVMPODValue_ {
public:
/*! \brief default constructor */
TVMArgValue() {}
/*!
* \brief constructor
* \param value of the function
* \param type_code The type code.
*/
TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {}
// reuse converter from parent
using TVMPODValue_::operator double;
using TVMPODValue_::operator int64_t;
using TVMPODValue_::operator uint64_t;
using TVMPODValue_::operator int;
using TVMPODValue_::operator bool;
using TVMPODValue_::operator void*;
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Device;
using TVMPODValue_::operator Module;
using TVMPODValue_::operator PackedFunc;
using TVMPODValue_::AsObjectRef;
using TVMPODValue_::IsObjectRef;

// conversion operator.
operator std::string() const {
if (type_code_ == kTVMDataType) {
return DLDataType2String(operator DLDataType());
} else if (type_code_ == kTVMBytes) {
TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
return std::string(arr->data, arr->size);
} else if (type_code_ == kTVMStr) {
return std::string(value_.v_str);
} else {
ICHECK(IsObjectRef<tvm::runtime::String>())
<< "Could not convert TVM object of type " << runtime::Object::TypeIndex2Key(type_code_)
<< " to a string.";
return AsObjectRef<tvm::runtime::String>().operator std::string();
}
}
template <typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
const TVMValue& value() const { return value_; }

template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
inline operator T() const;
inline operator DLDataType() const;
inline operator DataType() const;
};

    然后是 TVMArgValue,其是在 include/tvm/runtime/packed_func.h tvmsrc_packedfunc_h 中定义的。这个类继承自前面介绍的 TVMPODValue_ 类,用作表示 PackedFunc 的一个参数,它和 TVMPODValue_ 的区别是扩充了一些数据类型的支持,比如 stringPackedFuncTypedPackedFunc 等,其中对后两个的支持是在 C++ 代码中能够调用 Python 函数的根本原因。这个类只使用所保存的 underlying data,而不会去做释放。

TVMArgs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*! \brief Arguments into TVM functions. */
class TVMArgs {
public:
const TVMValue* values;
const int* type_codes;
int num_args;
/*!
* \brief constructor
* \param values The argument values
* \param type_codes The argument type codes
* \param num_args number of arguments.
*/
TVMArgs(const TVMValue* values, const int* type_codes, int num_args)
: values(values), type_codes(type_codes), num_args(num_args) {}
/*! \return size of the arguments */
inline int size() const;
/*!
* \brief Get i-th argument
* \param i the index.
* \return the ith argument.
*/
inline TVMArgValue operator[](int i) const;
};

    然后是 TVMArgs,它是在 include/tvm/runtime/packed_func.h tvmsrc_packedfunc_h 中定义的。这个类主要是为了封装传给代表着 C++ FFI 的 PackedFunc 的所有参数,这个类也比较简单原始,主要基于 TVMValue、参数类型编码、参数个数来实现。值得注意的是 TVMArgs 使用的是 TVMValue 数组来存储传入的所有参数,而对于迭代访问运算符 [],其返回的又是 TVMArgValue 类型的值。因此上面介绍的 TVMArgValue 实际上并不存储传入的参数本身,而用于对传入的参数进行处理使用。

TVMRetValue

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class TVMRetValue : public TVMPODValue_ {
public:
// ctor and dtor, dtor will release related buffer
TVMRetValue() {}
~TVMRetValue() { this->Clear(); }

// conversion operators
operator std::string() const { return *ptr<std::string>(); }
operator DLDataType() const { return value_.v_type; }
operator PackedFunc() const { return *ptr<PackedFunc>(); }

// Assign operators
TVMRetValue& operator=(double value) {}
TVMRetValue& operator=(void* value) {}
TVMRetValue& operator=(int64_t value) {}
TVMRetValue& operator=(std::string value) {}
TVMRetValue& operator=(PackedFunc f) {}

private:
// judge type_code_, release underlying data
void Clear() {
if (type_code_ == kTVMStr || type_code_ == kTVMBytes) {
delete ptr<std::string>();
} else if(type_code_ == kTVMPackedFuncHandle) {
delete ptr<PackedFunc>();
} else if(type_code_ == kTVMNDArrayHandle) {
NDArray::FFIDecRef(
static_cast<TVMArrayHandle>(value_.v_handle));
} else if(type_code_ == kTVMModuleHandle
|| type_code_ == kTVMObjectHandle ) {
static_cast<Object*>(value_.v_handle)->DecRef();
}
type_code_ = kTVMNullptr;
}
};

    最后是 TVMRetValue,它是在 include/tvm/runtime/packed_func.h tvmsrc_packedfunc_h 中定义的。这个类也是继承自 TVMPODValue_ 类,主要作用是作为存放调用 PackedFunc 返回值的容器,它和 TVMArgValue 的区别是,它会管理所保存的 underlying data,会对其做释放。这个类主要由四部分构成:

  1. 构造和析构函数;
  2. 对强制类型转换运算符重载的扩展;
  3. 对赋值运算符的重载;
  4. 辅助函数,包括释放资源的 Clear 函数

TVMArgsSetter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class TVMArgsSetter {
public:
TVMArgsSetter(TVMValue* values, int* type_codes)
: values_(values), type_codes_(type_codes) {}

void operator()(size_t i, double value) const {
values_[i].v_float64 = value;
type_codes_[i] = kDLFloat;
}
void operator()(size_t i, const string& value) const {
values_[i].v_str = value.c_str();
type_codes_[i] = kTVMStr;
}
void operator()(size_t i, const PackedFunc& value) const {
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kTVMPackedFuncHandle;
}

private:
TVMValue* values_;
int* type_codes_;
};

    另外,TVMArgsSetter 是一个用于给 TVMValue 对象赋值的辅助类,主要通过重载函数调用运算符来实现。

PackedFunc 的具体实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
/*!
* \brief Packed function is a type-erased function.
* The arguments are passed by packed format.
*
* This is an useful unified interface to call generated functions,
* It is the unified function function type of TVM.
* It corresponds to TVMFunctionHandle in C runtime API.
*/
class PackedFunc : public ObjectRef {
public:
/*!
* \brief Constructor from null
* 构造函数:传入空指针的构造函数,初始化自身也初始化父类
*/
PackedFunc(std::nullptr_t null) : ObjectRef(nullptr) {} // NOLINT(*)

/*!
* \brief Constructing a packed function from a callable type
* whose signature is consistent with `PackedFunc`
* 构造函数:基于传入的函数进行初始化,传入的函数类型必须要能转换成
* void(TVMArgs, TVMRetValue*) 的形式
* \param data the internal container of packed function.
*/
template <typename TCallable,
typename = std::enable_if_t<
std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
!std::is_base_of<TCallable, PackedFunc>::value>>
explicit PackedFunc(TCallable data) {
using ObjType = PackedFuncSubObj<TCallable>;
data_ = make_object<ObjType>(std::forward<TCallable>(data));
}
/*!
* \brief Call packed function by directly passing in unpacked format.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
*
* 重载函数调用符号
*
* \code
* // Example code on how to call packed function
* void CallPacked(PackedFunc f) {
* // call like normal functions by pass in arguments
* // return value is automatically converted back
* int rvalue = f(1, 2.0);
* }
* \endcode
*/
template <typename... Args>
inline TVMRetValue operator()(Args&&... args) const;
/*!
* \brief Call the function in packed format.
* \param args The arguments
* \param rv The return value.
*/
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const;
/*! \return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const { return data_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }

TVM_DEFINE_OBJECT_REF_METHODS(PackedFunc, ObjectRef, PackedFuncObj);
};

    有了前面所述的数据结构作为基础,再来看 PackedFunc 的实现,它是在 include/tvm/runtime/packed_func.h tvmsrc_packedfunc_h 中定义的。PackedFunc 代表了用户自定义的一个 C++ 函数,最终将被编译和注册到动态链接库中,以供 Python 程序调用。PackedFunc 使用了一个储存函数指针的变量 data_ (p.s. 继承自 ObjectRef),再通过重载函数调用运算符 () 来调用这个函数指针所指向的函数。重载 () 运算符的定义是在同个文件下实现的,定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
template <typename... Args>
inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;

// 初始化两个数组存放,用于存放转化用户传入的参数后的若干 TVMValue
// 以及它们的 type_codes
TVMValue values[kArraySize];
int type_codes[kArraySize];

// 使用 TVMArgsSetter 将用户传入的参数初始化为 TVMValue,放入数组中
detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
TVMRetValue rv;

// 调用在 data_ 中存储的函数
(static_cast<PackedFuncObj*>(data_.get()))
->CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);

// 返回由 TVMRetValue rv 存储的函数返回值
return rv;
}

    上面代码的相关注释以及解释了整个调用流程。我们可以发现在调用的时候,TVM 先把 转化为 PackedFuncObj 类实例,然后再调用其 CallPacked 函数以完成对函数的执行。

基于 PackedFunc 的函数注册机

    在完成对 PackedFunc 的理解后,下面我们对代码是如何将各个 PackedFunc 实例注册到动态链接库 (以供 Python 程序调用) 进行分析。

TVM_REGISTER_GLOBAL

    注册过程中最常使用的宏是 TVM_REGISTER_GLOBAL,调用这个宏将返回一个 Registry 对象,基于这个对象,我们可以调用诸如 set_body 等注册接口,将函数进行注册。下面是一个例子:

1
2
3
4
5
// src/topi/nn.cc
TVM_REGISTER_GLOBAL("topi.nn.relu")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = relu<float>(args[0]);
});

    值得注意的是,set_body 等注册接口可以接收普通函数,也可以接收 Lambda 表达式 cpp_lambda,如上所示的例子接收的就是 Lambda 表达式。

    下面我们对 TVM_REGISTER_GLOBAL 宏定义进行分析。这个宏是在 include/tvm/runtime/registry.h tvmsrc_registry_h 中定义的。如下所示:

1
2
3
4
5
6
7
8
#define TVM_REGISTER_GLOBAL(OpName) \
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::runtime::Registry::Register(OpName)

// 辅助 macro
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#define TVM_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_##TVM
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)

    可以发现,TVM_REGISTER_GLOBAL 实际上就是基于它接收的参数 OpName,在调用宏的位置定义了一个 static 的引用变量,引用到注册机内部 new 出来的一个新的 Registry 对象。上面的代码中使用到了 __COUNTER 编译器扩展宏,GCC 的文档里对该宏的描述如下所示:

This macro expands to sequential integral values starting from 0. In conjunction with the ## operator, this provides a convenient means to generate unique identifiers. Care must be taken to ensure that __COUNTER__ is not expanded prior to inclusion of precompiled headers which use it. Otherwise, the precompiled headers will not be used.

    将上面的宏进行展开,实际上就是在调用宏的地方,运行了这么一句代码:

1
2
3
static TVM_ATTRIBUTE_UNUSED \
::tvm::runtime::Registry& __mk_##TVM[编译器扩展宏 __COUNTER__] =
::tvm::runtime::Registry::Register(OpName);

Registry

    下面我们对 Registry 这个关键的类进行分析,它是在 include/tvm/runtime/registry.h tvmsrc_registry_h 中被定义的,这个类的每一个实例代表着被注册的每一个函数。Registry 有一个友元结构体 Manager,它用于存储全局所有的 Registry 实例。我们首先对它进行分析,下面是它在 src/runtime/registry.cc tvmsrc_registry_cc 中的具体定义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
struct Registry::Manager {
// map storing the functions.
// We deliberately used raw pointer.
// This is because PackedFunc can contain callbacks into the host language (Python) and the
// resource can become invalid because of indeterministic order of destruction and forking.
// The resources will only be recycled during program exit.
std::unordered_map<std::string, Registry*> fmap;
// mutex
std::mutex mutex;

Manager() {}

static Manager* Global() {
// We deliberately leak the Manager instance, to avoid leak sanitizers
// complaining about the entries in Manager::fmap being leaked at program
// exit.
static Manager* inst = new Manager();
return inst;
}
};

    从上面定义的 Global 函数可以发现,友元结构体 Manager 实际上是一种单例的设计模式 cpp_single_instance,因为被声明为 static 的局部变量只会在第一次被访问的时候被初始化 cpp_new_static_local,上面 Global 函数中的 inst 就是这样一种变量。全局的其他函数可以通过结构体 ManagerGlobal 函数访问到全局唯一的一个 Manager 实例。

    另外,我们还可以发现 Manager 使用的是 std::unordered_map 来存储注册的信息,注册的对象是 Registry 指针。

    最后我们来到了 Register 这个类,这个类是在 include/tvm/runtime/registry.h tvmsrc_registry_h 中被定义的。相关定义摘抄如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Registry {
public:
Registry& set_body(PackedFunc f);
Registry& set_body_typed(FLambda f);
Registry& set_body_method(R (T::*f)(Args...));

static Registry& Register(const std::string& name);
static const PackedFunc* Get(const std::string& name);
static std::vector ListNames();

protected:
std::string name_;
PackedFunc func_;
friend struct Manager;
};

    下面进行分析:

  • 首先 Registry 类首先提供了若干个 set_body 注册接口,这些接口将会基于用户传入的 PackedFunc 初始化 Registry 实例,具体定义代码如下所示:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    // include/tvm/runtime/registry.h
    class Registry {
    public:
    /*!
    * \brief set the body of the function to be f
    * \param f The body of the function.
    */
    TVM_DLL Registry& set_body(PackedFunc f); // NOLINT(*)

    /*!
    * \brief set the body of the function to be f
    * \param f The body of the function.
    */
    template <typename TCallable,
    typename = typename std::enable_if_t<
    std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
    !std::is_base_of<PackedFunc, TCallable>::value>>
    Registry& set_body(TCallable f) { // NOLINT(*)
    return set_body(PackedFunc(f));
    }

    /* ... */

    protected:
    /*! \brief internal packed function */
    PackedFunc func_;

    /* ... */
    }

    // include/tvm/runtime/registry.cc
    Registry& Registry::set_body(PackedFunc f) { // NOLINT(*)
    func_ = f;
    return *this;
    }

        可以发现 set_body 等接口实际上就是把用户传入的函数存储到了 Registry 的私有变量 func_ 中。

  • Registry 类还提供了用于创建 Registry 对象的 Register 静态接口,这个接口具体是在 tvm/src/runtime/registry.cc tvmsrc_registry_cc 中被定义的,具体代码如下所示:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    Registry& Registry::Register(const std::string& name, bool can_override) {  // NOLINT(*)
    Manager* m = Manager::Global();
    std::lock_guard<std::mutex> lock(m->mutex);
    if (m->fmap.count(name)) {
    ICHECK(can_override) << "Global PackedFunc " << name << " is already registered";
    }

    Registry* r = new Registry();
    r->name_ = name;
    m->fmap[name] = r;
    return *r;
    }

        可以发现 Register 函数实际上就是实例化出了一个 Registry 对象,并且将它存储在我们上面介绍的全局唯一的 Managerstd::unordered_map 中。

        这里穿插一句题外话,我们从上面 RegistryManager 的代码中可以发现,TVM prefer 使用静态类成员函数来实例化一个类对象,Registry 使用 Register 函数; Manager 使用 Global 函数,这样子写的好处是可以将类对象的初始化逻辑包括在类的相关定义中,给外界提供一个更加简洁的接口。

  • Registry 还提供了一个根据名称来获取注册函数的接口 Get,这个接口具体是在 tvm/src/runtime/registry.cc tvmsrc_registry_cc 中被定义的,具体代码如下所示:
    1
    2
    3
    4
    5
    6
    7
    const PackedFunc* Registry::Get(const std::string& name) {
    Manager* m = Manager::Global();
    std::lock_guard<std::mutex> lock(m->mutex);
    auto it = m->fmap.find(name);
    if (it == m->fmap.end()) return nullptr;
    return &(it->second->func_);
    }

        这个接口代码比较简单,不再过多赘述。

Python 的加载和调用细节

    在理解了 PackedFuncRegistry 后,现在让我们来看 Python 程序是如何调用 C++ 函数接口的。

C++ 侧暴露的函数接口

    我们上面看到了,所有的在 C++ 端定义的函数实际上都会存储在全局唯一的 Manager 下的 std::unordered_map 中存储的各个 Registry 对象实例中。基于这层结构,在 tvm/include/tvm/runtime/c_runtime_api.h tvmsrc_c_runtime_api_h 中,TVM 定义了动态链接库侧暴露给 Python 程序的接口。代码摘抄如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#ifdef __cplusplus
extern "C" {
#endif

/*!
* \brief List all the globally registered function name
* \param out_size The number of functions
* \param out_array The array of function names.
* \return 0 when success, nonzero when failure happens
*/
TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array);

/*!
* \brief Get a global function.
*
* \param name The name of the function.
* \param out the result function pointer, NULL if it does not exist.
*
* \note The function handle of global function is managed by TVM runtime,
* So TVMFuncFree is should not be called when it get deleted.
*/
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);

/*!
* \brief Call a Packed TVM Function.
*
* \param func node handle of the function.
* \param arg_values The arguments
* \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
*
* \param ret_val The return value.
* \param ret_type_code the type code of return value.
*
* \return 0 when success, nonzero when failure happens
* \note TVM calls always exchanges with type bits=64, lanes=1
*
* \note API calls always exchanges with type bits=64, lanes=1
* If API call returns container handles (e.g. FunctionHandle)
* these handles should be managed by the front-end.
* The front-end need to call free function (e.g. TVMFuncFree)
* to free these handles.
*/
TVM_DLL int TVMFuncCall(TVMFunctionHandle func, TVMValue* arg_values, int* type_codes, int num_args,
TVMValue* ret_val, int* ret_type_code);

#ifdef __cplusplus
} // TVM_EXTERN_C
#endif

    在上面的代码中,首先,对于动态链接库提供的 API,需要使用符合 C 语言编译和链接约定的 API,因为 Python 的 ctype 只和 C 兼容,而 C++ 编译器会对函数和变量名进行 name mangling,所以需要使用 __cplusplus 宏和 extern "C" 来得到符合 C 语言编译和链接约定的 API。

    其次,我们观察到上面的三个接口都是用了 TVM_DLL 加以修饰。TVM_DLL 的定义如下所示:

1
#define TVM_DLL __attribute__((visibility("default")))

    对于 Linux 的动态链接文件 (.so) 来说,GCC 的 visibility 属性可以控制共享库文件导出的符号。当我们为某个函数、变量、模板或者 C++ 类声明加上 __attribute__((visibility("default"))) 的编译属性,并且在编译该动态链接库的时候基于 -fvisibility=hidden 参数进行编译时,形成的动态链接文件的动态链接表中就只会包含附带该属性的函数、变量、模板或者 C++ 类 gcc_attribute_visibility

    因此,回到 TVM 的代码中,使用 TVM_DLL 宏进行修饰的函数接口将会被最终呈现在动态链接表中,可以被 Python 程序所使用。

Python 侧加载动态链接库

    TVM 的 Python 代码从 python/tvm/__init__.py tvmsrc_python_init_py 中开始真正执行:

1
from ._ffi.base import TVMError, __version__, _RUNTIME_ONLY

    在运行到这行代码时,参考我的另一篇博客 Python 循环 Import (Circular Import) 陷阱原理,此时 Python 解释器将会执行 python/tvm/_ffi/__init__.py tvmsrc_python_ffi_init_py:

1
2
3
4
from . import _pyversion
from .base import register_error
from .registry import register_object, register_func, register_extension
from .registry import _init_api, get_global_func

    上面的第一句,又会导致 python/tvm/_ffi/base.py tvmsrc_python_ffi_base_py 中的下面代码被执行:

1
2
3
4
5
6
7
8
9
10
11
12
13
def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
# The dll search path need to be added explicitly in
# windows after python 3.8
if sys.platform.startswith("win32") and sys.version_info >= (3, 8):
for path in libinfo.get_dll_directories():
os.add_dll_directory(path)
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
lib.TVMGetLastError.restype = ctypes.c_char_p
return lib, os.path.basename(lib_path[0])

_LIB, _LIB_NAME = _load_lib()

    此时,Python 运行时将可以获得 _LIB_LIB_NAME 两个变量。其中,_LIB 可以理解为动态链接库在 Python 侧的操作句柄,_LIB_NAME 则是 "libtvm.so" 这个字符串,代表着动态链接库的名字。后续在 Python 中,TVM 将通过 _LIB 这个桥梁不断地和 C++ 的部分进行交互。

Python 侧关联 C++ 侧的 PackedFunc

    在有了 _LIB 这个 Python 和动态链接库之间的桥梁后,现在让我们来看 Python 侧是如何基于动态链接库暴露出来的接口拿到 TVM 使用 PackedFunc 封装的函数的。

    Python 中获取 TVM C++ 使用 PackedFunc 封装的 API 的底层函数是 _get_global_func,它是在 python/tvm/_ffi/_ctypes/packed_func.py tvmsrc_python_ffi_packed_func_py 中定义的,具体定义如下:

1
2
3
4
5
6
7
8
9
10
11
def _get_global_func(name, allow_missing=False):
handle = PackedFuncHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))

if handle.value:
return _make_packed_func(handle, False)

if allow_missing:
return None

raise ValueError("Cannot find global function %s" % name)

    从上面的代码可以发现,Python 侧的 _get_global_func 函数实际上调用了动态链接库暴露的 TVMFuncGetGlobal 接口。下面是 C++ 侧关于 TVMFuncGetGlobal 接口的实现,这部分代码位于 src/runtime/registry.cc tvmsrc_registry_cc 中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
API_BEGIN();
const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name);
if (fp != nullptr) {
tvm::runtime::TVMRetValue ret;
ret = *fp;
TVMValue val;
int type_code;
ret.MoveToCHost(&val, &type_code);
*out = val.v_handle;
} else {
*out = nullptr;
}
API_END();
}

    可以发现 TVMFuncGetGlobal 接口实际上就是调用了我们上面介绍过的 RegistryGet 接口,来根据给出的 API 名称获得对应的 PackedFunc* 函数指针。如果无法基于给定的名称找到对应的函数,则返回空指针。

    回到 Python 侧,在调用 _LIB.TVMFuncGetGlobal 获取函数指针后,函数指针被保存在了 handle 变量中。我们可以看到 Python 侧随后调用了 _make_packed_func API 来在 Python 侧创建一个 Python 侧的函数封装对象。_make_packed_func 是在 python/tvm/_ffi/_ctypes/packed_func.py tvmsrc_python_ffi_packed_func_py 中定义的,具体如下所示:

1
2
3
4
5
6
def _make_packed_func(handle, is_global):
"""Make a packed function class"""
obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
obj.is_global = is_global
obj.handle = handle
return obj

    可以看见 _make_packed_func 接口基于传入的 handle 参数,初始化了一个 PackedFunc 类实例,这是一个空实现的类,在 python/tvm/runtime/packed_func.py tvmsrc_python_runtime_packed_func_py 中予以实现,其继承自 PackedFuncBase 类,后者是在 python/tvm/_ffi/_ctypes/packed_func.py tvmsrc_python_ffi_ctypes_packed_func_py 中被定义的。

Python 侧调用 C++ 侧的 PackedFunc

    当 Python 侧想要调用动态链接库里的函数时,其是通过调用 PackedFunc 类实现的,调用 PackedFunc 类间接调用了 PackedFuncBase 父类,后者中实现了一个 __cal__ 方法以供调用,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class PackedFuncBase(object):
def __call__(self, *args):
"""Call the function with positional arguments

args : list
The positional arguments to the function call.
"""
temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
if (
_LIB.TVMFuncCall(
self.handle,
values,
tcodes,
ctypes.c_int(num_args),
ctypes.byref(ret_val),
ctypes.byref(ret_tcode),
)
!= 0
):
raise get_last_ffi_error()
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)

    从上面可以看出,__call__ 函数调用了 C++ 侧的 TVMFuncCall 这个 API,把前面保存有 C++ PackedFunc 对象地址的 handle 以及相关的函数参数传了进去。在 C++ 侧,TVMFuncCall 是在 src/runtime/c_runtime_api.cc tvmsrc_src_runtime_c_runtime_api_cc 中定义的,其完成了对函数的调用,以及函数返回值的处理,主体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args,
TVMValue* ret_val, int* ret_type_code) {
API_BEGIN();

TVMRetValue rv;
(static_cast<const PackedFuncObj*>(func))
->CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv);

// handle return string.
if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
if (rv.type_code() != kTVMDataType) {
e->ret_str = *rv.ptr<std::string>();
} else {
e->ret_str = rv.operator std::string();
}
if (rv.type_code() == kTVMBytes) {
e->ret_bytes.data = e->ret_str.c_str();
e->ret_bytes.size = e->ret_str.length();
*ret_type_code = kTVMBytes;
ret_val->v_handle = &(e->ret_bytes);
} else {
*ret_type_code = kTVMStr;
ret_val->v_str = e->ret_str.c_str();
}
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
API_END();
}

将动态链接库定义的函数接口关联到各个 Python 模块

    上面我们分析了 Python 程序如何检索和调用动态链接库中提供的函数接口,最后让我们来看动态链接库中的函数接口如何被绑定到各个 Python 模块上去的。这里的 "绑定" 指的是在完成动态链接库构建,以及 Python 和动态链接库之间实现可访问之后,使用 Python 模块的函数实现对底层动态链接库函数接口的封装。

    首先让我们来看位于 python/tvm/_ffi/registry.py tvmsrc_python_ffi_registry_py 中的函数 _init_api_init_api_prefix 函数,这是实现模块绑定的两个关键 API。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name

namespace : str
The namespace of the source registry

target_module_name : str
The target module name if different from namespace
"""
target_module_name = target_module_name if target_module_name else namespace
if namespace.startswith("tvm."):
_init_api_prefix(target_module_name, namespace[4:])
else:
_init_api_prefix(target_module_name, namespace)

def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]

for name in list_global_func_names():
if not name.startswith(prefix):
continue

fname = name[len(prefix) + 1 :]
target_module = module

if fname.find(".") != -1:
continue
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = "TVM PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff)

    从函数的定义可以看出,_init_api 实际上是 _init_api_prefix 的一个 wrapper,我们直接对后者进行分析。

    后者调用 sys.modules 列表获取指定 Python 模块名的 Python 模块,然后调用 list_global_func_names 函数获取所有的在动态链接库中的 API,然后基于传入的名称前缀 (prefix) 对这些 API 进行过滤,对于符合前缀的 API 名称,就使用我们上面介绍过的 _get_global_func 函数获得封装有动态链接库函数指针的 PackedFunc 实例 (p.s. 上面代码中使用的 get_global_func 是对 _get_global_func 的直接封装),然后调用 setattr 函数 runoob_python_setattr 将这个实例绑定到指定的模块下。

    基于上面介绍的 _init_api 函数,TVM 在 Python 程序的一些模块中对 _init_api 进行了调用,就完成了这些模块和动态链接库中的函数接口之间的关联。代码中的几处示例如下所示:

1
2
3
4
5
6
7
8
# python/tvm/runtime/_ffi_api.py
tvm._ffi._init_api("runtime", __name__)

# python/tvm/relay/op/op.py
tvm._ffi._init_api("relay.op", __name__)

# python/tvm/relay/backend/_backend.py
tvm._ffi._init_api("relay.backend", __name__)

Python 关联 C++ 函数接口例子

    下面我们以 TVM 中求绝对值的函数 abs 为例,这个函数实现在 tir 模块,函数的功能很简单,不会造成额外的理解负担,我们只关注从 Python 调用是怎么映射到 C++ 中的,先看在 C++ 中 abs 函数的定义和注册:

1
2
3
4
5
6
// src/tir/op/op.cc
// 函数定义
PrimExpr abs(PrimExpr x, Span span) { ... }

// 函数注册
TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);

    再看 Python 端对动态链接库中该函数接口的封装过程:

1
2
3
4
5
6
7
8
9
10
# python/tvm/tir/_ffi_api.py
# 把c++ tir中注册的函数以python PackedFunc
# 对象的形式关联到了_ffi_api这个模块
tvm._ffi._init_api("tir", __name__)

# python/tvm/tir/op.py
# 定义了abs的python函数,其实内部调用了前面
# 关联到_ffi_api这个模块的python PackedFunc对象
def abs(x, span=None):
return _ffi_api.abs(x, span)

    完成封装后,用户就可以基于 tir 这个 Python 模块实现对动态链接库中的函数接口的调用:

1
2
3
4
5
import tvm
from tvm import tir

rlt = tir.abs(-100)
print("abs(-100) = %d" % (rlt)

DGL FFI 机制

    在理解了 TVM 的 FFI 机制后,我们来到 DGL 的代码,会发现基本上框架是一样的。举个例子来说,比如创建一个异构图的结构,DGL 在 src/graph/heterograph_capi.cc dglsrc_src_graph_heterograph_capi_cc 中有如下的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t nvtypes = args[0];
int64_t num_src = args[1];
int64_t num_dst = args[2];
IdArray row = args[3];
IdArray col = args[4];
List<Value> formats = args[5];
bool row_sorted = args[6];
bool col_sorted = args[7];
std::vector<SparseFormat> formats_vec;
for (Value val : formats) {
std::string fmt = val->data;
formats_vec.push_back(ParseSparseFormat(fmt));
}
const auto code = SparseFormatsToCode(formats_vec);
auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col,
row_sorted, col_sorted, code);
*rv = HeteroGraphRef(hgptr);
});

    同样是通过调用宏 DGL_REGISTER_GLOBAL 注册的方式,DGL 将函数注册进了注册机中。

    在 Python 侧,在 python/dgl/heterograph_index.py dglsrc_python_dgl_heterograph_index_py 中,通过调用 _init_api 的方式,将动态链接库中的函数接口绑定到了 heterograph_index 模块中:

1
2
# python/dgl/heterograph_index.py
_init_api("dgl.heterograph_index")

    这样一来,在这个模块中的代码就可以直接使用 C++ 函数接口了:

1
2
3
4
5
6
7
8
9
# python/dgl/heterograph_index.py
def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col,
formats, row_sorted=False, col_sorted=False):
if isinstance(formats, str):
formats = [formats]
return _CAPI_DGLHeteroCreateUnitGraphFromCOO(
int(num_ntypes), int(num_src), int(num_dst),
F.to_dgl_nd(row), F.to_dgl_nd(col),
formats, row_sorted, col_sorted)