From 79be89ad0491bfdd110b2c612e21a0f28c29fa87 Mon Sep 17 00:00:00 2001 From: ilotterytea Date: Wed, 2 Jul 2025 03:31:54 +0500 Subject: feat: MARIADB SUPPORT!!!! --- bot/src/database.hpp | 263 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 bot/src/database.hpp (limited to 'bot/src/database.hpp') diff --git a/bot/src/database.hpp b/bot/src/database.hpp new file mode 100644 index 0000000..2920438 --- /dev/null +++ b/bot/src/database.hpp @@ -0,0 +1,263 @@ +#pragma once + +#ifdef USE_POSTGRES +#include +#elif defined(USE_MARIADB) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "config.hpp" + +namespace bot::db { + using DatabaseRow = std::map; + using DatabaseRows = std::vector; + + struct BaseDatabase { + public: + virtual ~BaseDatabase() = default; + + template + std::vector query_all(const std::string &query) { + return this->query_all(query, {}); + } + + template + std::vector query_all(const std::string &query, + const std::vector ¶ms) { + std::vector results; + + for (DatabaseRow &row : this->exec(query, params)) { + results.push_back(T(row)); + } + + return results; + } + + virtual DatabaseRows exec(const std::string &sql) = 0; + + virtual DatabaseRows exec(const std::string &sql, + const std::vector ¶meters) = 0; + + virtual void close() = 0; + }; + +#ifdef USE_POSTGRES + struct PostgresDatabase : public BaseDatabase { + public: + pqxx::connection conn; + + PostgresDatabase(const std::string &credentials) : conn(credentials) {} + + DatabaseRows exec(const std::string &sql) override { + pqxx::work work(conn); + pqxx::result r = work.exec(sql); + work.commit(); + + std::vector> rows; + for (auto const &row : r) { + std::map m; + for (auto const &f : row) { + m[f.name()] = f.c_str() ? f.c_str() : ""; + } + rows.push_back(m); + } + return rows; + } + + DatabaseRows exec(const std::string &sql, + const std::vector ¶meters) override { + pqxx::work work(conn); + pqxx::result r = work.exec(sql, parameters); + work.commit(); + + std::vector> rows; + for (auto const &row : r) { + std::map m; + for (auto const &f : row) { + m[f.name()] = f.c_str() ? f.c_str() : ""; + } + rows.push_back(m); + } + return rows; + } + + void close() override { conn.close(); } + }; +#endif + +#ifdef USE_MARIADB + struct MariaDatabase : public BaseDatabase { + public: + MYSQL *conn = nullptr; + + MariaDatabase(const Configuration &cfg) : conn(mysql_init(nullptr)) { + if (conn == nullptr) { + throw std::runtime_error("mysql_init() failed"); + } + + if (!mysql_real_connect( + conn, cfg.database.host.c_str(), cfg.database.user.c_str(), + cfg.database.password.c_str(), cfg.database.name.c_str(), + std::stoi(cfg.database.port), nullptr, 0)) { + mysql_close(conn); + throw std::runtime_error("mysql_real_connect() failed"); + } + } + + ~MariaDatabase() { this->close(); } + + DatabaseRows exec(const std::string &sql) override { + std::regex regex(R"(\$[0-9]+)"); + std::string query = std::regex_replace(sql, regex, "?"); + + if (mysql_query(conn, query.c_str())) { + std::string err = std::string(mysql_error(conn)); + mysql_close(conn); + throw std::runtime_error("Query has failed. Error: " + err); + } + + MYSQL_RES *res = mysql_store_result(conn); + if (res == nullptr) { + std::string err = std::string(mysql_error(conn)); + mysql_close(conn); + throw std::runtime_error("mysql_store_result() has failed. Error: " + + err); + } + + int num_fields = mysql_num_fields(res); + MYSQL_FIELD *fields = mysql_fetch_fields(res); + MYSQL_ROW row; + + std::vector> rows; + + while ((row = mysql_fetch_row(res))) { + std::map m; + + for (int i = 0; i < num_fields; i++) { + m[fields[i].name] = row[i] == nullptr ? "" : row[i]; + } + + rows.push_back(std::move(m)); + } + + mysql_free_result(res); + + return rows; + } + + DatabaseRows exec(const std::string &sql, + const std::vector ¶meters) override { + std::regex regex(R"(\$[0-9]+)"); + std::string query = std::regex_replace(sql, regex, "?"); + + MYSQL_STMT *stmt = mysql_stmt_init(conn); + + if (mysql_stmt_prepare(stmt, query.c_str(), query.length())) { + std::string err = std::string(mysql_error(conn)); + mysql_stmt_close(stmt); + throw std::runtime_error("Prepared query has failed. Error: " + err); + } + + // binding input params + std::vector bind_params(parameters.size()); + std::vector lengths(parameters.size()); + for (int i = 0; i < parameters.size(); i++) { + memset(&bind_params[i], 0, sizeof(MYSQL_BIND)); + lengths[i] = parameters[i].size(); + bind_params[i].buffer_type = MYSQL_TYPE_STRING; + bind_params[i].buffer = (void *)parameters[i].c_str(); + bind_params[i].buffer_length = lengths[i]; + bind_params[i].length = &lengths[i]; + bind_params[i].is_null = 0; + } + + if (!parameters.empty() && + mysql_stmt_bind_param(stmt, bind_params.data())) { + std::string err = std::string(mysql_error(conn)); + mysql_stmt_close(stmt); + throw std::runtime_error( + "mysql_stmt_bind_param() has failed. Error: " + err); + } + + if (mysql_stmt_execute(stmt)) { + std::string err = std::string(mysql_error(conn)); + mysql_stmt_close(stmt); + throw std::runtime_error( + "Prepared query execution has failed. Error: " + err); + } + + // metadata + MYSQL_RES *meta = mysql_stmt_result_metadata(stmt); + if (!meta) { + mysql_stmt_close(stmt); + return {}; + } + + int num_fields = mysql_num_fields(meta); + MYSQL_FIELD *fields = mysql_fetch_fields(meta); + + // bind output + std::vector bind_res(num_fields); + std::vector bufs(num_fields); + std::vector lengths_out(num_fields); + std::vector is_null(num_fields); + + for (int i = 0; i < num_fields; i++) { + bufs[i].resize(1024); + memset(&bind_res[i], 0, sizeof(MYSQL_BIND)); + bind_res[i].buffer_type = MYSQL_TYPE_STRING; + bind_res[i].buffer = bufs[i].data(); + bind_res[i].buffer_length = bufs[i].size(); + bind_res[i].length = &lengths_out[i]; + bind_res[i].is_null = &is_null[i]; + } + + if (mysql_stmt_bind_result(stmt, bind_res.data())) { + std::string err = std::string(mysql_error(conn)); + mysql_free_result(meta); + mysql_stmt_close(stmt); + throw std::runtime_error( + "mysql_stmt_bind_result() has failed. Error: " + err); + } + + std::vector> rows; + + while (mysql_stmt_fetch(stmt) == 0) { + std::map m; + + for (int i = 0; i < num_fields; i++) { + m[fields[i].name] = + bufs[i].data() == nullptr + ? "" + : std::string(bufs[i].data(), *bind_res[i].length); + } + + rows.push_back(std::move(m)); + } + + mysql_free_result(meta); + mysql_stmt_close(stmt); + + return rows; + } + + void close() override { + if (!conn) return; + + mysql_close(conn); + conn = nullptr; + } + }; +#endif + + std::unique_ptr create_connection(const Configuration &cfg); +} \ No newline at end of file -- cgit v1.2.3