mlc-python#

MLC-Python 是一款以 Python 为核心的工具包,它通过提供 Python 风格的数据类、结构感知工具以及基于 Python 的文本格式,极大地简化了 AI 编译器、运行时环境和复合 AI 系统的开发流程。

除了纯粹的 Python 环境外,MLC 还原生支持与 C++ 插件的零拷贝互操作,使得从纯 Python 开发到混合开发乃至无 Python 依赖的工程实践过渡变得顺畅无阻。

📥 安装#

pip install -U mlc-python

🔑 关键特性#

🏗️ 使用 MLC 数据类定义 IRs#

MLC提供了 Python 风格的数据类:

import mlc.dataclasses as mlcd

@mlcd.py_class("demo.MyClass")
class MyClass(mlcd.PyClass):
  a: int
  b: str
  c: float | None

instance = MyClass(12, "test", c=None)
instance
demo.MyClass(a=12, b='test', c=None)

类型安全。MLC 数据类通过 Cython 和 C++ 强制执行严格的类型检查。

instance.c = 10
instance
demo.MyClass(a=12, b='test', c=10.0)
instance.c = "wrong type"
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 instance.c = "wrong type"

File /media/pc/data/lxw/ai/mlc-python/python/mlc/_cython/base.py:306, in attach_field.<locals>.fset(this, value, _name)
    305 def fset(this: typing.Any, value: typing.Any, _name: str = name) -> None:
--> 306     setter(this, value)

File core.pyx:1160, in mlc._cython.core._type_field_accessor.g()

TypeError: must be real number, not str
instance.non_exist = 1
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[5], line 1
----> 1 instance.non_exist = 1

AttributeError: 'MyClass' object has no attribute 'non_exist'

序列化。MLC 数据类支持 pickle 序列化和 JSON 序列化。

MyClass.from_json(instance.json())
demo.MyClass(a=12, b='test', c=10.0)
import pickle
pickle.loads(pickle.dumps(instance))
demo.MyClass(a=12, b='test', c=10.0)

🐍 为 IRs 设计基于 Python 的文本格式#

打印机。MLC 查找方法 __ir_print__ 以将 IR 节点转换为 Python AST:

[示例]。将玩具 IR 定义复制到 REPL 中,然后创建下面的 Func 节点:

from mlc.testing.toy_ir.ir import Var, Func, Assign, Add
a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e")
f = Func(
    "f", [a, b, c],
    stmts=[
        Assign(lhs=d, rhs=Add(a, b)),  # d = a + b
        Assign(lhs=e, rhs=Add(d, c)),  # e = d + c
    ],
    ret=e
)

方法 mlc.printer.to_python() 将 IR 节点转换为基于 Python 的文本;

import mlc
print(mlc.printer.to_python(f))  # 字符串化为Python
def f(a, b, c):
  d = a + b
  e = d + c
  return e

方法 mlc.printer.print_python 进一步以适当的语法高亮渲染文本。

mlc.printer.print_python(f) # 语法高亮
def f(a, b, c):
  d = a + b
  e = d + c
  return e

AST 解析器。MLC 提供了一套简洁的 API,用于利用 Python 的 AST 模块实现解析器,包括:

  • 检查 API,用于获取 Python 类或函数的源代码及其捕获的变量;

  • 变量管理 API,有助于正确处理作用域;

  • AST 片段评估 API;

  • 错误渲染 API。

[示例]。借助 MLC API,可以用 100 行代码实现解析器,用于解析上述由 __ir_printer__ 定义的 Python 文本格式。

🎯 使用 MLC 结构感知工具测试 IRs#

通过在 IR 定义中标注 structure,MLC 支持结构相等性和结构哈希,以检测 IR 之间的结构等价性:

使用 structure 定义玩具 IR:

import mlc.dataclasses as mlcd

@mlcd.py_class
class Expr(mlcd.PyClass):
    def __add__(self, other):
        return Add(a=self, b=other)

@mlcd.py_class(structure="nobind")
class Add(Expr):
    a: Expr
    b: Expr

@mlcd.py_class(structure="var")
class Var(Expr):
    name: str = mlcd.field(structure=None) # excludes `name` from defined structure

@mlcd.py_class(structure="bind")
class Let(Expr):
    rhs: Expr
    lhs: Var = mlcd.field(structure="bind") # `Let.lhs` is the def-site
    body: Expr

结构相等性。成员方法 eq_s 用于比较由MLC的结构化数据类表示的两个IR(中间表示)的结构相等性(即 alpha 等价)。

x, y, z = Var("x"), Var("y"), Var("z")
L1 = Let(rhs=x + y, lhs=z, body=z)  # let z = x + y; z
L2 = Let(rhs=y + z, lhs=x, body=x)  # let x = y + z; x
L3 = Let(rhs=x + x, lhs=z, body=z)  # let z = x + x; z
L1.eq_s(L2)
True
L1.eq_s(L3, assert_mode=True)
Hide code cell output
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 1
----> 1 L1.eq_s(L3, assert_mode=True)

File /media/pc/data/lxw/ai/mlc-python/python/mlc/core/object.py:30, in Object.eq_s(self, other, bind_free_vars, assert_mode)
     23 def eq_s(
     24     self,
     25     other: Object,
   (...)
     28     assert_mode: bool = False,
     29 ) -> bool:
---> 30     return PyAny._mlc_eq_s(self, other, bind_free_vars, assert_mode)

File core.pyx:339, in mlc._cython.core.PyAny._mlc_eq_s()

File core.pyx:1219, in mlc._cython.core.func_call()

File core.pyx:653, in mlc._cython.core._func_call_impl()

File core.pyx:644, in mlc._cython.core._func_call_impl_with_c_args()

File core.pyx:286, in mlc._cython.core._check_error_from()

File /media/pc/data/lxw/ai/mlc-python/include/mlc/core/./func.h:30, in mlc::FuncObj::SafeCallImpl(mlc::FuncObj const*, int, mlc::AnyView const*, mlc::Any*)()
     28 static int32_t SafeCallImpl(const FuncObj *self, int32_t num_args, const AnyView *args, Any *ret) {
     29   MLC_SAFE_CALL_BEGIN();
---> 30   self->call(self, num_args, args, ret);
     31   MLC_SAFE_CALL_END(ret);
     32 }

File /media/pc/data/lxw/ai/mlc-python/include/mlc/core/./func_details.h:145, in void mlc::core::FuncCallUnpacked<bool (*)(mlc::Object*, mlc::Object*, bool, bool)>(mlc::FuncObj const*, int, mlc::AnyView const*, mlc::Any*)()
    143   using IdxSeq = std::make_index_sequence<N>;
    144   using RetType = typename FuncCanonicalize<FuncType>::RetType;
--> 145   UnpackCall<RetType, typename Traits::ArgType>::template Run<FuncType>(
    146       &static_cast<const FuncImpl<FuncType> *>(obj)->func_, args, ret, IdxSeq{});
    147 }

File /media/pc/data/lxw/ai/mlc-python/include/mlc/core/./func_details.h:123, in void mlc::core::UnpackCall<bool, std::tuple<mlc::Object*, mlc::Object*, bool, bool> >::Run<bool (*)(mlc::Object*, mlc::Object*, bool, bool), bool (*)(mlc::Object*, mlc::Object*, bool, bool), 0ul, 1ul, 2ul, 3ul>(bool (**)(mlc::Object*, mlc::Object*, bool, bool), mlc::AnyView const*, mlc::Any*, std::integer_sequence<unsigned long, 0ul, 1ul, 2ul, 3ul>)()
    121     (*func)(CVT::template AsType<Args, I>::Run(args[I], nullptr)...);
    122   } else if constexpr (Storage::total == 0 && !std::is_void_v<RetType>) {
--> 123     *ret = (*func)(CVT::template AsType<Args, I>::Run(args[I], nullptr)...);
    124   }
    125 }

File /media/pc/data/lxw/ai/mlc-python/cpp/structure.cc:35, in StructuralEqual()
     33     std::ostringstream os;
     34     os << "Structural equality check failed at " << e.path << ": " << e.what();
---> 35     MLC_THROW(ValueError) << os.str();
     36   }
     37 }

ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent binding. RHS has been bound to a different node while LHS is not bound

结构哈希。MLC 数据类的结构可以通过 hash_s 进行哈希计算,这保证了如果两个数据类是 alpha 等价的,它们将具有相同的结构哈希值:

L1_hash, L2_hash, L3_hash = L1.hash_s(), L2.hash_s(), L3.hash_s()
assert L1_hash == L2_hash
assert L1_hash != L3_hash

⚡ 逐步迁移到 C++ 并使用 MLC 插件#

(🚧 正在建设中)

MLC 无缝支持与 C++ 插件的零拷贝双向互操作性,且无需额外依赖。通过逐步迁移类和方法,可以将纯 Python 原型过渡到混合开发或无 Python 开发模式。

⛽ 开发#

⚙️ 可编辑构建#

pip install --verbose --editable ".[dev]"
pre-commit install

🎡 构建 Wheels#

本项目使用 cibuildwheel 来构建跨平台的 wheels。更多详情请参阅 .github/workflows/wheels.yml