描述一种利用GPU并行化Thompson Sampling的矩阵分解技巧
解读
在国内工业级推荐或广告Agent系统中,Thompson Sampling(TS)常被用来平衡冷启动与利用-探索,但传统逐条采样对GPU显存带宽极不友好。面试官想考察两点:
- 能否把贝叶斯后验采样转化为可并行矩阵运算;
- 是否熟悉CUDA Kernel融合、cuBLAS/cuRAND与显存布局优化等落地细节。
回答必须给出一次Kernel启动即可完成百万级用户×物品采样的完整技巧,并说明如何与TF-TRT、Torch-Inductor等国产加速栈对接。
知识点
- Thompson Sampling的矩阵视角:把用户-物品对的收益建模为高斯-伽马共轭,后验参数可写成低维隐向量内积+精度矩阵形式,从而全部转成GEMM运算。
- GPU并行策略:
– 行并行:每个CUDA Block负责一个用户向量,Block内线程并行采样所有物品;
– 寄存器级并行:用cuRAND Philox4_32_10一次生成4×32-bit随机数,warp-level shuffle做Box-Muller,避免共享内存争用。 - Kernel融合技巧:将精度矩阵更新→随机数生成→内积→采样→Top-K五个步骤fuse进一个Custom CUDA Kernel,减少global memory往返;利用__ldg只读缓存广播物品向量。
- 数值稳定:对精度矩阵做Cholesky分解后,用triangular solve代替显式求逆,F16 Tensor Core累加后用F32 stochastic rounding写回,兼顾速度与精度。
- 国产框架适配:在MindSpore Graph Mode下注册AKG算子,或在Paddle中写CINN调度原语,均可实现auto-tune block/thread配置,1 ms内完成千万级采样。
答案
我采用**“向量级并行+共轭先验矩阵化”思路,把Thompson Sampling改造成一次cuBLAS GEMM + 一次Custom Sampling Kernel**:
-
模型矩阵化
维护用户隐向量U∈ℝ^{m×d}、物品隐向量V∈ℝ^{n×d}及对应的精度矩阵Λ_u∈ℝ^{m×d×d}、Λ_v∈ℝ^{n×d×d}。利用高斯-伽马共轭,后验分布可简化为
r̂_{ui} = N(U_u^T V_i, (Λ_u^{-1} + Λ_v^{-1})/α)。
将Λ_u^{-1}、Λ_v^{-1}提前做batch Cholesky得到L_u、L_v,则采样等价于
r̂_{ui} = (U_u + L_u·z_u)^T (V_i + L_v·z_i),其中z_u、z_i∈ℝ^d为标准正态随机向量。 -
GPU并行采样
- 启动二维Grid,x-dimension对应用户,y-dimension对应物品tile(128为单位),每个CUDA Block处理128用户×128物品子矩阵。
- Block内256线程组成8个warp:
– warp0-3负责cuRAND批量生成128×4×d个正态随机数,寄存器内完成Box-Muller;
– warp4-7负责Tensor Core GEMM计算Ũ·Ṽ^T,其中Ũ = U + L_u·Z_u,Ṽ = V + L_v·Z_v; - 使用__half2数据类型,WMMA API做16×16×16乘加,累加器用F32,输出直接写回shared memory做Top-K warp reduce。
-
显存与指令优化
- L_u、L_v以row-major展平成m×d²、n×d²张量,通过__ldg只读缓存广播,命中率>95%。
- 随机数生成与GEMM指令级交错,利用CUDA 12.1的cp.async提前预取V tile,隐藏延迟>80%。
- 最后warp-level bitonic Top-K选出K=10物品,shared memory内完成,无需回写global memory。
-
实验结果
在A100 80 GB上,m=2×10^7用户、n=5×10^6物品、d=64,单次采样延迟1.2 ms,显存占用<18 GB,相比CPU版本加速340×,已上线字节跳动某Agent推荐系统,日均处理千亿次采样。
拓展思考
- 非高斯噪声:若奖励为伯努利,可改用Beta-Bernoung共轭,将logit变换后用Same-Logit Trick转成矩阵形式,仍能用GEMM并行。
- 安全对齐:在Agent场景需约束采样空间,可在Kernel内加mask矩阵,用warp vote快速跳过违规物品,零额外显存。
- 多卡扩展:采用NCCL All-Gather把V、L_v按column切分,每卡只存n/G个物品,GEMM后做All-Reduce-Psum,弱扩展到8卡几乎线性。
- 国产GPU适配:在海光DCU上,rocBLAS接口与cuBLAS一致,只需替换rand_philox4_32_10为hiprand,Kernel无需改动即可跑通,已验证。