关注

开源CV模型落地实践:cv_resnet101_face-detection_cvpr22papermogface模型版本管理方案

开源CV模型落地实践:cv_resnet101_face-detection_cvpr22papermogface模型版本管理方案

你有没有遇到过这样的场景?项目里用的人脸检测模型,今天跑得好好的,明天突然就报错了。一查才发现,同事更新了模型权重文件,或者环境里的某个依赖库版本变了。这种“昨天还能用,今天就不行”的问题,在计算机视觉项目里太常见了。

今天要聊的,就是如何把 cv_resnet101_face-detection_cvpr22papermogface 这个高性能人脸检测模型,从一个“能用”的工具,变成一个“稳定可靠”的生产级解决方案。核心思路很简单:做好版本管理

很多人觉得,模型部署就是把代码跑起来,能检测出人脸就完事了。但真正要在业务里用起来,你会发现问题一个接一个:模型权重怎么保存?不同版本怎么切换?团队协作时怎么保证大家用的都是同一个版本?今天这篇文章,我就带你一步步解决这些问题。

1. 为什么需要模型版本管理?

在深入技术方案之前,我们先搞清楚一个基本问题:为什么模型也需要版本管理?

想象一下,你开发了一个基于 MogFace 的人脸检测服务,已经稳定运行了三个月。突然有一天,产品经理说:“我们想试试新训练的模型,看看检测精度能不能再提升一点。” 你更新了模型文件,结果发现:

  1. 检测速度变慢了 20%,服务器负载飙升
  2. 某些特定场景(比如戴墨镜的人脸)检测率反而下降了
  3. 团队其他成员本地开发环境全部报错,因为他们的代码还是基于旧模型

这时候,如果你有完善的版本管理,只需要简单回滚到上一个版本,问题就解决了。如果没有,你可能要花一整天时间排查问题,业务还得停摆。

1.1 模型版本管理的核心价值

模型版本管理不是“锦上添花”,而是“雪中送炭”。它能帮你解决四个核心问题:

问题一:可复现性 今天训练好的模型,三个月后还能不能复现同样的效果?如果连模型权重都找不到了,谈何复现。

问题二:团队协作 三个人一起开发,A 用了 v1.0 的权重,B 用了 v1.1,C 本地自己微调了一个版本。最后集成测试时,结果五花八门,谁都不知道问题出在哪。

问题三:A/B 测试 想对比新模型和旧模型的效果,如果没有版本管理,你只能手动切换文件,效率低下还容易出错。

问题四:故障回滚 线上服务出问题了,如果是模型导致的,你能在多快时间内回退到稳定版本?

对于 cv_resnet101_face-detection_cvpr22papermogface 这样的生产级模型,这些问题都是实实在在会遇到的。接下来,我就分享一套经过实践检验的版本管理方案。

2. 模型文件版本管理方案

我们先从最基础的开始:模型权重文件怎么管理?

2.1 文件命名规范

混乱是从命名开始的。看看这些常见的“坏例子”:

  • model_final.pth (最终版?还有更最终的吗?)
  • model_best.pth (最好的?根据什么指标?)
  • model_v2.pth (v1 在哪?v3 什么时候来?)
  • model_20240101.pth (只有日期,没有版本号)

我推荐使用这样的命名规范:

mogface_resnet101_v{主版本}.{次版本}.{修订版本}_{日期}_{备注}.pth

举个例子:

  • mogface_resnet101_v1.0.0_20240101_initial.pth (初始版本)
  • mogface_resnet101_v1.1.0_20240215_optimized_for_small_faces.pth (优化了小脸检测)
  • mogface_resnet101_v2.0.0_20240320_retrained_on_custom_dataset.pth (用自定义数据重新训练)

这样命名的好处很明显:

  1. 一看就知道版本:v1.1.0 比 v1.0.0 新,但比 v2.0.0 旧
  2. 包含关键信息:日期告诉你什么时候发布的,备注告诉你改了啥
  3. 便于自动化:程序可以自动解析版本号,实现自动升级/回滚

2.2 版本存储结构

文件命名规范了,接下来是存储结构。不要把所有模型文件都扔在一个文件夹里,那样很快就会变成一锅粥。

我建议的目录结构是这样的:

models/
├── cv_resnet101_face-detection_cvpr22papermogface/
│   ├── v1.0.0/
│   │   ├── model.pth
│   │   ├── config.json
│   │   ├── README.md
│   │   └── performance_report.pdf
│   ├── v1.1.0/
│   │   ├── model.pth
│   │   ├── config.json
│   │   └── ...
│   ├── v2.0.0/
│   │   └── ...
│   ├── latest -> v2.0.0/  # 符号链接,指向最新版本
│   └── stable -> v1.1.0/  # 符号链接,指向稳定版本
├── other_model_1/
└── other_model_2/

这个结构有几个关键点:

版本隔离 每个版本有自己的文件夹,互不干扰。想用 v1.0.0 就进 v1.0.0 文件夹,想用 v1.1.0 就进 v1.1.0 文件夹。

符号链接 latest 指向最新版本,stable 指向经过充分测试的稳定版本。这样在代码里可以这样引用:

# 开发环境用最新版
model_path = "models/cv_resnet101_face-detection_cvpr22papermogface/latest/model.pth"

# 生产环境用稳定版
model_path = "models/cv_resnet101_face-detection_cvpr22papermogface/stable/model.pth"

配套文件 每个版本文件夹里,除了模型权重(model.pth),还应该包含:

  • config.json:模型配置,比如输入尺寸、均值方差等
  • README.md:版本说明,包括训练数据、性能指标、已知问题等
  • performance_report.pdf:详细的性能测试报告

2.3 版本元数据管理

光有文件还不够,我们还需要记录每个版本的“元数据”——也就是关于版本的信息。

我建议为每个版本创建一个 metadata.json 文件:

{
  "version": "1.1.0",
  "release_date": "2024-02-15",
  "description": "优化了小脸检测,在WIDER FACE Hard子集上的AP提升了3.2%",
  "training_data": {
    "dataset": "WIDER FACE + 自定义数据集",
    "samples": 12500,
    "augmentation": "随机旋转、缩放、颜色抖动"
  },
  "performance": {
    "WIDER_FACE_Easy": {
      "AP": 0.956,
      "AR": 0.972
    },
    "WIDER_FACE_Medium": {
      "AP": 0.942,
      "AR": 0.961
    },
    "WIDER_FACE_Hard": {
      "AP": 0.887,
      "AR": 0.912
    }
  },
  "dependencies": {
    "torch": ">=1.9.0,<2.0.0",
    "opencv-python": ">=4.5.0",
    "modelscope": ">=1.0.0"
  },
  "known_issues": [
    "在极端背光条件下,置信度可能偏低",
    "对于小于20x20像素的人脸,检测率下降明显"
  ],
  "author": "张三",
  "checksum": "a1b2c3d4e5f6...",
  "download_url": "http://your-model-server/models/mogface/v1.1.0/model.pth"
}

这个元数据文件有什么用呢?

第一,追溯历史 三个月后,你想知道 v1.1.0 为什么比 v1.0.0 好,看看 descriptionperformance 就知道了。

第二,环境管理 dependencies 字段明确告诉你了这个版本需要哪些依赖库、什么版本。再也不会有“在我机器上能跑”的问题了。

第三,完整性校验 checksum 可以验证模型文件是否完整、是否被篡改。

第四,自动化部署 download_url 让自动化脚本可以直接下载指定版本的模型。

3. 代码层面的版本控制

模型文件管理好了,接下来是代码。代码怎么知道该加载哪个版本的模型呢?

3.1 配置文件管理

不要在代码里硬编码模型路径!这是最常见的错误做法:

# ❌ 错误做法:硬编码路径
model = load_model("/home/user/models/mogface_v1.0.pth")

应该使用配置文件:

# ✅ 正确做法:从配置文件读取
import json
import os

class ModelConfig:
    def __init__(self, config_path="config/model_config.json"):
        with open(config_path, 'r') as f:
            self.config = json.load(f)
        
        # 获取当前环境(开发/测试/生产)
        self.env = os.getenv("APP_ENV", "development")
        
    def get_model_path(self):
        """根据环境获取模型路径"""
        model_config = self.config["models"]["mogface"]
        
        if self.env == "production":
            # 生产环境用稳定版
            version = model_config["stable_version"]
        elif self.env == "testing":
            # 测试环境用指定版本
            version = model_config.get("test_version", model_config["stable_version"])
        else:
            # 开发环境用最新版或指定版本
            version = model_config.get("dev_version", model_config["latest_version"])
        
        # 构建完整路径
        base_path = model_config["base_path"]
        model_path = os.path.join(base_path, f"v{version}", "model.pth")
        
        return model_path
    
    def get_model_config(self, version=None):
        """获取指定版本的模型配置"""
        if version is None:
            version = self.config["models"]["mogface"]["stable_version"]
        
        config_path = os.path.join(
            self.config["models"]["mogface"]["base_path"],
            f"v{version}",
            "config.json"
        )
        
        with open(config_path, 'r') as f:
            return json.load(f)

# 使用示例
config = ModelConfig()
model_path = config.get_model_path()
print(f"加载模型: {model_path}")

model_config = config.get_model_config()
print(f"输入尺寸: {model_config['input_size']}")

对应的配置文件 config/model_config.json

{
  "models": {
    "mogface": {
      "base_path": "models/cv_resnet101_face-detection_cvpr22papermogface",
      "latest_version": "2.0.0",
      "stable_version": "1.1.0",
      "test_version": "2.0.0",
      "dev_version": "2.0.0",
      "auto_update": false,
      "fallback_version": "1.0.0"
    }
  },
  "detection": {
    "confidence_threshold": 0.5,
    "nms_threshold": 0.3,
    "max_detections": 100
  }
}

这个配置方案的好处:

环境隔离 开发、测试、生产环境可以用不同的模型版本,互不干扰。

灵活切换 想测试新版本?改一下 test_version 就行,不用改代码。

自动回退 如果指定版本的模型文件不存在,可以自动回退到 fallback_version

3.2 模型加载封装

有了配置文件,我们还需要一个健壮的模型加载器:

import torch
import os
import hashlib
import logging
from pathlib import Path

class ModelVersionManager:
    """模型版本管理器"""
    
    def __init__(self, model_name="mogface"):
        self.model_name = model_name
        self.logger = logging.getLogger(__name__)
        self.config = ModelConfig()
        
    def load_model(self, version=None, device=None):
        """
        加载指定版本的模型
        
        Args:
            version: 版本号,如 "1.1.0",为None时使用配置的默认版本
            device: 设备,如 "cuda:0" 或 "cpu"
        
        Returns:
            loaded_model: 加载的模型
            metadata: 模型元数据
        """
        # 1. 确定要加载的版本
        if version is None:
            if os.getenv("APP_ENV") == "production":
                version = self.config.get("models")[self.model_name]["stable_version"]
            else:
                version = self.config.get("models")[self.model_name]["latest_version"]
        
        self.logger.info(f"准备加载 {self.model_name} 版本 {version}")
        
        # 2. 构建模型路径
        model_dir = Path(self.config.get("models")[self.model_name]["base_path"])
        version_dir = model_dir / f"v{version}"
        
        if not version_dir.exists():
            self.logger.warning(f"版本 {version} 不存在,尝试回退")
            fallback = self.config.get("models")[self.model_name]["fallback_version"]
            version_dir = model_dir / f"v{fallback}"
            
            if not version_dir.exists():
                raise FileNotFoundError(f"模型版本 {version} 和回退版本 {fallback} 都不存在")
        
        model_path = version_dir / "model.pth"
        config_path = version_dir / "config.json"
        metadata_path = version_dir / "metadata.json"
        
        # 3. 验证文件完整性
        if not self._validate_model_file(model_path, metadata_path):
            raise ValueError(f"模型文件 {model_path} 验证失败")
        
        # 4. 加载配置和元数据
        with open(config_path, 'r') as f:
            model_config = json.load(f)
        
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        
        # 5. 加载模型
        self.logger.info(f"从 {model_path} 加载模型")
        
        # 根据设备选择加载方式
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        
        if device.startswith("cuda"):
            # GPU加载
            checkpoint = torch.load(model_path, map_location=device)
        else:
            # CPU加载
            checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        
        # 6. 构建模型(这里根据实际模型结构调整)
        model = self._build_model(model_config)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        
        self.logger.info(f"模型加载完成,版本: {version}, 设备: {device}")
        
        return model, metadata
    
    def _validate_model_file(self, model_path, metadata_path):
        """验证模型文件完整性"""
        if not model_path.exists():
            self.logger.error(f"模型文件不存在: {model_path}")
            return False
        
        if not metadata_path.exists():
            self.logger.warning(f"元数据文件不存在: {metadata_path}")
            return True  # 元数据文件不是必须的
        
        # 读取元数据中的校验和
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        
        expected_checksum = metadata.get("checksum", "")
        if not expected_checksum:
            self.logger.warning("元数据中没有校验和,跳过验证")
            return True
        
        # 计算实际文件的校验和
        with open(model_path, 'rb') as f:
            file_hash = hashlib.sha256()
            chunk = f.read(8192)
            while chunk:
                file_hash.update(chunk)
                chunk = f.read(8192)
        
        actual_checksum = file_hash.hexdigest()
        
        if actual_checksum != expected_checksum:
            self.logger.error(f"校验和不匹配: 期望 {expected_checksum[:16]}..., 实际 {actual_checksum[:16]}...")
            return False
        
        self.logger.info("模型文件验证通过")
        return True
    
    def _build_model(self, config):
        """根据配置构建模型"""
        # 这里是实际的模型构建代码
        # 以 MogFace + ResNet101 为例
        from modelscope.pipelines import pipeline
        from modelscope.utils.constant import Tasks
        
        # 创建模型管道
        model_pipeline = pipeline(
            task=Tasks.face_detection,
            model='damo/cv_resnet101_face-detection_cvpr22papermogface'
        )
        
        return model_pipeline
    
    def list_available_versions(self):
        """列出所有可用版本"""
        model_dir = Path(self.config.get("models")[self.model_name]["base_path"])
        
        versions = []
        for item in model_dir.iterdir():
            if item.is_dir() and item.name.startswith('v'):
                version_str = item.name[1:]  # 去掉 'v' 前缀
                metadata_path = item / "metadata.json"
                
                if metadata_path.exists():
                    with open(metadata_path, 'r') as f:
                        metadata = json.load(f)
                    versions.append({
                        'version': version_str,
                        'path': str(item),
                        'release_date': metadata.get('release_date', '未知'),
                        'description': metadata.get('description', '')
                    })
                else:
                    versions.append({
                        'version': version_str,
                        'path': str(item),
                        'release_date': '未知',
                        'description': '无元数据'
                    })
        
        # 按版本号排序
        versions.sort(key=lambda x: [int(num) for num in x['version'].split('.')], reverse=True)
        
        return versions
    
    def switch_version(self, new_version):
        """切换当前使用的版本"""
        available_versions = [v['version'] for v in self.list_available_versions()]
        
        if new_version not in available_versions:
            raise ValueError(f"版本 {new_version} 不可用。可用版本: {available_versions}")
        
        # 更新配置文件
        config_path = "config/model_config.json"
        with open(config_path, 'r') as f:
            config = json.load(f)
        
        # 根据当前环境更新对应的版本配置
        env = os.getenv("APP_ENV", "development")
        
        if env == "production":
            config["models"][self.model_name]["stable_version"] = new_version
        elif env == "testing":
            config["models"][self.model_name]["test_version"] = new_version
        else:
            config["models"][self.model_name]["dev_version"] = new_version
        
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)
        
        self.logger.info(f"已切换 {self.model_name} 版本到 {new_version} (环境: {env})")
        
        return True

# 使用示例
if __name__ == "__main__":
    # 初始化版本管理器
    manager = ModelVersionManager("mogface")
    
    # 列出所有可用版本
    versions = manager.list_available_versions()
    print("可用版本:")
    for v in versions:
        print(f"  - {v['version']}: {v['description']}")
    
    # 加载指定版本的模型
    model, metadata = manager.load_model(version="1.1.0")
    print(f"加载的模型版本: {metadata['version']}")
    print(f"训练数据: {metadata['training_data']['dataset']}")
    
    # 切换到新版本
    manager.switch_version("2.0.0")

这个 ModelVersionManager 类提供了完整的功能:

版本加载 自动根据环境加载合适的版本,支持版本回退。

完整性验证 通过校验和确保模型文件没有被损坏或篡改。

版本列表 可以查看所有可用的版本及其信息。

版本切换 动态切换当前使用的版本,无需重启服务。

元数据管理 加载模型的同时,也加载完整的元数据信息。

3.3 集成到 Streamlit 应用

现在,我们把版本管理集成到你的 Streamlit 应用中:

import streamlit as st
import cv2
import numpy as np
from PIL import Image
import json
import sys
import os

# 添加项目根目录到 Python 路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from model_manager import ModelVersionManager

class FaceDetectionApp:
    def __init__(self):
        # 初始化模型管理器
        self.manager = ModelVersionManager("mogface")
        
        # 当前加载的模型和版本
        self.current_model = None
        self.current_version = None
        self.current_metadata = None
        
    def load_model(self, version=None):
        """加载模型(带缓存)"""
        @st.cache_resource
        def _load_cached_model(_manager, _version):
            return _manager.load_model(_version)
        
        try:
            model, metadata = _load_cached_model(self.manager, version)
            self.current_model = model
            self.current_version = metadata['version']
            self.current_metadata = metadata
            return True
        except Exception as e:
            st.error(f"加载模型失败: {str(e)}")
            return False
    
    def run(self):
        st.title("👁️ MogFace 智能人脸检测工具")
        
        # 侧边栏:模型版本管理
        with st.sidebar:
            st.header("⚙️ 模型版本管理")
            
            # 显示当前版本信息
            if self.current_metadata:
                st.subheader("当前版本信息")
                st.write(f"**版本**: v{self.current_metadata['version']}")
                st.write(f"**发布日期**: {self.current_metadata['release_date']}")
                st.write(f"**描述**: {self.current_metadata['description']}")
                
                # 性能指标
                if 'performance' in self.current_metadata:
                    st.subheader("性能指标")
                    perf = self.current_metadata['performance']
                    if 'WIDER_FACE_Hard' in perf:
                        st.metric("WIDER FACE Hard AP", f"{perf['WIDER_FACE_Hard']['AP']*100:.1f}%")
            
            # 版本切换
            st.subheader("版本切换")
            
            # 获取可用版本
            versions = self.manager.list_available_versions()
            version_options = [v['version'] for v in versions]
            
            if version_options:
                selected_version = st.selectbox(
                    "选择模型版本",
                    options=version_options,
                    index=0  # 默认选第一个(最新版)
                )
                
                if st.button("切换版本", type="primary"):
                    with st.spinner("切换版本中..."):
                        if self.manager.switch_version(selected_version):
                            st.success(f"已切换到版本 {selected_version}")
                            st.info("请刷新页面以加载新版本模型")
                            st.rerun()
            else:
                st.warning("未找到可用版本")
            
            # 版本比较
            if len(versions) >= 2:
                st.subheader("版本比较")
                col1, col2 = st.columns(2)
                
                with col1:
                    ver1 = st.selectbox("版本 A", version_options, key="ver1")
                with col2:
                    ver2 = st.selectbox("版本 B", version_options, key="ver2", 
                                       index=min(1, len(version_options)-1))
                
                if ver1 != ver2 and st.button("比较版本"):
                    self.compare_versions(ver1, ver2, versions)
            
            # 显存管理
            st.subheader("系统状态")
            if torch.cuda.is_available():
                gpu_memory = torch.cuda.memory_allocated() / 1024**3
                st.write(f"GPU 显存使用: {gpu_memory:.2f} GB")
                
                if st.button("清理显存", type="secondary"):
                    torch.cuda.empty_cache()
                    st.rerun()
        
        # 主界面
        col1, col2 = st.columns(2)
        
        with col1:
            st.header("📤 图片上传与预览")
            uploaded_file = st.file_uploader(
                "选择图片文件",
                type=['jpg', 'jpeg', 'png'],
                help="支持 JPG、PNG、JPEG 格式"
            )
            
            if uploaded_file is not None:
                image = Image.open(uploaded_file)
                st.image(image, caption="原始图片", use_column_width=True)
                
                # 转换为 OpenCV 格式
                img_array = np.array(image)
                if len(img_array.shape) == 3 and img_array.shape[2] == 3:
                    img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
                
                # 检测按钮
                if st.button("🚀 开始检测", type="primary", use_container_width=True):
                    with st.spinner("检测中..."):
                        # 确保模型已加载
                        if self.current_model is None:
                            if not self.load_model():
                                st.error("模型加载失败,无法进行检测")
                                return
                        
                        # 执行检测
                        results = self.current_model(img_array)
                        
                        # 处理结果
                        self.display_results(col2, img_array, results)
        
        with col2:
            st.header("📥 检测结果展示")
            if 'detection_result' in st.session_state:
                st.image(st.session_state.detection_result, 
                        caption="检测结果", 
                        use_column_width=True)
                
                if 'face_count' in st.session_state:
                    st.metric("检测到的人脸数量", st.session_state.face_count)
                
                if 'detection_data' in st.session_state:
                    with st.expander("📊 查看原始数据"):
                        st.json(st.session_state.detection_data)
    
    def display_results(self, column, original_img, results):
        """显示检测结果"""
        # 绘制检测框
        result_img = original_img.copy()
        face_count = 0
        detection_data = []
        
        if 'boxes' in results:
            boxes = results['boxes']
            scores = results.get('scores', [])
            
            for i, box in enumerate(boxes):
                x1, y1, x2, y2 = map(int, box[:4])
                score = scores[i] if i < len(scores) else 0.0
                
                # 绘制边界框
                cv2.rectangle(result_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
                
                # 绘制置信度
                label = f"{score:.2f}"
                cv2.putText(result_img, label, (x1, y1-10),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
                
                face_count += 1
                detection_data.append({
                    'box': [int(x1), int(y1), int(x2), int(y2)],
                    'score': float(score)
                })
        
        # 转换回 RGB 用于显示
        result_img_rgb = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
        
        # 保存到 session state
        st.session_state.detection_result = result_img_rgb
        st.session_state.face_count = face_count
        st.session_state.detection_data = detection_data
        
        # 显示版本信息
        if self.current_metadata:
            column.info(f"使用模型版本: v{self.current_metadata['version']}")
    
    def compare_versions(self, ver1, ver2, all_versions):
        """比较两个版本的差异"""
        # 获取版本信息
        info1 = next((v for v in all_versions if v['version'] == ver1), None)
        info2 = next((v for v in all_versions if v['version'] == ver2), None)
        
        if info1 and info2:
            st.subheader(f"版本比较: v{ver1} vs v{ver2}")
            
            col1, col2 = st.columns(2)
            
            with col1:
                st.write(f"**v{ver1}**")
                st.write(f"发布日期: {info1['release_date']}")
                st.write(info1['description'])
            
            with col2:
                st.write(f"**v{ver2}**")
                st.write(f"发布日期: {info2['release_date']}")
                st.write(info2['description'])
            
            # 这里可以添加更详细的比较,比如性能指标对比
            st.info("提示:在实际业务中,建议对两个版本进行全面的测试对比")

if __name__ == "__main__":
    # 设置页面配置
    st.set_page_config(
        page_title="MogFace 人脸检测",
        page_icon="👁️",
        layout="wide"
    )
    
    # 运行应用
    app = FaceDetectionApp()
    
    # 初始加载模型
    if 'model_loaded' not in st.session_state:
        with st.spinner("正在加载模型..."):
            if app.load_model():
                st.session_state.model_loaded = True
            else:
                st.error("模型加载失败,请检查配置")
                st.stop()
    
    app.run()

这个增强版的 Streamlit 应用增加了以下功能:

版本管理界面 在侧边栏可以查看当前版本信息、切换版本、比较不同版本。

模型信息展示 显示当前使用的模型版本、发布日期、性能指标等。

版本比较功能 可以对比两个不同版本的信息。

智能加载 根据配置自动加载适合当前环境的模型版本。

4. 自动化部署与持续集成

手动管理模型版本还是太麻烦,我们来看看如何自动化。

4.1 模型版本自动化脚本

创建一个自动化脚本,用于管理模型版本:

#!/usr/bin/env python3
"""
模型版本管理自动化脚本
用法:
  python model_manager.py list                    # 列出所有版本
  python model_manager.py download <version>      # 下载指定版本
  python model_manager.py upload <path>           # 上传新版本
  python model_manager.py set-stable <version>    # 设置稳定版本
  python model_manager.py cleanup                 # 清理旧版本
"""

import argparse
import json
import os
import shutil
import hashlib
from datetime import datetime
from pathlib import Path
import requests
from typing import Dict, List, Optional

class ModelVersionAutomation:
    def __init__(self, model_name: str = "mogface"):
        self.model_name = model_name
        self.base_dir = Path(f"models/cv_resnet101_face-detection_cvpr22papermogface")
        self.base_dir.mkdir(parents=True, exist_ok=True)
        
        # 加载版本索引
        self.index_file = self.base_dir / "version_index.json"
        self.version_index = self._load_index()
    
    def _load_index(self) -> Dict:
        """加载版本索引"""
        if self.index_file.exists():
            with open(self.index_file, 'r') as f:
                return json.load(f)
        return {
            "latest_version": None,
            "stable_version": None,
            "versions": {}
        }
    
    def _save_index(self):
        """保存版本索引"""
        with open(self.index_file, 'w') as f:
            json.dump(self.version_index, f, indent=2)
    
    def _calculate_checksum(self, file_path: Path) -> str:
        """计算文件校验和"""
        sha256_hash = hashlib.sha256()
        with open(file_path, "rb") as f:
            for byte_block in iter(lambda: f.read(4096), b""):
                sha256_hash.update(byte_block)
        return sha256_hash.hexdigest()
    
    def list_versions(self, detailed: bool = False):
        """列出所有版本"""
        versions = self.version_index.get("versions", {})
        
        if not versions:
            print("没有找到任何版本")
            return
        
        print(f"\n{'='*60}")
        print(f"{self.model_name} 版本列表")
        print(f"{'='*60}")
        
        for version, info in sorted(versions.items(), 
                                   key=lambda x: [int(n) for n in x[0].split('.')], 
                                   reverse=True):
            status = []
            if version == self.version_index.get("latest_version"):
                status.append("最新")
            if version == self.version_index.get("stable_version"):
                status.append("稳定")
            
            status_str = f"[{', '.join(status)}]" if status else ""
            
            print(f"\n版本: v{version} {status_str}")
            print(f"  发布日期: {info.get('release_date', '未知')}")
            print(f"  描述: {info.get('description', '无描述')}")
            
            if detailed:
                print(f"  路径: {info.get('path', '未知')}")
                print(f"  校验和: {info.get('checksum', '未知')[:16]}...")
                if 'performance' in info:
                    print(f"  性能: {json.dumps(info['performance'], indent=4)}")
        
        print(f"\n当前稳定版本: v{self.version_index.get('stable_version', '未设置')}")
        print(f"当前最新版本: v{self.version_index.get('latest_version', '未设置')}")
        print(f"{'='*60}")
    
    def download_version(self, version: str, source_url: Optional[str] = None):
        """下载指定版本的模型"""
        version_dir = self.base_dir / f"v{version}"
        version_dir.mkdir(exist_ok=True)
        
        print(f"正在下载版本 v{version}...")
        
        # 如果提供了源URL,从URL下载
        if source_url:
            print(f"从 {source_url} 下载模型文件...")
            response = requests.get(source_url, stream=True)
            response.raise_for_status()
            
            model_path = version_dir / "model.pth"
            with open(model_path, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            
            print(f"模型文件已保存到: {model_path}")
        
        # 创建配置文件(如果不存在)
        config_path = version_dir / "config.json"
        if not config_path.exists():
            default_config = {
                "input_size": [640, 640],
                "mean": [0.485, 0.456, 0.406],
                "std": [0.229, 0.224, 0.225],
                "confidence_threshold": 0.5,
                "nms_threshold": 0.3
            }
            with open(config_path, 'w') as f:
                json.dump(default_config, f, indent=2)
            print(f"已创建默认配置文件: {config_path}")
        
        # 创建元数据文件(如果不存在)
        metadata_path = version_dir / "metadata.json"
        if not metadata_path.exists():
            checksum = self._calculate_checksum(version_dir / "model.pth")
            
            metadata = {
                "version": version,
                "release_date": datetime.now().strftime("%Y-%m-%d"),
                "description": f"版本 {version} 的模型文件",
                "training_data": {
                    "dataset": "未知",
                    "samples": 0
                },
                "performance": {},
                "dependencies": {
                    "torch": ">=1.9.0",
                    "opencv-python": ">=4.5.0"
                },
                "checksum": checksum,
                "download_url": source_url or "",
                "author": "自动化脚本"
            }
            
            with open(metadata_path, 'w') as f:
                json.dump(metadata, f, indent=2)
            print(f"已创建元数据文件: {metadata_path}")
        
        # 更新版本索引
        self.version_index["versions"][version] = {
            "path": str(version_dir),
            "release_date": datetime.now().strftime("%Y-%m-%d"),
            "description": f"版本 {version}",
            "checksum": checksum
        }
        
        # 更新最新版本
        self.version_index["latest_version"] = version
        
        self._save_index()
        print(f"版本 v{version} 下载完成并已添加到索引")
    
    def upload_version(self, model_path: str, version: str, description: str = ""):
        """上传新版本模型"""
        source_path = Path(model_path)
        if not source_path.exists():
            raise FileNotFoundError(f"模型文件不存在: {model_path}")
        
        version_dir = self.base_dir / f"v{version}"
        version_dir.mkdir(exist_ok=True)
        
        print(f"正在上传版本 v{version}...")
        
        # 复制模型文件
        dest_path = version_dir / "model.pth"
        shutil.copy2(source_path, dest_path)
        print(f"模型文件已复制到: {dest_path}")
        
        # 计算校验和
        checksum = self._calculate_checksum(dest_path)
        
        # 创建元数据
        metadata = {
            "version": version,
            "release_date": datetime.now().strftime("%Y-%m-%d"),
            "description": description or f"版本 {version}",
            "training_data": {
                "dataset": "自定义训练",
                "samples": "未知"
            },
            "performance": {
                "备注": "请在实际数据集上测试并更新性能指标"
            },
            "dependencies": {
                "torch": ">=1.9.0",
                "opencv-python": ">=4.5.0"
            },
            "checksum": checksum,
            "author": os.getenv("USER", "未知用户")
        }
        
        metadata_path = version_dir / "metadata.json"
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        # 创建默认配置(如果不存在)
        config_path = version_dir / "config.json"
        if not config_path.exists():
            default_config = {
                "input_size": [640, 640],
                "mean": [0.485, 0.456, 0.406],
                "std": [0.229, 0.224, 0.225]
            }
            with open(config_path, 'w') as f:
                json.dump(default_config, f, indent=2)
        
        # 更新版本索引
        self.version_index["versions"][version] = {
            "path": str(version_dir),
            "release_date": metadata["release_date"],
            "description": description,
            "checksum": checksum
        }
        
        # 更新最新版本
        self.version_index["latest_version"] = version
        
        self._save_index()
        print(f"版本 v{version} 上传完成")
        print(f"校验和: {checksum[:16]}...")
    
    def set_stable_version(self, version: str):
        """设置稳定版本"""
        if version not in self.version_index.get("versions", {}):
            raise ValueError(f"版本 v{version} 不存在")
        
        self.version_index["stable_version"] = version
        self._save_index()
        
        # 更新符号链接
        stable_link = self.base_dir / "stable"
        if stable_link.exists():
            stable_link.unlink()
        stable_link.symlink_to(f"v{version}")
        
        print(f"已将 v{version} 设置为稳定版本")
    
    def cleanup_old_versions(self, keep_count: int = 5):
        """清理旧版本,只保留指定数量的最新版本"""
        versions = list(self.version_index.get("versions", {}).keys())
        
        if len(versions) <= keep_count:
            print(f"当前只有 {len(versions)} 个版本,无需清理")
            return
        
        # 按版本号排序
        sorted_versions = sorted(versions, 
                               key=lambda x: [int(n) for n in x.split('.')])
        
        # 确定要删除的版本(保留最新的 keep_count 个)
        versions_to_keep = sorted_versions[-keep_count:]
        versions_to_delete = [v for v in sorted_versions if v not in versions_to_keep]
        
        # 不能删除当前稳定版本和最新版本
        stable_version = self.version_index.get("stable_version")
        latest_version = self.version_index.get("latest_version")
        
        versions_to_delete = [v for v in versions_to_delete 
                            if v != stable_version and v != latest_version]
        
        if not versions_to_delete:
            print("没有可清理的旧版本")
            return
        
        print(f"准备清理以下版本: {', '.join(versions_to_delete)}")
        
        confirm = input("确认删除?(y/n): ")
        if confirm.lower() != 'y':
            print("取消清理")
            return
        
        # 删除版本目录和索引
        for version in versions_to_delete:
            version_dir = self.base_dir / f"v{version}"
            if version_dir.exists():
                shutil.rmtree(version_dir)
                print(f"已删除版本目录: {version_dir}")
            
            if version in self.version_index["versions"]:
                del self.version_index["versions"][version]
                print(f"已从索引中移除版本: v{version}")
        
        self._save_index()
        print(f"清理完成,保留了 {len(versions_to_keep)} 个最新版本")

def main():
    parser = argparse.ArgumentParser(description="模型版本管理工具")
    subparsers = parser.add_subparsers(dest="command", help="可用命令")
    
    # list 命令
    list_parser = subparsers.add_parser("list", help="列出所有版本")
    list_parser.add_argument("--detailed", action="store_true", help="显示详细信息")
    
    # download 命令
    download_parser = subparsers.add_parser("download", help="下载指定版本")
    download_parser.add_argument("version", help="版本号,如 1.0.0")
    download_parser.add_argument("--url", help="下载URL(可选)")
    
    # upload 命令
    upload_parser = subparsers.add_parser("upload", help="上传新版本")
    upload_parser.add_argument("model_path", help="模型文件路径")
    upload_parser.add_argument("version", help="版本号,如 1.1.0")
    upload_parser.add_argument("--description", help="版本描述", default="")
    
    # set-stable 命令
    stable_parser = subparsers.add_parser("set-stable", help="设置稳定版本")
    stable_parser.add_argument("version", help="版本号,如 1.0.0")
    
    # cleanup 命令
    cleanup_parser = subparsers.add_parser("cleanup", help="清理旧版本")
    cleanup_parser.add_argument("--keep", type=int, default=5, 
                               help="保留的最新版本数量,默认5")
    
    args = parser.parse_args()
    
    manager = ModelVersionAutomation()
    
    if args.command == "list":
        manager.list_versions(detailed=args.detailed)
    
    elif args.command == "download":
        manager.download_version(args.version, args.url)
    
    elif args.command == "upload":
        manager.upload_version(args.model_path, args.version, args.description)
    
    elif args.command == "set-stable":
        manager.set_stable_version(args.version)
    
    elif args.command == "cleanup":
        manager.cleanup_old_versions(keep_count=args.keep)
    
    else:
        parser.print_help()

if __name__ == "__main__":
    main()

这个自动化脚本提供了完整的命令行界面:

版本管理

  • python model_manager.py list - 列出所有版本
  • python model_manager.py download 1.1.0 - 下载指定版本
  • python model_manager.py upload ./model.pth 1.2.0 - 上传新版本
  • python model_manager.py set-stable 1.1.0 - 设置稳定版本
  • python model_manager.py cleanup --keep 3 - 清理旧版本(只保留3个最新版本)

自动化优势

  1. 一键操作:复杂的版本管理任务变成简单命令
  2. 减少错误:自动化流程比手动操作更可靠
  3. 团队协作:所有人都用同样的工具,保证一致性
  4. 可集成:可以轻松集成到CI/CD流程中

4.2 CI/CD 集成示例

把版本管理集成到持续集成流程中:

# .github/workflows/model-deployment.yml
name: Model Deployment

on:
  push:
    tags:
      - 'v*'  # 当推送版本标签时触发
  workflow_dispatch:  # 手动触发

jobs:
  test-model:
    runs-on: ubuntu-latest
    
    steps:
    - uses: actions/checkout@v3
    
    - name: Set up Python
      uses: actions/setup-python@v4
      with:
        python-version: '3.9'
    
    - name: Install dependencies
      run: |
        pip install torch torchvision
        pip install opencv-python
        pip install modelscope
    
    - name: Download test dataset
      run: |
        # 下载测试数据集
        wget http://example.com/test_images.zip
        unzip test_images.zip -d test_data/
    
    - name: Run model tests
      run: |
        python tests/test_model.py --model-version ${{ github.ref_name }}
    
    - name: Upload test results
      uses: actions/upload-artifact@v3
      with:
        name: test-results
        path: test_results/
  
  deploy-model:
    needs: test-model
    runs-on: ubuntu-latest
    if: success()
    
    steps:
    - uses: actions/checkout@v3
    
    - name: Extract version
      id: extract_version
      run: |
        # 从标签提取版本号,如 v1.2.3 -> 1.2.3
        VERSION="${GITHUB_REF#refs/tags/v}"
        echo "version=$VERSION" >> $GITHUB_OUTPUT
    
    - name: Package model
      run: |
        # 创建版本目录
        mkdir -p models/cv_resnet101_face-detection_cvpr22papermogface/v${{ steps.extract_version.outputs.version }}
        
        # 复制模型文件
        cp trained_models/mogface_final.pth \
           models/cv_resnet101_face-detection_cvpr22papermogface/v${{ steps.extract_version.outputs.version }}/model.pth
        
        # 生成元数据
        python scripts/generate_metadata.py \
          --version ${{ steps.extract_version.outputs.version }} \
          --model-path trained_models/mogface_final.pth \
          --output models/cv_resnet101_face-detection_cvpr22papermogface/v${{ steps.extract_version.outputs.version }}/metadata.json
    
    - name: Update version index
      run: |
        python scripts/model_manager.py upload \
          trained_models/mogface_final.pth \
          ${{ steps.extract_version.outputs.version }} \
          --description "Automated deployment from CI/CD"
    
    - name: Deploy to model server
      run: |
        # 上传到模型服务器
        rsync -avz models/ user@model-server:/var/www/models/
        
        # 更新最新版本符号链接
        ssh user@model-server "cd /var/www/models/cv_resnet101_face-detection_cvpr22papermogface && ln -sfn v${{ steps.extract_version.outputs.version }} latest"
    
    - name: Notify deployment
      run: |
        # 发送部署通知
        curl -X POST -H "Content-Type: application/json" \
          -d '{"version":"${{ steps.extract_version.outputs.version }}","status":"deployed"}' \
          ${{ secrets.DEPLOY_WEBHOOK }}

这个CI/CD流程实现了:

自动化测试 每次有新版本时,自动运行测试确保质量。

自动化打包 自动创建版本目录、复制文件、生成元数据。

自动化部署 自动上传到模型服务器,更新符号链接。

自动化通知 部署完成后自动发送通知。

5. 总结

通过这一整套模型版本管理方案,我们把 cv_resnet101_face-detection_cvpr22papermogface 从一个简单的模型文件,变成了一个可管理、可追溯、可协作的生产级资产。

5.1 方案核心价值回顾

对个人开发者

  • 再也不会因为“手滑”覆盖了模型文件而懊恼
  • 可以轻松对比不同版本的效果
  • 实验记录更加规范,复现结果更容易

对团队

  • 所有人用的都是同一个版本的模型
  • 新成员能快速了解模型的历史和性能
  • 代码和模型的版本对应关系清晰

对项目

  • 模型更新变得可控可管理
  • 出现问题能快速回滚
  • 版本迭代有完整记录

5.2 实际落地建议

如果你正在考虑实施这套方案,我的建议是:

从小处开始 不要一开始就追求完美。先从简单的文件命名规范和目录结构开始,等团队适应了,再逐步引入更复杂的功能。

自动化是关键 手动管理版本很快就会变得繁琐。尽早把常用操作脚本化,比如模型上传、版本切换、清理旧版本等。

文档要跟上 每次发布新版本,都要更新元数据文件。记录清楚:这个版本改了什么地方、为什么改、效果怎么样。

与现有流程集成 看看你现有的开发流程,把模型版本管理自然地融入进去。比如在代码评审时,也要评审模型版本的变更。

定期回顾 每个季度回顾一次版本管理情况:有没有什么问题?哪些地方可以改进?团队用起来顺不顺手?

5.3 最后的思考

模型版本管理,本质上是对“不确定性”的管理。我们不知道哪个模型版本效果最好,不知道什么时候需要回退,不知道团队协作会出什么问题。但通过好的版本管理,我们可以把这些不确定性控制住。

cv_resnet101_face-detection_cvpr22papermogface 是一个很好的模型,但再好的模型,如果没有好的管理,也很难在项目中发挥最大价值。希望这套方案能帮你把人脸检测项目做得更稳、更好。

记住,好的工具不仅要“能用”,还要“好用”。版本管理就是让工具从“能用”到“好用”的关键一步。


获取更多AI镜像

想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。

转载自CSDN-专业IT技术社区

原文链接:https://blog.csdn.net/weixin_35294091/article/details/157244448

评论

赞0

评论列表

微信小程序
QQ小程序

关于作者

点赞数:0
关注数:0
粉丝:0
文章:0
关注标签:0
加入于:--