月之暗面開源FlashKDA,Kimi Linear推理速度提升1.7到2.2倍
MMetaEra
4 月 22 日(UTC+8),據動察 Beating 監測,月之暗面在 GitHub 開源 FlashKDA,一套專門給英偉達 Hopper 系列顯示卡(H100、H20 等)加速模型推理的工具,採用 MIT 協議。
它服務的物件是 KDA,即月之暗面去年在 Kimi Linear 論文裡提出的新注意力機制。大模型讀長文字時,老式注意力的計算量會隨長度平方級膨脹,線性注意力讓這個代價降到線性增長,而 KDA 是這條路線裡的一種改良版。Kimi Linear 模型的結構是 3 層 KDA 搭 1 層老式注意力輪著用。
KDA 之前已經有一份用 Triton 語言寫的版本,掛在開源庫 flash-linear-attention(簡稱 fla)裡。FlashKDA 改用英偉達的底層 GPU 庫 CUTLASS 重寫了一遍,專門榨 Hopper 顯示卡的效能。官方在 H20 上實測,同一次前向計算,FlashKDA 比 Triton 版快 1.7 到 2.2 倍,輸入長度參差不齊、拼批次跑的場景加速尤其明顯。只是官方只跟自家 Triton 版做了對比,沒跟其他線性注意力方案比。
此次只開源了前向計算,意味著只能「跑模型」(推理),還不能「訓模型」,訓練仍得用原來的 Triton 版。使用門檻要求顯示卡需 Hopper 及之後(SM90 架構起步)、CUDA 12.9 以上、PyTorch 2.4 以上。FlashKDA 同時作為新後端合併進了 fla 上游(PR #852),老使用者切過去只要改一行配置。
[ME News]