#pragma once #ifdef USE_POSTGRES #include #elif defined(USE_MARIADB) #include #endif #include #include #include #include #include #include #include #include #include "config.hpp" namespace bot::db { using DatabaseRow = std::map; using DatabaseRows = std::vector; struct BaseDatabase { public: virtual ~BaseDatabase() = default; template std::vector query_all(const std::string &query) { return this->query_all(query, {}); } template std::vector query_all(const std::string &query, const std::vector ¶ms) { std::vector results; for (DatabaseRow &row : this->exec(query, params)) { results.push_back(T(row)); } return results; } virtual DatabaseRows exec(const std::string &sql) = 0; virtual DatabaseRows exec(const std::string &sql, const std::vector ¶meters) = 0; virtual void close() = 0; }; #ifdef USE_POSTGRES struct PostgresDatabase : public BaseDatabase { public: pqxx::connection conn; PostgresDatabase(const std::string &credentials) : conn(credentials) {} DatabaseRows exec(const std::string &sql) override { pqxx::work work(conn); pqxx::result r = work.exec(sql); work.commit(); std::vector> rows; for (auto const &row : r) { std::map m; for (auto const &f : row) { m[f.name()] = f.c_str() ? f.c_str() : ""; } rows.push_back(m); } return rows; } DatabaseRows exec(const std::string &sql, const std::vector ¶meters) override { pqxx::work work(conn); pqxx::result r = work.exec(sql, parameters); work.commit(); std::vector> rows; for (auto const &row : r) { std::map m; for (auto const &f : row) { m[f.name()] = f.c_str() ? f.c_str() : ""; } rows.push_back(m); } return rows; } void close() override { conn.close(); } }; #endif #ifdef USE_MARIADB struct MariaDatabase : public BaseDatabase { public: MYSQL *conn = nullptr; MariaDatabase(const Configuration &cfg) : conn(mysql_init(nullptr)) { if (conn == nullptr) { throw std::runtime_error("mysql_init() failed"); } if (!mysql_real_connect( conn, cfg.database.host.c_str(), cfg.database.user.c_str(), cfg.database.password.c_str(), cfg.database.name.c_str(), std::stoi(cfg.database.port), nullptr, 0)) { mysql_close(conn); throw std::runtime_error("mysql_real_connect() failed"); } } ~MariaDatabase() { this->close(); } DatabaseRows exec(const std::string &sql) override { std::regex regex(R"(\$[0-9]+)"); std::string query = std::regex_replace(sql, regex, "?"); if (mysql_query(conn, query.c_str())) { std::string err = std::string(mysql_error(conn)); mysql_close(conn); throw std::runtime_error("Query has failed. Error: " + err); } MYSQL_RES *res = mysql_store_result(conn); if (res == nullptr) { std::string err = std::string(mysql_error(conn)); mysql_close(conn); throw std::runtime_error("mysql_store_result() has failed. Error: " + err); } int num_fields = mysql_num_fields(res); MYSQL_FIELD *fields = mysql_fetch_fields(res); MYSQL_ROW row; std::vector> rows; while ((row = mysql_fetch_row(res))) { std::map m; for (int i = 0; i < num_fields; i++) { m[fields[i].name] = row[i] == nullptr ? "" : row[i]; } rows.push_back(std::move(m)); } mysql_free_result(res); return rows; } DatabaseRows exec(const std::string &sql, const std::vector ¶meters) override { std::regex regex(R"(\$[0-9]+)"); std::string query = std::regex_replace(sql, regex, "?"); MYSQL_STMT *stmt = mysql_stmt_init(conn); if (mysql_stmt_prepare(stmt, query.c_str(), query.length())) { std::string err = std::string(mysql_error(conn)); mysql_stmt_close(stmt); throw std::runtime_error("Prepared query has failed. Error: " + err); } // binding input params std::vector bind_params(parameters.size()); std::vector lengths(parameters.size()); for (int i = 0; i < parameters.size(); i++) { memset(&bind_params[i], 0, sizeof(MYSQL_BIND)); lengths[i] = parameters[i].size(); bind_params[i].buffer_type = MYSQL_TYPE_STRING; bind_params[i].buffer = (void *)parameters[i].c_str(); bind_params[i].buffer_length = lengths[i]; bind_params[i].length = &lengths[i]; bind_params[i].is_null = 0; } if (!parameters.empty() && mysql_stmt_bind_param(stmt, bind_params.data())) { std::string err = std::string(mysql_error(conn)); mysql_stmt_close(stmt); throw std::runtime_error( "mysql_stmt_bind_param() has failed. Error: " + err); } if (mysql_stmt_execute(stmt)) { std::string err = std::string(mysql_error(conn)); mysql_stmt_close(stmt); throw std::runtime_error( "Prepared query execution has failed. Error: " + err); } // metadata MYSQL_RES *meta = mysql_stmt_result_metadata(stmt); if (!meta) { mysql_stmt_close(stmt); return {}; } int num_fields = mysql_num_fields(meta); MYSQL_FIELD *fields = mysql_fetch_fields(meta); // bind output std::vector bind_res(num_fields); std::vector bufs(num_fields); std::vector lengths_out(num_fields); #if MYSQL_VERSION_ID >= 80000 std::vector is_null(num_fields); #else std::vector is_null(num_fields); #endif for (int i = 0; i < num_fields; i++) { bufs[i].resize(1024); memset(&bind_res[i], 0, sizeof(MYSQL_BIND)); bind_res[i].buffer_type = MYSQL_TYPE_STRING; bind_res[i].buffer = bufs[i].data(); bind_res[i].buffer_length = bufs[i].size(); bind_res[i].length = &lengths_out[i]; #if MYSQL_VERSION_ID >= 80000 bind_res[i].is_null = reinterpret_cast(&is_null[i]); #else bind_res[i].is_null = &is_null[i]; #endif } if (mysql_stmt_bind_result(stmt, bind_res.data())) { std::string err = std::string(mysql_error(conn)); mysql_free_result(meta); mysql_stmt_close(stmt); throw std::runtime_error( "mysql_stmt_bind_result() has failed. Error: " + err); } std::vector> rows; while (mysql_stmt_fetch(stmt) == 0) { std::map m; for (int i = 0; i < num_fields; i++) { m[fields[i].name] = bufs[i].data() == nullptr ? "" : std::string(bufs[i].data(), *bind_res[i].length); } rows.push_back(std::move(m)); } mysql_free_result(meta); mysql_stmt_close(stmt); return rows; } void close() override { if (!conn) return; mysql_close(conn); conn = nullptr; } }; #endif std::unique_ptr create_connection(const Configuration &cfg); }