Add CSRF protection
This commit is contained in:
parent
973a0eada2
commit
9b21060b3a
|
@ -5,10 +5,14 @@ project(coyote C CXX)
|
|||
|
||||
find_package(nlohmann_json REQUIRED)
|
||||
find_package(CURL REQUIRED)
|
||||
find_package(OpenSSL REQUIRED)
|
||||
|
||||
set(HTTPLIB_REQUIRE_OPENSSL ON)
|
||||
add_subdirectory(thirdparty/httplib)
|
||||
|
||||
set(LEXBOR_BUILD_SHARED OFF)
|
||||
add_subdirectory(thirdparty/lexbor)
|
||||
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(HIREDIS REQUIRED hiredis)
|
||||
|
||||
|
@ -29,7 +33,7 @@ list(APPEND FLAGS -Werror -Wall -Wextra -Wshadow -Wpedantic -Wno-gnu-anonymous-s
|
|||
add_link_options(${FLAGS})
|
||||
|
||||
|
||||
add_executable(${PROJECT_NAME} main.cpp numberhelper.cpp hex.cpp config.cpp settings.cpp models.cpp client.cpp servehelper.cpp htmlhelper.cpp timeutils.cpp hiredis_wrapper.cpp
|
||||
add_executable(${PROJECT_NAME} main.cpp numberhelper.cpp hex.cpp config.cpp settings.cpp models.cpp client.cpp servehelper.cpp htmlhelper.cpp timeutils.cpp openssl_wrapper.cpp hiredis_wrapper.cpp
|
||||
routes/home.cpp routes/css.cpp routes/user.cpp routes/status.cpp routes/tags.cpp routes/about.cpp routes/user_settings.cpp
|
||||
blankie/serializer.cpp blankie/escape.cpp)
|
||||
set_target_properties(${PROJECT_NAME}
|
||||
|
@ -39,6 +43,6 @@ set_target_properties(${PROJECT_NAME}
|
|||
CXX_EXTENSIONS NO
|
||||
)
|
||||
target_include_directories(${PROJECT_NAME} PRIVATE thirdparty ${HIREDIS_INCLUDE_DIRS})
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE nlohmann_json::nlohmann_json CURL::libcurl httplib::httplib lexbor_static ${HIREDIS_LINK_LIBRARIES})
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE nlohmann_json::nlohmann_json CURL::libcurl OpenSSL::Crypto httplib::httplib lexbor_static ${HIREDIS_LINK_LIBRARIES})
|
||||
target_compile_definitions(${PROJECT_NAME} PRIVATE ${DEFINITIONS})
|
||||
target_compile_options(${PROJECT_NAME} PRIVATE ${FLAGS})
|
||||
|
|
|
@ -15,6 +15,7 @@ Copy `example_config.json` to a file with any name you like
|
|||
liking. Here's a list of what they are:
|
||||
- `bind_host` (string): What address to bind to
|
||||
- `bind_port` (zero or positive integer): What port to bind to
|
||||
- `hmac_key` (hex string): A secret key to be used; generate with `head -c32 /dev/urandom | basenc --base16`
|
||||
- `canonical_origin` (string or null): A fallback canonical origin if set, useful if you're, say, running Coyote behind Ngrok
|
||||
- `redis` (object)
|
||||
- `enabled` (boolean)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include <stdexcept>
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "hex.h"
|
||||
#include "file.h"
|
||||
#include "config.h"
|
||||
|
||||
|
@ -47,6 +48,7 @@ void from_json(const nlohmann::json& j, Config& conf) {
|
|||
throw std::invalid_argument("Invalid port to bind to: "s + std::to_string(conf.bind_port));
|
||||
}
|
||||
|
||||
conf.hmac_key = hex_decode(j.at("hmac_key").get_ref<const std::string&>());
|
||||
if (j.at("canonical_origin").is_string()) {
|
||||
conf.canonical_origin = j["canonical_origin"].get<std::string>();
|
||||
}
|
||||
|
|
2
config.h
2
config.h
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <variant>
|
||||
#include <optional>
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
@ -22,6 +23,7 @@ struct RedisConfig {
|
|||
struct Config {
|
||||
std::string bind_host = "127.0.0.1";
|
||||
int bind_port = 8080;
|
||||
std::vector<char> hmac_key;
|
||||
std::optional<std::string> canonical_origin;
|
||||
std::optional<RedisConfig> redis_config;
|
||||
};
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
{
|
||||
"bind_host": "127.0.0.1",
|
||||
"bind_port": 8080,
|
||||
"hmac_key": "AA",
|
||||
"canonical_origin": null,
|
||||
"redis": {
|
||||
"enabled": true,
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
#include <memory>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <openssl/hmac.h>
|
||||
#include <openssl/rand.h>
|
||||
#include "openssl_wrapper.h"
|
||||
|
||||
std::vector<char> secure_random_bytes(int num) {
|
||||
if (num < 0) {
|
||||
throw std::invalid_argument("secure_random_bytes(): num variable out of range (num < 0)");
|
||||
}
|
||||
|
||||
std::vector<char> bytes(static_cast<size_t>(num), 0);
|
||||
if (RAND_bytes(reinterpret_cast<unsigned char*>(bytes.data()), num) == 1) {
|
||||
return bytes;
|
||||
} else {
|
||||
throw OpenSSLException(ERR_get_error());
|
||||
}
|
||||
}
|
||||
|
||||
std::array<char, 32> hmac_sha3_256(const std::vector<char>& key, const std::vector<char>& data) {
|
||||
char hmac[32];
|
||||
unsigned int md_len;
|
||||
|
||||
std::unique_ptr<EVP_MD, decltype(&EVP_MD_free)> md(EVP_MD_fetch(nullptr, "SHA3-256", nullptr), EVP_MD_free);
|
||||
if (HMAC(md.get(), key.data(), static_cast<int>(key.size()), reinterpret_cast<const unsigned char*>(data.data()), data.size(), reinterpret_cast<unsigned char*>(hmac), &md_len)) {
|
||||
if (md_len != 32) {
|
||||
throw std::runtime_error("hmac_sha3_256(): HMAC() returned an unexpected size");
|
||||
}
|
||||
|
||||
return std::to_array(hmac);
|
||||
} else {
|
||||
throw OpenSSLException(ERR_get_error());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
#include <exception>
|
||||
|
||||
#include <openssl/err.h>
|
||||
|
||||
class OpenSSLException : public std::exception {
|
||||
public:
|
||||
OpenSSLException(unsigned long e) {
|
||||
ERR_error_string_n(e, this->_str, 1024);
|
||||
}
|
||||
|
||||
const char* what() const noexcept {
|
||||
return this->_str;
|
||||
}
|
||||
|
||||
private:
|
||||
char _str[1024];
|
||||
};
|
||||
|
||||
std::vector<char> secure_random_bytes(int num);
|
||||
std::array<char, 32> hmac_sha3_256(const std::vector<char>& key, const std::vector<char>& data);
|
|
@ -1,23 +1,58 @@
|
|||
#include "routes.h"
|
||||
#include "../hex.h"
|
||||
#include "../config.h"
|
||||
#include "../servehelper.h"
|
||||
#include "../settings.h"
|
||||
#include "../timeutils.h"
|
||||
#include "../curlu_wrapper.h"
|
||||
#include "../openssl_wrapper.h"
|
||||
|
||||
static void set_cookie(const httplib::Request& req, httplib::Response& res, const char* key, std::string_view value);
|
||||
static inline std::string generate_csrf_token(void);
|
||||
static inline bool validate_csrf_token(const httplib::Request& req, httplib::Response& res, std::string_view csrf_token, std::string_view query_csrf_token);
|
||||
static void set_cookie(const httplib::Request& req, httplib::Response& res, const char* key, std::string_view value, bool session = false);
|
||||
static bool safe_memcmp(const char* s1, const char* s2, size_t n);
|
||||
|
||||
|
||||
void user_settings_route(const httplib::Request& req, httplib::Response& res) {
|
||||
UserSettings settings;
|
||||
Cookies cookies = parse_cookies(req);
|
||||
std::string csrf_token;
|
||||
|
||||
if (req.method == "POST") {
|
||||
if (!cookies.contains("csrf-token")) {
|
||||
res.status = 400;
|
||||
serve_error(req, res, "400: Bad Request", "Missing CSRF token cookie, are cookies enabled?");
|
||||
return;
|
||||
}
|
||||
csrf_token = cookies["csrf-token"];
|
||||
|
||||
auto query_csrf_token = req.params.find("csrf-token");
|
||||
if (query_csrf_token == req.params.end()) {
|
||||
res.status = 400;
|
||||
serve_error(req, res, "400: Bad Request", "Missing CSRF token query parameter");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!validate_csrf_token(req, res, csrf_token, query_csrf_token->second)) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto& i : req.params) {
|
||||
settings.set(i.first, i.second);
|
||||
}
|
||||
|
||||
set_cookie(req, res, "auto-open-cw", settings.auto_open_cw ? "true" : "false");
|
||||
} else {
|
||||
settings.load_from_cookies(req);
|
||||
for (auto &[name, value] : cookies) {
|
||||
settings.set(name, value);
|
||||
}
|
||||
|
||||
if (cookies.contains("csrf-token")) {
|
||||
csrf_token = cookies["csrf-token"];
|
||||
} else {
|
||||
csrf_token = generate_csrf_token();
|
||||
set_cookie(req, res, "csrf-token", csrf_token, true);
|
||||
}
|
||||
}
|
||||
|
||||
Element auto_open_cw_checkbox("input", {{"type", "checkbox"}, {"name", "auto-open-cw"}, {"value", "true"}}, {});
|
||||
|
@ -33,6 +68,7 @@ void user_settings_route(const httplib::Request& req, httplib::Response& res) {
|
|||
}),
|
||||
|
||||
Element("br"),
|
||||
Element("input", {{"type", "hidden"}, {"name", "csrf-token"}, {"value", csrf_token}}, {}),
|
||||
Element("input", {{"type", "submit"}, {"value", "Save"}}, {}),
|
||||
}),
|
||||
Element("form", {{"class", "user_settings_page-form"}, {"method", "get"}, {"action", get_origin(req)}}, {
|
||||
|
@ -49,16 +85,69 @@ void user_settings_route(const httplib::Request& req, httplib::Response& res) {
|
|||
}
|
||||
|
||||
|
||||
static void set_cookie(const httplib::Request& req, httplib::Response& res, const char* key, std::string_view value) {
|
||||
static inline std::string generate_csrf_token(void) {
|
||||
std::vector<char> raw_token = secure_random_bytes(32);
|
||||
std::array<char, 32> raw_token_hmac = hmac_sha3_256(config.hmac_key, raw_token);
|
||||
|
||||
return hex_encode(raw_token) + '.' + hex_encode(raw_token_hmac.data(), raw_token_hmac.size());
|
||||
}
|
||||
|
||||
static inline bool validate_csrf_token(const httplib::Request& req, httplib::Response& res, std::string_view csrf_token, std::string_view query_csrf_token) {
|
||||
if (csrf_token.size() != query_csrf_token.size() || !safe_memcmp(csrf_token.data(), query_csrf_token.data(), csrf_token.size())) {
|
||||
res.status = 400;
|
||||
serve_error(req, res, "400: Bad Request", "CSRF token cookie and CSRF token query parameter do not match");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (csrf_token.size() != 64 + 1 + 64 || csrf_token[64] != '.') {
|
||||
res.status = 400;
|
||||
serve_error(req, res, "400: Bad Request", "CSRF token is in an unknown format");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<char> raw_token, raw_token_hmac;
|
||||
try {
|
||||
raw_token = hex_decode(csrf_token.substr(0, 64));
|
||||
raw_token_hmac = hex_decode(csrf_token.substr(64 + 1, 64));
|
||||
} catch (const std::exception& e) {
|
||||
res.status = 400;
|
||||
serve_error(req, res, "400: Bad Request", "Failed to parse CSRF token", e.what());
|
||||
return false;
|
||||
}
|
||||
|
||||
std::array<char, 32> our_raw_token_hmac = hmac_sha3_256(config.hmac_key, raw_token);
|
||||
if (!safe_memcmp(raw_token_hmac.data(), our_raw_token_hmac.data(), 32)) {
|
||||
res.status = 400;
|
||||
serve_error(req, res, "400: Bad Request", "CSRF token HMAC is not correct");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void set_cookie(const httplib::Request& req, httplib::Response& res, const char* key, std::string_view value, bool session) {
|
||||
CurlUrl origin;
|
||||
origin.set(CURLUPART_URL, get_origin(req));
|
||||
|
||||
std::string header = std::string(key) + '=' + std::string(value)
|
||||
+ "; HttpOnly; SameSite=Lax; Domain=" + origin.get(CURLUPART_HOST).get() + "; Path=" + origin.get(CURLUPART_PATH).get()
|
||||
+ "; Expires=" + to_web_date(current_time() + 365 * 24 * 60 * 60);
|
||||
+ "; HttpOnly; SameSite=Lax; Domain=" + origin.get(CURLUPART_HOST).get() + "; Path=" + origin.get(CURLUPART_PATH).get();
|
||||
if (!session) {
|
||||
header += "; Expires=";
|
||||
header += to_web_date(current_time() + 365 * 24 * 60 * 60);
|
||||
}
|
||||
if (strcmp(origin.get(CURLUPART_SCHEME).get(), "https") == 0) {
|
||||
header += "; Secure";
|
||||
}
|
||||
|
||||
res.set_header("Set-Cookie", header);
|
||||
}
|
||||
|
||||
static bool safe_memcmp(const char* s1, const char* s2, size_t n) {
|
||||
bool equal = true;
|
||||
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
equal &= s1[i] == s2[i];
|
||||
}
|
||||
|
||||
return equal;
|
||||
}
|
||||
|
|
|
@ -9,6 +9,11 @@
|
|||
#include "curlu_wrapper.h"
|
||||
#include "routes/routes.h"
|
||||
|
||||
static inline void parse_cookies(std::string_view str, Cookies& cookies);
|
||||
static inline bool lowercase_compare(std::string_view lhs, std::string_view rhs);
|
||||
|
||||
|
||||
|
||||
void serve(const httplib::Request& req, httplib::Response& res, std::string title, Element element, Nodes extra_head) {
|
||||
using namespace std::string_literals;
|
||||
|
||||
|
@ -146,3 +151,60 @@ bool should_send_304(const httplib::Request& req, uint64_t hash) {
|
|||
size_t pos = header.find(std::string(1, '"') + std::to_string(hash) + '"');
|
||||
return pos != std::string::npos && (pos == 0 || header[pos - 1] != '/');
|
||||
}
|
||||
|
||||
|
||||
Cookies parse_cookies(const httplib::Request& req) {
|
||||
Cookies cookies;
|
||||
|
||||
for (const auto& i : req.headers) {
|
||||
if (lowercase_compare(i.first, "cookie")) {
|
||||
parse_cookies(i.second, cookies);
|
||||
}
|
||||
}
|
||||
|
||||
return cookies;
|
||||
}
|
||||
|
||||
|
||||
|
||||
static inline void parse_cookies(std::string_view str, Cookies& cookies) {
|
||||
using namespace std::string_literals;
|
||||
|
||||
size_t offset = 0;
|
||||
size_t new_offset = 0;
|
||||
const char* delimiter = "; ";
|
||||
size_t delimiter_len = strlen(delimiter);
|
||||
|
||||
while (offset < str.size()) {
|
||||
new_offset = str.find(delimiter, offset);
|
||||
|
||||
std::string_view item = str.substr(offset, new_offset != std::string_view::npos ? new_offset - offset : std::string_view::npos);
|
||||
size_t equal_offset = item.find('=');
|
||||
if (equal_offset == std::string_view::npos) {
|
||||
throw std::invalid_argument("invalid user setting item: "s + std::string(item));
|
||||
}
|
||||
cookies.insert({std::string(item.substr(0, equal_offset)), std::string(item.substr(equal_offset + 1))});
|
||||
|
||||
if (new_offset == std::string_view::npos) {
|
||||
break;
|
||||
}
|
||||
offset = new_offset + delimiter_len;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool lowercase_compare(std::string_view lhs, std::string_view rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto lower = [](char c) {
|
||||
return c >= 'A' && c <= 'Z' ? c - 'A' + 'a' : c;
|
||||
};
|
||||
for (size_t i = 0; i < lhs.size(); i++) {
|
||||
if (lower(lhs[i]) != lower(rhs[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <httplib/httplib.h>
|
||||
#include "blankie/serializer.h"
|
||||
|
@ -9,6 +10,7 @@ class CurlUrl; // forward declaration from curlu_wrapper.h
|
|||
using Element = blankie::html::Element;
|
||||
using Node = blankie::html::Node;
|
||||
using Nodes = std::vector<Node>;
|
||||
using Cookies = std::unordered_map<std::string, std::string>;
|
||||
|
||||
void serve(const httplib::Request& req, httplib::Response& res, std::string title, Element element, Nodes extra_head = {});
|
||||
void serve_error(const httplib::Request& req, httplib::Response& res,
|
||||
|
@ -19,3 +21,5 @@ bool starts_with(const CurlUrl& url, const CurlUrl& base);
|
|||
std::string get_origin(const httplib::Request& req);
|
||||
std::string proxy_mastodon_url(const httplib::Request& req, const std::string& url_str);
|
||||
bool should_send_304(const httplib::Request& req, uint64_t hash);
|
||||
|
||||
Cookies parse_cookies(const httplib::Request& req);
|
||||
|
|
51
settings.cpp
51
settings.cpp
|
@ -1,9 +1,8 @@
|
|||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include "settings.h"
|
||||
#include "servehelper.h"
|
||||
|
||||
static void set_settings(std::string_view str, const char* delimiter, UserSettings& settings);
|
||||
static inline bool lowercase_compare(std::string_view lhs, std::string_view rhs);
|
||||
static bool parse_bool(std::string_view value);
|
||||
|
||||
|
||||
|
@ -14,54 +13,14 @@ void UserSettings::set(std::string_view key, std::string_view value) {
|
|||
}
|
||||
|
||||
void UserSettings::load_from_cookies(const httplib::Request& req) {
|
||||
for (const auto& i : req.headers) {
|
||||
if (lowercase_compare(i.first, "cookie")) {
|
||||
set_settings(i.second, "; ", *this);
|
||||
}
|
||||
Cookies cookies = parse_cookies(req);
|
||||
|
||||
for (auto &[name, value] : cookies) {
|
||||
this->set(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void set_settings(std::string_view str, const char* delimiter, UserSettings& settings) {
|
||||
using namespace std::string_literals;
|
||||
size_t offset = 0;
|
||||
size_t new_offset = 0;
|
||||
size_t delimiter_len = strlen(delimiter);
|
||||
|
||||
while (offset < str.size()) {
|
||||
new_offset = str.find(delimiter, offset);
|
||||
|
||||
std::string_view item = str.substr(offset, new_offset != std::string_view::npos ? new_offset - offset : std::string_view::npos);
|
||||
size_t equal_offset = item.find('=');
|
||||
if (equal_offset == std::string_view::npos) {
|
||||
throw std::invalid_argument("invalid user setting item: "s + std::string(item));
|
||||
}
|
||||
settings.set(item.substr(0, equal_offset), item.substr(equal_offset + 1));
|
||||
|
||||
if (new_offset == std::string_view::npos) {
|
||||
break;
|
||||
}
|
||||
offset = new_offset + delimiter_len;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool lowercase_compare(std::string_view lhs, std::string_view rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto lower = [](char c) {
|
||||
return c >= 'A' && c <= 'Z' ? c - 'A' + 'a' : c;
|
||||
};
|
||||
for (size_t i = 0; i < lhs.size(); i++) {
|
||||
if (lower(lhs[i]) != lower(rhs[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool parse_bool(std::string_view value) {
|
||||
using namespace std::string_literals;
|
||||
|
||||
|
|
Loading…
Reference in New Issue