mlc-python
#
MLC-Python 是一款以 Python 为核心的工具包,它通过提供 Python 风格的数据类、结构感知工具以及基于 Python 的文本格式,极大地简化了 AI 编译器、运行时环境和复合 AI 系统的开发流程。
除了纯粹的 Python 环境外,MLC 还原生支持与 C++ 插件的零拷贝互操作,使得从纯 Python 开发到混合开发乃至无 Python 依赖的工程实践过渡变得顺畅无阻。
📥 安装#
pip install -U mlc-python
🔑 关键特性#
🏗️ 使用 MLC 数据类定义 IRs#
MLC提供了 Python 风格的数据类:
import mlc.dataclasses as mlcd
@mlcd.py_class("demo.MyClass")
class MyClass(mlcd.PyClass):
a: int
b: str
c: float | None
instance = MyClass(12, "test", c=None)
instance
demo.MyClass(a=12, b='test', c=None)
类型安全。MLC 数据类通过 Cython 和 C++ 强制执行严格的类型检查。
instance.c = 10
instance
demo.MyClass(a=12, b='test', c=10.0)
instance.c = "wrong type"
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[4], line 1
----> 1 instance.c = "wrong type"
File /media/pc/data/lxw/ai/mlc-python/python/mlc/_cython/base.py:306, in attach_field.<locals>.fset(this, value, _name)
305 def fset(this: typing.Any, value: typing.Any, _name: str = name) -> None:
--> 306 setter(this, value)
File core.pyx:1160, in mlc._cython.core._type_field_accessor.g()
TypeError: must be real number, not str
instance.non_exist = 1
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[5], line 1
----> 1 instance.non_exist = 1
AttributeError: 'MyClass' object has no attribute 'non_exist'
序列化。MLC 数据类支持 pickle
序列化和 JSON 序列化。
MyClass.from_json(instance.json())
demo.MyClass(a=12, b='test', c=10.0)
import pickle
pickle.loads(pickle.dumps(instance))
demo.MyClass(a=12, b='test', c=10.0)
🐍 为 IRs 设计基于 Python 的文本格式#
打印机。MLC 查找方法 __ir_print__
以将 IR 节点转换为 Python AST:
[示例]。将玩具 IR 定义复制到 REPL 中,然后创建下面的 Func
节点:
from mlc.testing.toy_ir.ir import Var, Func, Assign, Add
a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e")
f = Func(
"f", [a, b, c],
stmts=[
Assign(lhs=d, rhs=Add(a, b)), # d = a + b
Assign(lhs=e, rhs=Add(d, c)), # e = d + c
],
ret=e
)
方法 mlc.printer.to_python()
将 IR 节点转换为基于 Python 的文本;
import mlc
print(mlc.printer.to_python(f)) # 字符串化为Python
def f(a, b, c):
d = a + b
e = d + c
return e
方法 mlc.printer.print_python
进一步以适当的语法高亮渲染文本。
mlc.printer.print_python(f) # 语法高亮
def f(a, b, c):
d = a + b
e = d + c
return e
AST 解析器。MLC 提供了一套简洁的 API,用于利用 Python 的 AST 模块实现解析器,包括:
检查 API,用于获取 Python 类或函数的源代码及其捕获的变量;
变量管理 API,有助于正确处理作用域;
AST 片段评估 API;
错误渲染 API。
[示例]。借助 MLC API,可以用 100 行代码实现解析器,用于解析上述由 __ir_printer__
定义的 Python 文本格式。
🎯 使用 MLC 结构感知工具测试 IRs#
通过在 IR 定义中标注 structure
,MLC 支持结构相等性和结构哈希,以检测 IR 之间的结构等价性:
使用 structure
定义玩具 IR:
import mlc.dataclasses as mlcd
@mlcd.py_class
class Expr(mlcd.PyClass):
def __add__(self, other):
return Add(a=self, b=other)
@mlcd.py_class(structure="nobind")
class Add(Expr):
a: Expr
b: Expr
@mlcd.py_class(structure="var")
class Var(Expr):
name: str = mlcd.field(structure=None) # excludes `name` from defined structure
@mlcd.py_class(structure="bind")
class Let(Expr):
rhs: Expr
lhs: Var = mlcd.field(structure="bind") # `Let.lhs` is the def-site
body: Expr
结构相等性。成员方法 eq_s
用于比较由MLC的结构化数据类表示的两个IR(中间表示)的结构相等性(即 alpha 等价)。
x, y, z = Var("x"), Var("y"), Var("z")
L1 = Let(rhs=x + y, lhs=z, body=z) # let z = x + y; z
L2 = Let(rhs=y + z, lhs=x, body=x) # let x = y + z; x
L3 = Let(rhs=x + x, lhs=z, body=z) # let z = x + x; z
L1.eq_s(L2)
True
L1.eq_s(L3, assert_mode=True)
Show code cell output
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[13], line 1
----> 1 L1.eq_s(L3, assert_mode=True)
File /media/pc/data/lxw/ai/mlc-python/python/mlc/core/object.py:30, in Object.eq_s(self, other, bind_free_vars, assert_mode)
23 def eq_s(
24 self,
25 other: Object,
(...)
28 assert_mode: bool = False,
29 ) -> bool:
---> 30 return PyAny._mlc_eq_s(self, other, bind_free_vars, assert_mode)
File core.pyx:339, in mlc._cython.core.PyAny._mlc_eq_s()
File core.pyx:1219, in mlc._cython.core.func_call()
File core.pyx:653, in mlc._cython.core._func_call_impl()
File core.pyx:644, in mlc._cython.core._func_call_impl_with_c_args()
File core.pyx:286, in mlc._cython.core._check_error_from()
File /media/pc/data/lxw/ai/mlc-python/include/mlc/core/./func.h:30, in mlc::FuncObj::SafeCallImpl(mlc::FuncObj const*, int, mlc::AnyView const*, mlc::Any*)()
28 static int32_t SafeCallImpl(const FuncObj *self, int32_t num_args, const AnyView *args, Any *ret) {
29 MLC_SAFE_CALL_BEGIN();
---> 30 self->call(self, num_args, args, ret);
31 MLC_SAFE_CALL_END(ret);
32 }
File /media/pc/data/lxw/ai/mlc-python/include/mlc/core/./func_details.h:145, in void mlc::core::FuncCallUnpacked<bool (*)(mlc::Object*, mlc::Object*, bool, bool)>(mlc::FuncObj const*, int, mlc::AnyView const*, mlc::Any*)()
143 using IdxSeq = std::make_index_sequence<N>;
144 using RetType = typename FuncCanonicalize<FuncType>::RetType;
--> 145 UnpackCall<RetType, typename Traits::ArgType>::template Run<FuncType>(
146 &static_cast<const FuncImpl<FuncType> *>(obj)->func_, args, ret, IdxSeq{});
147 }
File /media/pc/data/lxw/ai/mlc-python/include/mlc/core/./func_details.h:123, in void mlc::core::UnpackCall<bool, std::tuple<mlc::Object*, mlc::Object*, bool, bool> >::Run<bool (*)(mlc::Object*, mlc::Object*, bool, bool), bool (*)(mlc::Object*, mlc::Object*, bool, bool), 0ul, 1ul, 2ul, 3ul>(bool (**)(mlc::Object*, mlc::Object*, bool, bool), mlc::AnyView const*, mlc::Any*, std::integer_sequence<unsigned long, 0ul, 1ul, 2ul, 3ul>)()
121 (*func)(CVT::template AsType<Args, I>::Run(args[I], nullptr)...);
122 } else if constexpr (Storage::total == 0 && !std::is_void_v<RetType>) {
--> 123 *ret = (*func)(CVT::template AsType<Args, I>::Run(args[I], nullptr)...);
124 }
125 }
File /media/pc/data/lxw/ai/mlc-python/cpp/structure.cc:35, in StructuralEqual()
33 std::ostringstream os;
34 os << "Structural equality check failed at " << e.path << ": " << e.what();
---> 35 MLC_THROW(ValueError) << os.str();
36 }
37 }
ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent binding. RHS has been bound to a different node while LHS is not bound
结构哈希。MLC 数据类的结构可以通过 hash_s
进行哈希计算,这保证了如果两个数据类是 alpha 等价的,它们将具有相同的结构哈希值:
L1_hash, L2_hash, L3_hash = L1.hash_s(), L2.hash_s(), L3.hash_s()
assert L1_hash == L2_hash
assert L1_hash != L3_hash
⚡ 逐步迁移到 C++ 并使用 MLC 插件#
(🚧 正在建设中)
MLC 无缝支持与 C++ 插件的零拷贝双向互操作性,且无需额外依赖。通过逐步迁移类和方法,可以将纯 Python 原型过渡到混合开发或无 Python 开发模式。
⛽ 开发#
⚙️ 可编辑构建#
pip install --verbose --editable ".[dev]"
pre-commit install
🎡 构建 Wheels#
本项目使用 cibuildwheel
来构建跨平台的 wheels。更多详情请参阅 .github/workflows/wheels.yml
。