// standard
#include <chrono>
#include <cstdint>
#include <cstdlib>
#include <exception>
#include <iostream>
#include <map>
#include <stdexcept>
#include <string>
#include <vector>

// pybind11
#include <pybind11/chrono.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

// NVML
#include <nvml.h>

using time_point = std::chrono::_V2::system_clock::time_point;

#ifndef NVML_DEVICE_NAME_V2_BUFFER_SIZE
#define NVML_DEVICE_NAME_V2_BUFFER_SIZE 96
#endif

#ifndef NVML_CUDA_DRIVER_VERSION_MAJOR
#define NVML_CUDA_DRIVER_VERSION_MAJOR(v) ((v) / 1000)
#endif

#ifndef NVML_CUDA_DRIVER_VERSION_MINOR
#define NVML_CUDA_DRIVER_VERSION_MINOR(v) (((v) % 1000) / 10)
#endif

#ifndef NVML_SYSTEM_PROCESS_NAME_BUFFER_SIZE
#define NVML_SYSTEM_PROCESS_NAME_BUFFER_SIZE 120
#endif

struct Header
{
    int32_t cuda_version;
    std::string driver_version;
    std::string nvml_version;
    time_point timestamp;
};

struct Clock
{
    uint32_t graphics;
    uint32_t sm;
    uint32_t memory;
    uint32_t video;
};

struct Power
{
    uint32_t usage;
    uint32_t limit;
    unsigned long long energy;
};

struct Temperature
{
    uint32_t board;
    uint32_t thresh_shutdown;
    uint32_t thresh_slowdown;
};

struct Process
{
    Process(std::string name, uint32_t pid, ssize_t memory) : name(name), pid(pid), memory(memory)
    {
    }
    std::string name;
    uint32_t pid;
    ssize_t memory;
};

struct Stat
{
    std::string gpu_name;
    nvmlUtilization_t utilization;
    nvmlMemory_t memory;
    Clock clock;
    Power power;
    Temperature temperature;
    uint32_t fan_speed;
    std::vector<Process> compute_procs;
    std::vector<Process> graphics_procs;
};

struct Stats
{
    Header header;
    std::vector<Stat> stats;
};

class NvmlWrapper
{
  public:
    NvmlWrapper(ssize_t gpu_id)
    {
        /* Initialize NVML */
        throwException(nvmlInitWithFlags(NVML_INIT_FLAG_NO_GPUS), "Failed to initialize nvml");

        /* Get number of gpus */
        throwException(nvmlDeviceGetCount(&num_gpus_), "Failed to get number of gpus");

        if ((-1 < gpu_id) && (gpu_id < num_gpus_))
        {
            gpu_id_ = gpu_id;
            stats_.stats.resize(1);
        }
        else
        {
            gpu_id_ = -1;
            stats_.stats.resize(num_gpus_);
        }

        query(true);
    }

    ~NvmlWrapper()
    {
        /* shutdown NVML */
        throwException(nvmlShutdown(), "Failed to shutdown nvml");
    }

    void query(const bool init = false)
    {
        if (init)
        {
            /* Get cuda version */
            throwException(nvmlSystemGetCudaDriverVersion(&stats_.header.cuda_version), "Failed to get cuda version");

            /* Get driver version */
            char driver_version[NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE];
            throwException(nvmlSystemGetDriverVersion(driver_version, NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE),
                           "Failed to get cuda version");
            stats_.header.driver_version = std::string(driver_version);

            /* Get nvml version */
            char nvml_version[NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE];
            throwException(nvmlSystemGetNVMLVersion(nvml_version, NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE),
                           "Failed to get nvml version");
            stats_.header.nvml_version = std::string(nvml_version);
        }

        stats_.header.timestamp = std::chrono::system_clock::now();

        stats_.stats.clear();

        if (gpu_id_ >= 0)
        {
            stats_.stats.emplace_back(query_single(gpu_id_));
        }
        else
        {
            for (uint32_t i = 0; i < num_gpus_; i++)
            {
                stats_.stats.emplace_back(query_single(i));
            }
        }
    }

    std::string getCudaVersion()
    {
        return std::to_string(NVML_CUDA_DRIVER_VERSION_MAJOR(stats_.header.cuda_version)) + "." +
               std::to_string(NVML_CUDA_DRIVER_VERSION_MINOR(stats_.header.cuda_version));
    }

    std::string getDriverVersion()
    {
        return std::string(stats_.header.driver_version);
    }

    std::string getNvmlVersion()
    {
        return std::string(stats_.header.nvml_version);
    }

    time_point getTimeStamp()
    {
        return stats_.header.timestamp;
    }

    std::vector<std::string> getGpuNames()
    {
        std::vector<std::string> gpu_names;

        for (const auto &stat : stats_.stats)
        {
            gpu_names.emplace_back(stat.gpu_name);
        }

        return gpu_names;
    }

    std::vector<std::map<std::string, uint32_t>> getUtilizations()
    {
        std::vector<std::map<std::string, uint32_t>> utilizations;

        for (const auto &stat : stats_.stats)
        {
            std::map<std::string, uint32_t> utilization{{"gpu", stat.utilization.gpu},
                                                        {"memory", stat.utilization.memory}};
            utilizations.emplace_back(utilization);
        }

        return utilizations;
    }

    std::vector<std::map<std::string, uint64_t>> getMemories()
    {
        std::vector<std::map<std::string, uint64_t>> memories;

        for (const auto &stat : stats_.stats)
        {
            std::map<std::string, uint64_t> memory{
                {"free", stat.memory.free},
                {"used", stat.memory.used},
                {"total", stat.memory.total},
            };

            memories.emplace_back(memory);
        }

        return memories;
    }

    std::vector<std::map<std::string, uint32_t>> getClocks()
    {
        std::vector<std::map<std::string, uint32_t>> clocks;

        for (const auto &stat : stats_.stats)
        {
            std::map<std::string, uint32_t> clock{
                {"graphics", stat.clock.graphics},
                {"memory", stat.clock.memory},
                {"sm", stat.clock.sm},
                {"video", stat.clock.video},
            };

            clocks.emplace_back(clock);
        }

        return clocks;
    }

    std::vector<std::map<std::string, double>> getPowers()
    {
        std::vector<std::map<std::string, double>> powers;

        for (const auto &stat : stats_.stats)
        {
            std::map<std::string, double> power{
                {"limit", stat.power.limit * 1e-03},
                {"usage", stat.power.usage * 1e-03},
                {"energy", stat.power.energy * 1e-03},
            };

            powers.emplace_back(power);
        }

        return powers;
    }

    std::vector<std::map<std::string, uint32_t>> getTemperatures()
    {
        std::vector<std::map<std::string, uint32_t>> temperatures;

        for (const auto &stat : stats_.stats)
        {
            std::map<std::string, uint32_t> temperature{
                {"board", stat.temperature.board},
                {"thresh_shutdown", stat.temperature.thresh_shutdown},
                {"thresh_slowdown", stat.temperature.thresh_slowdown},
            };

            temperatures.emplace_back(temperature);
        }

        return temperatures;
    }

    std::vector<std::map<std::string, uint32_t>> getFanSpeeds()
    {
        std::vector<std::map<std::string, uint32_t>> fan_speeds;

        for (const auto &stat : stats_.stats)
        {
            std::map<std::string, uint32_t> fan_speed{
                {"speed", stat.fan_speed},
            };

            fan_speeds.emplace_back(fan_speed);
        }

        return fan_speeds;
    }

    std::vector<std::vector<std::tuple<std::string, uint32_t, ssize_t>>> getComputeProcs()
    {
        std::vector<std::vector<std::tuple<std::string, uint32_t, ssize_t>>> all_procs;

        for (const auto &stat : stats_.stats)
        {
            std::vector<std::tuple<std::string, uint32_t, ssize_t>> procs;
            for (const auto &proc : stat.compute_procs)
            {
                std::tuple<std::string, uint32_t, ssize_t> proc_tuple{proc.name, proc.pid, proc.memory};
                procs.emplace_back(proc_tuple);
            }

            all_procs.emplace_back(procs);
        }

        return all_procs;
    }

    std::vector<std::vector<std::tuple<std::string, uint32_t, ssize_t>>> getGraphicsProcs()
    {
        std::vector<std::vector<std::tuple<std::string, uint32_t, ssize_t>>> all_procs;

        for (const auto &stat : stats_.stats)
        {
            std::vector<std::tuple<std::string, uint32_t, ssize_t>> procs;
            for (const auto &proc : stat.graphics_procs)
            {
                std::tuple<std::string, uint32_t, ssize_t> proc_tuple{proc.name, proc.pid, proc.memory};
                procs.emplace_back(proc_tuple);
            }

            all_procs.emplace_back(procs);
        }

        return all_procs;
    }

  private:
    ssize_t gpu_id_;
    uint32_t num_gpus_;
    Stats stats_;

    Stat query_single(const int gpu_id)
    {
        Stat stat;

        /* Get device */
        nvmlDevice_t device;
        throwException(nvmlDeviceGetHandleByIndex_v2(gpu_id, &device), "Failed to get device handle");

        /* Get gpu name */
        char gpu_name[NVML_DEVICE_NAME_V2_BUFFER_SIZE];
        throwException(nvmlDeviceGetName(device, gpu_name, NVML_DEVICE_NAME_V2_BUFFER_SIZE),
                        "Failed to get gpu name");
        stat.gpu_name = std::string(gpu_name);

        /* Get GPU utilization */
        throwException(nvmlDeviceGetUtilizationRates(device, &stat.utilization), "Failed to get utilization rates");

        /* Get memory info */
        throwException(nvmlDeviceGetMemoryInfo(device, &stat.memory), "Failed to get device memory info");
        uint64_t total_memory = stat.memory.total;

        /* Get clock info */
        throwException(nvmlDeviceGetClockInfo(device, nvmlClockType_t::NVML_CLOCK_GRAPHICS, &stat.clock.graphics),
                       "Failed to get graphics clock info");
        throwException(nvmlDeviceGetClockInfo(device, nvmlClockType_t::NVML_CLOCK_SM, &stat.clock.sm),
                       "Failed to get device sm info");
        throwException(nvmlDeviceGetClockInfo(device, nvmlClockType_t::NVML_CLOCK_MEM, &stat.clock.memory),
                       "Failed to get device memory info");
        throwException(nvmlDeviceGetClockInfo(device, nvmlClockType_t::NVML_CLOCK_VIDEO, &stat.clock.video),
                       "Failed to get device video info");

        /* Get power usage */
        throwException(nvmlDeviceGetPowerUsage(device, &stat.power.usage), "Failed to get power usage");

        /* Get power limit */
        throwException(nvmlDeviceGetPowerManagementLimit(device, &stat.power.limit), "Failed to get power limit");

        /* Get total energy consumption */
        throwException(nvmlDeviceGetTotalEnergyConsumption(device, &stat.power.energy),
                       "Failed to get total energy consumption");

        /* Get gpu temperature */
        throwException(
            nvmlDeviceGetTemperature(device, nvmlTemperatureSensors_t::NVML_TEMPERATURE_GPU, &stat.temperature.board),
            "Failed to get gpu temperature");

        /* Get gpu thermal throttling temperature */
        throwException(
            nvmlDeviceGetTemperatureThreshold(device, nvmlTemperatureThresholds_t::NVML_TEMPERATURE_THRESHOLD_SHUTDOWN,
                                              &stat.temperature.thresh_shutdown),
            "Failed to get gpu shutdown temperature");

        throwException(
            nvmlDeviceGetTemperatureThreshold(device, nvmlTemperatureThresholds_t::NVML_TEMPERATURE_THRESHOLD_SLOWDOWN,
                                              &stat.temperature.thresh_slowdown),
            "Failed to get gpu slowdown temperature");

        /* Get fan speed */
        throwException(nvmlDeviceGetFanSpeed(device, &stat.fan_speed), "Failed to get fan speed");

        /* Get compute process */
        uint32_t num_compute_procs = 0;
        auto compute_result = nvmlDeviceGetComputeRunningProcesses(device, &num_compute_procs, nullptr);
        if (compute_result == NVML_ERROR_INSUFFICIENT_SIZE)
        {
            num_compute_procs = num_compute_procs * 2 + 5;
            nvmlProcessInfo_t procs[num_compute_procs];
            throwException(nvmlDeviceGetComputeRunningProcesses(device, &num_compute_procs, procs),
                           "Failed to get compute running process");

            for (uint32_t proc_id = 0; proc_id < num_compute_procs; proc_id++)
            {
                ssize_t used_memory = -1;
                if (procs[proc_id].usedGpuMemory <= total_memory)
                {
                    used_memory = static_cast<ssize_t>(procs[proc_id].usedGpuMemory);
                }
                stat.compute_procs.emplace_back(
                    Process(getProcessName(procs[proc_id].pid), procs[proc_id].pid, used_memory));
            }
        }

        /* Get graphics process */
        uint32_t num_graphics_procs = 0;
        auto graphics_result = nvmlDeviceGetGraphicsRunningProcesses(device, &num_graphics_procs, nullptr);
        if (graphics_result == NVML_ERROR_INSUFFICIENT_SIZE)
        {
            num_graphics_procs = num_graphics_procs * 2 + 5;
            nvmlProcessInfo_t procs[num_graphics_procs];
            throwException(nvmlDeviceGetGraphicsRunningProcesses(device, &num_graphics_procs, procs),
                           "Failed to get graphics running process");

            for (uint32_t proc_id = 0; proc_id < num_graphics_procs; proc_id++)
            {
                ssize_t used_memory = -1;
                if (procs[proc_id].usedGpuMemory <= total_memory)
                {
                    used_memory = static_cast<ssize_t>(procs[proc_id].usedGpuMemory);
                }
                stat.graphics_procs.emplace_back(
                    Process(getProcessName(procs[proc_id].pid), procs[proc_id].pid, used_memory));
            }
        }

        return stat;
    }

    void throwException(const nvmlReturn_t &result, const std::string &error_msg)
    {
        if (result != NVML_SUCCESS)
        {
            throw std::runtime_error(error_msg + ": " + nvmlErrorString(result));
        }
    }

    std::string getProcessName(const uint32_t pid)
    {
        FILE *fp = popen(("ps -fp " + std::to_string(pid) + " -o command=").c_str(), "r");
        if (!fp)
        {
            return "Unknown";
        }

        char buffer[NVML_SYSTEM_PROCESS_NAME_BUFFER_SIZE];
        char *line_p = fgets(buffer, sizeof(buffer), fp);
        if (!line_p)
        {
            return "Unknown";
        }

        pclose(fp);

        return std::string(line_p);
    }
};

PYBIND11_MODULE(librichsmi, m)
{
    m.doc() = "Python binding of librichsmi";

    pybind11::class_<NvmlWrapper>(m, "NvmlWrapper")
        .def(pybind11::init<ssize_t>())
        .def("query", &NvmlWrapper::query, "query GPU infos")
        .def("get_cuda_version", &NvmlWrapper::getCudaVersion, "get cuda version")
        .def("get_driver_version", &NvmlWrapper::getDriverVersion, "get driver version")
        .def("get_nvml_version", &NvmlWrapper::getNvmlVersion, "get nvml version")
        .def("get_timestamp", &NvmlWrapper::getTimeStamp, "get timestamp")
        .def("get_gpu_names", &NvmlWrapper::getGpuNames, "get gpu names")
        .def("get_utilizations", &NvmlWrapper::getUtilizations, "get utilization info")
        .def("get_memories", &NvmlWrapper::getMemories, "get memory info")
        .def("get_clocks", &NvmlWrapper::getClocks, "get clock info")
        .def("get_powers", &NvmlWrapper::getPowers, "get power info")
        .def("get_temperatures", &NvmlWrapper::getTemperatures, "get temperature info")
        .def("get_fan_speeds", &NvmlWrapper::getFanSpeeds, "get fan speed info")
        .def("get_compute_procs", &NvmlWrapper::getComputeProcs, "get compute procs info")
        .def("get_graphics_procs", &NvmlWrapper::getGraphicsProcs, "get graphics procs info");
}
