从训练到部署
模型训练完成后,将其部署到生产环境提供服务是整个 AI 流程中的关键一步。部署面临的挑战包括:性能优化、资源约束、延迟要求以及跨平台兼容性。
典型的部署流程:
- 训练模型(PyTorch / TensorFlow)
- 导出为中间格式(ONNX)
- 针对目标硬件优化(TensorRT、CoreML)
- 部署到服务器或边缘设备
- 持续监控与迭代
ONNX
ONNX(Open Neural Network Exchange)是微软和 Facebook 联合推出的开放神经网络交换格式。它的目标是让模型可以在不同的深度学习框架之间自由迁移。
ONNX 的核心优势
- 框架互操作性 — PyTorch 训练的模型可以导出为 ONNX,再导入到 TensorRT 或 CoreML
- 推理优化 — ONNX Runtime 提供了跨平台的推理引擎
- 算子标准化 — 定义了统一的算子集合
import torch
import torch.onnx
# 导出 PyTorch 模型到 ONNX
model = torch.load("model.pth")
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
export_params=True,
opset_version=17,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}}
)
TensorRT
TensorRT 是 NVIDIA 推出的深度学习推理优化引擎,专门针对 NVIDIA GPU 进行极致优化。
主要优化技术
- 量化(Quantization) — 将 FP32 精度降为 FP16 或 INT8,大幅减少计算量和显存
- 层融合(Layer Fusion) — 合并相邻的层,减少 kernel 启动开销
- 内核自动调优(Kernel Auto-Tuning) — 针对具体 GPU 架构选择最快的 kernel
- 张量内存复用 — 优化推理过程中的内存分配
import tensorrt as trt
# 构建 TensorRT 引擎
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
with open("model.onnx", "rb") as f:
parser.parse(f.read())
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16) # 启用 FP16
config.set_memory_pool_limit(
trt.MemoryPoolType.WORKSPACE, 1 << 30
) # 1GB
serialized_engine = builder.build_serialized_network(network, config)
with open("model.trt", "wb") as f:
f.write(serialized_engine)
推理优化的一般方法
模型层面
- 剪枝(Pruning) — 移除不重要的连接或通道
- 蒸馏(Knowledge Distillation) — 用大模型指导小模型训练
- 量化(Quantization) — 降低数值精度
- 轻量化架构 — 使用 MobileNet、EfficientNet-Lite 等
系统层面
- 批处理(Batching) — 合并多个请求一起推理
- 异步推理 — 推理与 IO 操作重叠
- 模型分片 — 超大模型分布在多张 GPU 上
边缘部署
边缘部署指的是将 AI 模型运行在资源受限的设备上,如手机、摄像头、IoT 设备。
主流边缘推理框架
- TensorFlow Lite — Google 出品,支持 Android / iOS / 嵌入式 Linux
- CoreML — Apple 生态专用
- ONNX Runtime Mobile — 跨平台移动端推理
- NVIDIA Jetson — 边缘端 GPU 平台,配合 TensorRT 使用
- OpenVINO — Intel CPU / VPU 专用优化
# 使用 ONNX Runtime Mobile 进行边缘推理
import onnxruntime as ort
providers = [
"CoreMLExecutionProvider", # iOS
"CPUExecutionProvider"
]
session = ort.InferenceSession(
"model.onnx", providers=providers
)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
result = session.run([output_name], {input_name: input_data})
模型服务框架
- Triton Inference Server — NVIDIA 出品,支持多模型、多框架、动态批处理
- TorchServe — PyTorch 官方模型服务框架
- MLflow — 涵盖训练、部署、管理的全生命周期平台
选择合适的部署方案需要综合考虑硬件成本、延迟要求、吞吐量和开发维护成本。