diff --git a/client.cpp b/client.cpp index 144d91e..8f322d3 100644 --- a/client.cpp +++ b/client.cpp @@ -4,6 +4,7 @@ #include "client.h" #include "models.h" +#include "curlu_wrapper.h" #include "hiredis_wrapper.h" MastodonClient mastodon_client; @@ -85,8 +86,13 @@ std::optional MastodonClient::get_account_by_username(std::string host, username.erase(username.size() - host.size() - 1); } + CurlUrl url; + url.set(CURLUPART_SCHEME, "https"); + url.set(CURLUPART_HOST, host); + url.set(CURLUPART_PATH, "/api/v1/accounts/lookup"); + url.set(CURLUPART_QUERY, "acct="s + url_encode(username)); try { - Account account = this->_send_request("coyote:"s + host + ":@" + username, "https://"s + host + "/api/v1/accounts/lookup?acct=" + url_encode(username)); + Account account = this->_send_request("coyote:"s + host + ":@" + username, url); account.same_server = host == account.server; return account; } catch (const MastodonException& e) { @@ -102,7 +108,12 @@ std::vector MastodonClient::get_pinned_posts(std::string host, const std:: using namespace std::string_literals; lowercase(host); - std::vector posts = this->_send_request("coyote:"s + host + ':' + account_id + ":pinned", "https://"s + host + "/api/v1/accounts/" + account_id + "/statuses?pinned=true"); + CurlUrl url; + url.set(CURLUPART_SCHEME, "https"); + url.set(CURLUPART_HOST, host); + url.set(CURLUPART_PATH, "/api/v1/accounts/"s + url_encode(account_id) + "/statuses"); + url.set(CURLUPART_QUERY, "pinned=true"); + std::vector posts = this->_send_request("coyote:"s + host + ':' + account_id + ":pinned", url); for (Post& post : posts) { handle_post_server(post, host); @@ -113,22 +124,17 @@ std::vector MastodonClient::get_pinned_posts(std::string host, const std:: std::vector MastodonClient::get_posts(const std::string& host, const std::string& account_id, PostSortingMethod sorting_method, std::optional max_id) { using namespace std::string_literals; - const char* sorting_parameters[3] = {"exclude_replies=true", "", "only_media=true"}; - std::string query = sorting_parameters[sorting_method]; + + CurlUrl url; + url.set(CURLUPART_SCHEME, "https"); + url.set(CURLUPART_HOST, host); + url.set(CURLUPART_PATH, "/api/v1/accounts/"s + url_encode(account_id) + "/statuses"); + url.set(CURLUPART_QUERY, sorting_parameters[sorting_method]); if (max_id) { - if (!query.empty()) { - query += '&'; - } - query += "max_id="; - query += url_encode(std::move(*max_id)); + url.set(CURLUPART_QUERY, "max_id="s + std::move(*max_id), CURLU_URLENCODE | CURLU_APPENDQUERY); } - std::string url = "https://"s + host + "/api/v1/accounts/" + account_id + "/statuses"; - if (!query.empty()) { - url += '?'; - url += query; - } std::vector posts = this->_send_request(std::nullopt, url); for (Post& post : posts) { @@ -141,8 +147,12 @@ std::vector MastodonClient::get_posts(const std::string& host, const std:: std::optional MastodonClient::get_post(const std::string& host, std::string id) { using namespace std::string_literals; + CurlUrl url; + url.set(CURLUPART_SCHEME, "https"); + url.set(CURLUPART_HOST, host); + url.set(CURLUPART_PATH, "/api/v1/statuses/"s + url_encode(std::move(id))); try { - Post post = this->_send_request(std::nullopt, "https://"s + host + "/api/v1/statuses/" + url_encode(std::move(id))); + Post post = this->_send_request(std::nullopt, url); handle_post_server(post, host); return post; } catch (const MastodonException& e) { @@ -157,7 +167,11 @@ std::optional MastodonClient::get_post(const std::string& host, std::strin PostContext MastodonClient::get_post_context(const std::string& host, std::string id) { using namespace std::string_literals; - PostContext context = this->_send_request(std::nullopt, "https://"s + host + "/api/v1/statuses/" + url_encode(std::move(id)) + "/context"); + CurlUrl url; + url.set(CURLUPART_SCHEME, "https"); + url.set(CURLUPART_HOST, host); + url.set(CURLUPART_PATH, "/api/v1/statuses/"s + url_encode(std::move(id)) + "/context"); + PostContext context = this->_send_request(std::nullopt, url); for (Post& post : context.ancestors) { handle_post_server(post, host); @@ -172,10 +186,12 @@ PostContext MastodonClient::get_post_context(const std::string& host, std::strin std::vector MastodonClient::get_tag_timeline(const std::string& host, const std::string& tag, std::optional max_id) { using namespace std::string_literals; - std::string url = "https://"s + host + "/api/v1/timelines/tag/" + url_encode(tag); + CurlUrl url; + url.set(CURLUPART_SCHEME, "https"); + url.set(CURLUPART_HOST, host); + url.set(CURLUPART_PATH, "/api/v1/timelines/tag/"s + url_encode(tag)); if (max_id) { - url += "?max_id="; - url += url_encode(std::move(*max_id)); + url.set(CURLUPART_QUERY, "max_id="s + std::move(*max_id), CURLU_URLENCODE | CURLU_APPENDQUERY); } std::vector posts = this->_send_request(std::nullopt, url); @@ -189,7 +205,11 @@ Instance MastodonClient::get_instance(std::string host) { using namespace std::string_literals; lowercase(host); - Instance instance = this->_send_request("coyote:"s + host + ":instance", "https://"s + host + "/api/v2/instance"); + CurlUrl url; + url.set(CURLUPART_SCHEME, "https"); + url.set(CURLUPART_HOST, host); + url.set(CURLUPART_PATH, "/api/v2/instance"); + Instance instance = this->_send_request("coyote:"s + host + ":instance", url); instance.contact_account.same_server = instance.contact_account.server == host; return instance; } @@ -198,7 +218,11 @@ blankie::html::HTMLString MastodonClient::get_extended_description(std::string h using namespace std::string_literals; lowercase(host); - nlohmann::json j = this->_send_request("coyote:"s + host + ":desc", "https://"s + host + "/api/v1/instance/extended_description"); + CurlUrl url; + url.set(CURLUPART_SCHEME, "https"); + url.set(CURLUPART_HOST, host); + url.set(CURLUPART_PATH, "/api/v1/instance/extended_description"); + nlohmann::json j = this->_send_request("coyote:"s + host + ":desc", url); return blankie::html::HTMLString(j.at("content").get()); } @@ -234,7 +258,7 @@ CURL* MastodonClient::_get_easy() { return curl; } -nlohmann::json MastodonClient::_send_request(std::optional cache_key, const std::string& url) { +nlohmann::json MastodonClient::_send_request(std::optional cache_key, const CurlUrl& url) { std::optional cached; if (redis && cache_key && (cached = redis->get(*cache_key))) { return nlohmann::json::parse(std::move(*cached)); @@ -243,10 +267,11 @@ nlohmann::json MastodonClient::_send_request(std::optional cache_ke std::string res; CURL* curl = this->_get_easy(); - setopt(curl, CURLOPT_URL, url.c_str()); + setopt(curl, CURLOPT_CURLU, url.get()); setopt(curl, CURLOPT_WRITEFUNCTION, curl_write_cb); setopt(curl, CURLOPT_WRITEDATA, &res); CURLcode code = curl_easy_perform(curl); + setopt(curl, CURLOPT_CURLU, nullptr); if (code) { throw CurlException(code); } diff --git a/client.h b/client.h index da1928b..09bb9f0 100644 --- a/client.h +++ b/client.h @@ -8,6 +8,7 @@ #include #include "models.h" +class CurlUrl; // forward declaration from curlu_wrapper.h class CurlException : public std::exception { public: @@ -78,7 +79,7 @@ public: private: CURL* _get_easy(); - nlohmann::json _send_request(std::optional cache_key, const std::string& url); + nlohmann::json _send_request(std::optional cache_key, const CurlUrl& url); long _response_status_code(); std::mutex _share_locks[CURL_LOCK_DATA_LAST]; diff --git a/curlu_wrapper.h b/curlu_wrapper.h new file mode 100644 index 0000000..011b2ef --- /dev/null +++ b/curlu_wrapper.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +class CurlUrlException : public std::exception { +public: + CurlUrlException(CURLUcode code_) : code(code_) { +#if !CURL_AT_LEAST_VERSION(7, 80, 0) + snprintf(this->_id_buf, 64, "curl url error %d", this->code); +#endif + } + + const char* what() const noexcept { +#if CURL_AT_LEAST_VERSION(7, 80, 0) + return curl_url_strerror(this->code); +#else + return this->_id_buf; +#endif + } + + CURLUcode code; + +private: +#if !CURL_AT_LEAST_VERSION(7, 80, 0) + char _id_buf[64]; +#endif +}; + +using CurlStr = std::unique_ptr; +class CurlUrl { +public: + CurlUrl(const CurlUrl&) = delete; + CurlUrl& operator=(const CurlUrl&) = delete; + + CurlUrl() { + this->_ptr = curl_url(); + if (!this->_ptr) { + throw std::bad_alloc(); + } + } + ~CurlUrl() { + curl_url_cleanup(this->_ptr); + } + + constexpr CURLU* get() const noexcept { + return this->_ptr; + } + + CurlStr get(CURLUPart part, unsigned int flags = 0) const { + char* content; + CURLUcode code = curl_url_get(this->_ptr, part, &content, flags); + if (code) { + throw CurlUrlException(code); + } + + return CurlStr(content, curl_free); + } + + void set(CURLUPart part, const char* content, unsigned int flags = 0) { + CURLUcode code = curl_url_set(this->_ptr, part, content, flags); + if (code) { + throw CurlUrlException(code); + } + } + + void set(CURLUPart part, const std::string& content, unsigned int flags = 0) { + this->set(part, content.c_str(), flags); + } + +private: + CURLU* _ptr; +}; diff --git a/servehelper.cpp b/servehelper.cpp index 1c55e37..8a526d3 100644 --- a/servehelper.cpp +++ b/servehelper.cpp @@ -10,6 +10,7 @@ #include "timeutils.h" #include "servehelper.h" #include "lxb_wrapper.h" +#include "curlu_wrapper.h" #include "routes/routes.h" #include "blankie/escape.h" @@ -27,30 +28,6 @@ static Element serialize_post(const httplib::Request& req, const std::string& se static inline Element serialize_media(const Media& media); static inline Element serialize_poll(const httplib::Request& req, const Poll& poll); -class CurlUrlException : public std::exception { -public: - CurlUrlException(CURLUcode code_) : code(code_) { -#if !CURL_AT_LEAST_VERSION(7, 80, 0) - snprintf(this->_id_buf, 64, "curl url error %d", this->code); -#endif - } - - const char* what() const noexcept { -#if CURL_AT_LEAST_VERSION(7, 80, 0) - return curl_url_strerror(this->code); -#else - return this->_id_buf; -#endif - } - - CURLUcode code; - -private: -#if !CURL_AT_LEAST_VERSION(7, 80, 0) - char _id_buf[64]; -#endif -}; - void serve(const httplib::Request& req, httplib::Response& res, std::string title, Element element, Nodes extra_head) { using namespace std::string_literals; @@ -118,6 +95,21 @@ void serve_redirect(const httplib::Request& req, httplib::Response& res, std::st +bool starts_with(const CurlUrl& url, const CurlUrl& base) { + if (strcmp(url.get(CURLUPART_SCHEME).get(), base.get(CURLUPART_SCHEME).get()) != 0) { + return false; + } + if (strcmp(url.get(CURLUPART_HOST).get(), base.get(CURLUPART_HOST).get()) != 0) { + return false; + } + + CurlStr url_path = url.get(CURLUPART_PATH); + CurlStr base_path = base.get(CURLUPART_PATH); + size_t base_path_len = strlen(base_path.get()); + return memcpy(url_path.get(), base_path.get(), base_path_len) == 0 + && (url_path.get()[base_path_len] == '/' || url_path.get()[base_path_len] == '\0'); +} + std::string get_origin(const httplib::Request& req) { if (req.has_header("X-Canonical-Origin")) { return req.get_header_value("X-Canonical-Origin"); @@ -139,43 +131,28 @@ std::string get_origin(const httplib::Request& req) { } std::string proxy_mastodon_url(const httplib::Request& req, const std::string& url_str) { - using CurlStr = std::unique_ptr; + CurlUrl url; + url.set(CURLUPART_URL, url_str.c_str()); - std::unique_ptr url(curl_url(), curl_url_cleanup); - if (!url) { - throw std::bad_alloc(); - } + std::string new_url = get_origin(req) + '/' + url.get(CURLUPART_HOST).get() + url.get(CURLUPART_PATH).get(); - // in a block to avoid a (potential) gcc bug where it thinks the lambda below - // shadows `code`, even if i do `[&url](...) { ... }` - { - CURLUcode code = curl_url_set(url.get(), CURLUPART_URL, url_str.c_str(), 0); - if (code) { - throw CurlUrlException(code); - } - } - - auto get_part = [&](CURLUPart part, CURLUcode ignore = CURLUE_OK) { - char* content = nullptr; - CURLUcode code = curl_url_get(url.get(), part, &content, 0); - if (code && code != ignore) { - throw CurlUrlException(code); - } - return CurlStr(content, curl_free); - }; - CurlStr host = get_part(CURLUPART_HOST); - CurlStr path = get_part(CURLUPART_PATH); - CurlStr query = get_part(CURLUPART_QUERY, CURLUE_NO_QUERY); - CurlStr fragment = get_part(CURLUPART_FRAGMENT, CURLUE_NO_FRAGMENT); - - std::string new_url = get_origin(req) + '/' + host.get() + path.get(); - if (query) { + try { + CurlStr query = url.get(CURLUPART_QUERY); new_url += '?'; new_url += query.get(); + } catch (const CurlUrlException& e) { + if (e.code != CURLUE_NO_QUERY) { + throw; + } } - if (fragment) { + try { + CurlStr fragment = url.get(CURLUPART_FRAGMENT); new_url += '#'; new_url += fragment.get(); + } catch (const CurlUrlException& e) { + if (e.code != CURLUE_NO_FRAGMENT) { + throw; + } } return new_url; } @@ -263,10 +240,14 @@ static inline void preprocess_link(const httplib::Request& req, const std::strin const lxb_char_t* cls_c = lxb_dom_element_class(element, &cls_c_len); std::string cls = cls_c ? std::string(reinterpret_cast(cls_c), cls_c_len) : ""; - std::string instance_url_base = "https://"s + domain_name; + CurlUrl href_url; + href_url.set(CURLUPART_URL, href); + CurlUrl instance_url_base; + instance_url_base.set(CURLUPART_SCHEME, "https"); + instance_url_base.set(CURLUPART_HOST, domain_name); // .mention is used in note and posts // Instance base is used for link fields - if (std::regex_search(cls, mention_class_re) || href.starts_with(instance_url_base + '/') || href == instance_url_base) { + if (std::regex_search(cls, mention_class_re) || starts_with(href_url, instance_url_base)) { // Proxy this instance's URLs to Coyote href = proxy_mastodon_url(req, std::move(href)); diff --git a/servehelper.h b/servehelper.h index d3448e4..3e5cb48 100644 --- a/servehelper.h +++ b/servehelper.h @@ -6,6 +6,7 @@ #include "blankie/serializer.h" struct Post; // forward declaration from models.h struct Emoji; // forward declaration from models.h +class CurlUrl; // forward declaration from curlu_wrapper.h using Element = blankie::html::Element; using Node = blankie::html::Node; @@ -16,6 +17,7 @@ void serve_error(const httplib::Request& req, httplib::Response& res, std::string title, std::optional subtitle = std::nullopt, std::optional info = std::nullopt); void serve_redirect(const httplib::Request& req, httplib::Response& res, std::string url, bool permanent = false); +bool starts_with(const CurlUrl& url, const CurlUrl& base); std::string get_origin(const httplib::Request& req); std::string proxy_mastodon_url(const httplib::Request& req, const std::string& url_str); bool should_send_304(const httplib::Request& req, uint64_t hash);