构建支持 ROCm 的 PyTorch 镜像并运行 Llama 模型

解读

在国内 GPU 资源受限、国产化加速的背景下,AMD ROCm 生态成为不少企业替代 CUDA 的低成本方案。面试官抛出此题,意在验证候选人能否:

  1. 非 CUDA 的 GPU 驱动栈(ROCm)封装进容器,并保证宿主机与容器内核版本、驱动、设备权限一致;
  2. 多阶段构建把 8 GB+ 的 PyTorch+ROCm 依赖压缩到可交付大小;
  3. 解决HIP 可见性、共享内存、IPC 隔离等运行时坑;
  4. 把 7 B/13 B 的 Llama 模型权重以只读层方式挂载,避免镜像膨胀;
  5. 非 root 用户 + seccomp + 只读根文件系统通过国企安全基线扫描;
  6. 给出一键 CI 脚本,让测试同事在 Swarm/K8s 上直接 docker stack deploy 即可拉起分布式推理服务。

知识点

  • ROCm 内核驱动与容器 ABI 兼容性:宿主机需安装 amdgpu-dkms,容器内仅需 rocm-dev 用户态库,内核版本必须 ≥ 5.4 且与驱动一致
  • 设备直通--device=/dev/kfd --device=/dev/dri --group-add video 三个参数缺一不可,缺失 kfd 会导致 HIP 初始化失败
  • 多阶段构建:先用 rocm/pytorch:rocm5.7_ubuntu22.04_py3.10 做编译阶段,再用 ubuntu:22.04 拷贝 whl 与 .so,镜像体积可从 11 GB 降到 3.2 GB
  • 国内源加速:把 pip 源换成清华、apt 源换成中科大,Dockerfile 里必须加 RUN sed -i 's@http://.*.ubuntu.com@http://mirrors.ustc.edu.cn@' /etc/apt/sources.list,否则构建时长翻倍。
  • 安全加固:创建 uid=1001torch 用户,/home/torch/.cache 挂成 tmpfs,防止写权限逃逸;用 --security-opt no-new-privileges 关闭提权。
  • Llama 权重管理:权重放宿主机的 nfs://models/llama-7b,容器内挂为 /data/model:roDockerfile 绝不 COPY 权重,避免镜像分发违规。
  • CI/CD 集成:在 GitLab-CI 中用 kaniko 构建,缓存层放到华为云 S3 兼容对象存储,每次只推送变更层,构建时间从 25 min 降到 4 min
  • 故障排查:若出现 hipErrorNoBinaryForGpu,99% 是GPU 架构与 PYTORCH_ROCM_ARCH 不匹配,需在构建阶段显式 export PYTORCH_ROCM_ARCH=gfx90a

答案

  1. 宿主机准备(CentOS 8 Stream 示例)
sudo dnf install -y amdgpu-dkms rocm-dev
sudo usermod -aG video,render $USER
reboot

验证:/dev/kfd/dev/dri/card* 存在,rocminfo 能看到 GPU。

  1. 项目结构
llama-rocm/
├── Dockerfile
├── docker-compose.yml
├── entrypoint.sh
├── requirements.txt
└── scripts/
    └── download_model.sh   # 从国内镜像站拉取权重
  1. Dockerfile(多阶段 + 国内源 + 非 root)
# 阶段1:编译+安装
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10 AS builder
RUN sed -i 's@http://.*.ubuntu.com@http://mirrors.ustc.edu.cn@' /etc/apt/sources.list && \
    apt-get update && apt-get install -y --no-install-recommends git ninja-build
COPY requirements.txt /tmp/
RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r /tmp/requirements.txt

# 阶段2:运行时
FROM ubuntu:22.04
RUN sed -i 's@http://.*.ubuntu.com@http://mirrors.ustc.edu.cn@' /etc/apt/sources.list && \
    apt-get update && \
    apt-get install -y --no-install-recommends rocm-libs5.7 hip-runtime-amd5.7 python3.10 python3-pip && \
    apt-get clean && rm -rf /var/lib/apt/lists/*
COPY --from=builder /usr/local/lib/python3.10/dist-packages /usr/local/lib/python3.10/dist-packages
COPY --from=builder /opt/rocm-5.7 /opt/rocm-5.7
ENV PATH=/opt/rocm-5.7/bin:$PATH \
    LD_LIBRARY_PATH=/opt/rocm-5.7/lib:$LD_LIBRARY_PATH \
    HIP_VISIBLE_DEVICES=0 \
    PYTORCH_ROCM_ARCH=gfx90a

RUN groupadd -g 1001 torch && useradd -u 1001 -g 1001 -m -s /bin/bash torch
USER torch
WORKDIR /home/torch
COPY entrypoint.sh /home/torch/
ENTRYPOINT ["bash","/home/torch/entrypoint.sh"]
  1. docker-compose.yml(开发环境一键起)
version: "3.9"
services:
  llama:
    build: .
    image: registry.internal.com/llama-rocm:7b-v1
    devices:
      - /dev/kfd
      - /dev/dri
    group_add:
      - video
    security_opt:
      - no-new-privileges:true
    read_only: true
    tmpfs:
      - /tmp
      - /home/torch/.cache
    volumes:
      - /data/models/llama-7b:/data/model:ro
    environment:
      - MODEL_PATH=/data/model
      - MAX_GPU_MEMORY=24GiB
    ports:
      - "8080:8080"
    deploy:
      resources:
        limits:
          memory: 32G
  1. entrypoint.sh
#!/bin/bash
set -e
echo "HIP devices:"
/opt/rocm-5.7/bin/rocminfo | grep -E 'Name|Marketing'
exec python3 -m torch.distributed.run --nproc_per_node=1 server.py
  1. 构建与运行
docker build -t registry.internal.com/llama-rocm:7b-v1 .
docker push registry.internal.com/llama-rocm:7b-v1
docker compose up -d

验证:curl -X POST http://localhost:8080/generate -d '{"prompt":"中国的AI发展"}' 能返回正常文本,显存占用 19 GB 左右即达标。

拓展思考

  • 国产化替代:若客户使用海光 7xxx(兼容 ROCm),只需把 PYTORCH_ROCM_ARCH 改成 gfx906同一套镜像无需重编,体现 Docker 一次构建随处运行的价值。
  • 多卡并行:在 Swarm 里用 docker service create --replicas 2 --generic-resource "gpu=1"把调度器换成支持 AMD GPU 的 HAMi 插件,可实现 2×13 B 模型并行推理。
  • 安全合规:国企审计要求镜像 ≤ 100 个 CVE、不允许 root 启动,可在 CI 里加 trivy 扫描,把 CRITICAL 漏洞修复后再推送到生产 Harbor 仓库
  • 冷启动优化:Llama 7 B 权重 13 GB,stargz-snapshotter 做延迟拉取,节点重启后 8 s 内即可提供服务,满足在线业务 SLA。
  • 混合精度与内存压缩:在 entrypoint 里注入 export PYTORCH_HIP_ALLOC_CONF=expandable_segments:True可把峰值显存再降 15%,单卡即可跑 13 B 模型,节省硬件预算。