概述

DIS(Dichotomous Image Segmentation)是一种新的图像分割任务,旨在从自然图像中分割出高精度的物体。与传统的图像分割任务相比,DIS更侧重于具有单个或几个目标的图像,因此可以提供更丰富准确的细节。

为了研究DIS任务,研究人员创建了一个名为DIS5K的大规模、可扩展的数据集。DIS5K数据集包含了5,470张高分辨率图像,每张图像都配有高精度的二值分割掩码。这个数据集的建立有助于推动多个应用方向的发展,如图像去背景、艺术设计、模拟视图运动、基于图像的增强现实(AR)应用、基于视频的AR应用、3D视频制作等。

通过研究DIS任务和使用DIS5K数据集,研究人员可以探索新的图像分割方法,并为各种应用领域提供更精确、更可靠的图像分割技术,从而推动分割技术在更广泛的领域中的应用。

官网:https://xuebinqin.github.io/dis/index.html
Github:https://github.com/xuebinqin/DIS

数据集

图像二类分割是将图像分割成两个主要区域:前景和背景。在这种情况下,前景代表图像中的某个类别的物体,而背景则是除了该物体之外的所有内容。
官方公布了算所使用的数据集DIS5K, DIS5K数据集中的每张图像都经过了像素级别的手工标注,标注的真值掩码非常精确,每张图像的标记时间相当长。这种高精度的标注使得数据集中的每个像素都与其相应的类别关联起来,从而为模型提供了可靠的训练数据。这种高精度的标注是实现图像二类分割的关键,因为模型需要能够准确地识别和分割出前景物体。

在DIS5K数据集中,标注对象的类型多样,包括透明和半透明的物体,标注使用单个像素的二值掩码进行。这种精确的标注确保了模型训练的有效性和准确性,并且使得模型能够预测出高精度的物体分割结果。

DIS5K数据集网盘地址:https://pan.baidu.com/s/1umNk2AeBG5aB5kXlHTHdIg
提取码:7qfs

模型训练

模型训练可参考git上的官方的文档

模型推理

模型C++使用onnxruntime进行推理

#include <opencv2/opencv.hpp>
#include <onnxruntime_cxx_api.h>


class DIS
{
public:
	DIS(std::string model_path);
	void inference(cv::Mat& cv_src, cv::Mat& cv_mask);
private:
	std::vector<float> input_image_;
	int inpWidth;
	int inpHeight;
	int outWidth;
	int outHeight;
	const float score_th = 0;

	Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "DIS");
	Ort::Session* ort_session = nullptr;
	Ort::SessionOptions sessionOptions = Ort::SessionOptions();
	std::vector<char*> input_names;
	std::vector<char*> output_names;
	std::vector<std::vector<int64_t>> input_node_dims; // >=1 outputs
	std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
};



DIS::DIS(std::string model_path)
{
	std::wstring widestr = std::wstring(model_path.begin(), model_path.end());
	//OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0);
	sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
	ort_session = new Ort::Session(env, widestr.c_str(), sessionOptions);
	size_t numInputNodes = ort_session->GetInputCount();
	size_t numOutputNodes = ort_session->GetOutputCount();
	Ort::AllocatorWithDefaultOptions allocator;
	for (int i = 0; i < numInputNodes; i++)
	{
		input_names.push_back(ort_session->GetInputName(i, allocator));
		Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i);
		auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
		auto input_dims = input_tensor_info.GetShape();
		input_node_dims.push_back(input_dims);
	}
	for (int i = 0; i < numOutputNodes; i++)
	{
		output_names.push_back(ort_session->GetOutputName(i, allocator));
		Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);
		auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
		auto output_dims = output_tensor_info.GetShape();
		output_node_dims.push_back(output_dims);
	}
	this->inpHeight = input_node_dims[0][2];
	this->inpWidth = input_node_dims[0][3];
	this->outHeight = output_node_dims[0][2];
	this->outWidth = output_node_dims[0][3];
}


void DIS::inference(cv::Mat& cv_src, cv::Mat& cv_mask)
{
	cv::Mat cv_dst;
	cv::resize(cv_src, cv_dst, cv::Size(this->inpWidth, this->inpHeight));
	this->input_image_.resize(this->inpWidth * this->inpHeight * cv_dst.channels());
	for (int c = 0; c < 3; c++)
	{
		for (int i = 0; i < this->inpHeight; i++)
		{
			for (int j = 0; j < this->inpWidth; j++)
			{
				float pix = cv_dst.ptr<uchar>(i)[j * 3 + 2 - c];
				this->input_image_[c * this->inpHeight * this->inpWidth + i * this->inpWidth + j] = pix / 255.0 - 0.5;
			}
		}
	}
	std::array<int64_t, 4> input_shape_{ 1, 3, this->inpHeight, this->inpWidth };

	auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
	Ort::Value input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info,
		input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());

	std::vector<Ort::Value> ort_outputs = ort_session->Run(Ort::RunOptions{ nullptr }, &input_names[0],
		&input_tensor_, 1, output_names.data(), output_names.size());   // 开始推理
	float* pred = ort_outputs[0].GetTensorMutableData<float>();
	cv::Mat mask(outHeight, outWidth, CV_32FC1, pred);
	double min_value, max_value;
	minMaxLoc(mask, &min_value, &max_value, 0, 0);
	mask = (mask - min_value) / (max_value - min_value);
	cv::resize(mask, cv_mask, cv::Size(cv_src.cols, cv_src.rows));
}

void show_img(std::string name, const cv::Mat& img)
{
	cv::namedWindow(name, 0);
	int max_rows = 500;
	int max_cols = 600;
	if (img.rows >= img.cols && img.rows > max_rows) {
		cv::resizeWindow(name, cv::Size(img.cols * max_rows / img.rows, max_rows));
	}
	else if (img.cols >= img.rows && img.cols > max_cols) {
		cv::resizeWindow(name, cv::Size(max_cols, img.rows * max_cols / img.cols));
	}
	cv::imshow(name, img);
}

cv::Mat replaceBG(const cv::Mat cv_src, cv::Mat& alpha, std::vector<int>& bg_color)
{
	int width = cv_src.cols;
	int height = cv_src.rows;

	cv::Mat cv_matting = cv::Mat::zeros(cv::Size(width, height), CV_8UC3);

	float* alpha_data = (float*)alpha.data;
	for (int i = 0; i < height; i++)
	{
		for (int j = 0; j < width; j++)
		{
			float alpha_ = alpha_data[i * width + j];
			cv_matting.at < cv::Vec3b>(i, j)[0] = cv_src.at < cv::Vec3b>(i, j)[0] * alpha_ + (1 - alpha_) * bg_color[0];
			cv_matting.at < cv::Vec3b>(i, j)[1] = cv_src.at < cv::Vec3b>(i, j)[1] * alpha_ + (1 - alpha_) * bg_color[1];
			cv_matting.at < cv::Vec3b>(i, j)[2] = cv_src.at < cv::Vec3b>(i, j)[2] * alpha_ + (1 - alpha_) * bg_color[2];
		}
	}

	return cv_matting;
}

int main()
{
	DIS dis_net("isnet_general_use_720x1280.onnx");

	std::string path = "images";
	std::vector<std::string> filenames;
	cv::glob(path, filenames, false);

	for (auto file_name : filenames)
	{
		cv::Mat cv_src = cv::imread(file_name);
		//std::vector<cv::Mat> cv_dsts;
		cv::Mat cv_dst, cv_mask;
		dis_net.inference(cv_src, cv_mask);
		std::vector<int> color{255, 0, 0};
		cv_dst=replaceBG(cv_src, cv_mask, color);

		show_img("src", cv_src);
		show_img("mask", cv_mask);
		show_img("dst", cv_dst);

		cv::waitKey(0);
	}
}

python推理代码也依赖onnxruntime

import argparse
import cv2
import numpy as np
import onnxruntime
### onnxruntime load ['isnet_general_use_HxW.onnx', 'isnet_HxW.onnx', 'isnet_Nx3xHxW.onnx']  inference failed
class DIS():
    def __init__(self, modelpath, score_th=None):
        so = onnxruntime.SessionOptions()
        so.log_severity_level = 3
        self.net = onnxruntime.InferenceSession(modelpath, so)
        self.input_height = self.net.get_inputs()[0].shape[2]
        self.input_width = self.net.get_inputs()[0].shape[3]
        self.input_name = self.net.get_inputs()[0].name
        self.output_name = self.net.get_outputs()[0].name
        self.score_th = score_th

    def detect(self, srcimg):
        img = cv2.resize(srcimg, dsize=(self.input_width, self.input_height))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img.astype(np.float32) / 255.0 - 0.5
        blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0).astype(np.float32)
        outs = self.net.run([self.output_name], {self.input_name: blob})
        
        mask = np.array(outs[0]).squeeze()
        min_value = np.min(mask)
        max_value = np.max(mask)
        mask = (mask - min_value) / (max_value - min_value)
        if self.score_th is not None:
            mask = np.where(mask < self.score_th, 0, 1)
        mask *= 255
        mask = mask.astype('uint8')

        mask = cv2.resize(mask, dsize=(srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_LINEAR)
        return mask

def generate_overlay_image(srcimg, mask):
    overlay_image = np.zeros(srcimg.shape, dtype=np.uint8)
    overlay_image[:] = (255, 255, 255)
    mask = np.stack((mask,) * 3, axis=-1).astype('uint8') 
    mask_image = np.where(mask, srcimg, overlay_image)
    return mask, mask_image

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--imgpath", type=str, default='images/cam_image47.jpg')
    parser.add_argument("--modelpath", type=str, default='weights/isnet_general_use_480x640.onnx')
    args = parser.parse_args()
    
    mynet = DIS(args.modelpath)
    srcimg = cv2.imread(args.imgpath)
    mask = mynet.detect(srcimg)
    mask, overlay_image = generate_overlay_image(srcimg, mask)

    winName = 'Deep learning object detection in onnxruntime'
    cv2.namedWindow(winName, cv2.WINDOW_NORMAL)
    cv2.imshow(winName, np.hstack((srcimg, mask)))
    cv2.waitKey(0)
    cv2.destroyAllWindows()

推理结果
图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)-LMLPHP
图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)-LMLPHP
图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)-LMLPHP
图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)-LMLPHP
图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)-LMLPHP
图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)-LMLPHP
图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)-LMLPHP
图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)-LMLPHP
资源和模型下载地址:https://download.csdn.net/download/matt45m/89024664

03-26 16:40