写真の解像度を上げるTecoGAN
この記事では写真を高解像度にするTecoGANを紹介します。
写真を撮ったけれど画質が微妙で、もっと綺麗にしたいという時に使うと良いかもしれません。人以外にも風景や生き物など様々な画像に対応しています。自然な変換なのも特徴の一つでしょう。
推論方法
TecoGANを動かす環境をローカルで構築しても良いのですが、簡単に実行するためにGoogle Colaboratoryを使用しました。(colabのコードはこちらです。)
まず初めにGitHubからthunil/TecoGANのコードをダウンロードし、GoogleDriveにアップロードします。次に必要なライブラリをダウンロードします。これはcolab上に書いていますのでご確認ください。
実行する前に既存のモデルをダウンロードする必要があります。学習を1からしても良いのですが時間がかかるため既存モデルで画像の変換を行います。以下のように実行すると学習に使う動画像と既存のモデルがダウンロードされます。
python3 runGan.py 0
既存モデルの実行は以下のように行います。
python3 runGan.py 1
デフォルトではLR/calendar以下に変換したい画像を配置すれば高解像度化が行われます。このコードではLR/testに変更しています。
変換後の画像はresults/test以下に保存されます。
変換された画像
若干、画質が良くなっていることがわかります。この画像においては自然な変換が得られたと思います。逆に望遠レンズで撮ってデジタルズームで画像が汚くなってしまったというケースではあまり役に立たない感じでした。あくまでも細かいノイズが綺麗になるという感じです。
コードの変更
colabで実行するために一部コードを変更しました。大きく変更したコードはmain.pyです。それ以外は基本的にはtensorflow v1で動くようにライブラリの名前を変更したものが殆どなので省略します。
main.py
import numpy as np
import os, math, time, collections, numpy as np
''' TF_CPP_MIN_LOG_LEVEL
0 = all messages are logged (default behavior)
1 = INFO messages are not printed
2 = INFO and WARNING messages are not printed
3 = INFO, WARNING, and ERROR messages are not printed
Disable Logs for now '''
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
import random as rn
# fix all randomness, except for multi-treading or GPU process
os.environ['PYTHONHASHSEED'] = '0'
np.random.seed(42)
rn.seed(12345)
tf.set_random_seed(1234)
import tensorflow.contrib.slim as slim
import sys, shutil, subprocess
from lib.ops import *
from lib.dataloader import inference_data_loader, frvsr_gpu_data_loader
from lib.frvsr import generator_F, fnet
from lib.Teco import FRVSR, TecoGAN
Flags = tf.app.flags
Flags.DEFINE_integer('rand_seed', 1 , 'random seed' )
# Directories
Flags.DEFINE_string('input_dir_LR', None, 'The directory of the input resolution input data, for inference mode')
Flags.DEFINE_integer('input_dir_len', -1, 'length of the input for inference mode, -1 means all')
Flags.DEFINE_string('input_dir_HR', None, 'The directory of the input resolution input data, for inference mode')
Flags.DEFINE_string('mode', 'inference', 'train, or inference')
Flags.DEFINE_string('output_dir', None, 'The output directory of the checkpoint')
Flags.DEFINE_string('output_pre', '', 'The name of the subfolder for the images')
Flags.DEFINE_string('output_name', 'output', 'The pre name of the outputs')
Flags.DEFINE_string('output_ext', 'jpg', 'The format of the output when evaluating')
Flags.DEFINE_string('summary_dir', None, 'The dirctory to output the summary')
# Models
Flags.DEFINE_string('checkpoint', None, 'If provided, the weight will be restored from the provided checkpoint')
Flags.DEFINE_integer('num_resblock', 16, 'How many residual blocks are there in the generator')
# Models for training
Flags.DEFINE_boolean('pre_trained_model', False, 'If True, the weight of generator will be loaded as an initial point'
'If False, continue the training')
Flags.DEFINE_string('vgg_ckpt', None, 'path to checkpoint file for the vgg19')
# Machine resources
Flags.DEFINE_string('cudaID', '0', 'CUDA devices')
Flags.DEFINE_integer('queue_thread', 6, 'The threads of the queue (More threads can speedup the training process.')
Flags.DEFINE_integer('name_video_queue_capacity', 512, 'The capacity of the filename queue (suggest large to ensure'
'enough random shuffle.')
Flags.DEFINE_integer('video_queue_capacity', 256, 'The capacity of the video queue (suggest large to ensure'
'enough random shuffle')
Flags.DEFINE_integer('video_queue_batch', 2, 'shuffle_batch queue capacity')
# Training details
# The data preparing operation
Flags.DEFINE_integer('RNN_N', 10, 'The number of the rnn recurrent length')
Flags.DEFINE_integer('batch_size', 4, 'Batch size of the input batch')
Flags.DEFINE_boolean('flip', True, 'Whether random flip data augmentation is applied')
Flags.DEFINE_boolean('random_crop', True, 'Whether perform the random crop')
Flags.DEFINE_boolean('movingFirstFrame', True, 'Whether use constant moving first frame randomly.')
Flags.DEFINE_integer('crop_size', 32, 'The crop size of the training image')
# Training data settings
Flags.DEFINE_string('input_video_dir', '', 'The directory of the video input data, for training')
Flags.DEFINE_string('input_video_pre', 'scene', 'The pre of the directory of the video input data')
Flags.DEFINE_integer('str_dir', 1000, 'The starting index of the video directory')
Flags.DEFINE_integer('end_dir', 2000, 'The ending index of the video directory')
Flags.DEFINE_integer('end_dir_val', 2050, 'The ending index for validation of the video directory')
Flags.DEFINE_integer('max_frm', 119, 'The ending index of the video directory')
# The loss parameters
Flags.DEFINE_float('vgg_scaling', -0.002, 'The scaling factor for the VGG perceptual loss, disable with negative value')
Flags.DEFINE_float('warp_scaling', 1.0, 'The scaling factor for the warp')
Flags.DEFINE_boolean('pingpang', False, 'use bi-directional recurrent or not')
Flags.DEFINE_float('pp_scaling', 1.0, 'factor of pingpang term, only works when pingpang is True')
# Training parameters
Flags.DEFINE_float('EPS', 1e-12, 'The eps added to prevent nan')
Flags.DEFINE_float('learning_rate', 0.0001, 'The learning rate for the network')
Flags.DEFINE_integer('decay_step', 500000, 'The steps needed to decay the learning rate')
Flags.DEFINE_float('decay_rate', 0.5, 'The decay rate of each decay step')
Flags.DEFINE_boolean('stair', False, 'Whether perform staircase decay. True => decay in discrete interval.')
Flags.DEFINE_float('beta', 0.9, 'The beta1 parameter for the Adam optimizer')
Flags.DEFINE_float('adameps', 1e-8, 'The eps parameter for the Adam optimizer')
Flags.DEFINE_integer('max_epoch', None, 'The max epoch for the training')
Flags.DEFINE_integer('max_iter', 1000000, 'The max iteration of the training')
Flags.DEFINE_integer('display_freq', 20, 'The diplay frequency of the training process')
Flags.DEFINE_integer('summary_freq', 100, 'The frequency of writing summary')
Flags.DEFINE_integer('save_freq', 10000, 'The frequency of saving images')
# Dst parameters
Flags.DEFINE_float('ratio', 0.01, 'The ratio between content loss and adversarial loss')
Flags.DEFINE_boolean('Dt_mergeDs', True, 'Whether only use a merged Discriminator.')
Flags.DEFINE_float('Dt_ratio_0', 1.0, 'The starting ratio for the temporal adversarial loss')
Flags.DEFINE_float('Dt_ratio_add', 0.0, 'The increasing ratio for the temporal adversarial loss')
Flags.DEFINE_float('Dt_ratio_max', 1.0, 'The max ratio for the temporal adversarial loss')
Flags.DEFINE_float('Dbalance', 0.4, 'An adaptive balancing for Discriminators')
Flags.DEFINE_float('crop_dt', 0.75, 'factor of dt crop') # dt input size = crop_size*crop_dt
Flags.DEFINE_boolean('D_LAYERLOSS', True, 'Whether use layer loss from D')
FLAGS = Flags.FLAGS
# Set CUDA devices correctly if you use multiple gpu system
os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.cudaID
# Fix randomness
my_seed = FLAGS.rand_seed
rn.seed(my_seed)
np.random.seed(my_seed)
tf.set_random_seed(my_seed)
# Check the output_dir is given
if FLAGS.output_dir is None:
raise ValueError('The output directory is needed')
# Check the output directory to save the checkpoint
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
# Check the summary directory to save the event
if not os.path.exists(FLAGS.summary_dir):
os.mkdir(FLAGS.summary_dir)
# custom Logger to write Log to file
class Logger(object):
def __init__(self):
self.terminal = sys.stdout
self.log = open(FLAGS.summary_dir + "logfile.txt", "a")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.log.flush()
sys.stdout = Logger()
def printVariable(scope, key = tf.GraphKeys.MODEL_VARIABLES):
print("Scope %s:" % scope)
variables_names = [ [v.name, v.get_shape().as_list()] for v in tf.get_collection(key, scope=scope)]
total_sz = 0
for k in variables_names:
print ("Variable: " + k[0])
print ("Shape: " + str(k[1]))
total_sz += np.prod(k[1])
print("total size: %d" %total_sz)
def preexec(): # Don't forward signals.
os.setpgrp()
def testWhileTrain(FLAGS, testno = 0):
'''
this function is called during training, Hard-Coded!!
to try the "inference" mode when a new model is saved.
The code has to be updated from machine to machine...
depending on python, and your training settings
'''
desstr = os.path.join(FLAGS.output_dir, 'train/') # saving in the ./train/ directory
cmd1 = ["python3", "main.py", # never tested with python2...
"--output_dir", desstr,
"--summary_dir", desstr,
"--mode","inference",
"--num_resblock", "%d"%FLAGS.num_resblock,
"--checkpoint", os.path.join(FLAGS.output_dir, 'model-%d'%testno),
"--cudaID", FLAGS.cudaID]
# a folder for short test
cmd1 += ["--input_dir_LR", "./LR/calendar/", # update the testing sequence
"--output_pre", "", # saving in train folder directly
"--output_name", "%09d"%testno, # name
"--input_dir_len", "10",]
print('[testWhileTrain] step %d:'%testno)
print(' '.join(cmd1))
# ignore signals
return subprocess.Popen(cmd1, preexec_fn = preexec)
if False: # If you want to take a look of the configuration, True
print_configuration_op(FLAGS)
# the inference mode (just perform super resolution on the input image)
if FLAGS.mode == 'inference':
if FLAGS.checkpoint is None:
raise ValueError('The checkpoint file is needed to performing the test.')
# Declare the test data reader
inference_data = inference_data_loader(FLAGS)
for i, input_data in enumerate(inference_data.inputs):
input_shape = [1,] + list(input_data.shape)
output_shape = [1,input_shape[1]*4, input_shape[2]*4, 3]
oh = input_shape[1] - input_shape[1]//8 * 8
ow = input_shape[2] - input_shape[2]//8 * 8
paddings = tf.constant([[0,0], [0,oh], [0,ow], [0,0]])
print("input shape:", input_shape)
print("output shape:", output_shape)
# build the graph
inputs_raw = tf.placeholder(tf.float32, shape=input_shape, name='inputs_raw')
pre_inputs = tf.Variable(tf.zeros(input_shape), trainable=tf.AUTO_REUSE, name='pre_inputs')
pre_gen = tf.Variable(tf.zeros(output_shape), trainable=tf.AUTO_REUSE, name='pre_gen')
pre_warp = tf.Variable(tf.zeros(output_shape), trainable=tf.AUTO_REUSE, name='pre_warp')
transpose_pre = tf.space_to_depth(pre_warp, 4)
inputs_all = tf.concat( (inputs_raw, transpose_pre), axis = -1)
with tf.variable_scope('generator'):
gen_output = generator_F(inputs_all, 3, reuse=tf.AUTO_REUSE, FLAGS=FLAGS)
# Deprocess the images outputed from the model, and assign things for next frame
with tf.control_dependencies([ tf.assign(pre_inputs, inputs_raw)]):
outputs = tf.assign(pre_gen, deprocess(gen_output))
inputs_frames = tf.concat( (pre_inputs, inputs_raw), axis = -1)
with tf.variable_scope('fnet'):
gen_flow_lr = fnet( inputs_frames, reuse=tf.AUTO_REUSE)
gen_flow_lr = tf.pad(gen_flow_lr, paddings, "SYMMETRIC")
gen_flow = upscale_four(gen_flow_lr*4.0)
gen_flow.set_shape( output_shape[:-1]+[2] )
pre_warp_hi = tf.contrib.image.dense_image_warp(pre_gen, gen_flow)
before_ops = tf.assign(pre_warp, pre_warp_hi)
print('Finish building the network')
# In inference time, we only need to restore the weight of the generator
var_list = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope='generator')
var_list = var_list + tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope='fnet')
weight_initiallizer = tf.train.Saver(var_list)
# Define the initialization operation
init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
if (FLAGS.output_pre == ""):
image_dir = FLAGS.output_dir
else:
image_dir = os.path.join(FLAGS.output_dir, FLAGS.output_pre)
if not os.path.exists(image_dir):
os.makedirs(image_dir)
with tf.Session(config=config) as sess:
# Load the pretrained model
sess.run(init_op)
sess.run(local_init_op)
print('Loading weights from ckpt model')
weight_initiallizer.restore(sess, FLAGS.checkpoint)
if False: # If you want to take a look of the weights, True
printVariable('generator')
printVariable('fnet')
srtime = 0
# print('Frame evaluation starts!!')
input_im = np.array([input_data]).astype(np.float32)
feed_dict={inputs_raw: input_im}
t0 = time.time()
# if(i != 0):
# sess.run(before_ops, feed_dict=feed_dict)
output_frame = sess.run(outputs, feed_dict=feed_dict)
srtime += time.time()-t0
name, _ = os.path.splitext(os.path.basename(str(inference_data.paths_LR[i])))
filename = FLAGS.output_name+'_'+name
print('saving image %s' % filename)
out_path = os.path.join(image_dir, "%s.%s"%(filename,FLAGS.output_ext))
save_img(out_path, output_frame[0])
# if(i >= 5):
# name, _ = os.path.splitext(os.path.basename(str(inference_data.paths_LR[i])))
# filename = FLAGS.output_name+'_'+name
# print('saving image %s' % filename)
# out_path = os.path.join(image_dir, "%s.%s"%(filename,FLAGS.output_ext))
# save_img(out_path, output_frame[0])
# else:# First 5 is a hard-coded symmetric frame padding, ignored but time added!
# print("Warming up %d"%(5-i))
# The training mode
elif FLAGS.mode == 'train':
# hard coded save
filelist = ['main.py','lib/Teco.py','lib/frvsr.py','lib/dataloader.py','lib/ops.py']
for filename in filelist:
shutil.copyfile('./' + filename, FLAGS.summary_dir + filename.replace("/","_"))
useValidat = tf.placeholder_with_default( tf.constant(False, dtype=tf.bool), shape=() )
rdata = frvsr_gpu_data_loader(FLAGS, useValidat)
# Data = collections.namedtuple('Data', 'paths_HR, s_inputs, s_targets, image_count, steps_per_epoch')
print('tData count = %d, steps per epoch %d' % (rdata.image_count, rdata.steps_per_epoch))
if (FLAGS.ratio>0):
Net = TecoGAN( rdata.s_inputs, rdata.s_targets, FLAGS )
else:
Net = FRVSR( rdata.s_inputs, rdata.s_targets, FLAGS )
# Network = collections.namedtuple('Network', 'gen_output, train, learning_rate, update_list, '
# 'update_list_name, update_list_avg, image_summary')
# Add scalar summary
tf.summary.scalar('learning_rate', Net.learning_rate)
train_summary = []
for key, value in zip(Net.update_list_name, Net.update_list_avg):
# 'map_loss, scale_loss, FrameA_loss, FrameA_loss,...'
train_summary += [tf.summary.scalar(key, value)]
train_summary += Net.image_summary
merged = tf.summary.merge(train_summary)
validat_summary = [] # val data statistics is not added to average
uplen = len(Net.update_list)
for key, value in zip(Net.update_list_name[:uplen], Net.update_list):
# 'map_loss, scale_loss, FrameA_loss, FrameA_loss,...'
validat_summary += [tf.summary.scalar("val_" + key, value)]
val_merged = tf.summary.merge(validat_summary)
# Define the saver and weight initiallizer
saver = tf.train.Saver(max_to_keep=50)
# variable lists
all_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
tfflag = tf.GraphKeys.MODEL_VARIABLES #tf.GraphKeys.TRAINABLE_VARIABLES
if (FLAGS.checkpoint is not None) and (FLAGS.pre_trained_model is True):
model_var_list = tf.get_collection(tfflag, scope='generator') + tf.get_collection(tfflag, scope='fnet')
assign_ops = get_existing_from_ckpt(FLAGS.checkpoint, model_var_list, rest_zero=True, print_level=1)
print('Prepare to load %d weights from the pre-trained model for generator and fnet'%len(assign_ops))
if FLAGS.ratio>0:
model_var_list = tf.get_collection(tfflag, scope='tdiscriminator')
dis_list = get_existing_from_ckpt(FLAGS.checkpoint, model_var_list, print_level=0)
print('Prepare to load %d weights from the pre-trained model for discriminator'%len(dis_list))
assign_ops += dis_list
if FLAGS.vgg_scaling > 0.0: # VGG weights are not trainable
vgg_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='vgg_19')
vgg_restore = tf.train.Saver(vgg_var_list)
print('Finish building the network.')
# Start the session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# init_op = tf.initialize_all_variables() # MonitoredTrainingSession will initialize automatically
with tf.train.MonitoredTrainingSession(config=config, save_summaries_secs=None, save_checkpoint_secs=None) as sess:
train_writer = tf.summary.FileWriter(FLAGS.summary_dir, sess.graph)
printVariable('generator')
printVariable('fnet')
if FLAGS.ratio>0:
printVariable('tdiscriminator')
if FLAGS.vgg_scaling > 0.0:
printVariable('vgg_19', tf.GraphKeys.GLOBAL_VARIABLES)
vgg_restore.restore(sess, FLAGS.vgg_ckpt)
print('VGG19 restored successfully!!')
if (FLAGS.checkpoint is not None):
if (FLAGS.pre_trained_model is False):
print('Loading everything from the checkpoint to continue the training...')
saver.restore(sess, FLAGS.checkpoint)
# this will restore everything, including ADAM training parameters and global_step
else:
print('Loading weights from the pre-trained model to start a new training...')
sess.run(assign_ops) # only restore existing model weights
print('The first run takes longer time for training data loading...')
# get the session for save
_sess = sess
while type(_sess).__name__ != 'Session':
# pylint: disable=W0212
_sess = _sess._sess
save_sess = _sess
if 1:
print('Save initial checkpoint, before any training')
init_run_no = sess.run(Net.global_step)
saver.save(save_sess, os.path.join(FLAGS.output_dir, 'model'), global_step=init_run_no)
testWhileTrain(FLAGS, init_run_no) # make sure that testWhileTrain works
# Performing the training
frame_len = (FLAGS.RNN_N*2-1) if FLAGS.pingpang else FLAGS.RNN_N
max_iter, step, start = FLAGS.max_iter, 0, time.time()
if max_iter is None:
if FLAGS.max_epoch is None:
raise ValueError('one of max_epoch or max_iter should be provided')
else:
max_iter = FLAGS.max_epoch * rdata.steps_per_epoch
try:
for step in range(max_iter):
run_step = sess.run(Net.global_step) + 1
fetches = { "train": Net.train, "learning_rate": Net.learning_rate }
if (run_step % FLAGS.display_freq) == 0:
for key, value in zip(Net.update_list_name, Net.update_list_avg):
fetches[str(key)] = value
if (run_step % FLAGS.summary_freq) == 0:
fetches["summary"] = merged
results = sess.run(fetches)
if(step == 0):
print('Optimization starts!!!(Ctrl+C to stop, will try saving the last model...)')
if (run_step % FLAGS.summary_freq) == 0:
print('Run and Recording summary!!')
train_writer.add_summary(results['summary'], run_step)
val_fetches = {}
for name, value in zip(Net.update_list_name[:uplen], Net.update_list):
val_fetches['val_' + name] = value
val_fetches['summary'] = val_merged
val_results = sess.run(val_fetches, feed_dict={useValidat: True})
train_writer.add_summary(val_results['summary'], run_step)
print('-----------Validation data scalars-----------')
for name in Net.update_list_name[:uplen]:
print('val_' + name, val_results['val_' + name])
if (run_step % FLAGS.display_freq) == 0:
train_epoch = math.ceil(run_step / rdata.steps_per_epoch)
train_step = (run_step - 1) % rdata.steps_per_epoch + 1
rate = (step + 1) * FLAGS.batch_size / (time.time() - start)
remaining = (max_iter - step) * FLAGS.batch_size / rate
print("progress epoch %d step %d image/sec %0.1fx%02d remaining %dh%dm" %
(train_epoch, train_step, rate, frame_len,
remaining // 3600, (remaining%3600) // 60))
print("global_step", run_step)
print("learning_rate", results['learning_rate'])
for name in Net.update_list_name:
print(name, results[name])
if (run_step % FLAGS.save_freq) == 0:
print('Save the checkpoint')
saver.save(save_sess, os.path.join(FLAGS.output_dir, 'model'), global_step=int(run_step))
testWhileTrain(FLAGS, run_step)
except KeyboardInterrupt:
if step > 1:
print('main.py: KeyboardInterrupt->saving the checkpoint')
saver.save(save_sess, os.path.join(FLAGS.output_dir, 'model'), global_step=int(run_step))
testWhileTrain(FLAGS, run_step).communicate()
print('main.py: quit')
exit()
print('Optimization done!!!!!!!!!!!!')
宣伝
Whispon合同会社では受託開発を請け負っております。
AIを自社のサービスに取り入れてみたい。AIを使ったサービス開発を行いたいなどがあればご相談お待ちしております。
Whispon合同会社
メールアドレス
oshita-n@whispon.com
電話番号
011-350-0092