見出し画像

Pythonライブラリ (超解像|画質向上):DiffBIR


1.緒言

 低画質の画像を高画質に変える技術である”超解像”として「DiffBIR」を紹介します。結論として、GPUでの実装まではできなかったため、CPUで時間かけても良い人向けとなります。

 ご参考までに同様の技術としてReal-ESRGANも紹介していますのでご確認ください。

1-1.DiffBIRの概要

 下図の通り、ぼやけた画像を鮮明にしたりできます。他の技術との優位性比較をしており、見たところ高性能のように見えます。

(a) Visual comparison of blind image super-resolution (BSR) methods on real-world low-quality images.
(b) Visual comparison of blind face restoration (BFR) methods on real-world low-quality face images.

1-2.DiffBIRのアーキテクチャ

 DiffBIRのアーキテクチャは下記の通りです。詳細は理解できませんが、2段のパイプラインで処理しており、Stable DiffusionやVAEエンコーダを使用しており、GANとは異なっていそうです。

https://0x3f3f3f3fun.github.io/projects/diffbir/
https://arxiv.org/pdf/2308.15070.pdf

2.環境構築:Windows編

 Windowsで環境構築する場合を記載します。

2-1.Windowsとの相性確認

 大前提として公式はLinux OS以外の使用を推奨していません

 Windowsの場合は「tritonのインストール時に躓くため、GPUでなくCPUで動かせる。CUDAを使うならissue#24で確認して」とあります。

https://github.com/XPixelGroup/DiffBIR/blob/main/assets/docs/installation_xOS.md

 よって環境構築の方針は下記2点で対応しました。

  1. CPUのみで使用できるようにする

  2. 頑張ってGPUも使えるようにする(tritonをインストールする

2-2.仮想環境の作成

 仮想環境をCondaで作成します。ここでのポイントは「公式はPython3.9を推奨しているが、GPUを使いたいならPython3.10を使用」です。
 環境が出来たら仮想環境を有効化します。

[Terminal:CPUだけでよい人]
conda create -n diffbir python=3.9
conda activate diffbir
[Terminal:GPU使いたい人]
conda create -n diffbir python=3.10
conda activate diffbir

2-3.ライブラリインストール

 ライブラリをインストールします。前述の通りWindowsではそのままだとtritonを入れられないためエラーが発生しますので少し対応します。

 2-3-1.CPUのみ使用用

 GPUを使わない人は”requirements.txt”を開いてtritonをコメントアウトしてから下記コマンド実行します。

[Terminal]
pip install -r requirements.txt

【参考:コメントアウトしない場合】
 tritonのコメントアウトをしない場合は、公式が説明している通りエラーが発生しました。

 2-3-2.GPUで実行用

 下記issueにHuggingfaceが提供するカスタムビルドされたTritonのバイナリが紹介されています。このバイナリを使用してTritonをインストールします。

https://github.com/openai/triton/issues/1057

 注意点としてPython3.9だと下記のようなエラーが発生します。必ず仮想環境構築時はPython3.10系を使用してください。

 実装方法としてはtritonをコメントアウトした後にpip installを実行して、次に別途tritonをインストールしました。

[Terminal]
pip install -r requirements.txt
pip install https://huggingface.co/r4ziel/xformers_pre_built/resolve/main/triton-2.0.0-cp310-cp310-win_amd64.whl

 私の環境では問題なくインストールできました。

2-4.学習済みモデルの重みを取得

 weightsフォルダを作成して、公式の下記からダウンロードしたファイルをweightsフォルダ内に保存します。

https://github.com/XPixelGroup/DiffBIR

3.環境構築:Google Colab編

 推奨はLinux OSのみのためGoogle Colabを使用していきます。

 現状では環境構築が出来なかったため、うまくいったら記事を修正します。参考までにPythonのVersionは3.10系です。

【パターン1:公式通り】
 Google Colab内で仮想環境を作るのは難しいため、仮想環境の作成を除いて公式通りに進めていきます。
 まずはレポジトリをクローンしてライブラリをインストールします。

[Terminal]
!git clone https://github.com/XPixelGroup/DiffBIR.git
%cd DiffBIR

!pip install -r requirements.txt

 次にweightsフォルダを作成して、その中に必要な学習済みモデルをインストールします。

[Terminal]
!mkdir weights
!wget -O weights/general_full_v1.ckpt https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt
!wget -O weights/general_swinir_v1.ckpt https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt

 最後にコードを実行します。するとPytorchに関するエラーが発生しました。torchdataと記載がありますがおそらくpytorch lightning関連の互換性のエラーだと予想されます。
 修復がムリゲーだったためあきらめました。

[Terminal]
!python inference.py --input inputs/demo/general --output results/demo/general --config configs/model/cldm.yaml --ckpt weights/general_full_v1.ckpt --device cuda
[OUT]
Traceback (most recent call last):
  File "/content/DiffBIR/DiffBIR/DiffBIR/inference.py", line 9, in <module>
    import pytorch_lightning as pl
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/__init__.py", line 20, in <module>
    from pytorch_lightning import metrics  # noqa: E402
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/metrics/__init__.py", line 15, in <module>
    from pytorch_lightning.metrics.classification import (  # noqa: F401
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/metrics/classification/__init__.py", line 14, in <module>
    from pytorch_lightning.metrics.classification.accuracy import Accuracy  # noqa: F401
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/metrics/classification/accuracy.py", line 18, in <module>
    from pytorch_lightning.metrics.utils import deprecated_metrics, void
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/metrics/utils.py", line 29, in <module>
    from pytorch_lightning.utilities import rank_zero_deprecation
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/__init__.py", line 18, in <module>
    from pytorch_lightning.utilities.apply_func import move_data_to_device  # noqa: F401
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/apply_func.py", line 34, in <module>
    from torchtext.legacy.data import Batch
  File "/usr/local/lib/python3.10/dist-packages/torchtext/__init__.py", line 12, in <module>
    from . import data, datasets, prototype, functional, models, nn, transforms, utils, vocab, experimental
  File "/usr/local/lib/python3.10/dist-packages/torchtext/datasets/__init__.py", line 3, in <module>
    from .ag_news import AG_NEWS
  File "/usr/local/lib/python3.10/dist-packages/torchtext/datasets/ag_news.py", line 12, in <module>
    from torchdata.datapipes.iter import FileOpener, IterableWrapper
  File "/usr/local/lib/python3.10/dist-packages/torchdata/__init__.py", line 9, in <module>
    from . import datapipes
  File "/usr/local/lib/python3.10/dist-packages/torchdata/datapipes/__init__.py", line 9, in <module>
    from . import iter, map, utils
  File "/usr/local/lib/python3.10/dist-packages/torchdata/datapipes/iter/__init__.py", line 124, in <module>
    from torchdata.datapipes.iter.util.sharding import (
  File "/usr/local/lib/python3.10/dist-packages/torchdata/datapipes/iter/util/sharding.py", line 9, in <module>
    from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
ModuleNotFoundError: No module named 'torch.utils.data.datapipes.iter.sharding'

【パターン2:公式のColab】
 公式GitHubから”Open in Colab”を開いてコマンドを実行しました。

https://github.com/XPixelGroup/DiffBIR
[DiffBIR_colab.ipynb]
%cd /content
!git clone -b dev https://github.com/camenduru/DiffBIR
%cd /content/DiffBIR

!pip install -q einops pytorch_lightning gradio omegaconf transformers lpips # opencv-python
!pip install -q https://download.pytorch.org/whl/cu118/xformers-0.0.22.post4%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl
!pip install -q git+https://github.com/mlfoundations/open_clip@v2.20.0

!apt -y install -qq aria2
# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt -d /content/models -o face_full_v1.ckpt
# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt -d /content/models -o face_swinir_v1.ckpt
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt -d /content/models -o general_full_v1.ckpt
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt -d /content/models -o general_swinir_v1.ckpt

# !python gradio_diffbir.py --ckpt /content/models/face_full_v1.ckpt --config /content/DiffBIR/configs/model/cldm.yaml --reload_swinir --swinir_ckpt /content/models/face_swinir_v1.ckpt
!python gradio_diffbir.py --ckpt /content/models/general_full_v1.ckpt --config /content/DiffBIR/configs/model/cldm.yaml --reload_swinir --swinir_ckpt /content/models/general_swinir_v1.ckpt

 inputsに画像を手動で入れて、モデルの場所を移動させても指定したフォルダ(--output results)に結果が出力されません。エラー文も出ないため原因不明です。

[Terminal]
import os 
if not os.path.exists('inputs'):
    os.mkdir('inputs')

!mkdir weights
!wget -O weights/general_full_v1.ckpt https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt
!wget -O weights/general_swinir_v1.ckpt https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt

!python inference.py --input inputs --config configs/model/cldm.yaml --ckpt weights/general_full_v1.ckpt --reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt --steps 50 --sr_scale 4 --output results 

4.実行編:Windows

 それでは実際に実行して超解像を試してみます。

4-1.CPUのみ:--device cpu

 コマンドの"--device"オプションをCPUに指定します。その他のオプションは下記の通りです(inference.pyのdef parse_args()参照)。

  1. --input inputs/demo/general: 入力画像が保存ディレクトリを指定

    • この場合”inputs/demo/general”ディレクトリ内の画像を処理

  2. --config configs/model/cldm.yaml:モデルの設定ファイルのパスを指定

  3. --ckpt weights/general_full_v1.ckpt:モデルの重みのファイルパスを指定

    • general_full_v1.ckpt は公式GitHubからダウンロードしてweightsに保存

  4. --reload_swinir: SwinIRモデルの重みをリロード

  5. --swinir_ckpt weights/general_swinir_v1.ckpt: SwinIRモデル重みのパス

    • general_swinir_v1.ckpt は公式からダウンロードしてweightsに保存。

  6. --steps 50: 処理の際に実行するステップ数を指定

  7. --sr_scale 4:高解像度化のスケールを指定(画像のサイズを4倍に拡大)

  8. --color_fix_type wavelet:色補正のタイプを指定

  9. --output results/demo/general:出力結果を保存するディレクトリを指定

  10. --device cuda:処理に使用するデバイスを指定

  11. --tiled, --tile_size 512, --tile_stride 256:画像をタイルに分割して処理?

    • 実施したけどエラーが発生(※原因不明

[Terminal]
python inference.py \
--input inputs/demo/general \
--config configs/model/cldm.yaml \
--ckpt weights/general_full_v1.ckpt \
--reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt \
--steps 50 \
--sr_scale 4 \
--color_fix_type wavelet \
--output results/demo/general \
--device cpu
[Terminal※1列Ver.]
python inference.py --input inputs/demo/general --config configs/model/cldm.yaml --ckpt weights/general_full_v1.ckpt --reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt --steps 50 --sr_scale 4 --color_fix_type wavelet --output results/demo/general --device cpu

 問題なく実行されると出力先が”save to <保存ディレクトリ>”で表示(--outputオプションで指定した場所と同じ)されており、画像名は接尾語として"_0"がついております。

 CPU(AMD Ryzen 7 3700X 8-Core Processor 3.59 GHz)だけだと、1枚の画像で1hくらいかかりました。

【結果の確認】
 結果は下記の通りです。

4-2.GPUで処理

 GPUで処理したい場合は"--device cuda"とします。デフォルトはcudaのため、おそらく記載なしでもいけると思います。

[Terminal]
python inference.py --input inputs/demo/general --config configs/model/cldm.yaml --ckpt weights/general_full_v1.ckpt --reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt --steps 50 --sr_scale 4 --color_fix_type wavelet --output results/demo/general --device cuda
[OUT]
torch.cuda.OutOfMemoryError: CUDA out of memory. 
Tried to allocate 468.00 MiB (GPU 0; 8.00 GiB total capacity; 6.94 GiB already allocated;
0 bytes free; 7.01 GiB reserved in total by PyTorch)
If reserved memory is >> allocated memory try setting max_split_size_mb to avoid 
fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

 私のPCはNVIDIA GeForce RTX 2060(8GB)を積んでますが、メモリが全然足りず動きませんでした。


参考資料

参考資料1:実装

参考資料2:技術

参考資料3:その他

あとがき

 そろそろちゃんと実装したい・・・・
 あと、メモリの大きいGPUが無いとなんもできないけど金がねえ・・・



この記事が気に入ったらサポートをしてみませんか?