102 lines
1.9 KiB
Go
102 lines
1.9 KiB
Go
package ipaccess
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"sort"
|
|
)
|
|
|
|
type Policy struct {
|
|
defaultAllow bool
|
|
rules []Rule
|
|
}
|
|
|
|
type Rule struct {
|
|
ipNet *net.IPNet
|
|
ports []int
|
|
allow bool
|
|
}
|
|
|
|
func NewPolicy(defaultAllow bool, rules []Rule) (*Policy, error) {
|
|
for _, rule := range rules {
|
|
if err := rule.Validate(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
policy := Policy{
|
|
defaultAllow: defaultAllow,
|
|
rules: rules,
|
|
}
|
|
|
|
return &policy, nil
|
|
}
|
|
|
|
func NewRuleByCIDR(prefix *string, ports []int, allow bool) (Rule, error) {
|
|
if prefix == nil || len(*prefix) == 0 {
|
|
return Rule{}, fmt.Errorf("no prefix provided")
|
|
}
|
|
|
|
_, ipnet, err := net.ParseCIDR(*prefix)
|
|
if err != nil {
|
|
return Rule{}, fmt.Errorf("unable to parse cidr: %s", *prefix)
|
|
}
|
|
|
|
return NewRule(ipnet, ports, allow)
|
|
}
|
|
|
|
func NewRule(ipnet *net.IPNet, ports []int, allow bool) (Rule, error) {
|
|
rule := Rule{
|
|
ipNet: ipnet,
|
|
ports: ports,
|
|
allow: allow,
|
|
}
|
|
return rule, rule.Validate()
|
|
}
|
|
|
|
func (r *Rule) Validate() error {
|
|
if r.ipNet == nil {
|
|
return fmt.Errorf("no ipnet set on the rule")
|
|
}
|
|
|
|
if len(r.ports) > 0 {
|
|
sort.Ints(r.ports)
|
|
for _, port := range r.ports {
|
|
if port < 1 || port > 65535 {
|
|
return fmt.Errorf("invalid port %d, needs to be between 1 and 65535", port)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *Policy) Allowed(ip net.IP, port int) (bool, *Rule) {
|
|
if len(h.rules) == 0 {
|
|
return h.defaultAllow, nil
|
|
}
|
|
|
|
for _, rule := range h.rules {
|
|
if rule.ipNet.Contains(ip) {
|
|
if len(rule.ports) == 0 {
|
|
return rule.allow, &rule
|
|
} else if pos := sort.SearchInts(rule.ports, port); pos < len(rule.ports) && rule.ports[pos] == port {
|
|
return rule.allow, &rule
|
|
}
|
|
}
|
|
}
|
|
|
|
return h.defaultAllow, nil
|
|
}
|
|
|
|
func (ipr *Rule) String() string {
|
|
return fmt.Sprintf("prefix:%s/port:%s/allow:%t", ipr.ipNet, ipr.PortsString(), ipr.allow)
|
|
}
|
|
|
|
func (ipr *Rule) PortsString() string {
|
|
if len(ipr.ports) > 0 {
|
|
return fmt.Sprint(ipr.ports)
|
|
}
|
|
return "all"
|
|
}
|