模型

segment-anything 在物体分割任务上功能非常强大,并且其官方提供了导出到onnx的示例,但是仅导出了segment部分,本文将其embedding部分也导出为onnx。然后将onnx模型转换到MNN模型,得益于MNN内部的cv函数与表达式接口,在仅依赖MNN的情况下即可端到端的运行segment-anything模型来对图片进行指定语义的分割。

MobileSam针对segment-anything做了优化,大幅提升了推理速度,更适合移动端部署。

模型导出

class ImageEmbedding(torch.nn.Module):
    def __init__(self, sam):
        super().__init__()
        self.model = sam
    @torch.no_grad()
    def forward(
        self,
        image: torch.Tensor
    ):
        return self.model.image_encoder(image)

# MobileSam
checkpoint = "./weights/mobile_sam.pt"
model_type = "vit_t"

# segment-anything vit_b
checkpoint = "../sam_vit_b_01ec64.pth"
model_type = "vit_b"

sam = sam_model_registry[model_type](checkpoint=checkpoint)
onnx_embed = ImageEmbedding(sam)
torch.onnx.export(
            onnx_embed,
            (torch.randn(1, 3, 1024, 1024, dtype=torch.float)),
            'embed.onnx',
            export_params=True,
            verbose=False,
            opset_version=15,
            do_constant_folding=True,
            input_names=['image'],
            output_names=['image_embeddings']

推理流程

流程为:

  • 图片前处理 -> embedding -> segment -> mask

推理代码:

def inference(emed, sam, img, precision, backend, thread):
    mask_threshold = 0.0
    # 0. load model
    config = {}
    config['precision'] = precision
    config['backend'] = backend
    config['numThread'] = thread
    rt = MNN.nn.create_runtime_manager((config,))
    embed = MNN.nn.load_module_from_file(emed, [], [], runtime_manager=rt)
    sam = MNN.nn.load_module_from_file(sam,
         ['point_coords', 'point_labels', 'image_embeddings', 'has_mask_input', 'mask_input', 'orig_im_size'],
         ['iou_predictions', 'low_res_masks', 'masks'], runtime_manager=rt)
    # 1. preprocess
    image = cv2.imread(img)
    origin_h, origin_w, _ = image.shape
    length = 1024
    if origin_h > origin_w:
        new_w = round(origin_w * float(length) / origin_h)
        new_h = length
    else:
        new_h = round(origin_h * float(length) / origin_w)
        new_w = length
    scale_w = new_w / origin_w
    sclae_h = new_h / origin_h
    input_var = cv2.resize(image, (new_w, new_h), 0., 0., cv2.INTER_LINEAR, -1, [123.675, 116.28, 103.53], [1/58.395, 1/57.12, 1/57.375])
    input_var = np.pad(input_var, [[0, length - new_h], [0, length - new_w], [0, 0]], 'constant')
    input_var = np.expand_dims(input_var, 0)
    # 2. embedding forward
    input_var = MNN.expr.convert(input_var, MNN.expr.NC4HW4)
    t1 = time.time()
    output_var = embed.forward(input_var)
    t2 = time.time()
    print('# 1. embedding times: {} ms'.format((t2 - t1) * 1000))
    image_embedding = MNN.expr.convert(output_var, MNN.expr.NCHW)
    # 3. segment forward
    points = [[500, 375]]
    sclaes = [scale_w, sclae_h]
    input_point = np.array(points) * sclaes
    input_label = np.array([1])
    point_coords = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
    point_labels = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
    orig_im_size = np.array([float(origin_h), float(origin_w)], dtype=np.float32)
    mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    has_mask_input = np.zeros(1, dtype=np.float32)
    t1 = time.time()
    output_vars = sam.onForward([point_coords, point_labels, image_embedding, has_mask_input, mask_input, orig_im_size])
    t2 = time.time()
    print('# 1. segment times: {} ms'.format((t2 - t1) * 1000))
    masks = MNN.expr.convert(output_vars[2], MNN.expr.NCHW)
    masks = masks.squeeze([0])[0]
    # 4. postprocess: draw masks and point
    masks = (masks > mask_threshold).reshape([origin_h, origin_w, 1])
    color = np.array([30, 144, 255]).reshape([1, 1, -1])
    image = (image + masks * color).astype(np.uint8)
    for point in points:
        cv2.circle(image, point, 10, (0, 0, 255), 5)
    cv2.imwrite('res.jpg', image)

github repo