cloudflared-mirror/vendor/go.uber.org/mock/mockgen/package_mode.go

359 lines
9.3 KiB
Go

package main
import (
"errors"
"flag"
"fmt"
"go/types"
"strings"
"go.uber.org/mock/mockgen/model"
"golang.org/x/tools/go/packages"
)
var (
buildFlags = flag.String("build_flags", "", "(package mode) Additional flags for go build.")
)
type packageModeParser struct {
pkgName string
}
func (p *packageModeParser) parsePackage(packageName string, ifaces []string) (*model.Package, error) {
p.pkgName = packageName
pkg, err := p.loadPackage(packageName)
if err != nil {
return nil, fmt.Errorf("load package: %w", err)
}
interfaces, err := p.extractInterfacesFromPackage(pkg, ifaces)
if err != nil {
return nil, fmt.Errorf("extract interfaces from package: %w", err)
}
return &model.Package{
Name: pkg.Types.Name(),
PkgPath: packageName,
Interfaces: interfaces,
}, nil
}
func (p *packageModeParser) loadPackage(packageName string) (*packages.Package, error) {
var buildFlagsSet []string
if *buildFlags != "" {
buildFlagsSet = strings.Split(*buildFlags, " ")
}
cfg := &packages.Config{
Mode: packages.NeedDeps | packages.NeedImports | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedEmbedFiles,
BuildFlags: buildFlagsSet,
}
pkgs, err := packages.Load(cfg, packageName)
if err != nil {
return nil, fmt.Errorf("load packages: %w", err)
}
if len(pkgs) != 1 {
return nil, fmt.Errorf("packages length must be 1: %d", len(pkgs))
}
if len(pkgs[0].Errors) > 0 {
errs := make([]error, len(pkgs[0].Errors))
for i, err := range pkgs[0].Errors {
errs[i] = err
}
return nil, errors.Join(errs...)
}
return pkgs[0], nil
}
func (p *packageModeParser) extractInterfacesFromPackage(pkg *packages.Package, ifaces []string) ([]*model.Interface, error) {
interfaces := make([]*model.Interface, len(ifaces))
for i, iface := range ifaces {
obj := pkg.Types.Scope().Lookup(iface)
if obj == nil {
return nil, fmt.Errorf("interface %s does not exist", iface)
}
modelIface, err := p.parseInterface(obj)
if err != nil {
return nil, newParseTypeError("parse interface", obj.Name(), err)
}
interfaces[i] = modelIface
}
return interfaces, nil
}
func (p *packageModeParser) parseInterface(obj types.Object) (*model.Interface, error) {
named, ok := types.Unalias(obj.Type()).(*types.Named)
if !ok {
return nil, fmt.Errorf("%s is not an interface. it is a %T", obj.Name(), obj.Type().Underlying())
}
iface, ok := named.Underlying().(*types.Interface)
if !ok {
return nil, fmt.Errorf("%s is not an interface. it is a %T", obj.Name(), obj.Type().Underlying())
}
if p.isConstraint(iface) {
return nil, fmt.Errorf("interface %s is a constraint", obj.Name())
}
methods := make([]*model.Method, iface.NumMethods())
for i := range iface.NumMethods() {
method := iface.Method(i)
typedMethod, ok := method.Type().(*types.Signature)
if !ok {
return nil, fmt.Errorf("method %s is not a signature", method.Name())
}
modelFunc, err := p.parseFunc(typedMethod)
if err != nil {
return nil, newParseTypeError("parse method", typedMethod.String(), err)
}
methods[i] = &model.Method{
Name: method.Name(),
In: modelFunc.In,
Out: modelFunc.Out,
Variadic: modelFunc.Variadic,
}
}
if named.TypeParams() == nil {
return &model.Interface{Name: obj.Name(), Methods: methods}, nil
}
typeParams := make([]*model.Parameter, named.TypeParams().Len())
for i := range named.TypeParams().Len() {
param := named.TypeParams().At(i)
typeParam, err := p.parseConstraint(param)
if err != nil {
return nil, newParseTypeError("parse type parameter", param.String(), err)
}
typeParams[i] = &model.Parameter{Name: param.Obj().Name(), Type: typeParam}
}
return &model.Interface{Name: obj.Name(), Methods: methods, TypeParams: typeParams}, nil
}
func (o *packageModeParser) isConstraint(t *types.Interface) bool {
for i := range t.NumEmbeddeds() {
embed := t.EmbeddedType(i)
if _, ok := embed.Underlying().(*types.Interface); !ok {
return true
}
}
return false
}
func (p *packageModeParser) parseType(t types.Type) (model.Type, error) {
switch t := t.(type) {
case *types.Array:
elementType, err := p.parseType(t.Elem())
if err != nil {
return nil, newParseTypeError("parse array type", t.Elem().String(), err)
}
return &model.ArrayType{Len: int(t.Len()), Type: elementType}, nil
case *types.Slice:
elementType, err := p.parseType(t.Elem())
if err != nil {
return nil, newParseTypeError("parse slice type", t.Elem().String(), err)
}
return &model.ArrayType{Len: -1, Type: elementType}, nil
case *types.Chan:
var dir model.ChanDir
switch t.Dir() {
case types.RecvOnly:
dir = model.RecvDir
case types.SendOnly:
dir = model.SendDir
}
chanType, err := p.parseType(t.Elem())
if err != nil {
return nil, newParseTypeError("parse chan type", t.Elem().String(), err)
}
return &model.ChanType{Dir: dir, Type: chanType}, nil
case *types.Signature:
sig, err := p.parseFunc(t)
if err != nil {
return nil, newParseTypeError("parse signature", t.String(), err)
}
return sig, nil
case *types.Named, *types.Alias:
object := t.(interface{ Obj() *types.TypeName })
var pkg string
if object.Obj().Pkg() != nil {
pkg = object.Obj().Pkg().Path()
}
// TypeArgs method not available for aliases in go1.22
genericType, ok := t.(interface{ TypeArgs() *types.TypeList })
if !ok || genericType.TypeArgs() == nil {
return &model.NamedType{
Package: pkg,
Type: object.Obj().Name(),
}, nil
}
typeParams := &model.TypeParametersType{TypeParameters: make([]model.Type, genericType.TypeArgs().Len())}
for i := range genericType.TypeArgs().Len() {
typeParam := genericType.TypeArgs().At(i)
typedParam, err := p.parseType(typeParam)
if err != nil {
return nil, newParseTypeError("parse type parameter", typeParam.String(), err)
}
typeParams.TypeParameters[i] = typedParam
}
return &model.NamedType{
Package: pkg,
Type: object.Obj().Name(),
TypeParams: typeParams,
}, nil
case *types.Interface:
if t.Empty() {
return model.PredeclaredType("any"), nil
}
return nil, fmt.Errorf("cannot handle non-empty unnamed interfaces")
case *types.Map:
key, err := p.parseType(t.Key())
if err != nil {
return nil, newParseTypeError("parse map key", t.Key().String(), err)
}
value, err := p.parseType(t.Elem())
if err != nil {
return nil, newParseTypeError("parse map value", t.Elem().String(), err)
}
return &model.MapType{Key: key, Value: value}, nil
case *types.Pointer:
valueType, err := p.parseType(t.Elem())
if err != nil {
return nil, newParseTypeError("parse pointer type", t.Elem().String(), err)
}
return &model.PointerType{Type: valueType}, nil
case *types.Struct:
if t.NumFields() > 0 {
return nil, fmt.Errorf("cannot handle non-empty unnamed structs")
}
return model.PredeclaredType("struct{}"), nil
case *types.Basic:
return model.PredeclaredType(t.Name()), nil
case *types.Tuple:
panic("tuple field") // TODO
case *types.TypeParam:
return &model.NamedType{Type: t.Obj().Name()}, nil
default:
panic("unknown type") // TODO
}
}
func (p *packageModeParser) parseFunc(sig *types.Signature) (*model.FuncType, error) {
var variadic *model.Parameter
params := make([]*model.Parameter, 0, sig.Params().Len())
for i := range sig.Params().Len() {
param := sig.Params().At(i)
isVariadicParam := i == sig.Params().Len()-1 && sig.Variadic()
parseType := param.Type()
if isVariadicParam {
sliceType, ok := param.Type().(*types.Slice)
if !ok {
return nil, newParseTypeError("variadic parameter is not a slice", param.String(), nil)
}
parseType = sliceType.Elem()
}
paramType, err := p.parseType(parseType)
if err != nil {
return nil, newParseTypeError("parse parameter type", parseType.String(), err)
}
modelParameter := &model.Parameter{Type: paramType, Name: param.Name()}
if isVariadicParam {
variadic = modelParameter
} else {
params = append(params, modelParameter)
}
}
if len(params) == 0 {
params = nil
}
results := make([]*model.Parameter, sig.Results().Len())
for i := range sig.Results().Len() {
result := sig.Results().At(i)
resultType, err := p.parseType(result.Type())
if err != nil {
return nil, newParseTypeError("parse result type", result.Type().String(), err)
}
results[i] = &model.Parameter{Type: resultType, Name: result.Name()}
}
if len(results) == 0 {
results = nil
}
return &model.FuncType{
In: params,
Out: results,
Variadic: variadic,
}, nil
}
func (p *packageModeParser) parseConstraint(t *types.TypeParam) (model.Type, error) {
if t == nil {
return nil, fmt.Errorf("nil type param")
}
typeParam, err := p.parseType(t.Constraint())
if err != nil {
return nil, newParseTypeError("parse constraint type", t.Constraint().String(), err)
}
return typeParam, nil
}
type parseTypeError struct {
message string
typeString string
error error
}
func newParseTypeError(message string, typeString string, error error) *parseTypeError {
return &parseTypeError{typeString: typeString, error: error, message: message}
}
func (p parseTypeError) Error() string {
if p.error != nil {
return fmt.Sprintf("%s: error parsing %s: %s", p.message, p.typeString, p.error)
}
return fmt.Sprintf("%s: error parsing type %s", p.message, p.typeString)
}
func (p parseTypeError) Unwrap() error {
return p.error
}