Punica:Multi-LoRA 推理优化
Punica 将 Multi-LoRA 推理中按请求执行的 for-loop 重写为分段聚合的矩阵向量计算(SGMV),在一个融合 kernel 中同时处理多个 LoRA Adapter,从而避免小规模 GEMV 的频繁调度,显著提升 GPU 利用率。
Multi-LoRA 问题
在 LoRA 和 Multi-LoRA:模型参数微调 中,我们将 Multi-LoRA 推理形式化为:
然而在实际推理中,这一表达需要细化到 token 级别:对于 batch 中的每个输入 ,其对应的 LoRA Adapter(即 )可能不同。因此,不同 token 实际使用的权重扰动是不同的,这使得原本可以统一表示的矩阵乘法退化为一组“每个样本使用不同权重”的不规则计算。
这意味着我们需要一个 for-loop 来执行,这里假设每个 token 只激活一个 LoRA:
h_result = W_0 @ x
for token_j in x:
lora_id = routing(token_j) # 每个 token 选择一个 LoRA adapter
h_lora = B[lora_id] @ (A[lora_id] @ token_j)
h_result[j] += h_loraFor-loop 会导致大量小规模矩阵向量乘法与 kernel 调度开销,严重影响 GPU 利用率。
Punica
Punica 将这种逐 token / 逐 adapter 的不规则计算重新组织为一个 fused segmented GEMV(SGMV)操作。
从逻辑上看,可以将 batch 中的 token 按其对应的 LoRA adapter 进行分组,即对于每个 adapter ,收集其对应的输入集合 。然而,Punica 并不会对每个 adapter 显式执行一个 for-loop 并分别调用 GEMM。
相反,Punica 通过一个融合的 SGMV kernel,在单次 kernel 调用中并行处理所有 token。对于每个 token ,根据其对应的 adapter id ,动态选择对应的参数 ,并完成如下计算:
- 和统一的 GEMV 的区别:不同 token 对应的参数不同,需要在 runtime 动态推导对应参数
- 在每次 kernel 调用中,并不会对 进行显式的重排或组合,而是根据每个 token 对应的 LoRA adapter id,在计算过程中动态索引并访问对应的参数。
该过程在 GPU 内部以 token 为粒度并行执行,从而避免了逐 adapter 的多次 kernel 调度,实现了对不规则计算的高效融合。