見出し画像

segment-anythingを使って画像を単純化してみる

segment-anythingを使って、画像をセグメント事に平均色で塗りつぶすスクリプトを作ってみたので公開します。
flatなり、塗りつぶし指示ControlNetなり作れそうな予感(確証はない)。

環境構築

git clone https://github.com/facebookresearch/segment-anything.git

cd segment-anything
python -m venv venv
.\venv\Scripts\activate

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
pip install opencv-python
pip install numpy

以下のscriptを実行。

import numpy as np
import torch
import cv2
import os
import sys
import requests
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

def download_file(url, save_path):
    """指定されたURLからファイルをダウンロードして、指定されたパスに保存する"""
    response = requests.get(url, stream=True)
    with open(save_path, 'wb') as f:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                f.write(chunk)
    return save_path

def reduce_colors(img, n_colors):
    if img.shape[2] == 4:
        Z = img[:, :, :3].reshape((-1, 3))
    else:
        Z = img.reshape((-1, 3))

    Z = np.float32(Z)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    ret, label, center = cv2.kmeans(Z, n_colors, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
    center = np.uint8(center)
    res = center[label.flatten()]
    img_reduced = res.reshape((img.shape[0], img.shape[1], 3))

    return img_reduced

def mask_color_avg(anns, original_image):
    if len(anns) == 0:
        return None

    sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)
    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4), dtype=np.float32)
    img[:, :, 3] = 0
    global_avg_color = original_image.mean(axis=(0, 1))

    for ann in sorted_anns:
        m = ann['segmentation']
        masked_image = original_image * np.expand_dims(m, axis=-1)
        avg_color = masked_image[m].mean(axis=0)
        color_mask = np.concatenate([avg_color / 255, [1]])
        img[m] = color_mask

    img[img[:, :, 3] == 0] = np.concatenate([global_avg_color / 255, [1]])
    img = img.astype(np.float32)
    return img

def main(image, sam):
    #処理を軽くするためimageリサイズ
    w,h = image.shape[1], image.shape[0]
    image = cv2.resize(image, (512, 512))
    #imageをブラーにかける
    image = cv2.GaussianBlur(image, (7, 7), 0)


    mask_generator = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=32,
        pred_iou_thresh=0.5,
        stability_score_thresh=0.5,
        crop_n_layers=2,
        crop_n_points_downscale_factor=4,
        min_mask_region_area=100,
    )
    masks = mask_generator.generate(image)
    img = mask_color_avg(masks, image)
    img = cv2.resize(img, (w, h))
    img_to_save = (img * 255).astype(np.uint8)

    return img_to_save

if __name__ == "__main__":
    input_dir = "E:/desktop/sfw/webp"
    output_dir = "E:/desktop/flat_collor"
    checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
    sam_checkpoint = "sam_vit_h_4b8939.pth"
    model_type = "vit_h"
    device = "cuda"

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if not os.path.exists(sam_checkpoint):
        print(f"Checkpoint {sam_checkpoint} not found. Downloading...")
        download_file(checkpoint_url, sam_checkpoint)
        print("Download completed.")

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)       

    for filename in os.listdir(input_dir):
        save_path = os.path.join(output_dir, filename)
        if filename.endswith(('.png', '.webp')) and not os.path.exists(save_path):
            image_path = os.path.join(input_dir, filename)
            image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
            processed_image = main(image, sam)

            if processed_image is not None:
                cv2.imwrite(save_path, processed_image)
                print(f"Processed and saved {filename} to {output_dir}")
            else:
                print(f"Processing of {filename} failed.")
        else:
            print(f"{filename} already exists, skipping...")

points_per_side (type: int, デフォルト: None):
このオプションで指定された整数値は、画像上に生成されるグリッドのポイント数を示します。各辺に対してのポイント数を設定します。デフォルト値では指定されておらず、必要に応じて設定する必要があります。

points_per_batch (type: int, デフォルト: None):
1つのバッチで同時に処理される入力ポイント(またはピクセル)の数を整数値で指定します。デフォルト値では指定されておらず、必要に応じて設定する必要があります。

pred_iou_thresh (type: float, デフォルト: None):
モデルからの予測スコアがこの浮動小数点数の閾値未満のマスクは除外されます。デフォルト値では指定されておらず、必要に応じて設定する必要があります。

stability_score_thresh (type: float, デフォルト: None):
安定性スコアがこの浮動小数点数の閾値未満のマスクは除外されます。デフォルト値では指定されておらず、必要に応じて設定する必要があります。

stability_score_offset (type: float, デフォルト: None):
安定性スコアを測定する際に使用される浮動小数点数のオフセット値です。デフォルト値では指定されておらず、必要に応じて設定する必要があります。

__box_nms_thresh (type: float, デフォルト: None):
重複するマスクを除外するための重なりの浮動小数点数の閾値を指定します。デフォルト値では指定されておらず、必要に応じて設定する必要があります。

crop_n_layers (type: int, デフォルト: None):
画像の小さなクロップ上でマスク生成を実行する場合に使用される整数値です。デフォルト値では指定されておらず、必要に応じて設定する必要があります。

crop_nms_thresh (type: float, デフォルト: None):
異なるクロップ間での重複するマスクを除外するための重なりの浮動小数点数の閾値を指定します。デフォルト値では指定されておらず、必要に応じて設定する必要があります。

crop_overlap_ratio (type: int, デフォルト: None):
この整数値が大きいほど、画像クロップ間のオーバーラップが増加します。デフォルト値では指定されておらず、必要に応じて設定する必要があります。

crop_n_points_downscale_factor (type: int, デフォルト: None):
各クロップの層ごとのポイント数を整数値のファクターで減少させます。デフォルト値では指定されておらず、必要に応じて設定する必要があります。

min_mask_region_area (type: int, デフォルト: None):
ピクセル単位の整数値で、この値未満の面積を持つ非連結のマスク領域または穴は、事後処理によって削除されます。デフォルト値では指定されておらず、必要に応じて設定する必要があります。

ベースのデータセットは前回の記事で紹介したpoloclub/diffusiondのlarge_random_50kです。

こんな感じのことができます。

この記事が気に入ったらサポートをしてみませんか?