diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..24fb858 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +build +run.sh +.vscode +.cache \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 5d2e07c..69f2916 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,10 +6,9 @@ include(FetchContent) file(GLOB_RECURSE SRC_FILES "src/*.cpp") -set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD 23) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -find_package(CURL) set(SKIP_BUILD_TEST ON) set(CPR_USE_SYSTEM_CURL ON) @@ -22,33 +21,22 @@ target_include_directories(${PROJECT_NAME} PRIVATE include) FetchContent_Declare( cpr GIT_REPOSITORY https://github.com/libcpr/cpr.git - GIT_TAG 1.11.2 -) -FetchContent_Declare( - json - GIT_REPOSITORY https://github.com/nlohmann/json.git - GIT_TAG v3.11.3 + GIT_TAG 1.12.0 ) FetchContent_Declare( tgbot GIT_REPOSITORY https://github.com/reo7sp/tgbot-cpp.git - GIT_TAG v1.9 -) -FetchContent_Declare( - fmt - GIT_REPOSITORY https://github.com/fmtlib/fmt.git - GIT_TAG 11.1.4 + GIT_TAG v1.9.1 ) FetchContent_Declare( libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git - GIT_TAG 7.10.0 + GIT_TAG 7.10.1 ) FetchContent_Declare( pugixml GIT_REPOSITORY https://github.com/zeux/pugixml.git GIT_TAG v1.15 ) -FetchContent_MakeAvailable(cpr json tgbot fmt libpqxx pugixml) - -target_link_libraries(${PROJECT_NAME} cpr nlohmann_json TgBot ${OPENSSL_LIBRARIES} fmt::fmt pugixml pqxx ${CURL_LIBRARIES}) \ No newline at end of file +FetchContent_MakeAvailable(cpr tgbot libpqxx pugixml) +target_link_libraries(${PROJECT_NAME} cpr TgBot pugixml pqxx) \ No newline at end of file diff --git a/include/bot.hpp b/include/bot.hpp index 570bcfd..9c6b310 100644 --- a/include/bot.hpp +++ b/include/bot.hpp @@ -1,27 +1,18 @@ #pragma once -#include -#include -#include -#include -#include -#include #include - #include - -#include #include -#include #include #include #include +#include TgBot::InlineKeyboardMarkup buildInlineKeyboard(const std::vector& buttons); class Bot : public TgBot::Bot { public: - Bot(const std::string& token, unsigned int workers = 15); + Bot(const std::string& token); ~Bot(); void start(); @@ -39,7 +30,8 @@ class Bot : public TgBot::Bot { template void sendMessagef(int64_t id, std::string fmt, Args&&... args) { - const std::string text = fmt::vformat(fmt, fmt::vargs{{args...}}); + auto format_args = std::make_format_args(std::forward(args)...); + const std::string text = std::vformat(fmt, format_args); sendMessage(id, text); } @@ -71,10 +63,11 @@ class Bot : public TgBot::Bot { std::vector> services; - std::shared_ptr commands_pool; - std::shared_ptr thread_pool; + std::unique_ptr commands_pool; + std::unique_ptr thread_pool; + std::unique_ptr sync; - void sendContent(const send_t& send, std::int64_t user_id, std::shared_ptr service); + void sendContent(const postData& post, std::int64_t user_id, std::shared_ptr service); void mainloop(); void update_services(); void command_handler(TgBot::Message::Ptr message); diff --git a/include/commands.hpp b/include/commands.hpp index 557209d..3dc5b54 100644 --- a/include/commands.hpp +++ b/include/commands.hpp @@ -87,12 +87,12 @@ struct BaseCommand : public AbstractBaseCommand { std::string args_str = ""; for (const auto& arg : args) - args_str += fmt::format("<{}> ", arg); + args_str += std::format("<{}> ", arg); - help_str += fmt::format("\t/{} {}", name, args_str); + help_str += std::format("\t/{} {}", name, args_str); if (help.size() > 0) - help_str += fmt::format(":\t{}", help); + help_str += std::format(":\t{}", help); return help_str; } diff --git a/include/db/connectionpool.hpp b/include/db/connectionpool.hpp index a0d7064..6bc6d30 100644 --- a/include/db/connectionpool.hpp +++ b/include/db/connectionpool.hpp @@ -1,17 +1,28 @@ #pragma once +#include #include #include #include #include #include +#include class ConnectionPool { public: - ConnectionPool(const std::string &url, int count) { - for (int i = 0; i < count; i++) - connections.push(std::make_unique(url)); + ConnectionPool(const std::string &url) { + std::vector>> futures; + for (int i = 0; i < CONNECT_COUNT; ++i) { + futures.push_back(std::async(std::launch::async, [url]() { + return std::make_unique(url); + })); + } + + for (auto& f : futures) { + auto conn = f.get(); + connections.push(std::move(conn)); + } LOG_INFO("[DB] Connected to database"); } diff --git a/include/db/db.hpp b/include/db/db.hpp index cd4f774..36a634e 100644 --- a/include/db/db.hpp +++ b/include/db/db.hpp @@ -1,35 +1,28 @@ #pragma once -#include -#include -#include -#include -#include -#include #include #include #include +#include class DB { public: DB(const DB&) = delete; DB& operator=(const DB&) = delete; - bool addTag(std::shared_ptr service, std::string_view tag, std::int64_t user_id); + std::expected addTag(std::shared_ptr service, std::string_view tag, std::int64_t user_id); void addAntiTag(std::shared_ptr service, std::string_view tag, std::int64_t user_id); void rmTag(std::shared_ptr service, std::string_view tag, std::int64_t user_id); void rmAntiTag(std::shared_ptr service, std::string_view tag, std::int64_t user_id); - void addHistory(std::shared_ptr service, int tag_id, const post_data_tv &newhistory, std::int64_t user_id); + void addHistory(const post_data_tv &newhistory, std::int64_t user_id); bool isUserTableEmpty(); bool userExist(std::int64_t tg_id, bool is_admin = false); - std::unordered_map> getUsersByTags(std::shared_ptr service); - std::vector getHistory(std::int64_t user_id, std::shared_ptr service); - std::vector getAntiTagForUserAndSite(std::int64_t user_id, std::shared_ptr service); + post_data_tv getNewPostForUser(std::shared_ptr service, std::int64_t tg_id); + std::vector getUserForTags(std::shared_ptr service); std::string getFormattedTagsAndAntiTags(std::int64_t user_id); void addUser(std::int64_t tg_id, bool is_admin = false); void rmUser(std::int64_t tg_id); - void scoreUpdate(std::shared_ptr service, std::int64_t user_id, std::int64_t score); - std::int64_t getScore(std::shared_ptr service, std::int64_t user_id); + bool userIsAdmin(std::int64_t tg_id); void addSystemValue(std::string_view key, std::string_view value); std::string getSystemValue(std::string_view key); static DB& getInstance(const std::string& connStr = "") { @@ -37,15 +30,21 @@ class DB { return instance; } - std::pair> getCacheTagById(std::string_view site, int tag_id); - std::pair> getCacheTagByName(std::string_view site, std::string_view name); - void addCacheTag(const taginfo_t& info); + std::optional getCacheTagById(int service_id, int tag_id); + std::pair> getCacheTagByName(int service_id, std::string_view name); + void addCacheTag(const Taginfo& info); std::optional getUserTagId(std::int64_t user_id, std::int64_t cache_tag_id); - std::optional getUserTagIdByExternalTagId(int external_tag_id); - std::optional getTagFromCacheOrAddTag(std::shared_ptr service, int tag_id); + std::optional getTagFromCacheOrAddTag(std::shared_ptr service, int tag_id); + std::optional getIdServiceByName(std::string_view name); + std::string getServiceName(int service_id); + + bool addPost(std::shared_ptr service, const post_data_tv& posts); + bool addTagOnPost(std::shared_ptr service, const std::vector& tags, int post_id); private: DB(const std::string &url_db); std::shared_ptr connection_pool; - std::pair> getAndCreate(std::shared_ptr service, std::string_view tag); + + std::optional tagInfoFromSite(std::shared_ptr service, const std::string& tag); + post_data_tv getPostsByTagId(std::shared_ptr service, int tag_id); }; \ No newline at end of file diff --git a/include/db/transaction.hpp b/include/db/transaction.hpp index e7410d3..a071ed3 100644 --- a/include/db/transaction.hpp +++ b/include/db/transaction.hpp @@ -18,7 +18,7 @@ class Transaction { void commit() { txn.commit(); } - + pqxx::work& get() { return txn; } diff --git a/include/keyboard.hpp b/include/keyboard.hpp index de3c63a..a3a9a7c 100644 --- a/include/keyboard.hpp +++ b/include/keyboard.hpp @@ -1,20 +1,14 @@ #pragma once -#include "commands.hpp" -#include "fmt/format.h" -#include "tgbot/types/InlineKeyboardButton.h" +#include #include #include -#include #include -#include #include #include -#include #include #include -#include #include struct BaseAbstractButton : public TgBot::InlineKeyboardButton { @@ -35,7 +29,7 @@ struct BaseButton : public BaseAbstractButton { BaseButton(const std::string& prefix) { text = Name; - callbackData = fmt::format("{} {}", prefix, std::string(Name)); + callbackData = std::format("{} {}", prefix, std::string(Name)); LOG_INFO("Make button: {}, CallData: {}", text, callbackData); } @@ -57,7 +51,7 @@ struct ButtonRow { template static std::vector make_buttons(const std::string& prefix, std::index_sequence) { - return {std::make_shared(fmt::format("{} {}", prefix, Is))...}; + return {std::make_shared(std::format("{} {}", prefix, Is))...}; } void dispatch(int index, CommandContext& ctx) { @@ -125,7 +119,7 @@ struct BaseKeyboard : public AbstractBaseKeyboard { template static std::tuple initialize_rows(std::index_sequence) { - return {Rows(fmt::format("{} {}", std::string(Name), Is))...}; + return {Rows(std::format("{} {}", std::string(Name), Is))...}; } // Recursive helper to dispatch based on the row index. diff --git a/include/log.hpp b/include/log.hpp index 0bcd7e5..71a8e8e 100644 --- a/include/log.hpp +++ b/include/log.hpp @@ -1,22 +1,25 @@ #pragma once -#include -#include #include #include +#include #define RESET "\033[0m" #define RED "\033[31m" #define GREEN "\033[32m" #define YELLOW "\033[33m" #define BOLDRED "\033[1m\033[31m" -#define BLUE "\033[34m" +#define BLUE "\033[34m" -#define LOG_GENERIC(color, level, ...) fmt::format(color "[" #level "][{}]" RESET " {}\n", std::this_thread::get_id(), fmt::format(__VA_ARGS__)) +#define LOG_GENERIC(color, level, ...) \ + std::cout << color "[" #level "][" << std::this_thread::get_id() << "] " << std::format(__VA_ARGS__) << RESET << "\n" -#define LOG_INFO(...) std::cout << LOG_GENERIC(GREEN, INFO, __VA_ARGS__) -#define LOG_WARN(...) std::cout << LOG_GENERIC(YELLOW, WARN, __VA_ARGS__) -#define LOG_DEBUG(...) std::cout << LOG_GENERIC(BLUE, DEBUG, __VA_ARGS__) -#define LOG_ERROR(...) std::cerr << LOG_GENERIC(RED, ERR, __VA_ARGS__) -#define LOG_CRITICAL(...) std::cerr << LOG_GENERIC(BOLDRED, CRIT, __VA_ARGS__) -#define LOG_FATAL(...) std::cerr << LOG_GENERIC(BOLDRED, FATAL, __VA_ARGS__) \ No newline at end of file +#define LOG_GENERIC2(color, level, ...) \ + std::cerr << color "[" #level "][" << std::this_thread::get_id() << "] " << std::format(__VA_ARGS__) << RESET << "\n" + +#define LOG_INFO(...) LOG_GENERIC(GREEN, INFO, __VA_ARGS__) +#define LOG_WARN(...) LOG_GENERIC(YELLOW, WARN, __VA_ARGS__) +#define LOG_DEBUG(...) LOG_GENERIC(BLUE, DEBUG, __VA_ARGS__) +#define LOG_ERROR(...) LOG_GENERIC2(BLUE, ERROR, __VA_ARGS__) +#define LOG_CRITICAL(...) LOG_GENERIC2(BLUE, CRIT, __VA_ARGS__) +#define LOG_FATAL(...) LOG_GENERIC2(BLUE, FATAL, __VA_ARGS__) diff --git a/include/services/gelbooru.hpp b/include/services/gelbooru.hpp index 59526c0..cd92826 100644 --- a/include/services/gelbooru.hpp +++ b/include/services/gelbooru.hpp @@ -1,22 +1,17 @@ #pragma once #include -#include -#include -#include -#include -#include -#include -using json = nlohmann::json; class Gelbooru : public Service, public std::enable_shared_from_this { public: Gelbooru(); post_data_tv parse(int tag_id) override; - std::optional getTagInfo(const std::string& tag) override; - std::optional tagInfoById(int tag_id) override; + std::optional getTagInfo(const std::string& tag) override; + std::optional tagInfoById(int tag_id) override; private: - std::optional parseTagInfoFromXml(std::string_view xml); + std::string apikey; + std::string userid; + std::optional parseTagInfoFromXml(std::string_view xml); }; \ No newline at end of file diff --git a/include/services/rule34.hpp b/include/services/rule34.hpp index 67741fb..b7a2048 100644 --- a/include/services/rule34.hpp +++ b/include/services/rule34.hpp @@ -1,20 +1,14 @@ #pragma once -#include #include -#include -#include -#include -#include -using json = nlohmann::json; class Rule34 : public Service, public std::enable_shared_from_this { public: Rule34(); post_data_tv parse(int tag_id) override; - std::optional getTagInfo(const std::string& tag) override; - std::optional tagInfoById(int tag_id) override; + std::optional getTagInfo(const std::string& tag) override; + std::optional tagInfoById(int tag_id) override; private: - std::optional parseTagInfoFromXml(std::string_view xml); + std::optional parseTagInfoFromXml(std::string_view xml); }; \ No newline at end of file diff --git a/include/services/service.hpp b/include/services/service.hpp index 74037f7..2b19f9b 100644 --- a/include/services/service.hpp +++ b/include/services/service.hpp @@ -2,52 +2,60 @@ #include #include -#include #include #include +#include -typedef struct { - std::string content; - std::string thumbnail; -} content_t; - -typedef struct { - content_t content; - std::string id; - int tag_id; -} send_t; - -typedef struct { +struct Taginfo { int id = -1; std::string name; int count = 1; int type = -1; bool ambiguous = 0; - std::string site; -} taginfo_t; + int site; +}; -typedef struct { - std::vector content; +struct postData { + std::string file_url; + std::string preview_url; std::vector tags; - std::string id; - std::string service; + int id; int score; -} post_data_t; + std::string rating; +}; + +using post_data_tv = std::vector; + +enum class TypeTags { + General, + Character, + Copyright, + Artist, + Meta, + Unknown +}; -typedef std::vector post_data_tv; +const std::unordered_map typeMap = { + {0, TypeTags::General}, + {1, TypeTags::Artist}, + {4, TypeTags::Character}, + {2, TypeTags::Copyright}, + {3, TypeTags::Copyright}, + {5, TypeTags::Meta} +}; class Service { public: const std::string type; const std::string url; const std::string postUrl; + int service_id; Service(std::string type, std::string url, std::string postUrl) : type(type), url(url), postUrl(postUrl) {}; - virtual void refresh() {}; virtual post_data_tv parse(int tag_id) = 0; - virtual std::string buildPostURL(const send_t& send) { - return postUrl + send.id; + virtual std::string buildPostURL(const postData& post) { + return postUrl + std::to_string(post.id); } virtual std::pair request(const std::string& url) { std::string login = std::format("&api_key={}&user_id={}", apikey, userid); @@ -60,17 +68,15 @@ class Service { return std::make_pair(r.text, r.status_code); }; - virtual std::optional getTagInfo(const std::string& tag) { - try { - taginfo_t taginfo; - taginfo.name = std::stoi(tag); - return taginfo; - } catch (...) { - return std::nullopt; - } - } + virtual std::optional getTagInfo(const std::string& tag) = 0; + virtual std::optional tagInfoById(int tag_id) = 0; - virtual std::optional tagInfoById(int tag_id) = 0; + static TypeTags getTypeFromValue(int apiValue) { + auto it = typeMap.find(apiValue); + if (it != typeMap.end()) + return it->second; + return TypeTags::Unknown; + } protected: cpr::ConnectTimeout defaultTimeOut{std::chrono::seconds{5}}; diff --git a/include/settings.hpp b/include/settings.hpp new file mode 100644 index 0000000..78d65ff --- /dev/null +++ b/include/settings.hpp @@ -0,0 +1,6 @@ +#pragma once + +#define LIMIT_COUNT_TAG 5000 +#define BOT_SLEEP_HOURS 2 +#define THREAD_COUNT 6 +#define CONNECT_COUNT THREAD_COUNT \ No newline at end of file diff --git a/include/sync.hpp b/include/sync.hpp new file mode 100644 index 0000000..f4683a0 --- /dev/null +++ b/include/sync.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include +#include +#include + +class Sync { + public: + Sync(std::vector>& services); + + private: + std::vector threads; + + void syncTag(std::shared_ptr service); + void syncPost(std::shared_ptr service); +}; \ No newline at end of file diff --git a/include/threadspool.hpp b/include/threadspool.hpp index 83bbc37..e5865af 100644 --- a/include/threadspool.hpp +++ b/include/threadspool.hpp @@ -15,7 +15,7 @@ class ThreadsPool { void start() { for (int i = 0; i < count; i++) - threads.emplace_back(std::thread(func)); + threads.emplace_back(std::jthread(func)); } void stop() { @@ -31,5 +31,5 @@ class ThreadsPool { private: int count; std::function func; - std::vector threads; + std::vector threads; }; \ No newline at end of file diff --git a/include/utils.hpp b/include/utils.hpp index 8c6ce6b..450321c 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -1,12 +1,8 @@ #pragma once +#include #include #include -#include -#include -#include -#include -#include template struct FixedString { @@ -28,17 +24,13 @@ class Utils { static bool contains(const std::vector& vec, const T& value) { return std::find(vec.begin(), vec.end(), value) != vec.end(); } - template - static bool contains(const std::vector& v1, const std::vector& v2) { + template + static bool contains(const std::vector& v1, const std::vector& v2) { for (const auto& element : v1) { if (std::find(v2.begin(), v2.end(), element) != v2.end()) return true; } return false; } - static std::vector sha256(std::string_view input); - static std::string urlsafe_b64encode(const std::vector& hash); - static std::string generate_urlsafe_token(std::size_t length); - static bool IsAllowFileFormat(std::string_view file); static std::string escapeMarkdownV2(std::string_view text); }; \ No newline at end of file diff --git a/src/bot.cpp b/src/bot.cpp index 61e71cf..0ea8be1 100644 --- a/src/bot.cpp +++ b/src/bot.cpp @@ -1,11 +1,12 @@ -#include "keyboard.hpp" -#include "log.hpp" +#include +#include +#include +#include #include -#include -Bot::Bot(const std::string &token, unsigned int workers) : TgBot::Bot(token) { - commands_pool = std::make_shared(); - thread_pool = std::make_shared(workers, +Bot::Bot(const std::string &token) : TgBot::Bot(token), sync(std::make_unique(services)) { + commands_pool = std::make_unique(); + thread_pool = std::make_unique(THREAD_COUNT, [this]() { thread_work(); }); getEvents().onUnknownCommand([this](TgBot::Message::Ptr message) { @@ -91,15 +92,15 @@ void Bot::command_handler(TgBot::Message::Ptr message) { cmd->dispatch(ctx); } -void Bot::sendContent(const send_t& send, std::int64_t user_id, std::shared_ptr service) { - std::filesystem::path url = send.content.content; - std::string thumbnail = send.content.thumbnail; +void Bot::sendContent(const postData& post, std::int64_t user_id, std::shared_ptr service) { + std::filesystem::path url = post.file_url; + std::string thumbnail = post.preview_url; LOG_INFO("Send: {}", url.string()); - std::string caption = fmt::format( + std::string caption = std::format( "[{}]({})\n[original\\({}\\)]({})", service->type, - service->buildPostURL(send), + service->buildPostURL(post), Utils::escapeMarkdownV2(url.extension().string()), url.string() ); @@ -111,30 +112,37 @@ void Bot::sendContent(const send_t& send, std::int64_t user_id, std::shared_ptr< else getApi().sendPhoto(user_id, url, caption, nullptr, nullptr, mode); } catch (const std::exception& e) { - try { - if (thumbnail.empty()) - throw std::runtime_error("thumbnail is empty"); - - getApi().sendPhoto(user_id, send.content.thumbnail, caption, nullptr, nullptr, mode); - } catch (const std::exception& e) { - getApi().sendMessage(user_id, caption, nullptr, nullptr, nullptr, mode); + if (!thumbnail.empty()) { + getApi().sendPhoto(user_id, thumbnail, caption, nullptr, nullptr, mode); + return; } + + getApi().sendMessage(user_id, caption, nullptr, nullptr, nullptr, mode); } } void Bot::mainloop() { - // Not really good, but also not really bad - // Can cause segfaults and some other memory access errors - auto update_services_threaded = [this]() { - std::thread(&Bot::update_services, this).detach(); - }; - - auto last_update = std::chrono::steady_clock::now(); - update_services_threaded(); + std::condition_variable_any cv; + std::mutex m; + + std::jthread updater([this, &cv, &m](std::stop_token stoken) { + while (!stoken.stop_requested()) { + try { + update_services(); + } catch (const std::exception& e) { + LOG_ERROR("update_services failed: {}", e.what()); + } + + std::unique_lock lk(m); + cv.wait_for(lk, std::chrono::hours(BOT_SLEEP_HOURS), [&]() { + return stoken.stop_requested(); + }); + } + }); getApi().deleteWebhook(); + TgBot::TgLongPoll longPoll(static_cast(*this)); - TgBot::TgLongPoll longPoll((TgBot::Bot&)*this); while (is_running) { try { longPoll.start(); @@ -142,17 +150,6 @@ void Bot::mainloop() { LOG_ERROR("Bot: {}", e.what()); std::this_thread::sleep_for(std::chrono::seconds(5)); } - - auto now = std::chrono::steady_clock::now(); - auto diff = std::chrono::duration_cast(now - last_update).count(); - if (diff >= 2) { - update_services_threaded(); - last_update = now; - LOG_INFO("Sleep 2h"); - } - - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - std::this_thread::yield(); } } @@ -167,54 +164,17 @@ std::string Bot::helpMessage() { void Bot::update_services() { for (const auto& service : services) { - LOG_INFO("Refresh: {}", service->type); - service->refresh(); - - auto repeatTags = DB::getInstance().getUsersByTags(service); - for (const auto& [tag_id, users] : repeatTags) { - post_data_tv posts = service->parse(tag_id); - if (posts.empty()) - continue; - - for (const auto& user: users) { - std::vector history = DB::getInstance().getHistory(user, service); - std::vector antitag = DB::getInstance().getAntiTagForUserAndSite(user, service); - std::int64_t score = DB::getInstance().getScore(service, user); - post_data_tv newhistory; - - for (const auto& post: posts) { - auto tags = post.tags; - if (Utils::contains(antitag, tags) || Utils::contains(history, post.id) || - (post.score < score) && score != 0) - continue; - - for (const auto& content : post.content) { - if (!Utils::IsAllowFileFormat(content.content)) - continue; - - send_t tmp; - tmp.content = content; - tmp.id = post.id; - tmp.tag_id = tag_id; - - sendContent(tmp, user, service); - } - - newhistory.push_back(post); - } - - int internal_tag_id = DB::getInstance().getUserTagIdByExternalTagId(tag_id).value_or(-1); - if (internal_tag_id == -1) { - LOG_ERROR("User tag id not found for user {} and external tag_id {}", user, tag_id); - continue; - } - - DB::getInstance().addHistory(service, internal_tag_id, newhistory, user); - - std::this_thread::sleep_for(std::chrono::seconds(5)); + std::vector users = DB::getInstance().getUserForTags(service); + for (const auto& user : users) { + post_data_tv posts = DB::getInstance().getNewPostForUser(service, user); + post_data_tv newhistory; + + for (const auto& post: posts) { + sendContent(post, user, service); + newhistory.push_back(post); } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + DB::getInstance().addHistory(newhistory, user); } } diff --git a/src/commands.cpp b/src/commands.cpp index d0a4beb..e05c26c 100644 --- a/src/commands.cpp +++ b/src/commands.cpp @@ -7,7 +7,7 @@ static const BaseCommand< std::string help_str = ctx.bot->helpMessage(); help_str += "\n\nAvailable services:\n"; for (const auto& name : ctx.bot->availableServices()) - help_str += fmt::format(" - {}\n", name); + help_str += std::format(" - {}\n", name); ctx.bot->replyMessagef(ctx.message, help_str.c_str()); } > help_command; @@ -45,9 +45,9 @@ static const BaseCommand< static const BaseCommand< "addtag", false, [](CommandContext& ctx, std::shared_ptr service){ - bool success = ctx.db.addTag(service, ctx.args[1], ctx.message->chat->id); - if (!success) { - ctx.bot->replyMessagef(ctx.message, "Tag `{}` already exists/Not valid", ctx.args[1]); + std::expected result = ctx.db.addTag(service, ctx.args[1], ctx.message->chat->id); + if (!result) { + ctx.bot->replyMessagef(ctx.message, "Tag `{}` {}", ctx.args[1], result.error()); return; } @@ -88,13 +88,4 @@ static const BaseCommand< [](CommandContext& ctx){ ctx.bot->replyMessagef(ctx.message, "Tags:\n{}", ctx.db.getFormattedTagsAndAntiTags(ctx.message->chat->id)); } -> taglist_command; - -static const BaseCommand< - "scorelimit", false, - [](CommandContext& ctx, std::shared_ptr service){ - int score = std::stoi(ctx.args[1]); - ctx.db.scoreUpdate(service, ctx.message->chat->id, score); - }, - CommandArgs<"service", "score"> -> scorelimit_command; +> taglist_command; \ No newline at end of file diff --git a/src/db.cpp b/src/db.cpp index 1360589..17c98f3 100644 --- a/src/db.cpp +++ b/src/db.cpp @@ -1,10 +1,8 @@ -#include "log.hpp" -#include "services/service.hpp" +#include +#include #include -#include -#include -DB::DB(const std::string &url_db): connection_pool(std::make_shared(url_db, 12)) { +DB::DB(const std::string &url_db): connection_pool(std::make_shared(url_db)) { try { Transaction txn(connection_pool); txn.exec(R"( @@ -14,17 +12,24 @@ DB::DB(const std::string &url_db): connection_pool(std::make_shared(); + } + + return false; + } catch (const std::exception& e) { + LOG_ERROR("{}", e.what()); + return false; + } +} + void DB::rmUser(std::int64_t id) { try { Transaction txn(connection_pool); @@ -101,63 +155,81 @@ void DB::rmUser(std::int64_t id) { } } -std::pair> DB::getAndCreate(std::shared_ptr service, std::string_view tag) { - auto [cachetag_id, cachetag_info] = getCacheTagByName(service->type, tag); - if (!cachetag_info) { - cachetag_info = service->getTagInfo(std::string(tag)); - if (!cachetag_info) - return { -1, std::nullopt }; +post_data_tv DB::getPostsByTagId(std::shared_ptr service, int tag_id) { + post_data_tv result; + try { + Transaction txn(connection_pool); + pqxx::result r = txn.exec(R"( + SELECT p.* + FROM posts p + JOIN post_tags pt + ON pt.post_id = p.id + AND pt.service_id = p.service_id + WHERE pt.tag_id = $1 + AND p.service_id = $2; + )", pqxx::params{tag_id, service->service_id}); + txn.commit(); - addCacheTag(*cachetag_info); - std::tie(cachetag_id, cachetag_info) = getCacheTagByName(service->type, tag); - if (cachetag_id == -1) - return { -1, std::nullopt }; + if (r.empty()) + return result; + + result.reserve(r.size()); + for (const auto& row : r) { + postData post; + post.id = row["id"].as(); + post.file_url = row["file_url"].c_str(); + post.preview_url = row["preview_url"].c_str(); + + result.emplace_back(std::move(post)); + } + + return result; + } catch (const std::exception& e) { + LOG_ERROR("getNewPostForUser: {}", e.what()); } - return {cachetag_id, cachetag_info}; + return result; } -bool DB::addTag(std::shared_ptr service, std::string_view tag, std::int64_t user_id) { +std::expected DB::addTag(std::shared_ptr service, std::string_view tag, std::int64_t user_id) { try { - auto [cachetag_id, cachetag_info] = getAndCreate(service, tag); - if (cachetag_id == -1) { - return false; - } + auto [cachetag_id, cachetag_info] = getCacheTagByName(service->service_id, tag); + if (cachetag_id == -1) + return std::unexpected("tag not valid in local DB"); - { - Transaction txn(connection_pool); - pqxx::result r = txn.exec(R"( - SELECT 1 FROM tags WHERE user_id = $1 AND tag_id = $2; - )", pqxx::params{user_id, cachetag_id}); + bool is_admin = userIsAdmin(user_id); - if (!r.empty()) { - return false; - } + TypeTags typetag = Service::getTypeFromValue(cachetag_info->type); + if (!is_admin && typetag != TypeTags::Artist) + return std::unexpected("Allowed only artist tag"); + if (!is_admin && cachetag_info->count >= LIMIT_COUNT_TAG) + return std::unexpected(std::format("Limit {} post on tag", LIMIT_COUNT_TAG)); + + { + Transaction txn(connection_pool); txn.exec(R"( INSERT INTO tags (user_id, tag_id) VALUES ($1, $2) ON CONFLICT (user_id, tag_id) DO NOTHING; )", pqxx::params{user_id, cachetag_id}); - txn.commit(); } - auto post = service->parse(cachetag_info->id); - addHistory(service, cachetag_id, post, user_id); - - return true; + auto post = getPostsByTagId(service,cachetag_info->id); + addHistory(post, user_id); + return std::monostate{}; } catch (const std::exception& e) { LOG_ERROR("Add Tag: {}", e.what()); } - return false; + return std::unexpected("Error adding tag"); } void DB::addAntiTag(std::shared_ptr service, std::string_view tag, std::int64_t user_id) { try { - auto [cachetag_id, _] = getAndCreate(service, tag); + auto [cachetag_id, cachetag_info] = getCacheTagByName(service->service_id, tag); if (cachetag_id == -1) { return; } @@ -172,12 +244,10 @@ void DB::addAntiTag(std::shared_ptr service, std::string_view tag, std: void DB::rmTag(std::shared_ptr service, std::string_view tag, std::int64_t user_id) { try { - auto [cachetag_id, _] = getAndCreate(service, tag); + auto [cachetag_id, _] = getCacheTagByName(service->service_id, tag); if (cachetag_id == -1) return; - int tag_entry_id; - Transaction txn(connection_pool); pqxx::result r = txn.exec(R"( SELECT id FROM tags @@ -187,7 +257,7 @@ void DB::rmTag(std::shared_ptr service, std::string_view tag, std::int6 if (r.empty()) return; - tag_entry_id = r[0]["id"].as(); + int tag_entry_id = r[0]["id"].as(); txn.exec(R"( DELETE FROM history WHERE user_id = $1 AND tag_id = $2; @@ -205,7 +275,7 @@ void DB::rmTag(std::shared_ptr service, std::string_view tag, std::int6 void DB::rmAntiTag(std::shared_ptr service, std::string_view tag, std::int64_t user_id) { try { - auto [cachetag_id, _] = getAndCreate(service, tag); + auto [cachetag_id, cachetag_info] = getCacheTagByName(service->service_id, tag); if (cachetag_id == -1) { return; } @@ -218,38 +288,25 @@ void DB::rmAntiTag(std::shared_ptr service, std::string_view tag, std:: } } -void DB::addHistory(std::shared_ptr service, int tag_id, const post_data_tv &newhistory, std::int64_t user_id) { +void DB::addHistory(const post_data_tv &newhistory, std::int64_t user_id) { if (newhistory.empty()) return; try { - auto user_tag_id = getUserTagId(user_id, tag_id); - if (!user_tag_id) { - LOG_ERROR("User tag not found for user {} and tag_id {} (cache)", user_id, tag_id); - return; - } - - Transaction txn(connection_pool); std::stringstream sql; - sql << "INSERT INTO history (history, tag_id, user_id) VALUES "; + sql << "INSERT INTO history (post_id, user_id) VALUES "; bool first = true; for (const auto& item : newhistory) { - const auto& data = item.id; - if (data.empty()) { - continue; - } - if (!first) { sql << ", "; } first = false; sql << "(" - << txn.get().quote(data) << ", " - << *user_tag_id << ", " + << item.id << ", " << user_id << ")"; } @@ -257,12 +314,12 @@ void DB::addHistory(std::shared_ptr service, int tag_id, const post_dat if (first) return; - sql << " ON CONFLICT (history, tag_id, user_id) DO NOTHING;"; + sql << " ON CONFLICT (post_id, user_id) DO NOTHING;"; txn.exec(sql.str()); txn.commit(); } catch (const std::exception& e) { - LOG_ERROR("Error adding history (tag_id: {}, user_id: {}): {}", tag_id, user_id, e.what()); + LOG_ERROR("Error adding history (user_id: {}): {}", user_id, e.what()); } } @@ -292,86 +349,79 @@ bool DB::userExist(std::int64_t user, bool is_admin) { } } -std::unordered_map> DB::getUsersByTags(std::shared_ptr service) { - std::unordered_map> tagToUsers; - +std::vector DB::getUserForTags(std::shared_ptr service) { + std::vector result; try { Transaction txn(connection_pool); - - pqxx::result tagQueryResult = txn.exec(R"( - SELECT tag_cache.tag_id, users.id - FROM users - JOIN tags ON users.id = tags.user_id - JOIN tag_cache ON tags.tag_id = tag_cache.id - WHERE tag_cache.site = $1; - )", pqxx::params{service->type}); - + pqxx::result r = txn.exec(R"( + SELECT DISTINCT t.user_id + FROM tags t + JOIN tag_cache tc ON tc.id = t.tag_id + WHERE tc.service_id = $1; + )", pqxx::params{service->service_id}); txn.commit(); - for (const auto& row : tagQueryResult) { - int external_tag_id = row[0].as(); - std::int64_t user_id = row[1].as(); + if (r.empty()) + return result; - tagToUsers[external_tag_id].push_back(user_id); + result.reserve(r.size()); + for (const auto& row : r) { + int id = row["user_id"].as(); + result.emplace_back(id); } - } catch (const std::exception &e) { - LOG_ERROR("Error fetching users by tags for site '{}': {}", service->type, e.what()); - } - - return tagToUsers; -} - -std::vector DB::getHistory(std::int64_t user_id, std::shared_ptr service) { - std::vector history_entries; - - try { - Transaction txn(connection_pool); - - pqxx::result queryResult = txn.exec(R"( - SELECT history.history - FROM history - JOIN tags ON history.tag_id = tags.id - JOIN tag_cache ON tags.tag_id = tag_cache.id - WHERE history.user_id = $1 AND tag_cache.site = $2; - )", pqxx::params{user_id, service->type}); - - txn.commit(); - - for (const auto& row : queryResult) { - history_entries.push_back(row[0].as()); - } + return result; } catch (const std::exception& e) { - LOG_ERROR("Error retrieving history for user {} and site '{}': {}", user_id, service->type, e.what()); + LOG_ERROR("getUserForTags: {}", e.what()); } - return history_entries; + return result; } - -std::vector DB::getAntiTagForUserAndSite(std::int64_t user_id, std::shared_ptr service) { - std::vector antitag_entries; - +post_data_tv DB::getNewPostForUser(std::shared_ptr service, std::int64_t tg_id) { + post_data_tv result; try { Transaction txn(connection_pool); + pqxx::result r = txn.exec(R"( + SELECT p.id, p.file_url, p.preview_url, p.score, p.rating + FROM posts p + WHERE p.service_id = $2 + AND NOT EXISTS ( + SELECT 1 FROM history h + WHERE h.post_id = p.id + AND h.user_id = $1 + ) + AND NOT EXISTS ( + SELECT 1 + FROM post_tags pt + JOIN anti_tags at + ON at.tag_id = pt.tag_id + AND at.user_id = $1 + WHERE pt.post_id = p.id + AND pt.service_id = $2 + ) + )", pqxx::params{tg_id, service->service_id}); + txn.commit(); - pqxx::result queryResult = txn.exec(R"( - SELECT tag_cache.tag_name - FROM anti_tags - JOIN tag_cache ON anti_tags.tag_id = tag_cache.id - WHERE anti_tags.user_id = $1 AND tag_cache.site = $2; - )", pqxx::params{user_id, service->type}); + if (r.empty()) + return result; - txn.commit(); + result.reserve(r.size()); + for (const auto& row : r) { + postData post; + post.id = row["id"].as(); + post.file_url = row["file_url"].c_str(); + post.preview_url = row["preview_url"].c_str(); - for (const auto& row : queryResult) { - antitag_entries.push_back(row[0].as()); + result.emplace_back(std::move(post)); } + + return result; } catch (const std::exception& e) { - LOG_ERROR("Error retrieving antitag for user {} and site {}: {}", user_id, service->type, e.what()); + LOG_ERROR("getNewPostForUser: {}", e.what()); } - return antitag_entries; + return result; } std::string DB::getFormattedTagsAndAntiTags(std::int64_t user_id) { @@ -381,26 +431,26 @@ std::string DB::getFormattedTagsAndAntiTags(std::int64_t user_id) { Transaction txn(connection_pool); pqxx::result siteResult = txn.exec(R"( - SELECT DISTINCT tag_cache.site + SELECT DISTINCT tag_cache.service_id FROM tags JOIN tag_cache ON tags.tag_id = tag_cache.id WHERE tags.user_id = $1 UNION - SELECT DISTINCT tag_cache.site + SELECT DISTINCT tag_cache.service_id FROM anti_tags JOIN tag_cache ON anti_tags.tag_id = tag_cache.id WHERE anti_tags.user_id = $1 )", pqxx::params{user_id}); for (const auto &siteRow : siteResult) { - std::string site = siteRow[0].as(); + int service_id = siteRow[0].as(); pqxx::result tagResult = txn.exec(R"( SELECT tag_cache.tag_name FROM tags JOIN tag_cache ON tags.tag_id = tag_cache.id - WHERE tags.user_id = $1 AND tag_cache.site = $2 - )", pqxx::params{user_id, site}); + WHERE tags.user_id = $1 AND tag_cache.service_id = $2 + )", pqxx::params{user_id, service_id}); std::vector tags; for (const auto &tagRow : tagResult) { @@ -411,14 +461,18 @@ std::string DB::getFormattedTagsAndAntiTags(std::int64_t user_id) { SELECT tag_cache.tag_name FROM anti_tags JOIN tag_cache ON anti_tags.tag_id = tag_cache.id - WHERE anti_tags.user_id = $1 AND tag_cache.site = $2 - )", pqxx::params{user_id, site}); + WHERE anti_tags.user_id = $1 AND tag_cache.service_id = $2 + )", pqxx::params{user_id, service_id}); std::vector antiTags; for (const auto &antiTagRow : antiTagResult) { antiTags.push_back(antiTagRow[0].as()); } + std::string site = getServiceName(service_id); + if (site.empty()) + site = "Error"; + result << "-----" << site << " Tags-----\n"; for (const auto &tag : tags) { result << tag << "\n"; @@ -441,43 +495,6 @@ std::string DB::getFormattedTagsAndAntiTags(std::int64_t user_id) { return res.empty() ? "Empty" : res; } - -void DB::scoreUpdate(std::shared_ptr service, std::int64_t user_id, std::int64_t score) { - try { - std::string site = service->type; - - Transaction txn(connection_pool); - txn.exec("INSERT INTO settings (user_id, site, score) VALUES ($1, $2, $3) ON CONFLICT (user_id, site) DO UPDATE SET score = EXCLUDED.score;", pqxx::params{user_id, site, score}); - txn.commit(); - } catch (const std::exception& e) { - LOG_ERROR("{}", e.what()); - } -} - -std::int64_t DB::getScore(std::shared_ptr service, std::int64_t user_id) { - try { - std::string site = service->type; - - { - Transaction txn(connection_pool); - pqxx::result r = txn.exec( - "SELECT score FROM settings WHERE user_id = $1 AND site = $2;", - pqxx::params{user_id, site} - ); - txn.commit(); - - if (!r.empty()) - return r[0][0].as(); - } - - scoreUpdate(service, user_id, 0); - } catch (const std::exception& e) { - LOG_ERROR("Failed to get score for user {} and site {}: {}", user_id, service->type, e.what()); - } - - return 0; -} - void DB::addSystemValue(std::string_view key, std::string_view value) { try { Transaction txn(connection_pool); @@ -507,44 +524,42 @@ std::string DB::getSystemValue(std::string_view key) { return ""; } -std::pair> DB::getCacheTagById(std::string_view site, int tag_id) { +std::optional DB::getCacheTagById(int service_id, int tag_id) { try { Transaction txn(connection_pool); pqxx::result res = txn.exec(R"( SELECT * FROM tag_cache - WHERE tag_id = $1 AND site = $2; - )", pqxx::params{tag_id, site}); + WHERE tag_id = $1 AND service_id = $2; + )", pqxx::params{tag_id, service_id}); txn.commit(); if (res.empty()) { - return { -1, std::nullopt }; + return std::nullopt; } const pqxx::row& row = res[0]; - taginfo_t info; + Taginfo info; info.id = row["tag_id"].as(); info.name = row["tag_name"].as(); - info.site = row["site"].as(); + info.site = row["service_id"].as(); info.count = row["count"].as(); - info.type = row["type"].as(); + info.type = row["type"].as(); info.ambiguous = row["ambiguous"].as(); - int db_id = row["id"].as(); - - return { db_id, info }; + return info; } catch (const std::exception& e) { - LOG_ERROR("Failed to get cache tag: {} and site {}", tag_id, site); - return { -1, std::nullopt }; + LOG_ERROR("Failed to get cache tag: {} and service_id {}", tag_id, service_id); + return std::nullopt; } } -std::pair> DB::getCacheTagByName(std::string_view site, std::string_view name) { +std::pair> DB::getCacheTagByName(int service_id, std::string_view name) { try { Transaction txn(connection_pool); pqxx::result res = txn.exec(R"( SELECT * FROM tag_cache - WHERE tag_name = $1 AND site = $2; - )", pqxx::params{name, site}); + WHERE tag_name = $1 AND service_id = $2; + )", pqxx::params{name, service_id}); txn.commit(); if (res.empty()) { return { -1, std::nullopt }; @@ -552,30 +567,30 @@ std::pair> DB::getCacheTagByName(std::string_view const pqxx::row& row = res[0]; - taginfo_t info; + Taginfo info; info.id = row["tag_id"].as(); info.name = row["tag_name"].as(); - info.site = row["site"].as(); + info.site = row["service_id"].as(); info.count = row["count"].as(); - info.type = row["type"].as(); + info.type = row["type"].as(); info.ambiguous = row["ambiguous"].as(); int db_id = row["id"].as(); return { db_id, info }; } catch (const std::exception& e) { - LOG_ERROR("Failed to get cache tag: {} and site {}", name, site); + LOG_ERROR("Failed to get cache tag: {} and site {}", name, service_id); return { -1, std::nullopt }; } } -void DB::addCacheTag(const taginfo_t& info) { +void DB::addCacheTag(const Taginfo& info) { try { Transaction txn(connection_pool); txn.exec(R"( - INSERT INTO tag_cache (tag_id, tag_name, count, type, ambiguous, site) + INSERT INTO tag_cache (tag_id, tag_name, count, type, ambiguous, service_id) VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (site, tag_id) DO UPDATE SET tag_name = EXCLUDED.tag_name; + ON CONFLICT (service_id, tag_id) DO UPDATE SET tag_name = EXCLUDED.tag_name; )", pqxx::params{info.id, info.name, info.count, info.type, info.ambiguous, info.site}); txn.commit(); } catch (const std::exception& e) { @@ -599,33 +614,208 @@ std::optional DB::getUserTagId(std::int64_t user_id, std::int64_t return std::nullopt; } -std::optional DB::getUserTagIdByExternalTagId(int external_tag_id) { - try { - Transaction txn(connection_pool); - pqxx::result res = txn.exec(R"( - SELECT id FROM tag_cache WHERE tag_id = $1; - )", pqxx::params{external_tag_id}); - txn.commit(); - if (!res.empty()) { - return res[0]["id"].as(); - } - } catch (const std::exception& e) { - LOG_ERROR("getCacheIdByExternalTagId failed: {}", e.what()); +std::optional DB::tagInfoFromSite(std::shared_ptr service, const std::string& tag) { + auto info = service->getTagInfo(tag); + if (info) { + DB::getInstance().addCacheTag(*info); + return info; } + return std::nullopt; } -std::optional DB::getTagFromCacheOrAddTag(std::shared_ptr service, int tag_id) { - auto [cache_id, cache] = DB::getInstance().getCacheTagById(service->type, tag_id); +std::optional DB::getTagFromCacheOrAddTag(std::shared_ptr service, int tag_id) { + std::optional cache = DB::getInstance().getCacheTagById(service->service_id, tag_id); if (cache) { return cache; } - auto info = service->getTagInfo(cache->name); + auto info = tagInfoFromSite(service, cache->name); if (info) { - DB::getInstance().addCacheTag(*info); return info; } return std::nullopt; -} \ No newline at end of file +} + +std::optional DB::getIdServiceByName(std::string_view name) { + try { + Transaction txn(connection_pool); + pqxx::result r = txn.exec(R"( + INSERT INTO service (name) + VALUES ($1) + ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name + RETURNING id; + )", pqxx::params{name}); + txn.commit(); + + if (!r.empty()) + return r[0]["id"].as(); + } catch (const std::exception& e) { + LOG_ERROR("Error getIdServiceByName failed: {}", e.what()); + } + return std::nullopt; +} + +std::string DB::getServiceName(int service_id) { + try { + Transaction txn(connection_pool); + pqxx::result res = txn.exec(R"( + SELECT name FROM service WHERE id = $1; + )", pqxx::params{service_id}); + txn.commit(); + + if (!res.empty()) + return res[0]["name"].as(); + } catch (const std::exception& e) { + LOG_ERROR("getServiceName failed: {}", e.what()); + } + + return ""; +} + +bool DB::addPost(std::shared_ptr service, const post_data_tv& posts) { + try { + if (posts.empty()) + return false; + + std::vector inserted_ids; + { + Transaction txn(connection_pool); + std::stringstream sql; + sql << "INSERT INTO posts (id, score, rating, file_url, preview_url, service_id) VALUES "; + + bool first = true; + for (const auto& item : posts) { + if (!first) { + sql << ", "; + } + first = false; + + sql << "(" + << item.id << ", " + << item.score << ", " + << txn.get().quote(item.rating) << ", " + << txn.get().quote(item.file_url) << ", " + << txn.get().quote(item.preview_url) << ", " + << service->service_id + << ")"; + } + + if (first) + return false; + + sql << " ON CONFLICT (id, service_id) DO UPDATE SET score = EXCLUDED.score RETURNING id;"; + pqxx::result r = txn.exec(sql.str()); + txn.commit(); + + for (const auto& row : r) { + inserted_ids.push_back(row[0].as()); + } + } + + for (size_t i = 0; i < inserted_ids.size(); ++i) { + const auto& post = posts[i]; + if (!addTagOnPost(service, post.tags, inserted_ids[i])) { + LOG_ERROR("Не удалось добавить теги к посту {}", inserted_ids[i]); + } + } + + return true; + } catch (const std::exception& e) { + LOG_ERROR("addPost failed: {}", e.what()); + } + + return false; +} + +bool DB::addTagOnPost(std::shared_ptr service, const std::vector& tags, int post_id) { + try { + pqxx::result r; + { + Transaction txn(connection_pool); + r = txn.exec(R"( + SELECT DISTINCT ON (tag_name) * + FROM tag_cache + WHERE tag_name = ANY($1::text[]) + ORDER BY tag_name, updated_at DESC; + )", pqxx::params{tags}); + txn.commit(); + } + + std::unordered_map db_set; + for (const auto& row : r) { + int id = row[0].as(); + std::string name = row[2].as(); + db_set.emplace(name, id); + } + + std::vector missing_tags; + for (const auto &tag : tags) { + if (!db_set.contains(tag)) { + missing_tags.push_back(tag); + } + } + + for (const auto& t : missing_tags) { + LOG_DEBUG("Отсутствует: {}", t); + + auto info = tagInfoFromSite(service, t); + if (!info) { + LOG_ERROR("[DB] tagInfoFromSite service: {} | tag: {}", service->type, t); + continue; + } + + Transaction txn(connection_pool); + pqxx::result res = txn.exec( + "SELECT id, tag_name " + "FROM tag_cache " + "WHERE service_id = $1 AND tag_id = $2 " + "ORDER BY updated_at DESC " + "LIMIT 1;", + pqxx::params{service->service_id, info->id} + ); + txn.commit(); + + if (res.empty()) { + LOG_ERROR("Не найден локальный id для тега: {} (service_id={}, tag_id={})", + info->name, service->service_id, info->id); + continue; + } + + int local_id = res[0][0].as(); + std::string name = res[0][1].as(); + + db_set.emplace(name, local_id); + } + + const size_t batch_size = 100; + std::vector> post_tag_pairs; + for (const auto& item : db_set) { + post_tag_pairs.emplace_back(post_id, item.second); + } + + for (size_t i = 0; i < post_tag_pairs.size(); i += batch_size) { + Transaction txn(connection_pool); + std::stringstream sql; + sql << "INSERT INTO post_tags (post_id, tag_id, service_id) VALUES "; + + bool first = true; + for (size_t j = i; j < std::min(i + batch_size, post_tag_pairs.size()); ++j) { + if (!first) sql << ", "; + first = false; + sql << "(" << post_tag_pairs[j].first << ", " << post_tag_pairs[j].second << ", "<< service->service_id << ")"; + } + + sql << " ON CONFLICT (post_id, tag_id, service_id) DO NOTHING;"; + txn.exec(sql.str()); + txn.commit(); + } + + return true; + } catch (const std::exception& e) { + LOG_ERROR("addTagOnPost failed: {}", e.what()); + } + + return false; +} diff --git a/src/keyboard.cpp b/src/keyboard.cpp index ab89840..35b998f 100644 --- a/src/keyboard.cpp +++ b/src/keyboard.cpp @@ -1,10 +1,9 @@ #include -using start_button = BaseButton<"start", false, - [](CommandContext& ctx){ - LOG_INFO("Start button pressed, mid: {}", ctx.message->messageId); - } ->; +inline constexpr auto start_handler = [](CommandContext& ctx) { + LOG_INFO("Start button pressed, mid: {}", ctx.message->messageId); +}; +using start_button = BaseButton<"start", false, start_handler>; static const BaseKeyboard< "main", diff --git a/src/main.cpp b/src/main.cpp index 99eb272..20b203b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,10 +1,6 @@ -//#include #include #include -//#include -//#include #include -#include std::unique_ptr bot; @@ -34,9 +30,6 @@ int main() { bot->addService(); bot->addService(); - //bot->addService(); - //bot->addService(); - //bot->addService(); bot->start(); diff --git a/src/services/gelbooru.cpp b/src/services/gelbooru.cpp index 8daf4d7..efd8921 100644 --- a/src/services/gelbooru.cpp +++ b/src/services/gelbooru.cpp @@ -1,11 +1,17 @@ -#include "log.hpp" -#include "services/service.hpp" -#include #include -#include -#include - -Gelbooru::Gelbooru() : Service("gelbooru", "https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&tags=", "https://gelbooru.com/index.php?page=post&s=view&id=") { +#include +#include +#include +#include +#include + +Gelbooru::Gelbooru() : Service("gelbooru", "https://gelbooru.com/index.php?page=dapi&s=post&q=index&tags=", "https://gelbooru.com/index.php?page=post&s=view&id=") { + std::optional get_id = DB::getInstance().getIdServiceByName(type); + if (!get_id) + throw std::runtime_error("[Gelbooru] Error reg service"); + + service_id = *get_id; + apikey = DB::getInstance().getSystemValue("gelbooru_api_key"); userid = DB::getInstance().getSystemValue("gelbooru_userid"); @@ -15,26 +21,26 @@ Gelbooru::Gelbooru() : Service("gelbooru", "https://gelbooru.com/index.php?page= LOG_INFO("Gelbooru api key: "); std::cin >> apikey; - LOG_INFO("Rule34 userid: "); + LOG_INFO("Gelbooru userid: "); std::cin >> userid; if (apikey.empty() || userid.empty()) - throw std::runtime_error("[Rule34] apikey/userid is empty"); + throw std::runtime_error("[Gelbooru] apikey/userid is empty"); DB::getInstance().addSystemValue("gelbooru_api_key", apikey); - DB::getInstance().addSystemValue("gelbooru_userid", userid); + DB::getInstance().addSystemValue("gelbooru_userid", userid); }; post_data_tv Gelbooru::parse(int tag_id) { - std::optional tag = DB::getInstance().getTagFromCacheOrAddTag(shared_from_this(), tag_id); + std::optional tag = DB::getInstance().getTagFromCacheOrAddTag(shared_from_this(), tag_id); if (!tag) { LOG_WARN("[Gelbooru] getTagFromCacheOrAddTag is null, tag: {}", tag_id); return {}; } - auto [data, status] = request(url + tag->name); + auto [data, status] = request(url + tag->name + "+-ai_generated"); if (data.empty() || status != 200) { - std::optional info = tagInfoById(tag->id); + std::optional info = tagInfoById(tag->id); if (!info) return {}; @@ -52,18 +58,24 @@ post_data_tv Gelbooru::parse(int tag_id) { post_data_tv tmp; try { - json js = json::parse(data); + pugi::xml_document doc; + pugi::xml_parse_result result = doc.load_string(data.c_str()); + if (!result) + return {}; - for (const auto& item : js.at("post")) { - std::string id = std::to_string(item.at("id").get()); - std::vector content = { - { item.at("file_url").get(), item.at("preview_url").get() } - }; - std::vector tags = Utils::split(item.at("tags").get(), ' '); - int score = item.at("score").get(); + for (pugi::xml_node post = doc.child("posts").child("post"); post; post = post.next_sibling("post")) { + int id = post.child("id").text().as_int(); + std::string file_url = post.child("file_url").text().as_string(); + std::string preview_url = post.child("preview_url").text().as_string(); + std::vector tags = Utils::split(post.child("tags").text().as_string(), ' '); + int score = post.child("score").text().as_int(); + std::string rating = post.child("rating").text().as_string(); - tmp.emplace_back(content, tags, id, type, score); + tmp.emplace_back(file_url, preview_url, tags, id, score, rating); } + + DB::getInstance().addPost(shared_from_this(), tmp); + } catch (const std::exception& e) { LOG_ERROR("[Gelbooru] Type error: {} Tag: {}", e.what(), tag->name); } @@ -71,7 +83,7 @@ post_data_tv Gelbooru::parse(int tag_id) { return tmp; } -std::optional Gelbooru::parseTagInfoFromXml(std::string_view xml) { +std::optional Gelbooru::parseTagInfoFromXml(std::string_view xml) { pugi::xml_document doc; pugi::xml_parse_result result = doc.load_buffer(xml.data(), xml.size()); if (!result) @@ -81,19 +93,19 @@ std::optional Gelbooru::parseTagInfoFromXml(std::string_view xml) { if (!first_tag) return std::nullopt; - taginfo_t taginfo; + Taginfo taginfo; taginfo.id = first_tag.child("id").text().as_int(); taginfo.name = first_tag.child("name").text().as_string(); taginfo.count = first_tag.child("count").text().as_int(); taginfo.type = first_tag.child("type").text().as_int(); taginfo.ambiguous = first_tag.child("ambiguous").text().as_bool(); - taginfo.site = type; + taginfo.site = service_id; return taginfo; } -std::optional Gelbooru::getTagInfo(const std::string& tag) { - auto [id, cached] = DB::getInstance().getCacheTagByName(type, tag); +std::optional Gelbooru::getTagInfo(const std::string& tag) { + auto [id, cached] = DB::getInstance().getCacheTagByName(service_id, tag); if (cached) return cached; @@ -102,13 +114,15 @@ std::optional Gelbooru::getTagInfo(const std::string& tag) { return std::nullopt; auto info = parseTagInfoFromXml(xml); - if (info) + if (info) { DB::getInstance().addCacheTag(*info); + return info; + } - return info; + return std::nullopt; } -std::optional Gelbooru::tagInfoById(int tag_id) { +std::optional Gelbooru::tagInfoById(int tag_id) { auto [xml, status] = request("https://gelbooru.com/index.php?page=dapi&s=tag&q=index&id=" + std::to_string(tag_id)); if (xml.empty() || status != 200) return std::nullopt; diff --git a/src/services/rule34.cpp b/src/services/rule34.cpp index 697f925..782aada 100644 --- a/src/services/rule34.cpp +++ b/src/services/rule34.cpp @@ -1,10 +1,17 @@ -#include "log.hpp" -#include "services/service.hpp" -#include #include +#include #include +#include +#include +#include + +Rule34::Rule34() : Service("rule34", "https://api.rule34.xxx/index.php?page=dapi&s=post&q=index&tags=", "https://rule34.xxx/index.php?page=post&s=view&id=") { + std::optional get_id = DB::getInstance().getIdServiceByName(type); + if (!get_id) + throw std::runtime_error("[Rule34] Error reg service"); + + service_id = *get_id; -Rule34::Rule34() : Service("rule34", "https://api.rule34.xxx/index.php?page=dapi&s=post&q=index&json=1&tags=", "https://rule34.xxx/index.php?page=post&s=view&id=") { apikey = DB::getInstance().getSystemValue("rule34_api_key"); userid = DB::getInstance().getSystemValue("rule34_userid"); @@ -25,15 +32,15 @@ Rule34::Rule34() : Service("rule34", "https://api.rule34.xxx/index.php?page=dapi }; post_data_tv Rule34::parse(int tag_id) { - std::optional tag = DB::getInstance().getTagFromCacheOrAddTag(shared_from_this(), tag_id); + std::optional tag = DB::getInstance().getTagFromCacheOrAddTag(shared_from_this(), tag_id); if (!tag) { LOG_WARN("[Rule34] getTagFromCacheOrAddTag is null, tag: {}", tag_id); return {}; } - auto [data, status] = request(url + tag->name); + auto [data, status] = request(url + tag->name + "+-ai_generated"); if (data.empty() || status != 200) { - std::optional info = tagInfoById(tag->id); + std::optional info = tagInfoById(tag->id); if (!info) return {}; @@ -51,17 +58,24 @@ post_data_tv Rule34::parse(int tag_id) { post_data_tv tmp; try { - json js = json::parse(data); - for (const auto& item : js) { - std::string id = std::to_string(item.at("id").get()); - std::vector content = { - { item.at("file_url").get(), item.at("preview_url").get() } - }; - std::vector tags = Utils::split(item.at("tags").get(), ' '); - int score = item.at("score").get(); - - tmp.emplace_back(content, tags, id, type, score); + pugi::xml_document doc; + pugi::xml_parse_result result = doc.load_string(data.c_str()); + if (!result) + return {}; + + for (pugi::xml_node post = doc.child("posts").child("post"); post; post = post.next_sibling("post")) { + int id = post.attribute("id").as_int(); + std::string file_url = post.attribute("file_url").as_string(); + std::string preview_url = post.attribute("preview_url").as_string(); + std::vector tags = Utils::split(post.attribute("tags").as_string(), ' '); + int score = post.attribute("score").as_int(); + std::string rating = post.attribute("rating").as_string(); + + tmp.emplace_back(file_url, preview_url, tags, id, score, rating); } + + DB::getInstance().addPost(shared_from_this(), tmp); + } catch (const std::exception& e) { LOG_ERROR("[Rule34] Type error: {} Tag: {}", e.what(), tag->name); } @@ -69,7 +83,7 @@ post_data_tv Rule34::parse(int tag_id) { return tmp; } -std::optional Rule34::parseTagInfoFromXml(std::string_view xml) { +std::optional Rule34::parseTagInfoFromXml(std::string_view xml) { pugi::xml_document doc; pugi::xml_parse_result result = doc.load_buffer(xml.data(), xml.size()); if (!result) @@ -79,19 +93,19 @@ std::optional Rule34::parseTagInfoFromXml(std::string_view xml) { if (!first_tag) return std::nullopt; - taginfo_t taginfo; + Taginfo taginfo; taginfo.name = first_tag.attribute("name").as_string(); taginfo.id = first_tag.attribute("id").as_int(); taginfo.count = first_tag.attribute("count").as_int(); taginfo.ambiguous = first_tag.attribute("ambiguous").as_bool(); taginfo.type = first_tag.attribute("type").as_int(); - taginfo.site = type; + taginfo.site = service_id; return taginfo; } -std::optional Rule34::getTagInfo(const std::string& tag) { - auto [id, cached] = DB::getInstance().getCacheTagByName(type, tag); +std::optional Rule34::getTagInfo(const std::string& tag) { + auto [id, cached] = DB::getInstance().getCacheTagByName(service_id, tag); if (cached) return cached; @@ -100,13 +114,15 @@ std::optional Rule34::getTagInfo(const std::string& tag) { return std::nullopt; auto info = parseTagInfoFromXml(xml); - if (info) + if (info) { DB::getInstance().addCacheTag(*info); + return info; + } - return info; + return std::nullopt; } -std::optional Rule34::tagInfoById(int tag_id) { +std::optional Rule34::tagInfoById(int tag_id) { auto [xml, status] = request("https://api.rule34.xxx/index.php?page=dapi&s=tag&q=index&id=" + std::to_string(tag_id)); if (status != 200) return std::nullopt; diff --git a/src/sync.cpp b/src/sync.cpp new file mode 100644 index 0000000..4c057b5 --- /dev/null +++ b/src/sync.cpp @@ -0,0 +1,13 @@ +#include + +Sync::Sync(std::vector>& services) { + for (auto& service : services) { + threads.emplace_back( + std::jthread([service](std::stop_token st) { + //while (!st.stop_requested()) { + + //} + }) + ); + } +} \ No newline at end of file diff --git a/src/utils.cpp b/src/utils.cpp index a4e5133..528d424 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -1,4 +1,5 @@ #include +#include std::vector Utils::split(const std::string& str, char delimiter) { std::vector tokens; @@ -12,76 +13,6 @@ std::vector Utils::split(const std::string& str, char delimiter) { return tokens; } -std::vector Utils::sha256(std::string_view input) { - EVP_MD_CTX* ctx = EVP_MD_CTX_new(); - if (!ctx) { - throw std::runtime_error("Failed to create context"); - } - - if (EVP_DigestInit_ex(ctx, EVP_sha256(), nullptr) != 1) { - EVP_MD_CTX_free(ctx); - throw std::runtime_error("Failed to initialize digest"); - } - - if (EVP_DigestUpdate(ctx, input.data(), input.size()) != 1) { - EVP_MD_CTX_free(ctx); - throw std::runtime_error("Failed to update digest"); - } - - std::vector hash(EVP_MD_size(EVP_sha256())); - unsigned int length = 0; - if (EVP_DigestFinal_ex(ctx, hash.data(), &length) != 1) { - EVP_MD_CTX_free(ctx); - throw std::runtime_error("Failed to finalize digest"); - } - - EVP_MD_CTX_free(ctx); - - return hash; -} - - -std::string Utils::urlsafe_b64encode(const std::vector& hash) { - const std::size_t encoded_size = 4 * ((hash.size() + 2) / 3); - - std::vector encoded(encoded_size + 1); - EVP_EncodeBlock(reinterpret_cast(encoded.data()), hash.data(), hash.size()); - - std::string encoded_str(encoded.data()); - for (auto& ch : encoded_str) { - if (ch == '+') ch = '-'; - else if (ch == '/') ch = '_'; - } - - encoded_str.erase(encoded_str.find_last_not_of('=') + 1); - - return encoded_str; -} - -std::string Utils::generate_urlsafe_token(std::size_t length) { - static const char alphanum[] = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789-_"; - - std::string token; - std::random_device rd; - std::mt19937 generator(rd()); - std::uniform_int_distribution<> dist(0, sizeof(alphanum) - 2); - - for (std::size_t i = 0; i < length; ++i) { - token += alphanum[dist(generator)]; - } - - return token; -} - -bool Utils::IsAllowFileFormat(std::string_view file) { - std::string ext = std::filesystem::path(file).extension().string(); - std::vector allow_ext = {".mp4", ".png", ".jpg", ".jpeg", ".gif", ".webp"}; - return Utils::contains(allow_ext, ext); -} - std::string Utils::escapeMarkdownV2(std::string_view text) { static const std::string special_chars = "_*[]()~`>#+-=|{}.!"; std::string escaped;