diff --git a/CMakeLists.txt b/CMakeLists.txt index d4114c2..0f28aba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,10 +5,14 @@ project(coyote C CXX) find_package(nlohmann_json REQUIRED) find_package(CURL REQUIRED) +find_package(OpenSSL REQUIRED) + set(HTTPLIB_REQUIRE_OPENSSL ON) add_subdirectory(thirdparty/httplib) + set(LEXBOR_BUILD_SHARED OFF) add_subdirectory(thirdparty/lexbor) + find_package(PkgConfig REQUIRED) pkg_check_modules(HIREDIS REQUIRED hiredis) @@ -29,7 +33,7 @@ list(APPEND FLAGS -Werror -Wall -Wextra -Wshadow -Wpedantic -Wno-gnu-anonymous-s add_link_options(${FLAGS}) -add_executable(${PROJECT_NAME} main.cpp numberhelper.cpp hex.cpp config.cpp settings.cpp models.cpp client.cpp servehelper.cpp htmlhelper.cpp timeutils.cpp hiredis_wrapper.cpp +add_executable(${PROJECT_NAME} main.cpp numberhelper.cpp hex.cpp config.cpp settings.cpp models.cpp client.cpp servehelper.cpp htmlhelper.cpp timeutils.cpp openssl_wrapper.cpp hiredis_wrapper.cpp routes/home.cpp routes/css.cpp routes/user.cpp routes/status.cpp routes/tags.cpp routes/about.cpp routes/user_settings.cpp blankie/serializer.cpp blankie/escape.cpp) set_target_properties(${PROJECT_NAME} @@ -39,6 +43,6 @@ set_target_properties(${PROJECT_NAME} CXX_EXTENSIONS NO ) target_include_directories(${PROJECT_NAME} PRIVATE thirdparty ${HIREDIS_INCLUDE_DIRS}) -target_link_libraries(${PROJECT_NAME} PRIVATE nlohmann_json::nlohmann_json CURL::libcurl httplib::httplib lexbor_static ${HIREDIS_LINK_LIBRARIES}) +target_link_libraries(${PROJECT_NAME} PRIVATE nlohmann_json::nlohmann_json CURL::libcurl OpenSSL::Crypto httplib::httplib lexbor_static ${HIREDIS_LINK_LIBRARIES}) target_compile_definitions(${PROJECT_NAME} PRIVATE ${DEFINITIONS}) target_compile_options(${PROJECT_NAME} PRIVATE ${FLAGS}) diff --git a/RUNNING.md b/RUNNING.md index f16b78b..c1c3010 100644 --- a/RUNNING.md +++ b/RUNNING.md @@ -15,6 +15,7 @@ Copy `example_config.json` to a file with any name you like liking. Here's a list of what they are: - `bind_host` (string): What address to bind to - `bind_port` (zero or positive integer): What port to bind to +- `hmac_key` (hex string): A secret key to be used; generate with `head -c32 /dev/urandom | basenc --base16` - `canonical_origin` (string or null): A fallback canonical origin if set, useful if you're, say, running Coyote behind Ngrok - `redis` (object) - `enabled` (boolean) diff --git a/config.cpp b/config.cpp index abd8370..7c9cd64 100644 --- a/config.cpp +++ b/config.cpp @@ -1,6 +1,7 @@ #include #include +#include "hex.h" #include "file.h" #include "config.h" @@ -47,6 +48,7 @@ void from_json(const nlohmann::json& j, Config& conf) { throw std::invalid_argument("Invalid port to bind to: "s + std::to_string(conf.bind_port)); } + conf.hmac_key = hex_decode(j.at("hmac_key").get_ref()); if (j.at("canonical_origin").is_string()) { conf.canonical_origin = j["canonical_origin"].get(); } diff --git a/config.h b/config.h index 88d55d5..34ea77f 100644 --- a/config.h +++ b/config.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -22,6 +23,7 @@ struct RedisConfig { struct Config { std::string bind_host = "127.0.0.1"; int bind_port = 8080; + std::vector hmac_key; std::optional canonical_origin; std::optional redis_config; }; diff --git a/example_config.json b/example_config.json index 2da2e00..f6bc363 100644 --- a/example_config.json +++ b/example_config.json @@ -1,6 +1,7 @@ { "bind_host": "127.0.0.1", "bind_port": 8080, + "hmac_key": "AA", "canonical_origin": null, "redis": { "enabled": true, diff --git a/openssl_wrapper.cpp b/openssl_wrapper.cpp new file mode 100644 index 0000000..99352ca --- /dev/null +++ b/openssl_wrapper.cpp @@ -0,0 +1,35 @@ +#include +#include + +#include +#include +#include "openssl_wrapper.h" + +std::vector secure_random_bytes(int num) { + if (num < 0) { + throw std::invalid_argument("secure_random_bytes(): num variable out of range (num < 0)"); + } + + std::vector bytes(static_cast(num), 0); + if (RAND_bytes(reinterpret_cast(bytes.data()), num) == 1) { + return bytes; + } else { + throw OpenSSLException(ERR_get_error()); + } +} + +std::array hmac_sha3_256(const std::vector& key, const std::vector& data) { + char hmac[32]; + unsigned int md_len; + + std::unique_ptr md(EVP_MD_fetch(nullptr, "SHA3-256", nullptr), EVP_MD_free); + if (HMAC(md.get(), key.data(), static_cast(key.size()), reinterpret_cast(data.data()), data.size(), reinterpret_cast(hmac), &md_len)) { + if (md_len != 32) { + throw std::runtime_error("hmac_sha3_256(): HMAC() returned an unexpected size"); + } + + return std::to_array(hmac); + } else { + throw OpenSSLException(ERR_get_error()); + } +} diff --git a/openssl_wrapper.h b/openssl_wrapper.h new file mode 100644 index 0000000..daf92b2 --- /dev/null +++ b/openssl_wrapper.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include + +#include + +class OpenSSLException : public std::exception { +public: + OpenSSLException(unsigned long e) { + ERR_error_string_n(e, this->_str, 1024); + } + + const char* what() const noexcept { + return this->_str; + } + +private: + char _str[1024]; +}; + +std::vector secure_random_bytes(int num); +std::array hmac_sha3_256(const std::vector& key, const std::vector& data); diff --git a/routes/user_settings.cpp b/routes/user_settings.cpp index 1c28438..7380c1a 100644 --- a/routes/user_settings.cpp +++ b/routes/user_settings.cpp @@ -1,23 +1,58 @@ #include "routes.h" +#include "../hex.h" +#include "../config.h" #include "../servehelper.h" #include "../settings.h" #include "../timeutils.h" #include "../curlu_wrapper.h" +#include "../openssl_wrapper.h" -static void set_cookie(const httplib::Request& req, httplib::Response& res, const char* key, std::string_view value); +static inline std::string generate_csrf_token(void); +static inline bool validate_csrf_token(const httplib::Request& req, httplib::Response& res, std::string_view csrf_token, std::string_view query_csrf_token); +static void set_cookie(const httplib::Request& req, httplib::Response& res, const char* key, std::string_view value, bool session = false); +static bool safe_memcmp(const char* s1, const char* s2, size_t n); void user_settings_route(const httplib::Request& req, httplib::Response& res) { UserSettings settings; + Cookies cookies = parse_cookies(req); + std::string csrf_token; if (req.method == "POST") { + if (!cookies.contains("csrf-token")) { + res.status = 400; + serve_error(req, res, "400: Bad Request", "Missing CSRF token cookie, are cookies enabled?"); + return; + } + csrf_token = cookies["csrf-token"]; + + auto query_csrf_token = req.params.find("csrf-token"); + if (query_csrf_token == req.params.end()) { + res.status = 400; + serve_error(req, res, "400: Bad Request", "Missing CSRF token query parameter"); + return; + } + + if (!validate_csrf_token(req, res, csrf_token, query_csrf_token->second)) { + return; + } + for (const auto& i : req.params) { settings.set(i.first, i.second); } set_cookie(req, res, "auto-open-cw", settings.auto_open_cw ? "true" : "false"); } else { - settings.load_from_cookies(req); + for (auto &[name, value] : cookies) { + settings.set(name, value); + } + + if (cookies.contains("csrf-token")) { + csrf_token = cookies["csrf-token"]; + } else { + csrf_token = generate_csrf_token(); + set_cookie(req, res, "csrf-token", csrf_token, true); + } } Element auto_open_cw_checkbox("input", {{"type", "checkbox"}, {"name", "auto-open-cw"}, {"value", "true"}}, {}); @@ -33,6 +68,7 @@ void user_settings_route(const httplib::Request& req, httplib::Response& res) { }), Element("br"), + Element("input", {{"type", "hidden"}, {"name", "csrf-token"}, {"value", csrf_token}}, {}), Element("input", {{"type", "submit"}, {"value", "Save"}}, {}), }), Element("form", {{"class", "user_settings_page-form"}, {"method", "get"}, {"action", get_origin(req)}}, { @@ -49,16 +85,69 @@ void user_settings_route(const httplib::Request& req, httplib::Response& res) { } -static void set_cookie(const httplib::Request& req, httplib::Response& res, const char* key, std::string_view value) { +static inline std::string generate_csrf_token(void) { + std::vector raw_token = secure_random_bytes(32); + std::array raw_token_hmac = hmac_sha3_256(config.hmac_key, raw_token); + + return hex_encode(raw_token) + '.' + hex_encode(raw_token_hmac.data(), raw_token_hmac.size()); +} + +static inline bool validate_csrf_token(const httplib::Request& req, httplib::Response& res, std::string_view csrf_token, std::string_view query_csrf_token) { + if (csrf_token.size() != query_csrf_token.size() || !safe_memcmp(csrf_token.data(), query_csrf_token.data(), csrf_token.size())) { + res.status = 400; + serve_error(req, res, "400: Bad Request", "CSRF token cookie and CSRF token query parameter do not match"); + return false; + } + + if (csrf_token.size() != 64 + 1 + 64 || csrf_token[64] != '.') { + res.status = 400; + serve_error(req, res, "400: Bad Request", "CSRF token is in an unknown format"); + return false; + } + + std::vector raw_token, raw_token_hmac; + try { + raw_token = hex_decode(csrf_token.substr(0, 64)); + raw_token_hmac = hex_decode(csrf_token.substr(64 + 1, 64)); + } catch (const std::exception& e) { + res.status = 400; + serve_error(req, res, "400: Bad Request", "Failed to parse CSRF token", e.what()); + return false; + } + + std::array our_raw_token_hmac = hmac_sha3_256(config.hmac_key, raw_token); + if (!safe_memcmp(raw_token_hmac.data(), our_raw_token_hmac.data(), 32)) { + res.status = 400; + serve_error(req, res, "400: Bad Request", "CSRF token HMAC is not correct"); + return false; + } + + return true; +} + +static void set_cookie(const httplib::Request& req, httplib::Response& res, const char* key, std::string_view value, bool session) { CurlUrl origin; origin.set(CURLUPART_URL, get_origin(req)); std::string header = std::string(key) + '=' + std::string(value) - + "; HttpOnly; SameSite=Lax; Domain=" + origin.get(CURLUPART_HOST).get() + "; Path=" + origin.get(CURLUPART_PATH).get() - + "; Expires=" + to_web_date(current_time() + 365 * 24 * 60 * 60); + + "; HttpOnly; SameSite=Lax; Domain=" + origin.get(CURLUPART_HOST).get() + "; Path=" + origin.get(CURLUPART_PATH).get(); + if (!session) { + header += "; Expires="; + header += to_web_date(current_time() + 365 * 24 * 60 * 60); + } if (strcmp(origin.get(CURLUPART_SCHEME).get(), "https") == 0) { header += "; Secure"; } res.set_header("Set-Cookie", header); } + +static bool safe_memcmp(const char* s1, const char* s2, size_t n) { + bool equal = true; + + for (size_t i = 0; i < n; i++) { + equal &= s1[i] == s2[i]; + } + + return equal; +} diff --git a/servehelper.cpp b/servehelper.cpp index e2bc89f..3460dd6 100644 --- a/servehelper.cpp +++ b/servehelper.cpp @@ -9,6 +9,11 @@ #include "curlu_wrapper.h" #include "routes/routes.h" +static inline void parse_cookies(std::string_view str, Cookies& cookies); +static inline bool lowercase_compare(std::string_view lhs, std::string_view rhs); + + + void serve(const httplib::Request& req, httplib::Response& res, std::string title, Element element, Nodes extra_head) { using namespace std::string_literals; @@ -146,3 +151,60 @@ bool should_send_304(const httplib::Request& req, uint64_t hash) { size_t pos = header.find(std::string(1, '"') + std::to_string(hash) + '"'); return pos != std::string::npos && (pos == 0 || header[pos - 1] != '/'); } + + +Cookies parse_cookies(const httplib::Request& req) { + Cookies cookies; + + for (const auto& i : req.headers) { + if (lowercase_compare(i.first, "cookie")) { + parse_cookies(i.second, cookies); + } + } + + return cookies; +} + + + +static inline void parse_cookies(std::string_view str, Cookies& cookies) { + using namespace std::string_literals; + + size_t offset = 0; + size_t new_offset = 0; + const char* delimiter = "; "; + size_t delimiter_len = strlen(delimiter); + + while (offset < str.size()) { + new_offset = str.find(delimiter, offset); + + std::string_view item = str.substr(offset, new_offset != std::string_view::npos ? new_offset - offset : std::string_view::npos); + size_t equal_offset = item.find('='); + if (equal_offset == std::string_view::npos) { + throw std::invalid_argument("invalid user setting item: "s + std::string(item)); + } + cookies.insert({std::string(item.substr(0, equal_offset)), std::string(item.substr(equal_offset + 1))}); + + if (new_offset == std::string_view::npos) { + break; + } + offset = new_offset + delimiter_len; + } +} + +static inline bool lowercase_compare(std::string_view lhs, std::string_view rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + + auto lower = [](char c) { + return c >= 'A' && c <= 'Z' ? c - 'A' + 'a' : c; + }; + for (size_t i = 0; i < lhs.size(); i++) { + if (lower(lhs[i]) != lower(rhs[i])) { + return false; + } + } + + return true; +} diff --git a/servehelper.h b/servehelper.h index 4cd9774..b707bf1 100644 --- a/servehelper.h +++ b/servehelper.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include "blankie/serializer.h" @@ -9,6 +10,7 @@ class CurlUrl; // forward declaration from curlu_wrapper.h using Element = blankie::html::Element; using Node = blankie::html::Node; using Nodes = std::vector; +using Cookies = std::unordered_map; void serve(const httplib::Request& req, httplib::Response& res, std::string title, Element element, Nodes extra_head = {}); void serve_error(const httplib::Request& req, httplib::Response& res, @@ -19,3 +21,5 @@ bool starts_with(const CurlUrl& url, const CurlUrl& base); std::string get_origin(const httplib::Request& req); std::string proxy_mastodon_url(const httplib::Request& req, const std::string& url_str); bool should_send_304(const httplib::Request& req, uint64_t hash); + +Cookies parse_cookies(const httplib::Request& req); diff --git a/settings.cpp b/settings.cpp index 9890c2d..1126ce9 100644 --- a/settings.cpp +++ b/settings.cpp @@ -1,9 +1,8 @@ #include #include #include "settings.h" +#include "servehelper.h" -static void set_settings(std::string_view str, const char* delimiter, UserSettings& settings); -static inline bool lowercase_compare(std::string_view lhs, std::string_view rhs); static bool parse_bool(std::string_view value); @@ -14,54 +13,14 @@ void UserSettings::set(std::string_view key, std::string_view value) { } void UserSettings::load_from_cookies(const httplib::Request& req) { - for (const auto& i : req.headers) { - if (lowercase_compare(i.first, "cookie")) { - set_settings(i.second, "; ", *this); - } + Cookies cookies = parse_cookies(req); + + for (auto &[name, value] : cookies) { + this->set(name, value); } } -static void set_settings(std::string_view str, const char* delimiter, UserSettings& settings) { - using namespace std::string_literals; - size_t offset = 0; - size_t new_offset = 0; - size_t delimiter_len = strlen(delimiter); - - while (offset < str.size()) { - new_offset = str.find(delimiter, offset); - - std::string_view item = str.substr(offset, new_offset != std::string_view::npos ? new_offset - offset : std::string_view::npos); - size_t equal_offset = item.find('='); - if (equal_offset == std::string_view::npos) { - throw std::invalid_argument("invalid user setting item: "s + std::string(item)); - } - settings.set(item.substr(0, equal_offset), item.substr(equal_offset + 1)); - - if (new_offset == std::string_view::npos) { - break; - } - offset = new_offset + delimiter_len; - } -} - -static inline bool lowercase_compare(std::string_view lhs, std::string_view rhs) { - if (lhs.size() != rhs.size()) { - return false; - } - - auto lower = [](char c) { - return c >= 'A' && c <= 'Z' ? c - 'A' + 'a' : c; - }; - for (size_t i = 0; i < lhs.size(); i++) { - if (lower(lhs[i]) != lower(rhs[i])) { - return false; - } - } - - return true; -} - static bool parse_bool(std::string_view value) { using namespace std::string_literals;