見出し画像

7割Copilot君に書いてもらったリサンプリングプログラム(仮)

経緯を残しているだけなので、(完)の記事が最新です。



前回のおさらい

環境作ったけど、リサンプラーが終わってた。
リアルタイムで88.2KHz/sinc_len12000でも性能が足らないし、後段のフィルターの品質いまいちだし、事前にアップサンプリングしておくしかない。
とりあえず352.8KHz/384KHz/sinc_len80000オーバーを実現したい。


インストールしとくもの

sudo apt update
sudo apt install libsndfile1 libsndfile1-dev

resample.cppで保存

#include <iostream>
#include <vector>
#include <cmath>
#include <cstring>
#include <sndfile.h>
#include <omp.h>
#include <atomic>
#include <mutex>
#include <thread>
#include <iomanip>
#include <chrono>
#include <immintrin.h>

double kaiserWindow(int n, int N, double beta) {
    return std::cyl_bessel_i(0, beta * std::sqrt(1 - std::pow(2.0 * n / (N - 1) - 1, 2))) / std::cyl_bessel_i(0, beta);
}

std::vector<double> precomputedKaiserWindow(int sincLen, double beta) {
    std::vector<double> kaiserWindowTable(2 * sincLen + 1);

    double collectWindowGain = kaiserWindow(sincLen, 2 * sincLen + 1, beta);

    #pragma omp parallel for
    for (int k = -sincLen; k <= sincLen - 3; k += 4) {
        __m256d k_vec = _mm256_set_pd(k + 3, k + 2, k + 1, k);
        __m256d sincLen_vec = _mm256_set1_pd(sincLen);
        __m256d index_vec = _mm256_add_pd(k_vec, sincLen_vec);

        double kaiser_vals[4];
        for (int i = 0; i < 4; ++i) {
            kaiser_vals[i] = kaiserWindow(static_cast<int>(index_vec[i]), 2 * sincLen + 1, beta) / collectWindowGain;
        }

        _mm256_storeu_pd(&kaiserWindowTable[k + sincLen], _mm256_loadu_pd(kaiser_vals));
    }

    for (int k = sincLen - ((2 * sincLen + 1) % 4 - 1); k <= sincLen; ++k) {
        kaiserWindowTable[k + sincLen] = kaiserWindow(k + sincLen, 2 * sincLen + 1, beta) / collectWindowGain;
    }

    kaiserWindowTable[sincLen] = 1.0;

    return kaiserWindowTable;
}

double sinc(double x) {
    if (x == 0.0) {
        return 1.0;
    }
    return sin(M_PI * x) / (M_PI * x);
}


std::vector<std::vector<double>> precomputedSinc(int sincLen, int oversamplingFactor , std::vector<double> kaiserWindowTable) {

    std::vector<std::vector<double>> sincTable(oversamplingFactor, std::vector<double>(2 * sincLen +1, 0.0));

    #pragma omp parallel for
    for (int i = 1; i < oversamplingFactor; ++i) {
        double resultPos = static_cast<double>(i) / static_cast<double>(oversamplingFactor);
        
        __m256d resultPosVec = _mm256_set1_pd(resultPos);
        
        for (int k = -sincLen; k <= sincLen; k += 4) {
            __m256d targetPosVec = _mm256_set_pd(static_cast<int>(resultPos) + k + 3, static_cast<int>(resultPos) + k + 2, static_cast<int>(resultPos) + k + 1, static_cast<int>(resultPos) + k);
            __m256d sampleGapVec = _mm256_sub_pd(resultPosVec, targetPosVec);

            __m256d sincValVec = _mm256_set_pd(sinc(sampleGapVec[3]), sinc(sampleGapVec[2]), sinc(sampleGapVec[1]), sinc(sampleGapVec[0]));

            __m256d kaiserWindowVec = _mm256_loadu_pd(&kaiserWindowTable[k + sincLen]);
            __m256d resultVec = _mm256_mul_pd(sincValVec, kaiserWindowVec);

            _mm256_storeu_pd(&sincTable[i][k + sincLen], resultVec);
        }


        for (int k = sincLen - ((2 * sincLen +1) % 4 - 1); k <= sincLen; ++k) {
            int targetPos = static_cast<int>(resultPos) + k;
            double sampleGap = resultPos - targetPos;
            sincTable[i][k + sincLen] = sinc(sampleGap) * kaiserWindowTable[k + sincLen];
        }


    }

    return sincTable;
}

double reduce_max(__m256d vec) {
    double result[4];
    _mm256_storeu_pd(result, vec);
    double maxVal = result[0];
    for (int i = 1; i < 4; ++i) {
        if (result[i] > maxVal) {
            maxVal = result[i];
        }
    }
    return maxVal;
}

void resampleAudio(const char *inputFile, const char *outputFile, int oversamplingFactor, int sincLen, int numThreads) {
    auto startTime = std::chrono::high_resolution_clock::now();

    std::cout << "データ読み込み中..." << std::endl;

    SF_INFO sfinfo;
    SNDFILE *infile = sf_open(inputFile, SFM_READ, &sfinfo);
    if (!infile) {
        std::cerr << "入力ファイルを開けませんでした: " << inputFile << std::endl;
        return;
    }

    std::vector<double> inputData(sfinfo.frames * sfinfo.channels);
    sf_read_double(infile, inputData.data(), inputData.size());
    sf_close(infile);

    std::cout << "リサンプリング中..." << std::endl;

    int outputFrames = sfinfo.frames * oversamplingFactor;
    std::vector<double> outputData(outputFrames * sfinfo.channels);

    // 進捗ステータス
    std::atomic<int> progressCounter(0);
    std::atomic<bool> done(false);
    std::mutex progressMutex;
    std::thread progressThread([&]() {
        while (!done) {
            std::this_thread::sleep_for(std::chrono::milliseconds(100));
            {
                std::lock_guard<std::mutex> lock(progressMutex);
                auto currentTime = std::chrono::high_resolution_clock::now();
                std::chrono::duration<double> elapsed = currentTime - startTime;
                double progress = 100.0 * progressCounter / outputFrames;

                double remainingTime;
                if (progressCounter > 0) remainingTime = (elapsed.count() / progressCounter) * (outputFrames - progressCounter);
                else                     remainingTime = std::numeric_limits<double>::infinity();

                std::string timeUnit;
                if (remainingTime >= 86400) {
                    remainingTime /= 86400;
                    timeUnit = "日";
                } else if (remainingTime >= 3600) {
                    remainingTime /= 3600;
                    timeUnit = "時間";
                } else if (remainingTime >= 60) {
                    remainingTime /= 60;
                    timeUnit = "分";
                } else {
                    timeUnit = "秒";
                }

                std::cout << "\rリサンプリング進行中: " << std::fixed << std::setprecision(2) << progress << " %"
                          << " 残り時間: " << std::setprecision(2) << remainingTime << " " << timeUnit << std::flush;
            }
        }
        {
            std::lock_guard<std::mutex> lock(progressMutex);
            std::cout << "\rリサンプリング進行中: 100 %" << std::endl;
        }
    });


    std::vector<double> kaiserWindowTable = precomputedKaiserWindow(sincLen, 23);
    std::vector<std::vector<double>> sincTable = precomputedSinc(sincLen,oversamplingFactor, kaiserWindowTable);

    #pragma omp parallel for
    for (int t = 0; t < outputFrames; ++t) {
        double resultPos = static_cast<double>(t) / static_cast<double>(oversamplingFactor);
        int i = t % oversamplingFactor;
        for (int c = 0; c < sfinfo.channels; ++c) {
            if (i == 0) {
                outputData[t * sfinfo.channels + c] = inputData[static_cast<int>(resultPos) * sfinfo.channels + c];
            } else {
                double sum = 0.0;
                __m256d sumVec = _mm256_setzero_pd();

                for (int k = -sincLen; k <= sincLen; k += 4) {
                    int targetPos1 = static_cast<int>(resultPos) + k;
                    int targetPos2 = static_cast<int>(resultPos) + k + 1;
                    int targetPos3 = static_cast<int>(resultPos) + k + 2;
                    int targetPos4 = static_cast<int>(resultPos) + k + 3;

                    if (targetPos1 >= 0 && targetPos1 < sfinfo.frames &&
                        targetPos2 >= 0 && targetPos2 < sfinfo.frames &&
                        targetPos3 >= 0 && targetPos3 < sfinfo.frames &&
                        targetPos4 >= 0 && targetPos4 < sfinfo.frames) {
                        __m256d inputVec = _mm256_set_pd(
                            inputData[targetPos4 * sfinfo.channels + c],
                            inputData[targetPos3 * sfinfo.channels + c],
                            inputData[targetPos2 * sfinfo.channels + c],
                            inputData[targetPos1 * sfinfo.channels + c]
                        );

                        __m256d sincVec = _mm256_set_pd(
                            sincTable[i][k + 3 + sincLen],
                            sincTable[i][k + 2 + sincLen],
                            sincTable[i][k + 1 + sincLen],
                            sincTable[i][k + sincLen]
                        );

                        __m256d prodVec = _mm256_mul_pd(inputVec, sincVec);
                        sumVec = _mm256_add_pd(sumVec, prodVec);
                    }
                }

                for (int k = sincLen - ((2 * sincLen + 1) % 4 - 1); k <= sincLen; ++k) {
                    int targetPos = static_cast<int>(resultPos) + k;
                    if (targetPos >= 0 && targetPos < sfinfo.frames) {
                        sum += inputData[targetPos * sfinfo.channels + c] * sincTable[i][k + sincLen];
                    }
                }

                double sumArray[4];
                _mm256_storeu_pd(sumArray, sumVec);
                sum += sumArray[0] + sumArray[1] + sumArray[2] + sumArray[3];

                outputData[t * sfinfo.channels + c] = sum;
            }
        }
        // 進捗を更新
        {
            std::lock_guard<std::mutex> lock(progressMutex);
            ++progressCounter;
        }
    }


    done = true;
    progressThread.join();

    std::cout << "正規化中..." << std::endl;

    double maxAmplitude = 0.0;


    #pragma omp parallel
    {
        double localMax = 0.0;

        #pragma omp for
        for (size_t i = 0; i < outputData.size(); i += 4) {
            __m256d samples = _mm256_loadu_pd(&outputData[i]);
            __m256d abs_samples = _mm256_andnot_pd(_mm256_set1_pd(-0.0), samples); // fabsのSIMDバージョン
            double maxVal = reduce_max(abs_samples);
            localMax = std::max(localMax, maxVal);
        }

        #pragma omp critical
        {
            maxAmplitude = std::max(maxAmplitude, localMax);
        }
    }

    size_t remainderStart = (outputData.size() / 4) * 4;
    for (size_t i = remainderStart; i < outputData.size(); ++i) {
        double sample = std::abs(outputData[i]);
        maxAmplitude = std::max(maxAmplitude, sample);
    }

    if (maxAmplitude > 0.0) {
        #pragma omp parallel for
        for (size_t i = 0; i < outputData.size(); i += 4) {
            __m256d samples = _mm256_loadu_pd(&outputData[i]);
            samples = _mm256_div_pd(samples, _mm256_set1_pd(maxAmplitude));
            _mm256_storeu_pd(&outputData[i], samples);
        }

        // 4で割り切れない残りの部分を処理
        for (size_t i = remainderStart; i < outputData.size(); ++i) {
            outputData[i] /= maxAmplitude;
        }
    }


    std::cout << "データ書き込み中..." << std::endl;
    SF_INFO outsfinfo;
    outsfinfo.samplerate = sfinfo.samplerate * oversamplingFactor;
    outsfinfo.channels = sfinfo.channels;
    // outsfinfo.format = sfinfo.format;
    outsfinfo.format = SF_FORMAT_WAV | SF_FORMAT_PCM_32;

    SNDFILE *outfile = sf_open(outputFile, SFM_WRITE, &outsfinfo);
    if (!outfile) {
        std::cerr << "出力ファイルを開けませんでした: " << outputFile << std::endl;
        return;
    }

    sf_write_double(outfile, outputData.data(), outputData.size());
    sf_close(outfile);

    std::cout << "リサンプリング完了!" << std::endl;
}

int main(int argc, char *argv[]) {
    if (argc < 3) {
        std::cerr << "使用方法: " << argv[0] << " input.wav output.wav -o <オーバーサンプリング係数> -l <sinc_len> -t <並列処理数>" << std::endl;
        return 1;
    }

    const char *inputFile = argv[1];
    const char *outputFile = argv[2];
    int oversamplingFactor = 2;
    int sincLen = 128;
    int numThreads = 1;

    for (int i = 3; i < argc; i++) {
        if (std::strcmp(argv[i], "-o") == 0) {
            oversamplingFactor = std::atoi(argv[++i]);
        } else if (std::strcmp(argv[i], "-l") == 0) {
            sincLen = std::atoi(argv[++i]);
        } else if (std::strcmp(argv[i], "-t") == 0) {
            numThreads = std::atoi(argv[++i]);
        }
    }

    omp_set_num_threads(numThreads);
    resampleAudio(inputFile, outputFile, oversamplingFactor, sincLen, numThreads);

    return 0;
}

コンパイル方法

resample.cppを置いてあるディレクトリでターミナルを起動
もしくはcdで飛ぶ

g++ -O3 -march-native -funroll-loops -fopenmp -lsndfile -o resample resample.cpp

実行

変換したいファイルがinput.wavで、出力したいファイルがoutput.wavのとき

./resample input.wav output.wav -l 1024 -o 2 -t 16

-l  sinc_lenを設定(初期値128)
-o  オーバーサンプリング倍率(初期値2)
-t  マルチスレッド数(初期値1)

メモ

AVX2に対応してるCPUなら動くと思う。
アップサンプリング専用でダウンサンプリングはできない。
アンチエイリアスフィルタは載ってない。 必要ないらしい
動作も遅い。
 少し早くなった

Debianで動作確認。

カイザー窓のα=5を使ってる 実質βでした。 β = 23にしました
sinc_len=136608で実行すればリンギングとさよならできるはず(多分)

オーディオファイルはただの配列だった。

アップサンプリングされた音源とNOS DACが流行していないのが理解できなくなった。

雑にオーディオやるならエイリアスが残ってる音源は88.2KHz~96KHzのリアルタイム整数倍オーバーサンプリングがCPUのリアルタイムな負荷としてバランス良さそう?
DSPを通す関係でアップサンプリングしてるだけなので、再生系の歪みを取らなくていい人は基本的にリサンプリングを無効にしてオリジナルのサンプリングレートでビット深度だけ高い環境で再生するのがベスト。
そもそもDSPを通して再生系の歪をある程度抑えないと、他の歪みに埋もれて違いが聞き分けられない可能性高い。

CUDA検討。 別記事にCUDA対応バージョン


いいなと思ったら応援しよう!