0%

TVM遍历器设计,functor(仿函数)

用法

功能是提供一个能够对图结构(节点为对象Expr)的数据流图进行深度优先搜索的基础类,并且可以多态注册当处理对象为不同的Expr子类时进行不同的处理。
具体的遍历器基础类ExprFunctor实现为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
template <typename R, typename... Args>
class ExprFunctor<R(const Expr& n, Args...)> {
private:
using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;

public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~ExprFunctor() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward<Args>(args)...); }
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may "
"have generated invalid data.";
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw;
}

private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
return vtable;
}
};

遍历器的基类,多态对节点进行访问。具体到每个子类的实现,如常量折叠pass的图深度优先遍历器,重载各个节点遍历函数,添加对具体node的的修改操作。
内部通过访问类的静态变量 vtable 进行,vtable是一个仿函数,定义为

1
2
3
4
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP)                                                    \
vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
});

vtable.template 的语法表示 vtable模板类 中的模板函数,具体解析见调用模板类的模板函数前必须加template关键字的情况,参考《C++ Template》一书中的9.3.2和9.3.3两节

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
template <typename R, typename... Args>
class NodeFunctor<R(const ObjectRef& n, Args...)> {
private:
/*! \brief internal function pointer type */
typedef R (*FPointer)(const ObjectRef& n, Args...);
/*! \brief refer to itself. */
using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>;
/*! \brief internal function table */
std::vector<FPointer> func_;

public:
/*! \brief the result type of this functor */
using result_type = R;
/*!
* \brief Whether the functor can dispatch the corresponding Node
* \param n The node to be dispatched
* \return Whether dispatching function is registered for n's type.
*/
bool can_dispatch(const ObjectRef& n) const {
uint32_t type_index = n->type_index();
return type_index < func_.size() && func_[type_index] != nullptr;
}
/*!
* \brief invoke the functor, dispatch on type of n
* \param n The Node argument
* \param args The additional arguments
* \return The result.
*/
R operator()(const ObjectRef& n, Args... args) const {
ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type "
<< n->GetTypeKey();
return (*func_[n->type_index()])(n, std::forward<Args>(args)...);
}
/*!
* \brief set the dispacher for type TNode
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template <typename TNode>
TSelf& set_dispatch(FPointer f) { // NOLINT(*)
uint32_t tindex = TNode::RuntimeTypeIndex();
if (func_.size() <= tindex) {
func_.resize(tindex + 1, nullptr);
}
ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set";
func_[tindex] = f;
return *this;
}

NodeFunctor是一个对于节点node的仿函数类,输入为const ObjectRef& n, Args... 返回值是 R,这个类有多态的属性,可注册函数。通过不同的ObjectRef参数从而多态调用注册好的函数

在ExprFunctor私有变量里,宏将对应的静态成员进行初始化,绑定到对应的节点方法之中,绑定的方法是ExprFunctor 的 VisitExper_方法,当外部调用该对象时,调用的路径为

R VisitExpr()
获得类的静态 vtable 次部分的函数已经被宏定义时,进行静态初始化,函数已经注册到NodeFunctor::func_中了

NodeFunctor(objref&, arg…)
调用子类objref对应的位于,func_中的函数。

此部分看起来令人疑惑,在遍历器的构造上又多套了一层 static FType vtable = InitVTable(), 其实光通过VisitExpr的多态调用就已经能完成深搜功能了,vtable的设计看起来多余,其实不然,在文章最后给出我的理解

遍历器的基类已经通过多态分发搭出了基本的框架。
下面的几种遍历子类的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
ExprVisitor{
...
std::unordered_map<const Object*, size_t> visit_counter_;
}

ExprMutator{
...
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo_;
}

/*!
* \brief A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
*
* MixedModeVisitor treats Expr as dataflow graph, and visits in post-DFS order
*
* MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratively to prevent stack overflows
*/
class MixedModeVisitor : public ::tvm::relay::ExprVisitor{
...
}

class MixedModeMutator : public ::tvm::relay::ExprMutator{
...
}


FunctorNode
的作用在于,在自己定义的类里,如果想要对ObjectRef进行操作,并且这种操作的多态的,可以注册函数,且易于扩展。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
 \\brief Useful macro to set NodeFunctor dispatch in a global static field.

// Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern.
// vtable allows easy patch of new Node types, without changing
// interface of ReprPrinter.
class ReprPrinter {
public:
std::ostream& stream;
the dispatch function.
void print(Expr e) {
const static FType& f = *vtable();
f(e, this);
}

using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>;
// function to return global function table
static FType& vtable();
};
// in cpp/cc file
ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*)
static FType inst; return inst;
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) {
auto* n = static_cast<const Add*>(ref.get());
p->print(n->a);
p->stream << '+'
p->print(n->b);
});

上述打打印类有两种实现方式

  1. 写很多的 print(Expr e) 的多态实现, 如print(Add e) 等
  2. 采用静态 vtable的,对函数进行注册面向add,min等,在外部的打印类只有print(Expr e)一个入口,前提的继承是 ObjectRef <- Expr <- Add等, 并且次注册可以在任意cpp里完成,因为ReprPrinter::vtable()是静态的

在类ReprPrinter中包含
using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>;
这表明 vtable里注册的函数形式的输入参数为 const ObjectRef&, ReprPrinter* 返回值void,通过 ObjectRef进行多态

static FType& vtable();返回静态vtable的引用

1
2
3
4
ReprPrinter::FType& ReprPrinter::vtable() {
static FType inst;
return inst;
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<Add> ... 展开为

static TVM_ATTRIBUTE_UNUSED auto& __make_functor_ReprPrinter__count__ ReprPrinter::vtable().set_dispatch<Add> ...

__COUNTER__为编译宏,展开为次数
含义是定义一个静态变量__make_functor_ReprPrinter__count__,并使用 ReprPrinter::vtable().set_dispatch<Add> ... 进行初始化,在初始化里,将函数注册进去

前提

1.ObjetctRef里有type_type的定义这表明了不同子类的type_index_是不同的,这个特性给自己的vtabke提供了索引

2.静态变量的初始化是在程序开始之前,一般存储在.bss或者.data段上,并且静态变量的初始化可以调用函数static int aa = func1()是合法的


注意:
此部分的返回需要通过静态函数的形式,如果采用静态变量会产生错误

1
2
3
4
ReprPrinter{
...
static FType vtable; //编译能过,产生链接错误
}

等效实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include <vector>

class CC{
public:
std::vector<int> v;
CC& push_back(int in){
v.push_back(in);
return *this;
}
};

class BB{
public:
int kk;
CC v;
};


class AA{
public:
static BB& get(){
static BB b;
return b;
}
static BB bin;
};


static auto& ll = AA::get().v.push_back(123);
static auto& ll2 = AA::bin.v.push_back(123);

int main(){
std::cout << ll.v.size() << std::endl;

std::cout<< ll2.v.size() << std::endl; //编译错误
//test.cpp:(.text+0xe4): undefined reference to `AA::bin'
}

链接错误的具体成应有待研究

值得学习的地方

1.静态变量的初始化条件与初始化阶段
2.通过静态变量与子类的类别标记实现多态调用vtable,类似c++虚函数的vtable,是一种查询手段
3.多分发