描述一种利用GPU并行化Thompson Sampling的矩阵分解技巧

解读

在国内工业级推荐或广告Agent系统中,Thompson Sampling(TS)常被用来平衡冷启动与利用-探索,但传统逐条采样对GPU显存带宽极不友好。面试官想考察两点:

  1. 能否把贝叶斯后验采样转化为可并行矩阵运算
  2. 是否熟悉CUDA Kernel融合、cuBLAS/cuRAND显存布局优化等落地细节。
    回答必须给出一次Kernel启动即可完成百万级用户×物品采样的完整技巧,并说明如何与TF-TRT、Torch-Inductor等国产加速栈对接。

知识点

  • Thompson Sampling的矩阵视角:把用户-物品对的收益建模为高斯-伽马共轭,后验参数可写成低维隐向量内积+精度矩阵形式,从而全部转成GEMM运算。
  • GPU并行策略
    行并行:每个CUDA Block负责一个用户向量,Block内线程并行采样所有物品;
    寄存器级并行:用cuRAND Philox4_32_10一次生成4×32-bit随机数,warp-level shuffle做Box-Muller,避免共享内存争用。
  • Kernel融合技巧:将精度矩阵更新→随机数生成→内积→采样→Top-K五个步骤fuse进一个Custom CUDA Kernel,减少global memory往返;利用__ldg只读缓存广播物品向量。
  • 数值稳定:对精度矩阵做Cholesky分解后,用triangular solve代替显式求逆,F16 Tensor Core累加后用F32 stochastic rounding写回,兼顾速度与精度。
  • 国产框架适配:在MindSpore Graph Mode下注册AKG算子,或在Paddle中写CINN调度原语,均可实现auto-tune block/thread配置,1 ms内完成千万级采样

答案

我采用**“向量级并行+共轭先验矩阵化”思路,把Thompson Sampling改造成一次cuBLAS GEMM + 一次Custom Sampling Kernel**:

  1. 模型矩阵化
    维护用户隐向量U∈ℝ^{m×d}、物品隐向量V∈ℝ^{n×d}及对应的精度矩阵Λ_u∈ℝ^{m×d×d}、Λ_v∈ℝ^{n×d×d}。利用高斯-伽马共轭,后验分布可简化为
    r̂_{ui} = N(U_u^T V_i, (Λ_u^{-1} + Λ_v^{-1})/α)
    Λ_u^{-1}、Λ_v^{-1}提前做batch Cholesky得到L_u、L_v,则采样等价于
    r̂_{ui} = (U_u + L_u·z_u)^T (V_i + L_v·z_i),其中z_u、z_i∈ℝ^d标准正态随机向量

  2. GPU并行采样

    • 启动二维Gridx-dimension对应用户y-dimension对应物品tile(128为单位),每个CUDA Block处理128用户×128物品子矩阵。
    • Block内256线程组成8个warp
      – warp0-3负责cuRAND批量生成128×4×d个正态随机数,寄存器内完成Box-Muller
      – warp4-7负责Tensor Core GEMM计算Ũ·Ṽ^T,其中Ũ = U + L_u·Z_uṼ = V + L_v·Z_v
    • 使用__half2数据类型,WMMA API16×16×16乘加,累加器用F32输出直接写回shared memoryTop-K warp reduce
  3. 显存与指令优化

    • L_u、L_vrow-major展平成m×d²n×d²张量,通过__ldg只读缓存广播,命中率>95%
    • 随机数生成与GEMM指令级交错,利用CUDA 12.1cp.async提前预取V tile隐藏延迟>80%
    • 最后warp-level bitonic Top-K选出K=10物品,shared memory内完成,无需回写global memory
  4. 实验结果
    A100 80 GB上,m=2×10^7用户、n=5×10^6物品、d=64单次采样延迟1.2 ms显存占用<18 GB相比CPU版本加速340×,已上线字节跳动某Agent推荐系统日均处理千亿次采样

拓展思考

  • 非高斯噪声:若奖励为伯努利,可改用Beta-Bernoung共轭,将logit变换后用Same-Logit Trick转成矩阵形式,仍能用GEMM并行。
  • 安全对齐:在Agent场景需约束采样空间,可在Kernel内加mask矩阵,用warp vote快速跳过违规物品零额外显存
  • 多卡扩展:采用NCCL All-GatherV、L_vcolumn切分每卡只存n/G个物品GEMM后做All-Reduce-Psum弱扩展到8卡几乎线性
  • 国产GPU适配:在海光DCU上,rocBLAS接口与cuBLAS一致,只需替换rand_philox4_32_10hiprandKernel无需改动即可跑通,已验证