diff options
Diffstat (limited to 'bot')
| -rw-r--r-- | bot/CMakeLists.txt | 43 | ||||
| -rw-r--r-- | bot/src/bundle.hpp | 4 | ||||
| -rw-r--r-- | bot/src/commands/command.cpp | 35 | ||||
| -rw-r--r-- | bot/src/commands/lua.cpp | 61 | ||||
| -rw-r--r-- | bot/src/commands/request.cpp | 170 | ||||
| -rw-r--r-- | bot/src/commands/request.hpp | 15 | ||||
| -rw-r--r-- | bot/src/commands/request_util.cpp | 208 | ||||
| -rw-r--r-- | bot/src/commands/request_util.hpp | 13 | ||||
| -rw-r--r-- | bot/src/database.cpp | 13 | ||||
| -rw-r--r-- | bot/src/database.hpp | 263 | ||||
| -rw-r--r-- | bot/src/emotes.cpp | 79 | ||||
| -rw-r--r-- | bot/src/github.cpp | 68 | ||||
| -rw-r--r-- | bot/src/handlers.cpp | 40 | ||||
| -rw-r--r-- | bot/src/handlers.hpp | 4 | ||||
| -rw-r--r-- | bot/src/main.cpp | 49 | ||||
| -rw-r--r-- | bot/src/schemas/channel.hpp | 62 | ||||
| -rw-r--r-- | bot/src/schemas/user.hpp | 26 | ||||
| -rw-r--r-- | bot/src/stream.cpp | 60 | ||||
| -rw-r--r-- | bot/src/timer.cpp | 44 |
19 files changed, 725 insertions, 532 deletions
diff --git a/bot/CMakeLists.txt b/bot/CMakeLists.txt index cd74eaf..85eb870 100644 --- a/bot/CMakeLists.txt +++ b/bot/CMakeLists.txt @@ -43,6 +43,40 @@ file(GLOB_RECURSE SOURCE_FILES "src/*.cpp" "src/*.h" "src/*.hpp") target_sources(Bot PRIVATE ${SOURCE_FILES}) target_include_directories(Bot PRIVATE src) +# DATABASE +option(USE_POSTGRES "Use PostgreSQL backend" OFF) +option(USE_MARIADB "Use MariaDB backend" ON) + +if (USE_POSTGRES) + FetchContent_Declare( + pqxx + GIT_REPOSITORY https://github.com/jtv/libpqxx.git + GIT_TAG 7.10.1 + ) + FetchContent_MakeAvailable(pqxx) + target_compile_definitions(Bot PRIVATE USE_POSTGRES) + target_link_libraries(Bot PRIVATE pqxx) +endif() + +# ev&doe it is mysql +if (USE_MARIADB) + target_compile_definitions(Bot PRIVATE USE_MARIADB) + + # searching for mysql + find_program(MYSQL_CONFIG_EXECUTABLE mysql_config REQUIRED) + + execute_process(COMMAND ${MYSQL_CONFIG_EXECUTABLE} --cflags + OUTPUT_VARIABLE MYSQL_CFLAGS + OUTPUT_STRIP_TRAILING_WHITESPACE) + + execute_process(COMMAND ${MYSQL_CONFIG_EXECUTABLE} --libs + OUTPUT_VARIABLE MYSQL_LIBS + OUTPUT_STRIP_TRAILING_WHITESPACE) + + target_compile_options(Bot PRIVATE ${MYSQL_CFLAGS}) + target_link_libraries(Bot PRIVATE ${MYSQL_LIBS}) +endif() + # Getting libraries # json @@ -60,14 +94,6 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(cpr) -# postgresql -FetchContent_Declare( - pqxx - GIT_REPOSITORY https://github.com/jtv/libpqxx.git - GIT_TAG 7.10.1 -) -FetchContent_MakeAvailable(pqxx) - # websockets FetchContent_Declare( ixwebsocket @@ -96,7 +122,6 @@ FetchContent_MakeAvailable(emotespp) target_link_libraries(Bot PRIVATE ixwebsocket::ixwebsocket - pqxx nlohmann_json::nlohmann_json cpr::cpr lua diff --git a/bot/src/bundle.hpp b/bot/src/bundle.hpp index a343416..5ee0c1f 100644 --- a/bot/src/bundle.hpp +++ b/bot/src/bundle.hpp @@ -5,6 +5,10 @@ namespace bot { class CommandLoader; } + namespace loc { + class Localization; + } + class InstanceBundle; } diff --git a/bot/src/commands/command.cpp b/bot/src/commands/command.cpp index 2ecb80f..533376a 100644 --- a/bot/src/commands/command.cpp +++ b/bot/src/commands/command.cpp @@ -6,15 +6,16 @@ #include <fstream> #include <memory> #include <optional> -#include <pqxx/pqxx> #include <sol/state.hpp> #include <sol/types.hpp> #include <stdexcept> #include <string> +#include <vector> #include "../bundle.hpp" #include "../utils/chrono.hpp" #include "commands/lua.hpp" +#include "database.hpp" #include "request.hpp" #include "response.hpp" @@ -84,17 +85,18 @@ namespace bot { return std::nullopt; } - pqxx::work work(request.conn); + std::unique_ptr<db::BaseDatabase> conn = + db::create_connection(bundle.configuration); - pqxx::result action_query = work.exec( - "SELECT sent_at FROM actions WHERE user_id = " + - std::to_string(request.user.get_id()) + - " AND channel_id = " + std::to_string(request.channel.get_id()) + - " AND command = '" + request.command_id + "' ORDER BY sent_at DESC"); + db::DatabaseRows actions = conn->exec( + "SELECT sent_at FROM actions WHERE user_id = $1 AND channel_id = $2 " + "AND command = $3 ORDER BY sent_at DESC", + {std::to_string(request.user.get_id()), + std::to_string(request.channel.get_id()), request.command_id}); - if (!action_query.empty()) { - auto last_sent_at = utils::chrono::string_to_time_point( - action_query[0][0].as<std::string>()); + if (!actions.empty()) { + auto last_sent_at = + utils::chrono::string_to_time_point(actions[0]["sent_at"]); auto now = std::chrono::system_clock::now(); auto now_time_it = std::chrono::system_clock::to_time_t(now); @@ -119,15 +121,12 @@ namespace bot { arguments += request.message.value(); } - work.exec( + conn->exec( "INSERT INTO actions(user_id, channel_id, command, arguments, " - "full_message) VALUES (" + - std::to_string(request.user.get_id()) + ", " + - std::to_string(request.channel.get_id()) + ", '" + - request.command_id + "', '" + arguments + "', '" + - request.irc_message.message + "')"); - - work.commit(); + "full_message) VALUES ($1, $2, $3, $4, $5)", + {std::to_string(request.user.get_id()), + std::to_string(request.channel.get_id()), request.command_id, + arguments, request.irc_message.message}); return (*command)->run(bundle, request); } diff --git a/bot/src/commands/lua.cpp b/bot/src/commands/lua.cpp index cb00887..a8db207 100644 --- a/bot/src/commands/lua.cpp +++ b/bot/src/commands/lua.cpp @@ -28,6 +28,7 @@ #include "cpr/cprtypes.h" #include "cpr/multipart.h" #include "cpr/response.h" +#include "database.hpp" #include "schemas/channel.hpp" #include "schemas/stream.hpp" #include "schemas/user.hpp" @@ -464,30 +465,31 @@ namespace bot::command::lua { state->set_function("db_execute", [state, cfg]( const std::string &query, const sol::table ¶meters) { - pqxx::connection conn(GET_DATABASE_CONNECTION_URL(cfg)); - pqxx::params p; + std::unique_ptr<db::BaseDatabase> conn = db::create_connection(cfg); + + std::vector<std::string> params; for (const auto &kv : parameters) { auto v = kv.second; switch (v.get_type()) { case sol::type::lua_nil: { - p.append(nullptr); + params.push_back("NULL"); break; } case sol::type::string: { - p.append(v.as<std::string>()); + params.push_back(v.as<std::string>()); break; } case sol::type::boolean: { - p.append(v.as<bool>()); + params.push_back(std::to_string(v.as<bool>())); break; } case sol::type::number: { double num = v.as<double>(); if (std::floor(num) == num) { - p.append(static_cast<long long>(num)); + params.push_back(std::to_string(static_cast<long long>(num))); } else { - p.append(num); + params.push_back(std::to_string(num)); } break; } @@ -496,41 +498,37 @@ namespace bot::command::lua { } } - pqxx::work work(conn); - - work.exec_params(query, p); - - work.commit(); - conn.close(); + conn->exec(query, params); }); state->set_function("db_query", [state, cfg]( const std::string &query, const sol::table ¶meters) { - pqxx::connection conn(GET_DATABASE_CONNECTION_URL(cfg)); - pqxx::params p; + std::unique_ptr<db::BaseDatabase> conn = db::create_connection(cfg); + + std::vector<std::string> params; for (const auto &kv : parameters) { auto v = kv.second; switch (v.get_type()) { case sol::type::lua_nil: { - p.append(nullptr); + params.push_back("NULL"); break; } case sol::type::string: { - p.append(v.as<std::string>()); + params.push_back(v.as<std::string>()); break; } case sol::type::boolean: { - p.append(v.as<bool>()); + params.push_back(std::to_string(v.as<bool>())); break; } case sol::type::number: { double num = v.as<double>(); if (std::floor(num) == num) { - p.append(static_cast<long long>(num)); + params.push_back(std::to_string(static_cast<long long>(num))); } else { - p.append(num); + params.push_back(std::to_string(num)); } break; } @@ -539,33 +537,26 @@ namespace bot::command::lua { } } - pqxx::work work(conn); - pqxx::result res = work.exec_params(query, p); + db::DatabaseRows rows = conn->exec(query, params); sol::table o = state->create_table(); - for (const auto &row : res) { + for (const db::DatabaseRow &row : rows) { sol::table r = state->create_table(); - for (int i = 0; i < row.size(); i++) { - auto v = row[i]; - - sol::object obj; - if (v.is_null()) { - obj = sol::make_object(*state, sol::lua_nil); + for (const auto &[k, v] : row) { + sol::object val; + if (v.empty()) { + val = sol::make_object(*state, sol::lua_nil); } else { - obj = sol::make_object(*state, v.as<std::string>()); + val = sol::make_object(*state, v); } - - r[res.column_name(i)] = obj; + r[k] = val; } o.add(r); } - work.commit(); - conn.close(); - return o; }); } diff --git a/bot/src/commands/request.cpp b/bot/src/commands/request.cpp index 7acc107..d7f1f83 100644 --- a/bot/src/commands/request.cpp +++ b/bot/src/commands/request.cpp @@ -1,6 +1,15 @@ #include "commands/request.hpp" +#include <algorithm> +#include <optional> #include <sol/types.hpp> +#include <string> +#include <vector> + +#include "constants.hpp" +#include "database.hpp" +#include "schemas/channel.hpp" +#include "utils/string.hpp" namespace bot::command { sol::table Request::as_lua_table(std::shared_ptr<sol::state> luaState) const { @@ -25,4 +34,165 @@ namespace bot::command { return o; } + + std::optional<Request> generate_request( + const command::CommandLoader &command_loader, + const irc::Message<irc::MessageType::Privmsg> &irc_message, + std::unique_ptr<db::BaseDatabase> &conn) { + // fetching channel + std::vector<schemas::Channel> chans = conn->query_all<schemas::Channel>( + "SELECT * FROM channels WHERE alias_id = $1", + {std::to_string(irc_message.source.id)}); + + if (chans.empty()) { + conn->exec( + "INSERT INTO channels(alias_id, alias_name) VALUES ($1, $2)", + {std::to_string(irc_message.source.id), irc_message.source.login}); + + chans = conn->query_all<schemas::Channel>( + "SELECT * FROM channels WHERE alias_id = $1", + {std::to_string(irc_message.source.id)}); + } + + schemas::Channel channel = chans[0]; + if (channel.get_opted_out_at().has_value()) { + return std::nullopt; + } + + // fetching channel preference + std::vector<schemas::ChannelPreferences> prefs = + conn->query_all<schemas::ChannelPreferences>( + "SELECT * FROM channel_preferences WHERE id = $1", + {std::to_string(channel.get_id())}); + + if (prefs.empty()) { + conn->exec( + "INSERT INTO channel_preferences(id, prefix, locale) VALUES ($1, " + "$2, $3)", + {std::to_string(channel.get_id()), DEFAULT_PREFIX, + DEFAULT_LOCALE_ID}); + + prefs = conn->query_all<schemas::ChannelPreferences>( + "SELECT * FROM channel_preferences WHERE id = $1", + {std::to_string(channel.get_id())}); + } + + schemas::ChannelPreferences pref = prefs[0]; + + // fetching channel preference + std::vector<schemas::User> users = conn->query_all<schemas::User>( + "SELECT * FROM users WHERE alias_id = $1", + {std::to_string(irc_message.sender.id)}); + + if (users.empty()) { + conn->exec( + "INSERT INTO users(alias_id, alias_name) VALUES ($1, " + "$2)", + {std::to_string(irc_message.sender.id), irc_message.sender.login}); + + users = conn->query_all<schemas::User>( + "SELECT * FROM users WHERE alias_id = $1", + {std::to_string(irc_message.sender.id)}); + } + + schemas::User user = users[0]; + + // updating username + if (user.get_alias_name() != irc_message.sender.login) { + conn->exec("UPDATE users SET alias_name = $1 WHERE id = $2", + {irc_message.sender.login, std::to_string(user.get_id())}); + + user.set_alias_name(irc_message.sender.login); + } + + // setting permissions + schemas::PermissionLevel level = schemas::PermissionLevel::USER; + const auto &badges = irc_message.sender.badges; + + if (user.get_alias_id() == channel.get_alias_id()) { + level = schemas::PermissionLevel::BROADCASTER; + } else if (std::any_of(badges.begin(), badges.end(), [&](const auto &x) { + return x.first == "moderator"; + })) { + level = schemas::PermissionLevel::MODERATOR; + } else if (std::any_of(badges.begin(), badges.end(), + [&](const auto &x) { return x.first == "vip"; })) { + level = schemas::PermissionLevel::VIP; + } + + std::vector<schemas::UserRights> user_rights = + conn->query_all<schemas::UserRights>( + "SELECT * FROM user_rights WHERE user_id = $1 AND channel_id = $2", + {std::to_string(user.get_id()), std::to_string(channel.get_id())}); + + if (user_rights.empty()) { + conn->exec( + "INSERT INTO user_rights(user_id, channel_id, level) VALUES ($1, " + "$2, $3)", + {std::to_string(user.get_id()), std::to_string(channel.get_id()), + std::to_string(level)}); + + user_rights = conn->query_all<schemas::UserRights>( + "SELECT * FROM user_rights WHERE user_id = $1 AND channel_id = $2", + {std::to_string(user.get_id()), std::to_string(channel.get_id())}); + } + + schemas::UserRights user_right = user_rights[0]; + + if (user_right.get_level() != level) { + conn->exec("UPDATE user_rights SET level = $1 WHERE id = $2", + {std::to_string(level), std::to_string(user_right.get_id())}); + + user_right.set_level(level); + } + + // --- FETCHING MESSAGES + std::string fullmsg = irc_message.message; + const std::string &prefix = pref.get_prefix(); + + if (fullmsg.empty() || fullmsg.substr(0, prefix.length()) != prefix) { + return std::nullopt; + } + + fullmsg = fullmsg.substr(prefix.length()); + + std::vector<std::string> parts = utils::string::split_text(fullmsg, ' '); + + std::string command_id = parts[0]; + + auto cmd = std::find_if( + command_loader.get_commands().begin(), + command_loader.get_commands().end(), + [&command_id](const auto &c) { return c->get_name() == command_id; }); + + if (cmd == command_loader.get_commands().end()) { + return std::nullopt; + } + + parts.erase(parts.begin()); + + Request req{command_id, std::nullopt, std::nullopt, irc_message, + channel, pref, user, user_right}; + + if (parts.empty()) { + return req; + } + + std::optional<std::string> scid = parts[0]; + auto scids = (*cmd)->get_subcommand_ids(); + + if (std::any_of(scids.begin(), scids.end(), + [&](const auto &x) { return x == scid.value(); })) { + parts.erase(parts.begin()); + } else { + scid = std::nullopt; + } + + req.subcommand_id = scid; + + std::optional<std::string> message = utils::string::join_vector(parts, ' '); + req.message = message; + + return req; + } }
\ No newline at end of file diff --git a/bot/src/commands/request.hpp b/bot/src/commands/request.hpp index b6ed534..9822fc8 100644 --- a/bot/src/commands/request.hpp +++ b/bot/src/commands/request.hpp @@ -2,7 +2,6 @@ #include <memory> #include <optional> -#include <pqxx/pqxx> #include <sol/state.hpp> #include <sol/table.hpp> #include <string> @@ -12,6 +11,13 @@ #include "../schemas/user.hpp" namespace bot::command { + struct Request; +} + +#include "commands/command.hpp" +#include "database.hpp" + +namespace bot::command { struct Request { std::string command_id; std::optional<std::string> subcommand_id; @@ -23,8 +29,11 @@ namespace bot::command { schemas::User user; schemas::UserRights user_rights; - pqxx::connection &conn; - sol::table as_lua_table(std::shared_ptr<sol::state> luaState) const; }; + + std::optional<Request> generate_request( + const command::CommandLoader &command_loader, + const irc::Message<irc::MessageType::Privmsg> &irc_message, + std::unique_ptr<db::BaseDatabase> &conn); } diff --git a/bot/src/commands/request_util.cpp b/bot/src/commands/request_util.cpp deleted file mode 100644 index ad8a174..0000000 --- a/bot/src/commands/request_util.cpp +++ /dev/null @@ -1,208 +0,0 @@ -#include "request_util.hpp" - -#include <algorithm> -#include <optional> -#include <pqxx/pqxx> -#include <string> - -#include "../constants.hpp" -#include "../irc/message.hpp" -#include "../schemas/channel.hpp" -#include "command.hpp" -#include "request.hpp" - -namespace bot::command { - std::optional<Request> generate_request( - const command::CommandLoader &command_loader, - const irc::Message<irc::MessageType::Privmsg> &irc_message, - pqxx::connection &conn) { - pqxx::work *work; - - work = new pqxx::work(conn); - - std::vector<std::string> parts = - utils::string::split_text(irc_message.message, ' '); - - pqxx::result query = work->exec("SELECT * FROM channels WHERE alias_id = " + - std::to_string(irc_message.source.id)); - - // Create new channel data in the database if it didn't exist b4 - if (query.empty()) { - work->exec("INSERT INTO channels (alias_id, alias_name) VALUES (" + - std::to_string(irc_message.source.id) + ", '" + - irc_message.source.login + "')"); - - work->commit(); - - delete work; - work = new pqxx::work(conn); - - query = work->exec("SELECT * FROM channels WHERE alias_id = " + - std::to_string(irc_message.source.id)); - } - - schemas::Channel channel(query[0]); - - if (channel.get_opted_out_at().has_value()) { - delete work; - return std::nullopt; - } - - query = work->exec("SELECT * FROM channel_preferences WHERE channel_id = " + - std::to_string(channel.get_id())); - - // Create new channel preference data in the database if it didn't exist b4 - if (query.empty()) { - work->exec( - "INSERT INTO channel_preferences (channel_id, prefix, locale) VALUES " - "(" + - std::to_string(channel.get_id()) + ", '" + DEFAULT_PREFIX + "', '" + - DEFAULT_LOCALE_ID + "')"); - - work->commit(); - - delete work; - work = new pqxx::work(conn); - - query = - work->exec("SELECT * FROM channel_preferences WHERE channel_id = " + - std::to_string(channel.get_id())); - } - - schemas::ChannelPreferences channel_preferences(query[0]); - - query = work->exec("SELECT * FROM users WHERE alias_id = " + - std::to_string(irc_message.sender.id)); - - // Create new user data in the database if it didn't exist before - if (query.empty()) { - work->exec("INSERT INTO users (alias_id, alias_name) VALUES (" + - std::to_string(irc_message.sender.id) + ", '" + - irc_message.sender.login + "')"); - - work->commit(); - - delete work; - work = new pqxx::work(conn); - - query = work->exec("SELECT * FROM users WHERE alias_id = " + - std::to_string(irc_message.sender.id)); - } - - schemas::User user(query[0]); - - if (user.get_alias_name() != irc_message.sender.login) { - work->exec("UPDATE users SET alias_name = '" + irc_message.sender.login + - "' WHERE id = " + std::to_string(user.get_id())); - work->commit(); - - delete work; - work = new pqxx::work(conn); - - user.set_alias_name(irc_message.sender.login); - } - - schemas::PermissionLevel level = schemas::PermissionLevel::USER; - const auto &badges = irc_message.sender.badges; - - if (user.get_alias_id() == channel.get_alias_id()) { - level = schemas::PermissionLevel::BROADCASTER; - } else if (std::any_of(badges.begin(), badges.end(), [&](const auto &x) { - return x.first == "moderator"; - })) { - level = schemas::PermissionLevel::MODERATOR; - } else if (std::any_of(badges.begin(), badges.end(), - [&](const auto &x) { return x.first == "vip"; })) { - level = schemas::PermissionLevel::VIP; - } - - query = work->exec("SELECT * FROM user_rights WHERE user_id = " + - std::to_string(user.get_id()) + - " AND channel_id = " + std::to_string(channel.get_id())); - - if (query.empty()) { - work->exec( - "INSERT INTO user_rights (user_id, channel_id, level) VALUES (" + - std::to_string(user.get_id()) + ", " + - std::to_string(channel.get_id()) + ", " + std::to_string(level) + - ")"); - - work->commit(); - - delete work; - work = new pqxx::work(conn); - - query = work->exec("SELECT * FROM user_rights WHERE user_id = " + - std::to_string(user.get_id()) + " AND channel_id = " + - std::to_string(channel.get_id())); - } - - schemas::UserRights user_rights(query[0]); - - if (user_rights.get_level() != level) { - work->exec("UPDATE user_rights SET level = " + std::to_string(level) + - " WHERE id = " + std::to_string(query[0][0].as<int>())); - - work->commit(); - - user_rights.set_level(level); - } - - // Checking if the user has sent a command - std::string command_id = parts[0]; - - const std::string &prefix = channel_preferences.get_prefix(); - - if (command_id.substr(0, prefix.length()) != prefix) { - delete work; - return std::nullopt; - } - - command_id = - command_id.substr(prefix.length(), command_id.length()); - - auto cmd = std::find_if( - command_loader.get_commands().begin(), - command_loader.get_commands().end(), - [&](const auto &command) { return command->get_name() == command_id; }); - - if (cmd == command_loader.get_commands().end()) { - delete work; - return std::nullopt; - } - - parts.erase(parts.begin()); - - delete work; - - if (parts.empty()) { - Request req{command_id, std::nullopt, std::nullopt, - irc_message, channel, channel_preferences, - user, user_rights, conn}; - - return req; - } - - std::optional<std::string> subcommand_id = parts[0]; - auto subcommand_ids = (*cmd)->get_subcommand_ids(); - - if (std::any_of( - subcommand_ids.begin(), subcommand_ids.end(), - [&](const auto &x) { return x == subcommand_id.value(); })) { - parts.erase(parts.begin()); - } else { - subcommand_id = std::nullopt; - } - - std::optional<std::string> message = utils::string::join_vector(parts, ' '); - - if (message->empty()) { - message = std::nullopt; - } - - Request req{command_id, subcommand_id, message, - irc_message, channel, channel_preferences, - user, user_rights, conn}; - return req; - } -} diff --git a/bot/src/commands/request_util.hpp b/bot/src/commands/request_util.hpp deleted file mode 100644 index dea6e12..0000000 --- a/bot/src/commands/request_util.hpp +++ /dev/null @@ -1,13 +0,0 @@ -#include <optional> -#include <pqxx/pqxx> - -#include "../irc/message.hpp" -#include "command.hpp" -#include "request.hpp" - -namespace bot::command { - std::optional<Request> generate_request( - const command::CommandLoader &command_loader, - const irc::Message<irc::MessageType::Privmsg> &irc_message, - pqxx::connection &conn); -} diff --git a/bot/src/database.cpp b/bot/src/database.cpp new file mode 100644 index 0000000..dcb7dae --- /dev/null +++ b/bot/src/database.cpp @@ -0,0 +1,13 @@ +#include "database.hpp" + +#include <memory> + +namespace bot::db { + std::unique_ptr<BaseDatabase> create_connection(const Configuration &cfg) { +#if USE_POSTGRES + return std::make_unique<PostgresDatabase>(GET_DATABASE_CONNECTION_URL(cfg)); +#elif defined(USE_MARIADB) + return std::make_unique<MariaDatabase>(cfg); +#endif + } +}
\ No newline at end of file diff --git a/bot/src/database.hpp b/bot/src/database.hpp new file mode 100644 index 0000000..2920438 --- /dev/null +++ b/bot/src/database.hpp @@ -0,0 +1,263 @@ +#pragma once + +#ifdef USE_POSTGRES +#include <pqxx/pqxx> +#elif defined(USE_MARIADB) +#include <mysql/mysql.h> +#endif + +#include <cstring> +#include <map> +#include <memory> +#include <regex> +#include <stdexcept> +#include <string> +#include <utility> +#include <vector> + +#include "config.hpp" + +namespace bot::db { + using DatabaseRow = std::map<std::string, std::string>; + using DatabaseRows = std::vector<DatabaseRow>; + + struct BaseDatabase { + public: + virtual ~BaseDatabase() = default; + + template <typename T> + std::vector<T> query_all(const std::string &query) { + return this->query_all<T>(query, {}); + } + + template <typename T> + std::vector<T> query_all(const std::string &query, + const std::vector<std::string> ¶ms) { + std::vector<T> results; + + for (DatabaseRow &row : this->exec(query, params)) { + results.push_back(T(row)); + } + + return results; + } + + virtual DatabaseRows exec(const std::string &sql) = 0; + + virtual DatabaseRows exec(const std::string &sql, + const std::vector<std::string> ¶meters) = 0; + + virtual void close() = 0; + }; + +#ifdef USE_POSTGRES + struct PostgresDatabase : public BaseDatabase { + public: + pqxx::connection conn; + + PostgresDatabase(const std::string &credentials) : conn(credentials) {} + + DatabaseRows exec(const std::string &sql) override { + pqxx::work work(conn); + pqxx::result r = work.exec(sql); + work.commit(); + + std::vector<std::map<std::string, std::string>> rows; + for (auto const &row : r) { + std::map<std::string, std::string> m; + for (auto const &f : row) { + m[f.name()] = f.c_str() ? f.c_str() : ""; + } + rows.push_back(m); + } + return rows; + } + + DatabaseRows exec(const std::string &sql, + const std::vector<std::string> ¶meters) override { + pqxx::work work(conn); + pqxx::result r = work.exec(sql, parameters); + work.commit(); + + std::vector<std::map<std::string, std::string>> rows; + for (auto const &row : r) { + std::map<std::string, std::string> m; + for (auto const &f : row) { + m[f.name()] = f.c_str() ? f.c_str() : ""; + } + rows.push_back(m); + } + return rows; + } + + void close() override { conn.close(); } + }; +#endif + +#ifdef USE_MARIADB + struct MariaDatabase : public BaseDatabase { + public: + MYSQL *conn = nullptr; + + MariaDatabase(const Configuration &cfg) : conn(mysql_init(nullptr)) { + if (conn == nullptr) { + throw std::runtime_error("mysql_init() failed"); + } + + if (!mysql_real_connect( + conn, cfg.database.host.c_str(), cfg.database.user.c_str(), + cfg.database.password.c_str(), cfg.database.name.c_str(), + std::stoi(cfg.database.port), nullptr, 0)) { + mysql_close(conn); + throw std::runtime_error("mysql_real_connect() failed"); + } + } + + ~MariaDatabase() { this->close(); } + + DatabaseRows exec(const std::string &sql) override { + std::regex regex(R"(\$[0-9]+)"); + std::string query = std::regex_replace(sql, regex, "?"); + + if (mysql_query(conn, query.c_str())) { + std::string err = std::string(mysql_error(conn)); + mysql_close(conn); + throw std::runtime_error("Query has failed. Error: " + err); + } + + MYSQL_RES *res = mysql_store_result(conn); + if (res == nullptr) { + std::string err = std::string(mysql_error(conn)); + mysql_close(conn); + throw std::runtime_error("mysql_store_result() has failed. Error: " + + err); + } + + int num_fields = mysql_num_fields(res); + MYSQL_FIELD *fields = mysql_fetch_fields(res); + MYSQL_ROW row; + + std::vector<std::map<std::string, std::string>> rows; + + while ((row = mysql_fetch_row(res))) { + std::map<std::string, std::string> m; + + for (int i = 0; i < num_fields; i++) { + m[fields[i].name] = row[i] == nullptr ? "" : row[i]; + } + + rows.push_back(std::move(m)); + } + + mysql_free_result(res); + + return rows; + } + + DatabaseRows exec(const std::string &sql, + const std::vector<std::string> ¶meters) override { + std::regex regex(R"(\$[0-9]+)"); + std::string query = std::regex_replace(sql, regex, "?"); + + MYSQL_STMT *stmt = mysql_stmt_init(conn); + + if (mysql_stmt_prepare(stmt, query.c_str(), query.length())) { + std::string err = std::string(mysql_error(conn)); + mysql_stmt_close(stmt); + throw std::runtime_error("Prepared query has failed. Error: " + err); + } + + // binding input params + std::vector<MYSQL_BIND> bind_params(parameters.size()); + std::vector<unsigned long> lengths(parameters.size()); + for (int i = 0; i < parameters.size(); i++) { + memset(&bind_params[i], 0, sizeof(MYSQL_BIND)); + lengths[i] = parameters[i].size(); + bind_params[i].buffer_type = MYSQL_TYPE_STRING; + bind_params[i].buffer = (void *)parameters[i].c_str(); + bind_params[i].buffer_length = lengths[i]; + bind_params[i].length = &lengths[i]; + bind_params[i].is_null = 0; + } + + if (!parameters.empty() && + mysql_stmt_bind_param(stmt, bind_params.data())) { + std::string err = std::string(mysql_error(conn)); + mysql_stmt_close(stmt); + throw std::runtime_error( + "mysql_stmt_bind_param() has failed. Error: " + err); + } + + if (mysql_stmt_execute(stmt)) { + std::string err = std::string(mysql_error(conn)); + mysql_stmt_close(stmt); + throw std::runtime_error( + "Prepared query execution has failed. Error: " + err); + } + + // metadata + MYSQL_RES *meta = mysql_stmt_result_metadata(stmt); + if (!meta) { + mysql_stmt_close(stmt); + return {}; + } + + int num_fields = mysql_num_fields(meta); + MYSQL_FIELD *fields = mysql_fetch_fields(meta); + + // bind output + std::vector<MYSQL_BIND> bind_res(num_fields); + std::vector<std::string> bufs(num_fields); + std::vector<unsigned long> lengths_out(num_fields); + std::vector<my_bool> is_null(num_fields); + + for (int i = 0; i < num_fields; i++) { + bufs[i].resize(1024); + memset(&bind_res[i], 0, sizeof(MYSQL_BIND)); + bind_res[i].buffer_type = MYSQL_TYPE_STRING; + bind_res[i].buffer = bufs[i].data(); + bind_res[i].buffer_length = bufs[i].size(); + bind_res[i].length = &lengths_out[i]; + bind_res[i].is_null = &is_null[i]; + } + + if (mysql_stmt_bind_result(stmt, bind_res.data())) { + std::string err = std::string(mysql_error(conn)); + mysql_free_result(meta); + mysql_stmt_close(stmt); + throw std::runtime_error( + "mysql_stmt_bind_result() has failed. Error: " + err); + } + + std::vector<std::map<std::string, std::string>> rows; + + while (mysql_stmt_fetch(stmt) == 0) { + std::map<std::string, std::string> m; + + for (int i = 0; i < num_fields; i++) { + m[fields[i].name] = + bufs[i].data() == nullptr + ? "" + : std::string(bufs[i].data(), *bind_res[i].length); + } + + rows.push_back(std::move(m)); + } + + mysql_free_result(meta); + mysql_stmt_close(stmt); + + return rows; + } + + void close() override { + if (!conn) return; + + mysql_close(conn); + conn = nullptr; + } + }; +#endif + + std::unique_ptr<BaseDatabase> create_connection(const Configuration &cfg); +}
\ No newline at end of file diff --git a/bot/src/emotes.cpp b/bot/src/emotes.cpp index dbae407..5b1e122 100644 --- a/bot/src/emotes.cpp +++ b/bot/src/emotes.cpp @@ -4,13 +4,13 @@ #include <chrono> #include <exception> #include <map> +#include <memory> #include <optional> -#include <pqxx/pqxx> #include <string> #include <thread> #include <vector> -#include "config.hpp" +#include "database.hpp" #include "logger.hpp" #include "schemas/stream.hpp" #include "utils/string.hpp" @@ -61,48 +61,43 @@ namespace bot::emotes { return; } - pqxx::connection conn(GET_DATABASE_CONNECTION_URL(bundle.configuration)); - pqxx::work work(conn); + std::unique_ptr<db::BaseDatabase> conn = + db::create_connection(bundle.configuration); - pqxx::result events = work.exec_params( - "SELECT e.id, e.message, array_to_json(e.flags) AS " - "flags, c.alias_name AS channel_aname, c.alias_id AS channel_aid FROM " + db::DatabaseRows events = conn->exec( + "SELECT e.id, e.message, is_massping, c.alias_name AS channel_aname, " + "c.alias_id AS channel_aid FROM " "events e " "INNER JOIN channels c ON c.id = e.channel_id " "WHERE e.event_type = $1 AND e.name = $2", - pqxx::params{static_cast<int>(event_type), c_name}); + {std::to_string(static_cast<int>(event_type)), c_name}); - for (const auto &event : events) { + for (const db::DatabaseRow &event : events) { std::vector<std::string> names; - bool massping_enabled = false; - if (!event[2].is_null()) { - nlohmann::json j = nlohmann::json::parse(event[2].as<std::string>()); - massping_enabled = std::any_of(j.begin(), j.end(), [](const auto &x) { - return static_cast<int>(x) == static_cast<int>(schemas::MASSPING); - }); - } + bool massping_enabled = std::stoi(event.at("is_massping")); if (massping_enabled) { auto chatters = bundle.helix_client.get_chatters( - event[4].as<int>(), bundle.irc_client.get_bot_id()); + std::stoi(event.at("channel_aid")), bundle.irc_client.get_bot_id()); std::for_each(chatters.begin(), chatters.end(), [&names](const auto &x) { names.push_back(x.login); }); } else { - pqxx::result subs = work.exec_params( + db::DatabaseRows subs = conn->exec( "SELECT u.alias_name FROM users u " "INNER JOIN events e ON e.id = $1 " "INNER JOIN event_subscriptions es ON es.event_id = e.id " "WHERE u.id = es.user_id", - pqxx::params{event[0].as<int>()}); + {event.at("id")}); - std::for_each(subs.begin(), subs.end(), [&names](const pqxx::row &x) { - names.push_back(x[0].as<std::string>()); - }); + std::for_each(subs.begin(), subs.end(), + [&names](const db::DatabaseRow &x) { + names.push_back(x.at("alias_name")); + }); } - std::string base = prefix + " " + event[1].as<std::string>(); + std::string base = prefix + " " + event.at("message"); if (!names.empty()) { base.append(" · "); } @@ -126,26 +121,23 @@ namespace bot::emotes { utils::string::separate_by_length(base, names, "@", " ", 500); for (const auto &msg : msgs) { - bundle.irc_client.say(event[3].as<std::string>(), base + msg); + bundle.irc_client.say(event.at("channel_aname"), base + msg); } } - - work.commit(); - conn.close(); } void check_seventv_emotesets(const EmoteEventBundle *bundle, - pqxx::work &work) { - pqxx::result events = work.exec( + std::unique_ptr<db::BaseDatabase> &conn) { + db::DatabaseRows events = conn->exec( "SELECT name FROM events WHERE event_type >= 10 AND event_type <= " "12 GROUP BY name"); auto &ids = bundle->stv_ws_client.get_ids(); std::vector<std::string> names; - std::for_each(events.begin(), events.end(), [&names](const pqxx::row &r) { - names.push_back(r[0].as<std::string>()); - }); + std::for_each( + events.begin(), events.end(), + [&names](const db::DatabaseRow &r) { names.push_back(r.at("name")); }); // adding new emote sets for (const std::string &name : names) { @@ -189,8 +181,9 @@ namespace bot::emotes { } #ifdef BUILD_BETTERTTV - void check_betterttv_users(const EmoteEventBundle *bundle, pqxx::work &work) { - pqxx::result events = work.exec( + void check_betterttv_users(const EmoteEventBundle *bundle, + std::unique_ptr<db::BaseDatabase> &conn) { + db::DatabaseRows events = conn->exec( "SELECT name FROM events WHERE event_type >= 13 AND event_type <= " "15 GROUP BY name"); @@ -198,9 +191,10 @@ namespace bot::emotes { bundle->bttv_ws_client.get_ids(); std::vector<std::string> names; - std::for_each(events.begin(), events.end(), [&names](const pqxx::row &r) { - names.push_back("twitch:" + r[0].as<std::string>()); - }); + std::for_each(events.begin(), events.end(), + [&names](const db::DatabaseRow &r) { + names.push_back("twitch:" + r.at("name")); + }); // adding new users for (const std::string &name : names) { @@ -228,21 +222,20 @@ namespace bot::emotes { log::info("emotes/thread", "Started emote thread."); while (true) { - pqxx::connection conn(GET_DATABASE_CONNECTION_URL(bundle->configuration)); - pqxx::work work(conn); + std::unique_ptr<db::BaseDatabase> conn = + db::create_connection(bundle->configuration); try { - check_seventv_emotesets(bundle, work); + check_seventv_emotesets(bundle, conn); #ifdef BUILD_BETTERTTV - check_betterttv_users(bundle, work); + check_betterttv_users(bundle, conn); #endif } catch (std::exception ex) { log::error("emotes/thread", "Error occurred in emote thread: " + std::string(ex.what())); } - work.commit(); - conn.close(); + conn->close(); std::this_thread::sleep_for(std::chrono::seconds(30)); } diff --git a/bot/src/github.cpp b/bot/src/github.cpp index 13825cb..fb96136 100644 --- a/bot/src/github.cpp +++ b/bot/src/github.cpp @@ -3,7 +3,7 @@ #include <algorithm> #include <chrono> #include <iterator> -#include <pqxx/pqxx> +#include <memory> #include <string> #include <thread> #include <unordered_map> @@ -14,10 +14,10 @@ #include "cpr/api.h" #include "cpr/cprtypes.h" #include "cpr/response.h" +#include "database.hpp" #include "irc/client.hpp" #include "logger.hpp" #include "nlohmann/json.hpp" -#include "pqxx/internal/statement_parameters.hxx" #include "schemas/stream.hpp" #include "utils/string.hpp" @@ -60,15 +60,15 @@ namespace bot { } void GithubListener::check_for_listeners() { - pqxx::connection conn(GET_DATABASE_CONNECTION_URL(this->configuration)); - pqxx::work work(conn); + std::unique_ptr<db::BaseDatabase> conn = + db::create_connection(this->configuration); - pqxx::result repos = - work.exec("SELECT name FROM events WHERE event_type = 40"); + db::DatabaseRows repos = + conn->exec("SELECT name FROM events WHERE event_type = 40"); // Adding new repos for (const auto &repo : repos) { - std::string id = repo[0].as<std::string>(); + std::string id = repo.at("name"); if (std::any_of(this->ids.begin(), this->ids.end(), [&id](const auto &x) { return x == id; })) continue; @@ -81,9 +81,9 @@ namespace bot { std::vector<std::string> names_to_delete; for (const std::string &id : this->ids) { - if (std::any_of(repos.begin(), repos.end(), [&id](const pqxx::row &x) { - return x[0].as<std::string>() == id; - })) + if (std::any_of( + repos.begin(), repos.end(), + [&id](const db::DatabaseRow &x) { return x.at("name") == id; })) continue; names_to_delete.push_back(id); @@ -94,16 +94,10 @@ namespace bot { this->ids.erase(id_pos); this->commits.erase(name); } - - work.commit(); - conn.close(); } std::unordered_map<std::string, std::vector<Commit>> GithubListener::check_new_commits() { - pqxx::connection conn(GET_DATABASE_CONNECTION_URL(this->configuration)); - pqxx::work work(conn); - std::unordered_map<std::string, std::vector<Commit>> new_commits; for (const std::string &id : this->ids) { @@ -144,60 +138,53 @@ namespace bot { std::this_thread::sleep_for(std::chrono::seconds(2)); } - work.commit(); - conn.close(); - return new_commits; } void GithubListener::notify_about_commits( const std::unordered_map<std::string, std::vector<Commit>> &new_commits) { - pqxx::connection conn(GET_DATABASE_CONNECTION_URL(this->configuration)); - pqxx::work work(conn); + std::unique_ptr<db::BaseDatabase> conn = + db::create_connection(this->configuration); for (const auto &pair : new_commits) { // don't notify on startup if (this->commits.at(pair.first).size() == 0) continue; - pqxx::result events = work.exec_params( - "SELECT e.id, e.message, array_to_json(e.flags) AS flags, " + db::DatabaseRows events = conn->exec( + "SELECT e.id, e.message, is_massping, " "c.alias_name AS " "channel_name, c.alias_id AS channel_aid " "FROM events e " "INNER JOIN channels c ON c.id = e.channel_id " "WHERE e.name = $1 AND e.event_type = 40", - pqxx::params{pair.first}); + {pair.first}); for (const auto &event : events) { std::vector<std::string> names; - bool massping_enabled = false; - if (!event[2].is_null()) { - nlohmann::json j = nlohmann::json::parse(event[2].as<std::string>()); - massping_enabled = std::any_of(j.begin(), j.end(), [](const auto &x) { - return static_cast<int>(x) == static_cast<int>(schemas::MASSPING); - }); - } + bool massping_enabled = std::stoi(event.at("is_massping")); if (massping_enabled) { auto chatters = this->helix_client.get_chatters( - event[4].as<int>(), this->irc_client.get_bot_id()); + std::stoi(event.at("channel_aid")), + this->irc_client.get_bot_id()); std::for_each(chatters.begin(), chatters.end(), [&names](const auto &u) { names.push_back(u.login); }); } else { - pqxx::result subs = work.exec_params( + db::DatabaseRows subs = conn->exec( "SELECT u.alias_name FROM users u INNER JOIN event_subscriptions " "es ON es.user_id = u.id WHERE es.event_id = $1", - pqxx::params{event[0].as<int>()}); + {event.at("id")}); - std::for_each(subs.begin(), subs.end(), [&names](const pqxx::row &x) { - names.push_back(x[0].as<std::string>()); - }); + std::for_each(subs.begin(), subs.end(), + [&names](const db::DatabaseRow &x) { + names.push_back(x.at("alias_name")); + }); } for (const Commit &commit : pair.second) { - std::string message = event[1].as<std::string>(); + std::string message = event.at("message"); message = "🧑💻 " + message; // Replacing SHA placeholder @@ -223,14 +210,13 @@ namespace bot { std::for_each(parts.begin(), parts.end(), [&message, &event, this](const std::string &part) { - this->irc_client.say(event[3].as<std::string>(), + this->irc_client.say(event.at("channel_name"), message + part); }); } } } - work.commit(); - conn.close(); + conn->close(); } }
\ No newline at end of file diff --git a/bot/src/handlers.cpp b/bot/src/handlers.cpp index f8b911d..c7bdc36 100644 --- a/bot/src/handlers.cpp +++ b/bot/src/handlers.cpp @@ -2,8 +2,8 @@ #include <algorithm> #include <exception> +#include <memory> #include <optional> -#include <pqxx/pqxx> #include <random> #include <string> #include <vector> @@ -11,12 +11,12 @@ #include "bundle.hpp" #include "commands/command.hpp" #include "commands/request.hpp" -#include "commands/request_util.hpp" #include "commands/response_error.hpp" #include "constants.hpp" #include "cpr/api.h" #include "cpr/multipart.h" #include "cpr/response.h" +#include "database.hpp" #include "irc/message.hpp" #include "localization/line_id.hpp" #include "logger.hpp" @@ -27,14 +27,9 @@ namespace bot::handlers { void handle_private_message( const InstanceBundle &bundle, command::CommandLoader &command_loader, - const irc::Message<irc::MessageType::Privmsg> &message, - pqxx::connection &conn) { - if (utils::string::string_contains_sql_injection(message.message)) { - log::warn("PrivateMessageHandler", - "Received the message in #" + message.source.login + - " with SQL injection: " + message.message); - return; - } + const irc::Message<irc::MessageType::Privmsg> &message) { + std::unique_ptr<db::BaseDatabase> conn = + db::create_connection(bundle.configuration); std::optional<command::Request> request = command::generate_request(command_loader, message, conn); @@ -60,28 +55,27 @@ namespace bot::handlers { } } - pqxx::work work(conn); - pqxx::result channels = - work.exec("SELECT * FROM channels WHERE alias_id = " + - std::to_string(message.source.id)); + db::DatabaseRows channels = + conn->exec("SELECT * FROM channels WHERE alias_id = $1", + {std::to_string(message.source.id)}); if (!channels.empty()) { schemas::Channel channel(channels[0]); - pqxx::result channel_preferences = work.exec( + db::DatabaseRows channel_preferences = conn->exec( "SELECT * FROM channel_preferences WHERE " - "channel_id = " + - std::to_string(channel.get_id())); + "id = $1", + {std::to_string(channel.get_id())}); schemas::ChannelPreferences preference(channel_preferences[0]); - pqxx::result cmds = - work.exec("SELECT message FROM custom_commands WHERE name = '" + - message.message + "' AND channel_id = '" + - std::to_string(channel.get_id()) + "'"); + db::DatabaseRows cmds = conn->exec( + "SELECT message FROM custom_commands WHERE name = $1 AND channel_id " + "= $2", + {message.message, std::to_string(channel.get_id())}); if (!cmds.empty()) { - std::string msg = cmds[0][0].as<std::string>(); + std::string msg = cmds[0].at("message"); bundle.irc_client.say(message.source.login, msg); } else { @@ -89,7 +83,7 @@ namespace bot::handlers { } } - work.commit(); + conn->close(); } void make_markov_response( diff --git a/bot/src/handlers.hpp b/bot/src/handlers.hpp index 2046bbe..f136fa1 100644 --- a/bot/src/handlers.hpp +++ b/bot/src/handlers.hpp @@ -3,13 +3,11 @@ #include "bundle.hpp" #include "commands/command.hpp" #include "irc/message.hpp" -#include "pqxx/pqxx" namespace bot::handlers { void handle_private_message( const InstanceBundle &bundle, command::CommandLoader &command_loader, - const irc::Message<irc::MessageType::Privmsg> &message, - pqxx::connection &conn); + const irc::Message<irc::MessageType::Privmsg> &message); void make_markov_response( const InstanceBundle &bundle, diff --git a/bot/src/main.cpp b/bot/src/main.cpp index fe0d3d1..a9ed193 100644 --- a/bot/src/main.cpp +++ b/bot/src/main.cpp @@ -1,7 +1,7 @@ #include <emotespp/seventv.hpp> +#include <map> #include <memory> #include <optional> -#include <pqxx/pqxx> #include <sol/state.hpp> #include <string> #include <thread> @@ -14,6 +14,7 @@ #include "commands/lua.hpp" #include "commands/response.hpp" #include "config.hpp" +#include "database.hpp" #include "emotes.hpp" #include "github.hpp" #include "handlers.hpp" @@ -74,48 +75,41 @@ int main(int argc, char *argv[]) { client.join(client.get_bot_username()); - pqxx::connection conn(GET_DATABASE_CONNECTION_URL(cfg)); - pqxx::work *work = new pqxx::work(conn); + std::unique_ptr<bot::db::BaseDatabase> conn = bot::db::create_connection(cfg); - pqxx::result rows = work->exec( - "SELECT alias_id FROM channels WHERE opted_out_at is null AND alias_id " - "!= " + - std::to_string(client.get_bot_id())); + bot::db::DatabaseRows rows = conn->exec( + "SELECT alias_id FROM channels WHERE opted_out_At IS NULL AND alias_id " + "!= " + "$1", + {std::to_string(client.get_bot_id())}); std::vector<int> ids; - for (const auto &row : rows) { - ids.push_back(row[0].as<int>()); + for (const bot::db::DatabaseRow &row : rows) { + ids.push_back(std::stoi(row.at("alias_id"))); } auto helix_channels = helix_client.get_users(ids); // it could be optimized for (const auto &helix_channel : helix_channels) { - auto channel = - work->exec("SELECT id, alias_name FROM channels WHERE alias_id = " + - std::to_string(helix_channel.id)); + std::vector<std::map<std::string, std::string>> channels = + conn->exec("SELECT id, alias_name FROM channels WHERE alias_id = $1", + {std::to_string(helix_channel.id)}); - if (!channel.empty()) { - std::string name = channel[0][1].as<std::string>(); + if (!channels.empty()) { + std::string name = channels[0]["alias_name"]; if (name != helix_channel.login) { - work->exec("UPDATE channels SET alias_name = '" + helix_channel.login + - "' WHERE id = " + std::to_string(channel[0][0].as<int>())); - work->commit(); - - delete work; - work = new pqxx::work(conn); + conn->exec("UPDATE channels SET alias_name = $1 WHERE id = $2", + {helix_channel.login, channels[0][0]}); } client.join(helix_channel.login); } } - work->commit(); - delete work; - - conn.close(); + conn->close(); bot::stream::StreamListenerClient stream_listener_client( helix_client, kick_api_client, client, cfg); @@ -206,12 +200,7 @@ int main(int argc, char *argv[]) { bot::InstanceBundle bundle{client, helix_client, kick_api_client, localization, cfg, command_loader}; - pqxx::connection conn(GET_DATABASE_CONNECTION_URL(cfg)); - - bot::handlers::handle_private_message(bundle, command_loader, message, - conn); - - conn.close(); + bot::handlers::handle_private_message(bundle, command_loader, message); }); client.run(); diff --git a/bot/src/schemas/channel.hpp b/bot/src/schemas/channel.hpp index d2f13eb..a8979ec 100644 --- a/bot/src/schemas/channel.hpp +++ b/bot/src/schemas/channel.hpp @@ -1,30 +1,30 @@ #pragma once +#include <algorithm> #include <chrono> #include <optional> -#include <pqxx/pqxx> #include <sol/sol.hpp> #include <string> #include <vector> #include "../constants.hpp" #include "../utils/chrono.hpp" -#include "../utils/string.hpp" +#include "database.hpp" namespace bot::schemas { class Channel { public: - Channel(const pqxx::row &row) { - this->id = row[0].as<int>(); - this->alias_id = row[1].as<int>(); - this->alias_name = row[2].as<std::string>(); + Channel(const db::DatabaseRow &row) { + this->id = std::stoi(row.at("id")); + this->alias_id = std::stoi(row.at("alias_id")); + this->alias_name = row.at("alias_name"); this->joined_at = - utils::chrono::string_to_time_point(row[3].as<std::string>()); + utils::chrono::string_to_time_point(row.at("joined_at")); - if (!row[4].is_null()) { + if (!row.at("opted_out_at").empty()) { this->opted_out_at = - utils::chrono::string_to_time_point(row[4].as<std::string>()); + utils::chrono::string_to_time_point(row.at("opted_out_at")); } } @@ -51,6 +51,8 @@ namespace bot::schemas { }; enum ChannelFeature { MARKOV_RESPONSES, RANDOM_MARKOV_RESPONSES }; + const std::vector<ChannelFeature> FEATURES = {MARKOV_RESPONSES, + RANDOM_MARKOV_RESPONSES}; std::optional<ChannelFeature> string_to_channel_feature( const std::string &value); std::optional<std::string> channelfeature_to_string( @@ -58,32 +60,22 @@ namespace bot::schemas { class ChannelPreferences { public: - ChannelPreferences(const pqxx::row &row) { - this->channel_id = row[0].as<int>(); - - if (!row[1].is_null()) { - this->prefix = row[1].as<std::string>(); - } else { - this->prefix = DEFAULT_PREFIX; - } - - if (!row[2].is_null()) { - this->locale = row[2].as<std::string>(); - } else { - this->locale = DEFAULT_LOCALE_ID; - } - - if (!row[3].is_null()) { - std::string x = row[3].as<std::string>(); - x = x.substr(1, x.length() - 2); - std::vector<std::string> split_text = - utils::string::split_text(x, ','); - - for (const std::string &part : split_text) { - ChannelFeature feature = (ChannelFeature)std::stoi(part); - this->features.push_back(feature); - } - } + ChannelPreferences(const db::DatabaseRow &row) { + this->channel_id = std::stoi(row.at("id")); + this->prefix = + row.at("prefix").empty() ? DEFAULT_PREFIX : row.at("prefix"); + this->locale = + row.at("locale").empty() ? DEFAULT_LOCALE_ID : row.at("locale"); + + std::for_each( + FEATURES.begin(), FEATURES.end(), + [this, &row](const ChannelFeature &f) { + std::optional<std::string> feature = channelfeature_to_string(f); + if (feature.has_value() && row.find(*feature) != row.end() && + row.at(*feature) == "1") { + this->features.push_back(f); + } + }); } ~ChannelPreferences() = default; diff --git a/bot/src/schemas/user.hpp b/bot/src/schemas/user.hpp index e9d7e0f..ee9bd10 100644 --- a/bot/src/schemas/user.hpp +++ b/bot/src/schemas/user.hpp @@ -2,26 +2,26 @@ #include <chrono> #include <optional> -#include <pqxx/pqxx> #include <sol/sol.hpp> #include <string> #include "../utils/chrono.hpp" +#include "database.hpp" namespace bot::schemas { class User { public: - User(const pqxx::row &row) { - this->id = row[0].as<int>(); - this->alias_id = row[1].as<int>(); - this->alias_name = row[2].as<std::string>(); + User(const db::DatabaseRow &row) { + this->id = std::stoi(row.at("id")); + this->alias_id = std::stoi(row.at("alias_id")); + this->alias_name = row.at("alias_name"); this->joined_at = - utils::chrono::string_to_time_point(row[3].as<std::string>()); + utils::chrono::string_to_time_point(row.at("joined_at")); - if (!row[4].is_null()) { + if (!row.at("opted_out_at").empty()) { this->opted_out_at = - utils::chrono::string_to_time_point(row[4].as<std::string>()); + utils::chrono::string_to_time_point(row.at("opted_out_at")); } } @@ -54,11 +54,11 @@ namespace bot::schemas { class UserRights { public: - UserRights(const pqxx::row &row) { - this->id = row[0].as<int>(); - this->user_id = row[1].as<int>(); - this->channel_id = row[2].as<int>(); - this->level = static_cast<PermissionLevel>(row[3].as<int>()); + UserRights(const db::DatabaseRow &row) { + this->id = std::stoi(row.at("id")); + this->user_id = std::stoi(row.at("user_id")); + this->channel_id = std::stoi(row.at("channel_id")); + this->level = static_cast<PermissionLevel>(std::stoi(row.at("level"))); } ~UserRights() = default; diff --git a/bot/src/stream.cpp b/bot/src/stream.cpp index 107b883..ec10b71 100644 --- a/bot/src/stream.cpp +++ b/bot/src/stream.cpp @@ -2,14 +2,14 @@ #include <algorithm> #include <chrono> -#include <pqxx/pqxx> +#include <memory> #include <string> #include <thread> #include <vector> #include "api/kick.hpp" #include "api/twitch/schemas/stream.hpp" -#include "config.hpp" +#include "database.hpp" #include "logger.hpp" #include "nlohmann/json.hpp" #include "schemas/stream.hpp" @@ -209,48 +209,44 @@ namespace bot::stream { void StreamListenerClient::handler(const schemas::EventType &type, const api::twitch::schemas::Stream &stream, const StreamerData &data) { - pqxx::connection conn(GET_DATABASE_CONNECTION_URL(this->configuration)); - pqxx::work work(conn); + std::unique_ptr<db::BaseDatabase> conn = + db::create_connection(this->configuration); - pqxx::result events = work.exec_params( - "SELECT e.id, e.message, array_to_json(e.flags) AS " - "flags, c.alias_name AS channel_aname, c.alias_id AS channel_aid FROM " + db::DatabaseRows events = conn->exec( + "SELECT e.id, e.message, is_massping, c.alias_name AS channel_aname, " + "c.alias_id AS channel_aid FROM " "events e " "INNER JOIN channels c ON c.id = e.channel_id " "WHERE e.event_type = $1 AND e.name = $2", - pqxx::params{static_cast<int>(type), stream.get_user_id()}); + {std::to_string(static_cast<int>(type)), + std::to_string(stream.get_user_id())}); for (const auto &event : events) { std::vector<std::string> names; - bool massping_enabled = false; - if (!event[2].is_null()) { - nlohmann::json j = nlohmann::json::parse(event[2].as<std::string>()); - massping_enabled = std::any_of(j.begin(), j.end(), [](const auto &x) { - return static_cast<int>(x) == static_cast<int>(schemas::MASSPING); - }); - } + bool massping_enabled = std::stoi(event.at("is_massping")); if (massping_enabled) { auto chatters = this->helix_client.get_chatters( - event[4].as<int>(), this->irc_client.get_bot_id()); + std::stoi(event.at("channel_aid")), this->irc_client.get_bot_id()); std::for_each(chatters.begin(), chatters.end(), [&names](const auto &x) { names.push_back(x.login); }); } else { - pqxx::result subs = work.exec_params( + db::DatabaseRows subs = conn->exec( "SELECT u.alias_name FROM users u " "INNER JOIN events e ON e.id = $1 " "INNER JOIN event_subscriptions es ON es.event_id = e.id " "WHERE u.id = es.user_id", - pqxx::params{event[0].as<int>()}); + {event.at("id")}); - std::for_each(subs.begin(), subs.end(), [&names](const pqxx::row &x) { - names.push_back(x[0].as<std::string>()); - }); + std::for_each(subs.begin(), subs.end(), + [&names](const db::DatabaseRow &x) { + names.push_back(x.at("alias_name")); + }); } - std::string base = "⚡ " + event[1].as<std::string>(); + std::string base = "⚡ " + event.at("message"); if (!names.empty()) { base.append(" · "); } @@ -279,24 +275,23 @@ namespace bot::stream { utils::string::separate_by_length(base, names, "@", " ", 500); for (const auto &msg : msgs) { - this->irc_client.say(event[3].as<std::string>(), base + msg); + this->irc_client.say(event.at("channel_aname"), base + msg); } } - work.commit(); - conn.close(); + conn->close(); } void StreamListenerClient::update_channel_ids() { - pqxx::connection conn(GET_DATABASE_CONNECTION_URL(this->configuration)); - pqxx::work work(conn); + std::unique_ptr<db::BaseDatabase> conn = + db::create_connection(this->configuration); - pqxx::result ids = - work.exec("SELECT name, event_type FROM events WHERE event_type < 10"); + db::DatabaseRows ids = + conn->exec("SELECT name, event_type FROM events WHERE event_type < 10"); for (const auto &row : ids) { - int id = row[0].as<int>(); - int event_type = row[1].as<int>(); + int id = std::stoi(row.at("name")); + int event_type = std::stoi(row.at("event_type")); StreamerType type = (event_type >= schemas::EventType::KICK_LIVE && event_type <= schemas::EventType::KICK_GAME) @@ -313,7 +308,6 @@ namespace bot::stream { listen_channel(id, type); } - work.commit(); - conn.close(); + conn->close(); } } diff --git a/bot/src/timer.cpp b/bot/src/timer.cpp index 055dde0..e4f3508 100644 --- a/bot/src/timer.cpp +++ b/bot/src/timer.cpp @@ -1,11 +1,11 @@ #include "timer.hpp" #include <chrono> -#include <pqxx/pqxx> #include <string> #include <thread> #include "config.hpp" +#include "database.hpp" #include "irc/client.hpp" #include "utils/chrono.hpp" @@ -13,22 +13,22 @@ namespace bot { void create_timer_thread(irc::Client *irc_client, Configuration *configuration) { while (true) { - pqxx::connection conn(GET_DATABASE_CONNECTION_URL_POINTER(configuration)); - pqxx::work *work = new pqxx::work(conn); + std::unique_ptr<db::BaseDatabase> conn = + db::create_connection(*configuration); - pqxx::result timers = work->exec( + db::DatabaseRows timers = conn->exec( "SELECT id, interval_sec, message, channel_id, last_executed_at FROM " "timers"); for (const auto &timer : timers) { - int id = timer[0].as<int>(); - int interval_sec = timer[1].as<int>(); - std::string message = timer[2].as<std::string>(); - int channel_id = timer[3].as<int>(); + int id = std::stoi(timer.at("id")); + int interval_sec = std::stoi(timer.at("interval_sec")); + std::string message = timer.at("message"); + int channel_id = std::stoi(timer.at("channel_id")); // it could be done in sql query std::chrono::system_clock::time_point last_executed_at = - utils::chrono::string_to_time_point(timer[4].as<std::string>()); + utils::chrono::string_to_time_point(timer.at("last_executed_at")); auto now = std::chrono::system_clock::now(); auto now_time_it = std::chrono::system_clock::to_time_t(now); auto now_tm = std::gmtime(&now_time_it); @@ -38,31 +38,25 @@ namespace bot { now - last_executed_at); if (difference.count() > interval_sec) { - pqxx::result channels = work->exec( - "SELECT alias_name, opted_out_at FROM channels WHERE id = " + - std::to_string(channel_id)); + db::DatabaseRows channels = conn->exec( + "SELECT alias_name, opted_out_at FROM channels WHERE id = $1", + {std::to_string(channel_id)}); - if (!channels.empty() && channels[0][1].is_null()) { - std::string alias_name = channels[0][0].as<std::string>(); + if (!channels.empty() && channels[0].at("opted_out_at").empty()) { + std::string alias_name = channels[0].at("alias_name"); irc_client->say(alias_name, message); } - work->exec( - "UPDATE timers SET last_executed_at = timezone('utc', now()) " + conn->exec( + "UPDATE timers SET last_executed_at = UTC_TIMESTAMP " "WHERE " - "id = " + - std::to_string(id)); - - work->commit(); - - delete work; - work = new pqxx::work(conn); + "id = $1", + {std::to_string(id)}); } } - delete work; - conn.close(); + conn->close(); std::this_thread::sleep_for(std::chrono::seconds(1)); } |
