#include <benchmark/benchmark.h>
#include <cassert>
#include <fstream>
#include <numeric>
#include <sstream>
#include <vector>


struct TestData {
    std::vector<int> Lp;
    std::vector<int> Li;
    std::vector<double> Lx;
};

TestData getTestDataFromMtxFile(const std::string& filepath) {
    // world's worst implementation of a Matrix Market file reader, sorry
    // only the lower triangle is read
    std::ifstream file(filepath);
    std::string line;

    int current_col = -1;
    TestData data;
    int num_rows;
    int num_cols;
    int nnz;
    bool first_line = true;

    while (std::getline(file, line)) {
        if (line[0] == '%') {
            continue;
        }

        if (first_line) {
            std::stringstream stream(line);
            stream >> num_rows;
            stream >> num_cols;
            stream >> nnz;
            assert(num_rows == num_cols);
            data.Lp.reserve(num_cols + 1);
            data.Li.reserve(nnz);
            data.Lx.reserve(nnz);
            first_line = false;
            continue;
        }

        int row;
        int col;
        double val;
        {
            std::stringstream stream(line);
            stream >> row;
            stream >> col;
            stream >> val;
        }
        row -= 1;
        col -= 1;

        while (col > current_col) {
            data.Lp.push_back(data.Li.size());
            current_col += 1;
        }
        if (row > current_col) {
            data.Li.push_back(row);
            data.Lx.push_back(val);
        }
    }
    while (current_col < num_cols) {
        data.Lp.push_back(data.Li.size());
        current_col += 1;
    }
    assert(data.Lp.size() == num_cols + 1);
    assert(data.Li.size() == nnz);
    assert(data.Lx.size() == nnz);

    return data;
}

struct BacksolveBaseline {
    void operator()(const int* __restrict__ Lp,
                    const int* __restrict__ Li,
                    const double* __restrict__ Lx,
                    const int n,
                    double* __restrict__ x) {
        for (int i=n-1; i>=0; --i) {
            for (int j=Lp[i]; j<Lp[i+1]; ++j) {
                x[i] -= Lx[j] * x[Li[j]];
            }
        }
    }
};

struct BacksolveOptimized {
    void operator()(const int* __restrict__ Lp,
                    const int* __restrict__ Li,
                    const double* __restrict__ Lx,
                    const int n,
                    double* __restrict__ x) {
        for (int i=n-1; i>=0; --i) {
            const int col_begin = Lp[i];
            const int col_end = Lp[i+1];
            const bool is_col_nnz_odd = (col_end - col_begin) & 1;
            double xi_temp = x[i];
            int j = col_end - 1;
            if (is_col_nnz_odd) {
                xi_temp -= Lx[j] * x[Li[j]];
                --j;
            }
            for (; j >= col_begin; j -= 2) {
                xi_temp -= Lx[j - 0] * x[Li[j - 0]] +
                           Lx[j - 1] * x[Li[j - 1]];
            }
            x[i] = xi_temp;
        }
    }
};

void fillRhs(std::vector<double>& x) {
    // just simulate some kind of semi complex calculation for the right hand side
    std::iota(x.begin(), x.end(), -x.size() / 2.0);
}

void benchFillRhs(benchmark::State& state) {
    const auto test_data = getTestDataFromMtxFile("../reorientation_6.mtx");
    const auto dimension = test_data.Lp.size() - 1;
    std::vector<double> x(dimension);

    for (auto _ : state) {
        fillRhs(x);
        benchmark::DoNotOptimize(x.data());
        benchmark::ClobberMemory();
    }
}

template <class BacksolveFunction>
void benchBacksolve(benchmark::State& state) {
    const auto test_data = getTestDataFromMtxFile("../reorientation_6.mtx");
    const auto dimension = test_data.Lp.size() - 1;
    std::vector<double> x(dimension);

    BacksolveFunction backsolveFunction;
    for (auto _ : state) {
        fillRhs(x);
        backsolveFunction(test_data.Lp.data(), test_data.Li.data(),
                          test_data.Lx.data(), dimension, x.data());
        benchmark::DoNotOptimize(x.data());
        benchmark::ClobberMemory();
    }
}

void benchBacksolveBaseline(benchmark::State& state) {
    benchBacksolve<BacksolveBaseline>(state);
}

void benchBacksolveOptimized(benchmark::State& state) {
    benchBacksolve<BacksolveOptimized>(state);
}

BENCHMARK(benchFillRhs);
BENCHMARK(benchBacksolveBaseline);
BENCHMARK(benchBacksolveOptimized);
BENCHMARK_MAIN();