diff options
| author | ilotterytea <iltsu@alright.party> | 2024-12-09 14:12:52 +0500 |
|---|---|---|
| committer | ilotterytea <iltsu@alright.party> | 2024-12-09 14:12:52 +0500 |
| commit | 69eb8e57770240637158df5b5578084f61366c5f (patch) | |
| tree | b78afe5c433c40adcf6081a52f53c0d44f9135d4 /bot | |
| parent | 42cbe6ab4f4e7018786bfab59f54fb8422bcaa87 (diff) | |
feat: markov responses
Diffstat (limited to 'bot')
| -rw-r--r-- | bot/src/handlers.cpp | 63 | ||||
| -rw-r--r-- | bot/src/handlers.hpp | 6 |
2 files changed, 66 insertions, 3 deletions
diff --git a/bot/src/handlers.cpp b/bot/src/handlers.cpp index e764652..38ff28b 100644 --- a/bot/src/handlers.cpp +++ b/bot/src/handlers.cpp @@ -1,5 +1,6 @@ #include "handlers.hpp" +#include <algorithm> #include <exception> #include <optional> #include <pqxx/pqxx> @@ -10,9 +11,14 @@ #include "commands/command.hpp" #include "commands/request.hpp" #include "commands/request_util.hpp" +#include "cpr/api.h" +#include "cpr/multipart.h" +#include "cpr/response.h" #include "irc/message.hpp" #include "localization/line_id.hpp" #include "logger.hpp" +#include "nlohmann/json.hpp" +#include "schemas/channel.hpp" #include "utils/string.hpp" namespace bot::handlers { @@ -57,23 +63,74 @@ namespace bot::handlers { pqxx::work work(conn); pqxx::result channels = - work.exec("SELECT id FROM channels WHERE alias_id = " + + work.exec("SELECT * FROM channels WHERE alias_id = " + std::to_string(message.source.id)); if (!channels.empty()) { - int channel_id = channels[0][0].as<int>(); + schemas::Channel channel(channels[0]); + + pqxx::result channel_preferences = work.exec( + "SELECT * FROM channel_preferences WHERE " + "channel_id = " + + std::to_string(channel.get_id())); + + schemas::ChannelPreferences preference(channel_preferences[0]); + pqxx::result cmds = work.exec("SELECT message FROM custom_commands WHERE name = '" + message.message + "' AND channel_id = '" + - std::to_string(channel_id) + "'"); + std::to_string(channel.get_id()) + "'"); if (!cmds.empty()) { std::string msg = cmds[0][0].as<std::string>(); bundle.irc_client.say(message.source.login, msg); + } else { + make_markov_response(bundle, message, channel, preference); } } work.commit(); } + + void make_markov_response( + const InstanceBundle &bundle, + const irc::Message<irc::MessageType::Privmsg> &message, + const schemas::Channel &channel, + const schemas::ChannelPreferences &preference) { + bool is_markov_responses_enabled = + std::any_of(preference.get_features().begin(), + preference.get_features().end(), [](const int &x) { + return (schemas::ChannelFeature)x == + schemas::ChannelFeature::MARKOV_RESPONSES; + }); + + if (!is_markov_responses_enabled) return; + + std::string prefix = "@" + bundle.irc_client.get_bot_username() + ","; + + if (message.message.substr(0, prefix.length()) != prefix) return; + + cpr::Response response = cpr::Post( + cpr::Url{"https://markov.ilotterytea.kz/api/v1/generate"}, + cpr::Multipart{ + {"question", + message.message.substr(prefix.length(), message.message.length())}, + {"max_length", 200}}); + + if (response.status_code != 200) return; + + nlohmann::json j = nlohmann::json::parse(response.text); + + std::string answer; + auto answer_field = j["data"]["answer"]; + + if (answer_field.is_null()) + answer = "..."; + else + answer = answer_field; + + bundle.irc_client.say(message.source.login, + message.sender.login + ": " + answer); + } } diff --git a/bot/src/handlers.hpp b/bot/src/handlers.hpp index a143f76..a78a529 100644 --- a/bot/src/handlers.hpp +++ b/bot/src/handlers.hpp @@ -11,4 +11,10 @@ namespace bot::handlers { const command::CommandLoader &command_loader, const irc::Message<irc::MessageType::Privmsg> &message, pqxx::connection &conn); + + void make_markov_response( + const InstanceBundle &bundle, + const irc::Message<irc::MessageType::Privmsg> &message, + const schemas::Channel &channel, + const schemas::ChannelPreferences &preference); } |
