セミリアルタイム、高速イメージ・アップスケールAPI
real-SERGANを用いたアップスケールAPIサーバとクリアント側の実装例です。リアルタイムまでは届きませんでした。
環境
CUDA 12.2
PyTORCH 2.2.1
Python 3.9
GPU:RTX4090
CPU:Core™ i5-13600K
変換速度
TCP/IP版 16.5fps
FastAPI版 15fps
TCP/IP版
送受信にプロトコルによるオーバーヘッドが少ないと予想される手法です。
データをシリアライズしてTCP/IPパケットにそのまま載せています。受信側はパケット受信後にデシリアライズで元データに復元し、呼び出し側へ結果としてリターンしています。クライアント側は専用の通信関数が必要なのでサーバとクライアントが対になります。
FastAPI版
ASOGフォームで通信を行います。すなわちPOST/GETでデータを送受信できます。通信処理のオーバーヘッドが増えるので、若干パフォーマンスが低下しています。
コード
サーバ部全体のコードはTCP/IP版、FastAPI版共に最後に記載します。
TCP/IP版の実装
アーギュメント処理とイニシャライズ
モデル選択ができるよになっているのでこの部分が長いですが、やっていることは単純です。モデルパスは面倒なので削除しました。modelsディレクトリを作成してダウンロードしたモデルを入れておいてください。
def main():
global upsampler
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
parser.add_argument('-n','--model_name', type=str, default='RealESRGAN_x4plus', help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus | realesr-animevideov3 | realesr-general-x4v3'))
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
parser.add_argument('-dn','--denoise_strength',type=float, default=0.5, help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. Only used for the realesr-general-x4v3 model'))
parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
parser.add_argument( '-t', '--test', type=bool, default=False, help='excecute test PG if True')
parser.add_argument("--host", type=str, default="0.0.0.0", help="サービスを提供するip アドレスを指定。")
parser.add_argument("--port", type=int, default=50008, help="サービスを提供するポートを指定。")
args = parser.parse_args()
# determine models according to model names
args.model_name = args.model_name.split('.')[0]
if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
netscale = 4
elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
#+++++++++++++++++++ init +++++++++++++++++++
model_path = "./weights/" + args.model_name +".pth"
print(model_path )
print(netscale)
# use dni to control the denoise strength
dni_weight = None
if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
# restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=0,
tile_pad=10,
pre_pad=0,
half=True,
gpu_id=0)
通信モジュールとサーバ
単純に受信待ちで監視し、クライアントからリクエストが来れば、
up_scale(img , scale)関数を呼び、結果をクリアントへ返送するだけの簡単な処理です。
# ++++++++++++++ TCP/IP server ++++++++
if args.test==False:
host=args.host
port=args.port
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # ソケット定義(IPv4,TCPによるソケット)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((host,port))
s.listen(10) # ソケット接続待受(キューの最大数を指定)
while True:
try:
try: # ソケット接続受信待ち
print(host,port,'クライアントからの接続待ち...')
clientsock, client_address = s.accept()
print("Conected client= ", client_address," Conection date_time= ",datetime.now())
except KeyboardInterrupt: # 接続待ちの間に強制終了が入った時の例外処理
clientsock.shutdown(1)# データ送信完了後、送信路を閉じる
break
else: # 接続待ちの間に強制終了なく、クライアントからの接続が来た場合
all_data=b'' # 受信データ保存用変数の初期化
while True: # ソケット接続開始後の処理
data = clientsock.recv(4096*256) # データ受信。受信バッファサイズ1024バイト
if not data: # 全データ受信完了(受信路切断)時に、ループ離脱
break
all_data += data # 受信データを追加し繋げていく
get_data=(pickle.loads(all_data)) #受信データ解析 元の形式にpickle.loadsで復元 get_data[0]=OpenCV-image , get_data[1]=scale
tx_gen_out=up_scale(get_data[0], float(get_data[1]))# 背景削除実行
# 結果をクライアントに送信
tx_dat=pickle.dumps(tx_gen_out,5)
clientsock.send(tx_dat) #pickle.dumpsでシリアライズ
clientsock.shutdown(1)# データ送信完了後、送信路を閉じる
print(client_address," へ送信完了","Tx bytes=",len(tx_dat)/1000,"kB",datetime.now())
except:
print("connection error")
アップスケール関数
ここが、今回のreal-ESERの実行を行う部分です。サーバ化しない場合はこの関数をプログラムから呼び出せば直接実行できます。通信プロトコルが入らないのでパフォーマンスは上がります。モデル関係の初期化は別途必要です。
def up_scale(img , scale):
global upsampler
try:
output, _ = upsampler.enhance(img , outscale=scale)
except RuntimeError as error:
print('Error', error)
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
return output
TESTコード
アップスケール関数が正しいかどうかを確かめるための処理です。冒頭でif args.test==True: としているので、--test Falseで実行されません。デフォルトは実行されます。
#+++++++++++++++++++ TEST +++++++++++++++++++
if args.test==True:
if os.path.isfile(args.input):
paths = [args.input]
else:
paths = sorted(glob.glob(os.path.join(args.input, '*')))
img_list=[]
for idx, path in enumerate(paths):
imgname, extension = os.path.splitext(os.path.basename(path))
print('Testing', idx, imgname)
cv_img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img_list.append(cv_img)
print("start_time=",datetime.now())
count=len(img_list)
for i in range(0,count):
img=img_list[i]
output = up_scale(img , args.outscale)
if len(img.shape) == 3 and img.shape[2] == 4:
extension = '.png'
else:
extension = '.jpg'
save_path = "./results/" + args.output+ str(i)+extension
cv2.imwrite(save_path, output) #if files are require
print("end_time=",datetime.now())
クライント側
前半はアーギュメントとTESTプログラム、中間にTCP/IPプロトコルがあります。最後の up_scale(img , scale) をアプリにimportして呼び出せばアップスケールされた画像が受け取れます。scaleは倍率で2,4,8が指定出来ます。イニシャライズは不要です。
クライアント側全コード
import argparse
import cv2
import glob
import os
from datetime import datetime
import pickle
import socket
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
parser.add_argument( '-t', '--test', type=bool, default=False, help='excecute test PG if True')
args = parser.parse_args()
#+++++++++++++++++++ TEST +++++++++++++++++++
if args.test==True:
if os.path.isfile(args.input):
paths = [args.input]
else:
paths = sorted(glob.glob(os.path.join(args.input, '*')))
img_list=[]
for idx, path in enumerate(paths):
imgname, extension = os.path.splitext(os.path.basename(path))
print('Testing', idx, imgname)
cv_img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img_list.append(cv_img)
start_time = datetime.now()
count=len(img_list)
for i in range(0,count):
img=img_list[i]
output = up_scale(img , args.outscale) # <<<<<<<<<<<<<<<<<<<<< inference関数 up_scale(img , scale)
if len(img.shape) == 3 and img.shape[2] == 4:
extension = '.png'
else:
extension = '.jpg'
save_path = args.output + "/" + str(i)+extension
print(save_path)
cv2.imwrite(save_path, output) #if files are require
print("No of pictures =",count)
print("Start_time=", start_time)
print("End_time =",datetime.now())
# ++++++++++++++ TCP/IP server ++++++++
def get_out(tx_list):
host="127.0.0.1" # サーバーIPアドレス定義
port=8001 # サーバー待ち受けポート番号定義
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # ソケットクライアント作成
s.connect((host, port)) # 送信先サーバーに接続
all_data=b'' # 受信データ保存用変数の初期化
print(host , port , "del_bkg function TX start",datetime.now())
tx_data=pickle.dumps(tx_list,5)
s.send(tx_data) #pickle.dumpsで送信データをシリアライズしサーバに送信
s.shutdown(1)# データ送信完了後、送信路を閉じる
while True: # ソケット接続開始後の処理
data = s.recv(4096*256) # データ受信。受信バッファサイズ4096*256バイト
if not data: # 全データ受信完了(受信路切断)時に、ループ離脱
break
all_data += data # 受信データを追加し繋げていく
get_out =(pickle.loads(all_data))#元の形式にpickle.loadsで復元
print("upscal function RX done",datetime.now())
return get_out
# ++++++++++++++ up scale ++++++++++++++++
def up_scale(img , scale):
tx_list=[]
tx_list.append(img)
tx_list.append(scale)
try:
image =get_out(tx_list)
except RuntimeError as error:
print('Error', error)
return image
if __name__ == '__main__':
main()
FastAPI版の実装
FastAPIによるASOGフォームの通信を用いた実装です。冒頭にも書いたようにパフォーマンスは若干落ちます。
コード
冒頭に up_scale(img , scale) 関数があります。real-ESERの実行を行う部分です。
# ++++++++++++++ up scale ++++++++++++++++
def up_scale(img , scale):
print("inf_start_time=",datetime.now())
global upsampler
try:
output, _ = upsampler.enhance(img , outscale=scale)
except RuntimeError as error:
print('Error', error)
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
print("inf_end_time=",datetime.now())
return output
その後に続く部分がアーギュメント処理部分です。
更にTCP/IPと同じテストプログラムがあります。
通信部
FastAPIによる通信処理部です。エンドポイントは
resr_upscalのみです。クライントから受け取るデータはOpenCV形式の画像データ、パラメータはint型のscaleです。
out_img = up_scale(img ,scale)
でアップスケール処理を呼び出して、結果のイメージを返送しています。
形式がjesonではないので注意してください。
# ============= FastAPI ============
app = FastAPI()
@app.post("/resr_upscal/")
async def resr_upscal(file: UploadFile = File(...), scale:int = Form(...)): #file=OpenCV
print("scale=",scale)
scale=float(scale)
file_contents = await file.read()
nparr = np.frombuffer(file_contents, np.uint8) # バイナリデータをNumPy配列に変換
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # OpenCVで画像として読み込む
out_img = up_scale(img ,scale)
frame_data = pickle.dumps(out_img, 5) # tx_dataはpklデータ、イメージのみ返送
print("send_time=",datetime.now())
return Response(content=frame_data, media_type="application/octet-stream")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8008)
APIクライント
こちらも前半はTCP/IPと同じで、アーギュメント処理とテストプログラム、後半に通信用関数が有ります。
import up_scale
でアプリから使うことができます。イニシャライズは不要です。本番ではテストプログラム部分は不要です。なのでわずかなコードだけです。
up_scale(url , img , scale)関数を直接プログラムに埋め込んでも動きます。
クライアント側全コード
import argparse
import cv2
import glob
import os
from datetime import datetime
import pickle
import requests
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
parser.add_argument('-s', '--outscale', type=str, default=4, help='The final upsampling scale of the image')
parser.add_argument( '-t', '--test', type=bool, default=False, help='excecute test PG if True')
parser.add_argument("--host", type=str, default="0.0.0.0", help="サービスを提供するip アドレスを指定。")
parser.add_argument("--port", type=int, default=50008, help="サービスを提供するポートを指定。")
args = parser.parse_args()
host="0.0.0.0" # サーバーIPアドレス定義
port=8008 # サーバー待ち受けポート番号定義
url="http://" + host + ":" + str(port) + "/resr_upscal/"
#+++++++++++++++ test +++++++++++++++++++
if args.test==True:
if os.path.isfile(args.input):
paths = [args.input]
else:
paths = sorted(glob.glob(os.path.join(args.input, '*')))
img_list=[]
for idx, path in enumerate(paths):
imgname, extension = os.path.splitext(os.path.basename(path))
print('Testing', idx, imgname)
cv_img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img_list.append(cv_img)
start_time = datetime.now()
count=len(img_list)
for i in range(0,count):
img=img_list[i]
output = up_scale(url, img , args.outscale) # <<<<<<<<<<<<<<<<<< up_scale(url , img , scale):
if len(img.shape) == 3 and img.shape[2] == 4:
extension = '.png'
else:
extension = '.jpg'
save_path = args.output + "/" + str(i)+extension
print(save_path)
cv2.imwrite(save_path, output) #if files are require #ファイルへ書き出しをすると遅くなります。
print("start_time=",start_time)
print("end_time=",datetime.now())
# ++++++++++++++ up scale ++++++++++++++++
def up_scale(url , img , scale):
_, img_encoded = cv2.imencode('.jpg', img)
response = requests.post(url, files={"file": ("image.jpg", img_encoded.tobytes(), "image/jpeg"),"scale":(None,scale)})
all_data =response.content
up_data = (pickle.loads(all_data))#元の形式にpickle.loadsで復元
return up_data #形式はimg_mode指定の通り
if __name__ == '__main__':
main()
サーバの全コード
TCP/IP版
import argparse
import cv2
import glob
import os
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from datetime import datetime
import socket
import pickle
def main():
global upsampler
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
parser.add_argument('-n','--model_name', type=str, default='RealESRGAN_x4plus', help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus | realesr-animevideov3 | realesr-general-x4v3'))
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
parser.add_argument('-dn','--denoise_strength',type=float, default=0.5, help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. Only used for the realesr-general-x4v3 model'))
parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
parser.add_argument( '-t', '--test', type=bool, default=False, help='excecute test PG if True')
parser.add_argument("--host", type=str, default="0.0.0.0", help="サービスを提供するip アドレスを指定。")
parser.add_argument("--port", type=int, default=50008, help="サービスを提供するポートを指定。")
args = parser.parse_args()
# determine models according to model names
args.model_name = args.model_name.split('.')[0]
if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
netscale = 4
elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
#+++++++++++++++++++ init +++++++++++++++++++
model_path = "./weights/" + args.model_name +".pth"
print(model_path )
print(netscale)
# use dni to control the denoise strength
dni_weight = None
if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
# restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=0,
tile_pad=10,
pre_pad=0,
half=True,
gpu_id=0)
#+++++++++++++++++++ TEST +++++++++++++++++++
if args.test==True:
if os.path.isfile(args.input):
paths = [args.input]
else:
paths = sorted(glob.glob(os.path.join(args.input, '*')))
img_list=[]
for idx, path in enumerate(paths):
imgname, extension = os.path.splitext(os.path.basename(path))
print('Testing', idx, imgname)
cv_img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img_list.append(cv_img)
print("start_time=",datetime.now())
count=len(img_list)
for i in range(0,count):
img=img_list[i]
output = up_scale(img , args.outscale)
if len(img.shape) == 3 and img.shape[2] == 4:
extension = '.png'
else:
extension = '.jpg'
save_path = "./results/" + args.output+ str(i)+extension
cv2.imwrite(save_path, output) #if files are require
print("end_time=",datetime.now())
# ++++++++++++++ TCP/IP server ++++++++
if args.test==False:
host=args.host
port=args.port
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # ソケット定義(IPv4,TCPによるソケット)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((host,port))
s.listen(10) # ソケット接続待受(キューの最大数を指定)
while True:
try:
try: # ソケット接続受信待ち
print(host,port,'クライアントからの接続待ち...')
clientsock, client_address = s.accept()
print("Conected client= ", client_address," Conection date_time= ",datetime.now())
except KeyboardInterrupt: # 接続待ちの間に強制終了が入った時の例外処理
clientsock.shutdown(1)# データ送信完了後、送信路を閉じる
break
else: # 接続待ちの間に強制終了なく、クライアントからの接続が来た場合
all_data=b'' # 受信データ保存用変数の初期化
while True: # ソケット接続開始後の処理
data = clientsock.recv(4096*256) # データ受信。受信バッファサイズ1024バイト
if not data: # 全データ受信完了(受信路切断)時に、ループ離脱
break
all_data += data # 受信データを追加し繋げていく
get_data=(pickle.loads(all_data)) #受信データ解析 元の形式にpickle.loadsで復元 get_data[0]=OpenCV-image , get_data[1]=scale
tx_gen_out=up_scale(get_data[0], float(get_data[1]))# 背景削除実行
# 結果をクライアントに送信
tx_dat=pickle.dumps(tx_gen_out,5)
clientsock.send(tx_dat) #pickle.dumpsでシリアライズ
clientsock.shutdown(1)# データ送信完了後、送信路を閉じる
print(client_address," へ送信完了","Tx bytes=",len(tx_dat)/1000,"kB",datetime.now())
except:
print("connection error")
# ++++++++++++++ up scale ++++++++++++++++
def up_scale(img , scale):
global upsampler
try:
output, _ = upsampler.enhance(img , outscale=scale)
except RuntimeError as error:
print('Error', error)
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
return output
if __name__ == '__main__':
main()
FastAPI版サーバ
import argparse
import cv2
import glob
import os
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from datetime import datetime
import pickle
from fastapi import FastAPI, File, UploadFile, Form
from starlette.responses import Response
from io import BytesIO
import numpy as np
# ++++++++++++++ up scale ++++++++++++++++
def up_scale(img , scale):
print("inf_start_time=",datetime.now())
global upsampler
try:
output, _ = upsampler.enhance(img , outscale=scale)
except RuntimeError as error:
print('Error', error)
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
print("inf_end_time=",datetime.now())
return output
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
parser.add_argument('-n','--model_name', type=str, default='RealESRGAN_x4plus', help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus | realesr-animevideov3 | realesr-general-x4v3'))
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
parser.add_argument('-dn','--denoise_strength',type=float, default=0.5, help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. Only used for the realesr-general-x4v3 model'))
parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
parser.add_argument( '-t', '--test', type=bool, default=False, help='excecute test PG if True')
parser.add_argument("--host", type=str, default="0.0.0.0", help="サービスを提供するip アドレスを指定。")
parser.add_argument("--port", type=int, default=50008, help="サービスを提供するポートを指定。")
args = parser.parse_args()
# determine models according to model names
args.model_name = args.model_name.split('.')[0]
if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
netscale = 4
elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
#+++++++++++++++++++ init +++++++++++++++++++
model_path = "./weights/" + args.model_name +".pth"
print(model_path )
print(netscale)
# use dni to control the denoise strength
dni_weight = None
if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
# restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=0,
tile_pad=10,
pre_pad=0,
half=True,
gpu_id=0)
#+++++++++++++++++++ TEST +++++++++++++++++++
if args.test==True:
if os.path.isfile(args.input):
paths = [args.input]
else:
paths = sorted(glob.glob(os.path.join(args.input, '*')))
img_list=[]
for idx, path in enumerate(paths):
imgname, extension = os.path.splitext(os.path.basename(path))
print('Testing', idx, imgname)
cv_img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img_list.append(cv_img)
print("start_time=",datetime.now())
count=len(img_list)
for i in range(0,count):
img=img_list[i]
output = up_scale(img , args.outscale)
if len(img.shape) == 3 and img.shape[2] == 4:
extension = '.png'
else:
extension = '.jpg'
save_path = "./results/" + args.output+ str(i)+extension
cv2.imwrite(save_path, output) #if files are require
print("end_time=",datetime.now())
# ============= FastAPI ============
app = FastAPI()
@app.post("/resr_upscal/")
async def resr_upscal(file: UploadFile = File(...), scale:int = Form(...)): #file=OpenCV
print("scale=",scale)
scale=float(scale)
file_contents = await file.read()
nparr = np.frombuffer(file_contents, np.uint8) # バイナリデータをNumPy配列に変換
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # OpenCVで画像として読み込む
out_img = up_scale(img ,scale)
frame_data = pickle.dumps(out_img, 5) # tx_dataはpklデータ、イメージのみ返送
print("send_time=",datetime.now())
return Response(content=frame_data, media_type="application/octet-stream")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8008)