package ipaccess

import (
	"fmt"
	"net"
	"sort"
)

type Policy struct {
	defaultAllow bool
	rules        []Rule
}

type Rule struct {
	ipNet *net.IPNet
	ports []int
	allow bool
}

func NewPolicy(defaultAllow bool, rules []Rule) (*Policy, error) {
	for _, rule := range rules {
		if err := rule.Validate(); err != nil {
			return nil, err
		}
	}

	policy := Policy{
		defaultAllow: defaultAllow,
		rules:        rules,
	}

	return &policy, nil
}

func NewRuleByCIDR(prefix *string, ports []int, allow bool) (Rule, error) {
	if prefix == nil || len(*prefix) == 0 {
		return Rule{}, fmt.Errorf("no prefix provided")
	}

	_, ipnet, err := net.ParseCIDR(*prefix)
	if err != nil {
		return Rule{}, fmt.Errorf("unable to parse cidr: %s", *prefix)
	}

	return NewRule(ipnet, ports, allow)
}

func NewRule(ipnet *net.IPNet, ports []int, allow bool) (Rule, error) {
	rule := Rule{
		ipNet: ipnet,
		ports: ports,
		allow: allow,
	}
	return rule, rule.Validate()
}

func (r *Rule) Validate() error {
	if r.ipNet == nil {
		return fmt.Errorf("no ipnet set on the rule")
	}

	if len(r.ports) > 0 {
		sort.Ints(r.ports)
		for _, port := range r.ports {
			if port < 1 || port > 65535 {
				return fmt.Errorf("invalid port %d, needs to be between 1 and 65535", port)
			}
		}
	}

	return nil
}

func (h *Policy) Allowed(ip net.IP, port int) (bool, *Rule) {
	if len(h.rules) == 0 {
		return h.defaultAllow, nil
	}

	for _, rule := range h.rules {
		if rule.ipNet.Contains(ip) {
			if len(rule.ports) == 0 {
				return rule.allow, &rule
			} else if pos := sort.SearchInts(rule.ports, port); pos < len(rule.ports) && rule.ports[pos] == port {
				return rule.allow, &rule
			}
		}
	}

	return h.defaultAllow, nil
}

func (ipr *Rule) String() string {
	return fmt.Sprintf("prefix:%s/port:%s/allow:%t", ipr.ipNet, ipr.PortsString(), ipr.allow)
}

func (ipr *Rule) PortsString() string {
	if len(ipr.ports) > 0 {
		return fmt.Sprint(ipr.ports)
	}
	return "all"
}

func (ipr *Rule) Ports() []int {
	return ipr.ports
}

func (ipr *Rule) RulePolicy() bool {
	return ipr.allow
}

func (ipr *Rule) StringCIDR() string {
	return ipr.ipNet.String()
}