Use CURLU* for URLs

This commit is contained in:
blankie 2023-11-29 22:36:52 +11:00
parent ab8f5569be
commit e231afb49c
Signed by: blankie
GPG Key ID: CC15FC822C7F61F5
5 changed files with 166 additions and 80 deletions

View File

@ -4,6 +4,7 @@
#include "client.h" #include "client.h"
#include "models.h" #include "models.h"
#include "curlu_wrapper.h"
#include "hiredis_wrapper.h" #include "hiredis_wrapper.h"
MastodonClient mastodon_client; MastodonClient mastodon_client;
@ -85,8 +86,13 @@ std::optional<Account> MastodonClient::get_account_by_username(std::string host,
username.erase(username.size() - host.size() - 1); username.erase(username.size() - host.size() - 1);
} }
CurlUrl url;
url.set(CURLUPART_SCHEME, "https");
url.set(CURLUPART_HOST, host);
url.set(CURLUPART_PATH, "/api/v1/accounts/lookup");
url.set(CURLUPART_QUERY, "acct="s + url_encode(username));
try { try {
Account account = this->_send_request("coyote:"s + host + ":@" + username, "https://"s + host + "/api/v1/accounts/lookup?acct=" + url_encode(username)); Account account = this->_send_request("coyote:"s + host + ":@" + username, url);
account.same_server = host == account.server; account.same_server = host == account.server;
return account; return account;
} catch (const MastodonException& e) { } catch (const MastodonException& e) {
@ -102,7 +108,12 @@ std::vector<Post> MastodonClient::get_pinned_posts(std::string host, const std::
using namespace std::string_literals; using namespace std::string_literals;
lowercase(host); lowercase(host);
std::vector<Post> posts = this->_send_request("coyote:"s + host + ':' + account_id + ":pinned", "https://"s + host + "/api/v1/accounts/" + account_id + "/statuses?pinned=true"); CurlUrl url;
url.set(CURLUPART_SCHEME, "https");
url.set(CURLUPART_HOST, host);
url.set(CURLUPART_PATH, "/api/v1/accounts/"s + url_encode(account_id) + "/statuses");
url.set(CURLUPART_QUERY, "pinned=true");
std::vector<Post> posts = this->_send_request("coyote:"s + host + ':' + account_id + ":pinned", url);
for (Post& post : posts) { for (Post& post : posts) {
handle_post_server(post, host); handle_post_server(post, host);
@ -113,22 +124,17 @@ std::vector<Post> MastodonClient::get_pinned_posts(std::string host, const std::
std::vector<Post> MastodonClient::get_posts(const std::string& host, const std::string& account_id, PostSortingMethod sorting_method, std::optional<std::string> max_id) { std::vector<Post> MastodonClient::get_posts(const std::string& host, const std::string& account_id, PostSortingMethod sorting_method, std::optional<std::string> max_id) {
using namespace std::string_literals; using namespace std::string_literals;
const char* sorting_parameters[3] = {"exclude_replies=true", "", "only_media=true"}; const char* sorting_parameters[3] = {"exclude_replies=true", "", "only_media=true"};
std::string query = sorting_parameters[sorting_method];
CurlUrl url;
url.set(CURLUPART_SCHEME, "https");
url.set(CURLUPART_HOST, host);
url.set(CURLUPART_PATH, "/api/v1/accounts/"s + url_encode(account_id) + "/statuses");
url.set(CURLUPART_QUERY, sorting_parameters[sorting_method]);
if (max_id) { if (max_id) {
if (!query.empty()) { url.set(CURLUPART_QUERY, "max_id="s + std::move(*max_id), CURLU_URLENCODE | CURLU_APPENDQUERY);
query += '&';
}
query += "max_id=";
query += url_encode(std::move(*max_id));
} }
std::string url = "https://"s + host + "/api/v1/accounts/" + account_id + "/statuses";
if (!query.empty()) {
url += '?';
url += query;
}
std::vector<Post> posts = this->_send_request(std::nullopt, url); std::vector<Post> posts = this->_send_request(std::nullopt, url);
for (Post& post : posts) { for (Post& post : posts) {
@ -141,8 +147,12 @@ std::vector<Post> MastodonClient::get_posts(const std::string& host, const std::
std::optional<Post> MastodonClient::get_post(const std::string& host, std::string id) { std::optional<Post> MastodonClient::get_post(const std::string& host, std::string id) {
using namespace std::string_literals; using namespace std::string_literals;
CurlUrl url;
url.set(CURLUPART_SCHEME, "https");
url.set(CURLUPART_HOST, host);
url.set(CURLUPART_PATH, "/api/v1/statuses/"s + url_encode(std::move(id)));
try { try {
Post post = this->_send_request(std::nullopt, "https://"s + host + "/api/v1/statuses/" + url_encode(std::move(id))); Post post = this->_send_request(std::nullopt, url);
handle_post_server(post, host); handle_post_server(post, host);
return post; return post;
} catch (const MastodonException& e) { } catch (const MastodonException& e) {
@ -157,7 +167,11 @@ std::optional<Post> MastodonClient::get_post(const std::string& host, std::strin
PostContext MastodonClient::get_post_context(const std::string& host, std::string id) { PostContext MastodonClient::get_post_context(const std::string& host, std::string id) {
using namespace std::string_literals; using namespace std::string_literals;
PostContext context = this->_send_request(std::nullopt, "https://"s + host + "/api/v1/statuses/" + url_encode(std::move(id)) + "/context"); CurlUrl url;
url.set(CURLUPART_SCHEME, "https");
url.set(CURLUPART_HOST, host);
url.set(CURLUPART_PATH, "/api/v1/statuses/"s + url_encode(std::move(id)) + "/context");
PostContext context = this->_send_request(std::nullopt, url);
for (Post& post : context.ancestors) { for (Post& post : context.ancestors) {
handle_post_server(post, host); handle_post_server(post, host);
@ -172,10 +186,12 @@ PostContext MastodonClient::get_post_context(const std::string& host, std::strin
std::vector<Post> MastodonClient::get_tag_timeline(const std::string& host, const std::string& tag, std::optional<std::string> max_id) { std::vector<Post> MastodonClient::get_tag_timeline(const std::string& host, const std::string& tag, std::optional<std::string> max_id) {
using namespace std::string_literals; using namespace std::string_literals;
std::string url = "https://"s + host + "/api/v1/timelines/tag/" + url_encode(tag); CurlUrl url;
url.set(CURLUPART_SCHEME, "https");
url.set(CURLUPART_HOST, host);
url.set(CURLUPART_PATH, "/api/v1/timelines/tag/"s + url_encode(tag));
if (max_id) { if (max_id) {
url += "?max_id="; url.set(CURLUPART_QUERY, "max_id="s + std::move(*max_id), CURLU_URLENCODE | CURLU_APPENDQUERY);
url += url_encode(std::move(*max_id));
} }
std::vector<Post> posts = this->_send_request(std::nullopt, url); std::vector<Post> posts = this->_send_request(std::nullopt, url);
@ -189,7 +205,11 @@ Instance MastodonClient::get_instance(std::string host) {
using namespace std::string_literals; using namespace std::string_literals;
lowercase(host); lowercase(host);
Instance instance = this->_send_request("coyote:"s + host + ":instance", "https://"s + host + "/api/v2/instance"); CurlUrl url;
url.set(CURLUPART_SCHEME, "https");
url.set(CURLUPART_HOST, host);
url.set(CURLUPART_PATH, "/api/v2/instance");
Instance instance = this->_send_request("coyote:"s + host + ":instance", url);
instance.contact_account.same_server = instance.contact_account.server == host; instance.contact_account.same_server = instance.contact_account.server == host;
return instance; return instance;
} }
@ -198,7 +218,11 @@ blankie::html::HTMLString MastodonClient::get_extended_description(std::string h
using namespace std::string_literals; using namespace std::string_literals;
lowercase(host); lowercase(host);
nlohmann::json j = this->_send_request("coyote:"s + host + ":desc", "https://"s + host + "/api/v1/instance/extended_description"); CurlUrl url;
url.set(CURLUPART_SCHEME, "https");
url.set(CURLUPART_HOST, host);
url.set(CURLUPART_PATH, "/api/v1/instance/extended_description");
nlohmann::json j = this->_send_request("coyote:"s + host + ":desc", url);
return blankie::html::HTMLString(j.at("content").get<std::string>()); return blankie::html::HTMLString(j.at("content").get<std::string>());
} }
@ -234,7 +258,7 @@ CURL* MastodonClient::_get_easy() {
return curl; return curl;
} }
nlohmann::json MastodonClient::_send_request(std::optional<std::string> cache_key, const std::string& url) { nlohmann::json MastodonClient::_send_request(std::optional<std::string> cache_key, const CurlUrl& url) {
std::optional<std::string> cached; std::optional<std::string> cached;
if (redis && cache_key && (cached = redis->get(*cache_key))) { if (redis && cache_key && (cached = redis->get(*cache_key))) {
return nlohmann::json::parse(std::move(*cached)); return nlohmann::json::parse(std::move(*cached));
@ -243,10 +267,11 @@ nlohmann::json MastodonClient::_send_request(std::optional<std::string> cache_ke
std::string res; std::string res;
CURL* curl = this->_get_easy(); CURL* curl = this->_get_easy();
setopt(curl, CURLOPT_URL, url.c_str()); setopt(curl, CURLOPT_CURLU, url.get());
setopt(curl, CURLOPT_WRITEFUNCTION, curl_write_cb); setopt(curl, CURLOPT_WRITEFUNCTION, curl_write_cb);
setopt(curl, CURLOPT_WRITEDATA, &res); setopt(curl, CURLOPT_WRITEDATA, &res);
CURLcode code = curl_easy_perform(curl); CURLcode code = curl_easy_perform(curl);
setopt(curl, CURLOPT_CURLU, nullptr);
if (code) { if (code) {
throw CurlException(code); throw CurlException(code);
} }

View File

@ -8,6 +8,7 @@
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include "models.h" #include "models.h"
class CurlUrl; // forward declaration from curlu_wrapper.h
class CurlException : public std::exception { class CurlException : public std::exception {
public: public:
@ -78,7 +79,7 @@ public:
private: private:
CURL* _get_easy(); CURL* _get_easy();
nlohmann::json _send_request(std::optional<std::string> cache_key, const std::string& url); nlohmann::json _send_request(std::optional<std::string> cache_key, const CurlUrl& url);
long _response_status_code(); long _response_status_code();
std::mutex _share_locks[CURL_LOCK_DATA_LAST]; std::mutex _share_locks[CURL_LOCK_DATA_LAST];

77
curlu_wrapper.h Normal file
View File

@ -0,0 +1,77 @@
#pragma once
#include <cstdio>
#include <string>
#include <memory>
#include <exception>
#include <stdexcept>
#include <curl/curl.h>
class CurlUrlException : public std::exception {
public:
CurlUrlException(CURLUcode code_) : code(code_) {
#if !CURL_AT_LEAST_VERSION(7, 80, 0)
snprintf(this->_id_buf, 64, "curl url error %d", this->code);
#endif
}
const char* what() const noexcept {
#if CURL_AT_LEAST_VERSION(7, 80, 0)
return curl_url_strerror(this->code);
#else
return this->_id_buf;
#endif
}
CURLUcode code;
private:
#if !CURL_AT_LEAST_VERSION(7, 80, 0)
char _id_buf[64];
#endif
};
using CurlStr = std::unique_ptr<char, decltype(&curl_free)>;
class CurlUrl {
public:
CurlUrl(const CurlUrl&) = delete;
CurlUrl& operator=(const CurlUrl&) = delete;
CurlUrl() {
this->_ptr = curl_url();
if (!this->_ptr) {
throw std::bad_alloc();
}
}
~CurlUrl() {
curl_url_cleanup(this->_ptr);
}
constexpr CURLU* get() const noexcept {
return this->_ptr;
}
CurlStr get(CURLUPart part, unsigned int flags = 0) const {
char* content;
CURLUcode code = curl_url_get(this->_ptr, part, &content, flags);
if (code) {
throw CurlUrlException(code);
}
return CurlStr(content, curl_free);
}
void set(CURLUPart part, const char* content, unsigned int flags = 0) {
CURLUcode code = curl_url_set(this->_ptr, part, content, flags);
if (code) {
throw CurlUrlException(code);
}
}
void set(CURLUPart part, const std::string& content, unsigned int flags = 0) {
this->set(part, content.c_str(), flags);
}
private:
CURLU* _ptr;
};

View File

@ -10,6 +10,7 @@
#include "timeutils.h" #include "timeutils.h"
#include "servehelper.h" #include "servehelper.h"
#include "lxb_wrapper.h" #include "lxb_wrapper.h"
#include "curlu_wrapper.h"
#include "routes/routes.h" #include "routes/routes.h"
#include "blankie/escape.h" #include "blankie/escape.h"
@ -27,30 +28,6 @@ static Element serialize_post(const httplib::Request& req, const std::string& se
static inline Element serialize_media(const Media& media); static inline Element serialize_media(const Media& media);
static inline Element serialize_poll(const httplib::Request& req, const Poll& poll); static inline Element serialize_poll(const httplib::Request& req, const Poll& poll);
class CurlUrlException : public std::exception {
public:
CurlUrlException(CURLUcode code_) : code(code_) {
#if !CURL_AT_LEAST_VERSION(7, 80, 0)
snprintf(this->_id_buf, 64, "curl url error %d", this->code);
#endif
}
const char* what() const noexcept {
#if CURL_AT_LEAST_VERSION(7, 80, 0)
return curl_url_strerror(this->code);
#else
return this->_id_buf;
#endif
}
CURLUcode code;
private:
#if !CURL_AT_LEAST_VERSION(7, 80, 0)
char _id_buf[64];
#endif
};
void serve(const httplib::Request& req, httplib::Response& res, std::string title, Element element, Nodes extra_head) { void serve(const httplib::Request& req, httplib::Response& res, std::string title, Element element, Nodes extra_head) {
using namespace std::string_literals; using namespace std::string_literals;
@ -118,6 +95,21 @@ void serve_redirect(const httplib::Request& req, httplib::Response& res, std::st
bool starts_with(const CurlUrl& url, const CurlUrl& base) {
if (strcmp(url.get(CURLUPART_SCHEME).get(), base.get(CURLUPART_SCHEME).get()) != 0) {
return false;
}
if (strcmp(url.get(CURLUPART_HOST).get(), base.get(CURLUPART_HOST).get()) != 0) {
return false;
}
CurlStr url_path = url.get(CURLUPART_PATH);
CurlStr base_path = base.get(CURLUPART_PATH);
size_t base_path_len = strlen(base_path.get());
return memcpy(url_path.get(), base_path.get(), base_path_len) == 0
&& (url_path.get()[base_path_len] == '/' || url_path.get()[base_path_len] == '\0');
}
std::string get_origin(const httplib::Request& req) { std::string get_origin(const httplib::Request& req) {
if (req.has_header("X-Canonical-Origin")) { if (req.has_header("X-Canonical-Origin")) {
return req.get_header_value("X-Canonical-Origin"); return req.get_header_value("X-Canonical-Origin");
@ -139,43 +131,28 @@ std::string get_origin(const httplib::Request& req) {
} }
std::string proxy_mastodon_url(const httplib::Request& req, const std::string& url_str) { std::string proxy_mastodon_url(const httplib::Request& req, const std::string& url_str) {
using CurlStr = std::unique_ptr<char, decltype(&curl_free)>; CurlUrl url;
url.set(CURLUPART_URL, url_str.c_str());
std::unique_ptr<CURLU, decltype(&curl_url_cleanup)> url(curl_url(), curl_url_cleanup); std::string new_url = get_origin(req) + '/' + url.get(CURLUPART_HOST).get() + url.get(CURLUPART_PATH).get();
if (!url) {
throw std::bad_alloc();
}
// in a block to avoid a (potential) gcc bug where it thinks the lambda below try {
// shadows `code`, even if i do `[&url](...) { ... }` CurlStr query = url.get(CURLUPART_QUERY);
{
CURLUcode code = curl_url_set(url.get(), CURLUPART_URL, url_str.c_str(), 0);
if (code) {
throw CurlUrlException(code);
}
}
auto get_part = [&](CURLUPart part, CURLUcode ignore = CURLUE_OK) {
char* content = nullptr;
CURLUcode code = curl_url_get(url.get(), part, &content, 0);
if (code && code != ignore) {
throw CurlUrlException(code);
}
return CurlStr(content, curl_free);
};
CurlStr host = get_part(CURLUPART_HOST);
CurlStr path = get_part(CURLUPART_PATH);
CurlStr query = get_part(CURLUPART_QUERY, CURLUE_NO_QUERY);
CurlStr fragment = get_part(CURLUPART_FRAGMENT, CURLUE_NO_FRAGMENT);
std::string new_url = get_origin(req) + '/' + host.get() + path.get();
if (query) {
new_url += '?'; new_url += '?';
new_url += query.get(); new_url += query.get();
} catch (const CurlUrlException& e) {
if (e.code != CURLUE_NO_QUERY) {
throw;
}
} }
if (fragment) { try {
CurlStr fragment = url.get(CURLUPART_FRAGMENT);
new_url += '#'; new_url += '#';
new_url += fragment.get(); new_url += fragment.get();
} catch (const CurlUrlException& e) {
if (e.code != CURLUE_NO_FRAGMENT) {
throw;
}
} }
return new_url; return new_url;
} }
@ -263,10 +240,14 @@ static inline void preprocess_link(const httplib::Request& req, const std::strin
const lxb_char_t* cls_c = lxb_dom_element_class(element, &cls_c_len); const lxb_char_t* cls_c = lxb_dom_element_class(element, &cls_c_len);
std::string cls = cls_c ? std::string(reinterpret_cast<const char*>(cls_c), cls_c_len) : ""; std::string cls = cls_c ? std::string(reinterpret_cast<const char*>(cls_c), cls_c_len) : "";
std::string instance_url_base = "https://"s + domain_name; CurlUrl href_url;
href_url.set(CURLUPART_URL, href);
CurlUrl instance_url_base;
instance_url_base.set(CURLUPART_SCHEME, "https");
instance_url_base.set(CURLUPART_HOST, domain_name);
// .mention is used in note and posts // .mention is used in note and posts
// Instance base is used for link fields // Instance base is used for link fields
if (std::regex_search(cls, mention_class_re) || href.starts_with(instance_url_base + '/') || href == instance_url_base) { if (std::regex_search(cls, mention_class_re) || starts_with(href_url, instance_url_base)) {
// Proxy this instance's URLs to Coyote // Proxy this instance's URLs to Coyote
href = proxy_mastodon_url(req, std::move(href)); href = proxy_mastodon_url(req, std::move(href));

View File

@ -6,6 +6,7 @@
#include "blankie/serializer.h" #include "blankie/serializer.h"
struct Post; // forward declaration from models.h struct Post; // forward declaration from models.h
struct Emoji; // forward declaration from models.h struct Emoji; // forward declaration from models.h
class CurlUrl; // forward declaration from curlu_wrapper.h
using Element = blankie::html::Element; using Element = blankie::html::Element;
using Node = blankie::html::Node; using Node = blankie::html::Node;
@ -16,6 +17,7 @@ void serve_error(const httplib::Request& req, httplib::Response& res,
std::string title, std::optional<std::string> subtitle = std::nullopt, std::optional<std::string> info = std::nullopt); std::string title, std::optional<std::string> subtitle = std::nullopt, std::optional<std::string> info = std::nullopt);
void serve_redirect(const httplib::Request& req, httplib::Response& res, std::string url, bool permanent = false); void serve_redirect(const httplib::Request& req, httplib::Response& res, std::string url, bool permanent = false);
bool starts_with(const CurlUrl& url, const CurlUrl& base);
std::string get_origin(const httplib::Request& req); std::string get_origin(const httplib::Request& req);
std::string proxy_mastodon_url(const httplib::Request& req, const std::string& url_str); std::string proxy_mastodon_url(const httplib::Request& req, const std::string& url_str);
bool should_send_304(const httplib::Request& req, uint64_t hash); bool should_send_304(const httplib::Request& req, uint64_t hash);