359 lines
9.3 KiB
Go
359 lines
9.3 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"go/types"
|
|
"strings"
|
|
|
|
"go.uber.org/mock/mockgen/model"
|
|
"golang.org/x/tools/go/packages"
|
|
)
|
|
|
|
var (
|
|
buildFlags = flag.String("build_flags", "", "(package mode) Additional flags for go build.")
|
|
)
|
|
|
|
type packageModeParser struct {
|
|
pkgName string
|
|
}
|
|
|
|
func (p *packageModeParser) parsePackage(packageName string, ifaces []string) (*model.Package, error) {
|
|
p.pkgName = packageName
|
|
|
|
pkg, err := p.loadPackage(packageName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load package: %w", err)
|
|
}
|
|
|
|
interfaces, err := p.extractInterfacesFromPackage(pkg, ifaces)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("extract interfaces from package: %w", err)
|
|
}
|
|
|
|
return &model.Package{
|
|
Name: pkg.Types.Name(),
|
|
PkgPath: packageName,
|
|
Interfaces: interfaces,
|
|
}, nil
|
|
}
|
|
|
|
func (p *packageModeParser) loadPackage(packageName string) (*packages.Package, error) {
|
|
var buildFlagsSet []string
|
|
if *buildFlags != "" {
|
|
buildFlagsSet = strings.Split(*buildFlags, " ")
|
|
}
|
|
|
|
cfg := &packages.Config{
|
|
Mode: packages.NeedDeps | packages.NeedImports | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedEmbedFiles,
|
|
BuildFlags: buildFlagsSet,
|
|
}
|
|
pkgs, err := packages.Load(cfg, packageName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load packages: %w", err)
|
|
}
|
|
|
|
if len(pkgs) != 1 {
|
|
return nil, fmt.Errorf("packages length must be 1: %d", len(pkgs))
|
|
}
|
|
|
|
if len(pkgs[0].Errors) > 0 {
|
|
errs := make([]error, len(pkgs[0].Errors))
|
|
for i, err := range pkgs[0].Errors {
|
|
errs[i] = err
|
|
}
|
|
|
|
return nil, errors.Join(errs...)
|
|
}
|
|
|
|
return pkgs[0], nil
|
|
}
|
|
|
|
func (p *packageModeParser) extractInterfacesFromPackage(pkg *packages.Package, ifaces []string) ([]*model.Interface, error) {
|
|
interfaces := make([]*model.Interface, len(ifaces))
|
|
for i, iface := range ifaces {
|
|
obj := pkg.Types.Scope().Lookup(iface)
|
|
if obj == nil {
|
|
return nil, fmt.Errorf("interface %s does not exist", iface)
|
|
}
|
|
|
|
modelIface, err := p.parseInterface(obj)
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse interface", obj.Name(), err)
|
|
}
|
|
|
|
interfaces[i] = modelIface
|
|
}
|
|
|
|
return interfaces, nil
|
|
}
|
|
|
|
func (p *packageModeParser) parseInterface(obj types.Object) (*model.Interface, error) {
|
|
named, ok := types.Unalias(obj.Type()).(*types.Named)
|
|
if !ok {
|
|
return nil, fmt.Errorf("%s is not an interface. it is a %T", obj.Name(), obj.Type().Underlying())
|
|
}
|
|
|
|
iface, ok := named.Underlying().(*types.Interface)
|
|
if !ok {
|
|
return nil, fmt.Errorf("%s is not an interface. it is a %T", obj.Name(), obj.Type().Underlying())
|
|
}
|
|
|
|
if p.isConstraint(iface) {
|
|
return nil, fmt.Errorf("interface %s is a constraint", obj.Name())
|
|
}
|
|
|
|
methods := make([]*model.Method, iface.NumMethods())
|
|
for i := range iface.NumMethods() {
|
|
method := iface.Method(i)
|
|
typedMethod, ok := method.Type().(*types.Signature)
|
|
if !ok {
|
|
return nil, fmt.Errorf("method %s is not a signature", method.Name())
|
|
}
|
|
|
|
modelFunc, err := p.parseFunc(typedMethod)
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse method", typedMethod.String(), err)
|
|
}
|
|
|
|
methods[i] = &model.Method{
|
|
Name: method.Name(),
|
|
In: modelFunc.In,
|
|
Out: modelFunc.Out,
|
|
Variadic: modelFunc.Variadic,
|
|
}
|
|
}
|
|
|
|
if named.TypeParams() == nil {
|
|
return &model.Interface{Name: obj.Name(), Methods: methods}, nil
|
|
}
|
|
|
|
typeParams := make([]*model.Parameter, named.TypeParams().Len())
|
|
for i := range named.TypeParams().Len() {
|
|
param := named.TypeParams().At(i)
|
|
typeParam, err := p.parseConstraint(param)
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse type parameter", param.String(), err)
|
|
}
|
|
|
|
typeParams[i] = &model.Parameter{Name: param.Obj().Name(), Type: typeParam}
|
|
}
|
|
|
|
return &model.Interface{Name: obj.Name(), Methods: methods, TypeParams: typeParams}, nil
|
|
}
|
|
|
|
func (o *packageModeParser) isConstraint(t *types.Interface) bool {
|
|
for i := range t.NumEmbeddeds() {
|
|
embed := t.EmbeddedType(i)
|
|
if _, ok := embed.Underlying().(*types.Interface); !ok {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (p *packageModeParser) parseType(t types.Type) (model.Type, error) {
|
|
switch t := t.(type) {
|
|
case *types.Array:
|
|
elementType, err := p.parseType(t.Elem())
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse array type", t.Elem().String(), err)
|
|
}
|
|
return &model.ArrayType{Len: int(t.Len()), Type: elementType}, nil
|
|
case *types.Slice:
|
|
elementType, err := p.parseType(t.Elem())
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse slice type", t.Elem().String(), err)
|
|
}
|
|
|
|
return &model.ArrayType{Len: -1, Type: elementType}, nil
|
|
case *types.Chan:
|
|
var dir model.ChanDir
|
|
switch t.Dir() {
|
|
case types.RecvOnly:
|
|
dir = model.RecvDir
|
|
case types.SendOnly:
|
|
dir = model.SendDir
|
|
}
|
|
|
|
chanType, err := p.parseType(t.Elem())
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse chan type", t.Elem().String(), err)
|
|
}
|
|
|
|
return &model.ChanType{Dir: dir, Type: chanType}, nil
|
|
case *types.Signature:
|
|
sig, err := p.parseFunc(t)
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse signature", t.String(), err)
|
|
}
|
|
|
|
return sig, nil
|
|
case *types.Named, *types.Alias:
|
|
object := t.(interface{ Obj() *types.TypeName })
|
|
var pkg string
|
|
if object.Obj().Pkg() != nil {
|
|
pkg = object.Obj().Pkg().Path()
|
|
}
|
|
|
|
// TypeArgs method not available for aliases in go1.22
|
|
genericType, ok := t.(interface{ TypeArgs() *types.TypeList })
|
|
if !ok || genericType.TypeArgs() == nil {
|
|
return &model.NamedType{
|
|
Package: pkg,
|
|
Type: object.Obj().Name(),
|
|
}, nil
|
|
}
|
|
|
|
typeParams := &model.TypeParametersType{TypeParameters: make([]model.Type, genericType.TypeArgs().Len())}
|
|
for i := range genericType.TypeArgs().Len() {
|
|
typeParam := genericType.TypeArgs().At(i)
|
|
typedParam, err := p.parseType(typeParam)
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse type parameter", typeParam.String(), err)
|
|
}
|
|
|
|
typeParams.TypeParameters[i] = typedParam
|
|
}
|
|
|
|
return &model.NamedType{
|
|
Package: pkg,
|
|
Type: object.Obj().Name(),
|
|
TypeParams: typeParams,
|
|
}, nil
|
|
case *types.Interface:
|
|
if t.Empty() {
|
|
return model.PredeclaredType("any"), nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("cannot handle non-empty unnamed interfaces")
|
|
case *types.Map:
|
|
key, err := p.parseType(t.Key())
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse map key", t.Key().String(), err)
|
|
}
|
|
value, err := p.parseType(t.Elem())
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse map value", t.Elem().String(), err)
|
|
}
|
|
|
|
return &model.MapType{Key: key, Value: value}, nil
|
|
case *types.Pointer:
|
|
valueType, err := p.parseType(t.Elem())
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse pointer type", t.Elem().String(), err)
|
|
}
|
|
|
|
return &model.PointerType{Type: valueType}, nil
|
|
case *types.Struct:
|
|
if t.NumFields() > 0 {
|
|
return nil, fmt.Errorf("cannot handle non-empty unnamed structs")
|
|
}
|
|
|
|
return model.PredeclaredType("struct{}"), nil
|
|
case *types.Basic:
|
|
return model.PredeclaredType(t.Name()), nil
|
|
case *types.Tuple:
|
|
panic("tuple field") // TODO
|
|
case *types.TypeParam:
|
|
return &model.NamedType{Type: t.Obj().Name()}, nil
|
|
default:
|
|
panic("unknown type") // TODO
|
|
}
|
|
}
|
|
|
|
func (p *packageModeParser) parseFunc(sig *types.Signature) (*model.FuncType, error) {
|
|
var variadic *model.Parameter
|
|
params := make([]*model.Parameter, 0, sig.Params().Len())
|
|
for i := range sig.Params().Len() {
|
|
param := sig.Params().At(i)
|
|
|
|
isVariadicParam := i == sig.Params().Len()-1 && sig.Variadic()
|
|
parseType := param.Type()
|
|
if isVariadicParam {
|
|
sliceType, ok := param.Type().(*types.Slice)
|
|
if !ok {
|
|
return nil, newParseTypeError("variadic parameter is not a slice", param.String(), nil)
|
|
}
|
|
|
|
parseType = sliceType.Elem()
|
|
}
|
|
|
|
paramType, err := p.parseType(parseType)
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse parameter type", parseType.String(), err)
|
|
}
|
|
|
|
modelParameter := &model.Parameter{Type: paramType, Name: param.Name()}
|
|
|
|
if isVariadicParam {
|
|
variadic = modelParameter
|
|
} else {
|
|
params = append(params, modelParameter)
|
|
}
|
|
}
|
|
|
|
if len(params) == 0 {
|
|
params = nil
|
|
}
|
|
|
|
results := make([]*model.Parameter, sig.Results().Len())
|
|
for i := range sig.Results().Len() {
|
|
result := sig.Results().At(i)
|
|
|
|
resultType, err := p.parseType(result.Type())
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse result type", result.Type().String(), err)
|
|
}
|
|
|
|
results[i] = &model.Parameter{Type: resultType, Name: result.Name()}
|
|
}
|
|
|
|
if len(results) == 0 {
|
|
results = nil
|
|
}
|
|
|
|
return &model.FuncType{
|
|
In: params,
|
|
Out: results,
|
|
Variadic: variadic,
|
|
}, nil
|
|
}
|
|
|
|
func (p *packageModeParser) parseConstraint(t *types.TypeParam) (model.Type, error) {
|
|
if t == nil {
|
|
return nil, fmt.Errorf("nil type param")
|
|
}
|
|
|
|
typeParam, err := p.parseType(t.Constraint())
|
|
if err != nil {
|
|
return nil, newParseTypeError("parse constraint type", t.Constraint().String(), err)
|
|
}
|
|
|
|
return typeParam, nil
|
|
}
|
|
|
|
type parseTypeError struct {
|
|
message string
|
|
typeString string
|
|
error error
|
|
}
|
|
|
|
func newParseTypeError(message string, typeString string, error error) *parseTypeError {
|
|
return &parseTypeError{typeString: typeString, error: error, message: message}
|
|
}
|
|
|
|
func (p parseTypeError) Error() string {
|
|
if p.error != nil {
|
|
return fmt.Sprintf("%s: error parsing %s: %s", p.message, p.typeString, p.error)
|
|
}
|
|
|
|
return fmt.Sprintf("%s: error parsing type %s", p.message, p.typeString)
|
|
}
|
|
|
|
func (p parseTypeError) Unwrap() error {
|
|
return p.error
|
|
}
|