[Python] scipyを使用した線形・非線形補完ツール
Pythonのscipyを使用して補完ツールを作ってみたので共有します。
1.コード
import csv
import numpy as np
import datetime
from scipy import signal, interpolate
from matplotlib import pylab as plt
#現在の日時を取得する
now = datetime.datetime.now()
#読み込むcsv。ソースと同じフォルダに置く。
fi = 'input/input.csv'
#x間隔の読み込み
with open('input/prm.txt') as f:
prm= f.readlines()
deltax=prm[0]
deltax = deltax[:10]
deltax = float(deltax)
print('x出力間隔=',end='')
print(deltax,end=' ')
#補完法の選択
with open('input/prm.txt') as f:
prm= f.readlines()
flg=prm[1]
flg = flg[:2]
flg = int(flg)
1
#出力するcsvの名前を定義する。末尾にタイムスタンプ入れる。
fo = 'output/test_{0:%Y%m%d%H%M%S}.csv'.format(now)
#csvを読み込んでdata_arrayに格納
with open(fi, mode='r', newline='') as f_in:
reader = csv.reader(f_in)
data_array = [row for row in reader]
#デバッグ用:data_arrayの中身を確認する
#print(data_array)
#print(type(data_array))
#data_arrayを1次元配列(リスト)に分けるための配列を定義する
plot_x = []
plot_y = []
#data_arrayの要素を一個ずつ取り出して、各リストに格納する
for i in data_array:
plot_x.append(float(i[0]))#要素は文字列として扱われているのでfloatに変換する
plot_y.append(float(i[1]))
#linspaceにx軸データ範囲が必要なのでplot_xの最小値/最大値を調べる
xs = min(plot_x)
xm = max(plot_x)
#print("x軸の最小値は",xs)
#print("x軸の最大値は",xm)
#lilnspaceでデータ補間数を決める
tt = np.arange(xs, xm, deltax)
#print(tt)
#線形補間 if flg==1:
print('線形補間')
f = interpolate.interp1d(plot_x, plot_y, kind="linear")
y = f(tt)
#2次スプライン補間 if flg==2:
print('2次スプライン補間')
f = interpolate.interp1d(plot_x, plot_y, kind="quadratic")
y = f(tt)
#3次スプライン補間 if flg==3:
print('3次スプライン補間')
f = interpolate.interp1d(plot_x, plot_y, kind="cubic")
y = f(tt)
#最近傍点補間 if flg==4:
print('最近傍点補間')
f = interpolate.interp1d(plot_x, plot_y, kind="nearest")
y = f(tt)
#ラグランジュ補間 if flg==5:
print('ラグランジュ補間')
f = interpolate.lagrange(plot_x, plot_y)
y = f(tt)
#print(f(tt))
#補間後のx,yをnp.arrayで結合し、エクセルで使いやすいように転置する。
yArry = np.array([tt, y])
yArry_t = yArry.T
#print(yArry)
#print(type(yArry))
#print(yArry_t)
#print(type(yArry_t))
#csvに出力する。
with open(fo, mode='a') as f_out:
csvWriter = csv.writer(f_out, lineterminator = '\n')
csvWriter.writerows(yArry_t)
#データプロット plt.plot(plot_x, plot_y,"r")
plt.plot(tt, y)
plt.show()
2.使い方
ソースコードと同じディレクトリにinputフォルダとoutputフォルダを用意します。inputフォルダには補完したいxyデータ(input.csv)とパラメータファイル(prm.txt)を入れます。サンプルに以下のデータをご使用ください。
パラメータファイルでは出力したい補完後のxの間隔と補完法を選択できます。xが小さすぎるとうまくいかないことがあります。準備が出来てpythonを実行すると以下のような出力がされます。
matplotlibも使用しているので画像の保存等も行えます。上は線形補間ですがパラメータファイルで補完方法を変えていろいろ試していただければと思います。サポートいただければ幸いです。
サポートいただけますとやる気が2.8倍くらい上がって4.2倍面白いものが書けます。