cloudflared-mirror/ipaccess/access.go

102 lines
1.9 KiB
Go

package ipaccess
import (
"fmt"
"net"
"sort"
)
type Policy struct {
defaultAllow bool
rules []Rule
}
type Rule struct {
ipNet *net.IPNet
ports []int
allow bool
}
func NewPolicy(defaultAllow bool, rules []Rule) (*Policy, error) {
for _, rule := range rules {
if err := rule.Validate(); err != nil {
return nil, err
}
}
policy := Policy{
defaultAllow: defaultAllow,
rules: rules,
}
return &policy, nil
}
func NewRuleByCIDR(prefix *string, ports []int, allow bool) (Rule, error) {
if prefix == nil || len(*prefix) == 0 {
return Rule{}, fmt.Errorf("no prefix provided")
}
_, ipnet, err := net.ParseCIDR(*prefix)
if err != nil {
return Rule{}, fmt.Errorf("unable to parse cidr: %s", *prefix)
}
return NewRule(ipnet, ports, allow)
}
func NewRule(ipnet *net.IPNet, ports []int, allow bool) (Rule, error) {
rule := Rule{
ipNet: ipnet,
ports: ports,
allow: allow,
}
return rule, rule.Validate()
}
func (r *Rule) Validate() error {
if r.ipNet == nil {
return fmt.Errorf("no ipnet set on the rule")
}
if len(r.ports) > 0 {
sort.Ints(r.ports)
for _, port := range r.ports {
if port < 1 || port > 65535 {
return fmt.Errorf("invalid port %d, needs to be between 1 and 65535", port)
}
}
}
return nil
}
func (h *Policy) Allowed(ip net.IP, port int) (bool, *Rule) {
if len(h.rules) == 0 {
return h.defaultAllow, nil
}
for _, rule := range h.rules {
if rule.ipNet.Contains(ip) {
if len(rule.ports) == 0 {
return rule.allow, &rule
} else if pos := sort.SearchInts(rule.ports, port); pos < len(rule.ports) && rule.ports[pos] == port {
return rule.allow, &rule
}
}
}
return h.defaultAllow, nil
}
func (ipr *Rule) String() string {
return fmt.Sprintf("prefix:%s/port:%s/allow:%t", ipr.ipNet, ipr.PortsString(), ipr.allow)
}
func (ipr *Rule) PortsString() string {
if len(ipr.ports) > 0 {
return fmt.Sprint(ipr.ports)
}
return "all"
}