+
56
-

如何用python实现Set-of-Mark (SoM) 技术?

如何用python实现Set-of-Mark (SoM) 技术?


网友回复

+
20
-

SoM (Set-of-Mark) 是一种视觉提示(Visual Prompting)技术,通常用于增强多模态大模型(如 GPT-4V, LLaVA)的细粒度空间理解能力。

它的核心逻辑非常直接:

分割(Segmentation):使用 SAM (Segment Anything Model) 等模型将图像分割成不同区域。

标记(Marking):在这些区域上覆盖半透明掩码,并打上数字标签(ID)。

提示(Prompting):将处理后的图发给 GPT-4V,问它:“标签 [5] 的物体是什么?”

下面是一个基于 Python 的完整示例。

准备工作

你需要安装 segment-anything 库以及 PyTorch。

# 1. 安装 PyTorch (根据你的环境选择合适的 CUDA 版本)
pip install torch torchvision

# 2. 安装 SAM 和 OpenCV
pip install opencv-python matplotlib segment-anything

重要提示:你需要下载 SAM 的权重文件(Model Checkpoint)。请下载 sam_vit_h_4b8939.pth (2.4GB) 并放在代码同级目录下。下载地址: Meta AI SAM Repository

1. 完整 SoM 生成代码

这段代码会自动分割图片,并为每个分割区域生成“彩色掩码 + 数字标签”。

import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

class SoMGenerator:
    def __init__(self, checkpoint_path, model_type="vit_h", device="cuda"):
        """
        初始化 SAM 模型
        """
        print(f"正在加载 SAM 模型 ({device})...")
        self.device = device
        self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
        self.sam.to(device=device)

        # 初始化自动掩码生成器
        # points_per_side: 采样点数,越低速度越快但精度降低
        # pred_iou_thresh: 过滤低质量掩码的阈值
        self.mask_generator = SamAutomaticMaskGenerator(
            model=self.sam,
            points_per_side=32,
            pred_iou_thresh=0.86,
            stability_score_thresh=0.92,
            crop_n_layers=0,
            crop_n_points_downscale_factor=1,
            min_mask_region_area=100,  # 忽略太小的区域
        )

    def generate_som_image(self, image_path, output_path=None):
        """
        核心流程:读取图片 -> 分割 -> 绘制标记
        """
        # 1. 读取图片
        image = cv2.imread(image_path)
        if image is None:
            raise ValueError("无法读取图片,请检查路径")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # 2. 生成掩码
        print("正在生成分割掩码 (这可能需要几秒钟)...")
        masks = self.mask_generator.generate(image)
        print(f"检测到 {len(masks)...

点击查看剩余70%

我知道答案,我要回答