summaryrefslogtreecommitdiff
path: root/bot/src/database.hpp
diff options
context:
space:
mode:
authorilotterytea <iltsu@alright.party>2025-07-02 03:31:54 +0500
committerilotterytea <iltsu@alright.party>2025-07-02 03:31:54 +0500
commit79be89ad0491bfdd110b2c612e21a0f28c29fa87 (patch)
tree589daab514de11cc87e424b62d87f8cac12494ab /bot/src/database.hpp
parentfea3c12d6b621796bb239cebb57a5a5014dfe350 (diff)
feat: MARIADB SUPPORT!!!!
Diffstat (limited to 'bot/src/database.hpp')
-rw-r--r--bot/src/database.hpp263
1 files changed, 263 insertions, 0 deletions
diff --git a/bot/src/database.hpp b/bot/src/database.hpp
new file mode 100644
index 0000000..2920438
--- /dev/null
+++ b/bot/src/database.hpp
@@ -0,0 +1,263 @@
+#pragma once
+
+#ifdef USE_POSTGRES
+#include <pqxx/pqxx>
+#elif defined(USE_MARIADB)
+#include <mysql/mysql.h>
+#endif
+
+#include <cstring>
+#include <map>
+#include <memory>
+#include <regex>
+#include <stdexcept>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "config.hpp"
+
+namespace bot::db {
+ using DatabaseRow = std::map<std::string, std::string>;
+ using DatabaseRows = std::vector<DatabaseRow>;
+
+ struct BaseDatabase {
+ public:
+ virtual ~BaseDatabase() = default;
+
+ template <typename T>
+ std::vector<T> query_all(const std::string &query) {
+ return this->query_all<T>(query, {});
+ }
+
+ template <typename T>
+ std::vector<T> query_all(const std::string &query,
+ const std::vector<std::string> &params) {
+ std::vector<T> results;
+
+ for (DatabaseRow &row : this->exec(query, params)) {
+ results.push_back(T(row));
+ }
+
+ return results;
+ }
+
+ virtual DatabaseRows exec(const std::string &sql) = 0;
+
+ virtual DatabaseRows exec(const std::string &sql,
+ const std::vector<std::string> &parameters) = 0;
+
+ virtual void close() = 0;
+ };
+
+#ifdef USE_POSTGRES
+ struct PostgresDatabase : public BaseDatabase {
+ public:
+ pqxx::connection conn;
+
+ PostgresDatabase(const std::string &credentials) : conn(credentials) {}
+
+ DatabaseRows exec(const std::string &sql) override {
+ pqxx::work work(conn);
+ pqxx::result r = work.exec(sql);
+ work.commit();
+
+ std::vector<std::map<std::string, std::string>> rows;
+ for (auto const &row : r) {
+ std::map<std::string, std::string> m;
+ for (auto const &f : row) {
+ m[f.name()] = f.c_str() ? f.c_str() : "";
+ }
+ rows.push_back(m);
+ }
+ return rows;
+ }
+
+ DatabaseRows exec(const std::string &sql,
+ const std::vector<std::string> &parameters) override {
+ pqxx::work work(conn);
+ pqxx::result r = work.exec(sql, parameters);
+ work.commit();
+
+ std::vector<std::map<std::string, std::string>> rows;
+ for (auto const &row : r) {
+ std::map<std::string, std::string> m;
+ for (auto const &f : row) {
+ m[f.name()] = f.c_str() ? f.c_str() : "";
+ }
+ rows.push_back(m);
+ }
+ return rows;
+ }
+
+ void close() override { conn.close(); }
+ };
+#endif
+
+#ifdef USE_MARIADB
+ struct MariaDatabase : public BaseDatabase {
+ public:
+ MYSQL *conn = nullptr;
+
+ MariaDatabase(const Configuration &cfg) : conn(mysql_init(nullptr)) {
+ if (conn == nullptr) {
+ throw std::runtime_error("mysql_init() failed");
+ }
+
+ if (!mysql_real_connect(
+ conn, cfg.database.host.c_str(), cfg.database.user.c_str(),
+ cfg.database.password.c_str(), cfg.database.name.c_str(),
+ std::stoi(cfg.database.port), nullptr, 0)) {
+ mysql_close(conn);
+ throw std::runtime_error("mysql_real_connect() failed");
+ }
+ }
+
+ ~MariaDatabase() { this->close(); }
+
+ DatabaseRows exec(const std::string &sql) override {
+ std::regex regex(R"(\$[0-9]+)");
+ std::string query = std::regex_replace(sql, regex, "?");
+
+ if (mysql_query(conn, query.c_str())) {
+ std::string err = std::string(mysql_error(conn));
+ mysql_close(conn);
+ throw std::runtime_error("Query has failed. Error: " + err);
+ }
+
+ MYSQL_RES *res = mysql_store_result(conn);
+ if (res == nullptr) {
+ std::string err = std::string(mysql_error(conn));
+ mysql_close(conn);
+ throw std::runtime_error("mysql_store_result() has failed. Error: " +
+ err);
+ }
+
+ int num_fields = mysql_num_fields(res);
+ MYSQL_FIELD *fields = mysql_fetch_fields(res);
+ MYSQL_ROW row;
+
+ std::vector<std::map<std::string, std::string>> rows;
+
+ while ((row = mysql_fetch_row(res))) {
+ std::map<std::string, std::string> m;
+
+ for (int i = 0; i < num_fields; i++) {
+ m[fields[i].name] = row[i] == nullptr ? "" : row[i];
+ }
+
+ rows.push_back(std::move(m));
+ }
+
+ mysql_free_result(res);
+
+ return rows;
+ }
+
+ DatabaseRows exec(const std::string &sql,
+ const std::vector<std::string> &parameters) override {
+ std::regex regex(R"(\$[0-9]+)");
+ std::string query = std::regex_replace(sql, regex, "?");
+
+ MYSQL_STMT *stmt = mysql_stmt_init(conn);
+
+ if (mysql_stmt_prepare(stmt, query.c_str(), query.length())) {
+ std::string err = std::string(mysql_error(conn));
+ mysql_stmt_close(stmt);
+ throw std::runtime_error("Prepared query has failed. Error: " + err);
+ }
+
+ // binding input params
+ std::vector<MYSQL_BIND> bind_params(parameters.size());
+ std::vector<unsigned long> lengths(parameters.size());
+ for (int i = 0; i < parameters.size(); i++) {
+ memset(&bind_params[i], 0, sizeof(MYSQL_BIND));
+ lengths[i] = parameters[i].size();
+ bind_params[i].buffer_type = MYSQL_TYPE_STRING;
+ bind_params[i].buffer = (void *)parameters[i].c_str();
+ bind_params[i].buffer_length = lengths[i];
+ bind_params[i].length = &lengths[i];
+ bind_params[i].is_null = 0;
+ }
+
+ if (!parameters.empty() &&
+ mysql_stmt_bind_param(stmt, bind_params.data())) {
+ std::string err = std::string(mysql_error(conn));
+ mysql_stmt_close(stmt);
+ throw std::runtime_error(
+ "mysql_stmt_bind_param() has failed. Error: " + err);
+ }
+
+ if (mysql_stmt_execute(stmt)) {
+ std::string err = std::string(mysql_error(conn));
+ mysql_stmt_close(stmt);
+ throw std::runtime_error(
+ "Prepared query execution has failed. Error: " + err);
+ }
+
+ // metadata
+ MYSQL_RES *meta = mysql_stmt_result_metadata(stmt);
+ if (!meta) {
+ mysql_stmt_close(stmt);
+ return {};
+ }
+
+ int num_fields = mysql_num_fields(meta);
+ MYSQL_FIELD *fields = mysql_fetch_fields(meta);
+
+ // bind output
+ std::vector<MYSQL_BIND> bind_res(num_fields);
+ std::vector<std::string> bufs(num_fields);
+ std::vector<unsigned long> lengths_out(num_fields);
+ std::vector<my_bool> is_null(num_fields);
+
+ for (int i = 0; i < num_fields; i++) {
+ bufs[i].resize(1024);
+ memset(&bind_res[i], 0, sizeof(MYSQL_BIND));
+ bind_res[i].buffer_type = MYSQL_TYPE_STRING;
+ bind_res[i].buffer = bufs[i].data();
+ bind_res[i].buffer_length = bufs[i].size();
+ bind_res[i].length = &lengths_out[i];
+ bind_res[i].is_null = &is_null[i];
+ }
+
+ if (mysql_stmt_bind_result(stmt, bind_res.data())) {
+ std::string err = std::string(mysql_error(conn));
+ mysql_free_result(meta);
+ mysql_stmt_close(stmt);
+ throw std::runtime_error(
+ "mysql_stmt_bind_result() has failed. Error: " + err);
+ }
+
+ std::vector<std::map<std::string, std::string>> rows;
+
+ while (mysql_stmt_fetch(stmt) == 0) {
+ std::map<std::string, std::string> m;
+
+ for (int i = 0; i < num_fields; i++) {
+ m[fields[i].name] =
+ bufs[i].data() == nullptr
+ ? ""
+ : std::string(bufs[i].data(), *bind_res[i].length);
+ }
+
+ rows.push_back(std::move(m));
+ }
+
+ mysql_free_result(meta);
+ mysql_stmt_close(stmt);
+
+ return rows;
+ }
+
+ void close() override {
+ if (!conn) return;
+
+ mysql_close(conn);
+ conn = nullptr;
+ }
+ };
+#endif
+
+ std::unique_ptr<BaseDatabase> create_connection(const Configuration &cfg);
+} \ No newline at end of file