cloudflared-mirror/vendor/github.com/sudarshan-reddy/go-sumtype/check.go

184 lines
5.0 KiB
Go

package main
import (
"fmt"
"go/ast"
"go/token"
"go/types"
"sort"
"strings"
"golang.org/x/tools/go/packages"
)
// inexhaustiveError is returned from check for each occurrence of inexhaustive
// case analysis in a Go type switch statement.
type inexhaustiveError struct {
Pos token.Position
Def sumTypeDef
Missing []types.Object
}
func (e inexhaustiveError) Error() string {
return fmt.Sprintf(
"%s: exhaustiveness check failed for sum type '%s': missing cases for %s",
e.Pos, e.Def.Decl.TypeName, strings.Join(e.Names(), ", "))
}
// Names returns a sorted list of names corresponding to the missing variant
// cases.
func (e inexhaustiveError) Names() []string {
var list []string
for _, o := range e.Missing {
list = append(list, o.Name())
}
sort.Sort(sort.StringSlice(list))
return list
}
// check does exhaustiveness checking for the given sum type definitions in the
// given package. Every instance of inexhaustive case analysis is returned.
func check(pkg *packages.Package, defs []sumTypeDef) []error {
var errs []error
for _, astfile := range pkg.Syntax {
ast.Inspect(astfile, func(n ast.Node) bool {
swtch, ok := n.(*ast.TypeSwitchStmt)
if !ok {
return true
}
if err := checkSwitch(pkg, defs, swtch); err != nil {
errs = append(errs, err)
}
return true
})
}
return errs
}
// checkSwitch performs an exhaustiveness check on the given type switch
// statement. If the type switch is used on a sum type and does not cover
// all variants of that sum type, then an error is returned indicating which
// variants were missed.
//
// Note that if the type switch contains a non-panicing default case, then
// exhaustiveness checks are disabled.
func checkSwitch(
pkg *packages.Package,
defs []sumTypeDef,
swtch *ast.TypeSwitchStmt,
) error {
def, missing := missingVariantsInSwitch(pkg, defs, swtch)
if len(missing) > 0 {
return inexhaustiveError{
Pos: pkg.Fset.Position(swtch.Pos()),
Def: *def,
Missing: missing,
}
}
return nil
}
// missingVariantsInSwitch returns a list of missing variants corresponding to
// the given switch statement. The corresponding sum type definition is also
// returned. (If no sum type definition could be found, then no exhaustiveness
// checks are performed, and therefore, no missing variants are returned.)
func missingVariantsInSwitch(
pkg *packages.Package,
defs []sumTypeDef,
swtch *ast.TypeSwitchStmt,
) (*sumTypeDef, []types.Object) {
asserted := findTypeAssertExpr(swtch)
ty := pkg.TypesInfo.TypeOf(asserted)
def := findDef(defs, ty)
if def == nil {
// We couldn't find a corresponding sum type, so there's
// nothing we can do to check it.
return nil, nil
}
variantExprs, hasDefault := switchVariants(swtch)
if hasDefault && !defaultClauseAlwaysPanics(swtch) {
// A catch-all case defeats all exhaustiveness checks.
return def, nil
}
var variantTypes []types.Type
for _, expr := range variantExprs {
variantTypes = append(variantTypes, pkg.TypesInfo.TypeOf(expr))
}
return def, def.missing(variantTypes)
}
// switchVariants returns all case expressions found in a type switch. This
// includes expressions from cases that have a list of expressions.
func switchVariants(swtch *ast.TypeSwitchStmt) (exprs []ast.Expr, hasDefault bool) {
for _, stmt := range swtch.Body.List {
clause := stmt.(*ast.CaseClause)
if clause.List == nil {
hasDefault = true
} else {
exprs = append(exprs, clause.List...)
}
}
return
}
// defaultClauseAlwaysPanics returns true if the given switch statement has a
// default clause that always panics. Note that this is done on a best-effort
// basis. While there will never be any false positives, there may be false
// negatives.
//
// If the given switch statement has no default clause, then this function
// panics.
func defaultClauseAlwaysPanics(swtch *ast.TypeSwitchStmt) bool {
var clause *ast.CaseClause
for _, stmt := range swtch.Body.List {
c := stmt.(*ast.CaseClause)
if c.List == nil {
clause = c
break
}
}
if clause == nil {
panic("switch statement has no default clause")
}
if len(clause.Body) != 1 {
return false
}
exprStmt, ok := clause.Body[0].(*ast.ExprStmt)
if !ok {
return false
}
callExpr, ok := exprStmt.X.(*ast.CallExpr)
if !ok {
return false
}
fun, ok := callExpr.Fun.(*ast.Ident)
if !ok {
return false
}
return fun.Name == "panic"
}
// findTypeAssertExpr extracts the expression that is being type asserted from a
// type swtich statement.
func findTypeAssertExpr(swtch *ast.TypeSwitchStmt) ast.Expr {
var expr ast.Expr
if assign, ok := swtch.Assign.(*ast.AssignStmt); ok {
expr = assign.Rhs[0]
} else {
expr = swtch.Assign.(*ast.ExprStmt).X
}
return expr.(*ast.TypeAssertExpr).X
}
// findDef returns the sum type definition corresponding to the given type. If
// no such sum type definition exists, then nil is returned.
func findDef(defs []sumTypeDef, needle types.Type) *sumTypeDef {
for i := range defs {
def := &defs[i]
if types.Identical(needle.Underlying(), def.Ty) {
return def
}
}
return nil
}