78 lines
2.0 KiB
Go
78 lines
2.0 KiB
Go
package socks
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
)
|
|
|
|
const (
|
|
// NoAuth means no authentication is used when connecting
|
|
NoAuth = uint8(0)
|
|
|
|
// UserPassAuth means a user/password is used when connecting
|
|
UserPassAuth = uint8(2)
|
|
|
|
noAcceptable = uint8(255)
|
|
userAuthVersion = uint8(1)
|
|
authSuccess = uint8(0)
|
|
authFailure = uint8(1)
|
|
)
|
|
|
|
// AuthHandler handles socks authentication requests
|
|
type AuthHandler interface {
|
|
Handle(io.Reader, io.Writer) error
|
|
Register(uint8, Authenticator)
|
|
}
|
|
|
|
// StandardAuthHandler loads the default authenticators
|
|
type StandardAuthHandler struct {
|
|
authenticators map[uint8]Authenticator
|
|
}
|
|
|
|
// NewAuthHandler creates a default auth handler
|
|
func NewAuthHandler() AuthHandler {
|
|
defaults := make(map[uint8]Authenticator)
|
|
defaults[NoAuth] = NewNoAuthAuthenticator()
|
|
return &StandardAuthHandler{
|
|
authenticators: defaults,
|
|
}
|
|
}
|
|
|
|
// Register adds/replaces an Authenticator to use when handling Authentication requests
|
|
func (h *StandardAuthHandler) Register(method uint8, a Authenticator) {
|
|
h.authenticators[method] = a
|
|
}
|
|
|
|
// Handle gets the methods from the SOCKS5 client and authenticates with the first supported method
|
|
func (h *StandardAuthHandler) Handle(bufConn io.Reader, conn io.Writer) error {
|
|
methods, err := readMethods(bufConn)
|
|
if err != nil {
|
|
return fmt.Errorf("Failed to read auth methods: %v", err)
|
|
}
|
|
|
|
// first supported method is used
|
|
for _, method := range methods {
|
|
authenticator := h.authenticators[method]
|
|
if authenticator != nil {
|
|
return authenticator.Handle(bufConn, conn)
|
|
}
|
|
}
|
|
|
|
// failed to authenticate. No supported authentication type found
|
|
conn.Write([]byte{socks5Version, noAcceptable})
|
|
return fmt.Errorf("unknown authentication type")
|
|
}
|
|
|
|
// readMethods is used to read the number and type of methods
|
|
func readMethods(r io.Reader) ([]byte, error) {
|
|
header := []byte{0}
|
|
if _, err := r.Read(header); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
numMethods := int(header[0])
|
|
methods := make([]byte, numMethods)
|
|
_, err := io.ReadAtLeast(r, methods, numMethods)
|
|
return methods, err
|
|
}
|