490 lines
12 KiB
Go
490 lines
12 KiB
Go
|
package types
|
||
|
|
||
|
import (
|
||
|
"flag"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"github.com/onsi/ginkgo/v2/formatter"
|
||
|
)
|
||
|
|
||
|
type GinkgoFlag struct {
|
||
|
Name string
|
||
|
KeyPath string
|
||
|
SectionKey string
|
||
|
|
||
|
Usage string
|
||
|
UsageArgument string
|
||
|
UsageDefaultValue string
|
||
|
|
||
|
DeprecatedName string
|
||
|
DeprecatedDocLink string
|
||
|
DeprecatedVersion string
|
||
|
|
||
|
ExportAs string
|
||
|
}
|
||
|
|
||
|
type GinkgoFlags []GinkgoFlag
|
||
|
|
||
|
func (f GinkgoFlags) CopyAppend(flags ...GinkgoFlag) GinkgoFlags {
|
||
|
out := GinkgoFlags{}
|
||
|
out = append(out, f...)
|
||
|
out = append(out, flags...)
|
||
|
return out
|
||
|
}
|
||
|
|
||
|
func (f GinkgoFlags) WithPrefix(prefix string) GinkgoFlags {
|
||
|
if prefix == "" {
|
||
|
return f
|
||
|
}
|
||
|
out := GinkgoFlags{}
|
||
|
for _, flag := range f {
|
||
|
if flag.Name != "" {
|
||
|
flag.Name = prefix + "." + flag.Name
|
||
|
}
|
||
|
if flag.DeprecatedName != "" {
|
||
|
flag.DeprecatedName = prefix + "." + flag.DeprecatedName
|
||
|
}
|
||
|
if flag.ExportAs != "" {
|
||
|
flag.ExportAs = prefix + "." + flag.ExportAs
|
||
|
}
|
||
|
out = append(out, flag)
|
||
|
}
|
||
|
return out
|
||
|
}
|
||
|
|
||
|
func (f GinkgoFlags) SubsetWithNames(names ...string) GinkgoFlags {
|
||
|
out := GinkgoFlags{}
|
||
|
for _, flag := range f {
|
||
|
for _, name := range names {
|
||
|
if flag.Name == name {
|
||
|
out = append(out, flag)
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return out
|
||
|
}
|
||
|
|
||
|
type GinkgoFlagSection struct {
|
||
|
Key string
|
||
|
Style string
|
||
|
Succinct bool
|
||
|
Heading string
|
||
|
Description string
|
||
|
}
|
||
|
|
||
|
type GinkgoFlagSections []GinkgoFlagSection
|
||
|
|
||
|
func (gfs GinkgoFlagSections) Lookup(key string) (GinkgoFlagSection, bool) {
|
||
|
for _, section := range gfs {
|
||
|
if section.Key == key {
|
||
|
return section, true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return GinkgoFlagSection{}, false
|
||
|
}
|
||
|
|
||
|
type GinkgoFlagSet struct {
|
||
|
flags GinkgoFlags
|
||
|
bindings interface{}
|
||
|
|
||
|
sections GinkgoFlagSections
|
||
|
extraGoFlagsSection GinkgoFlagSection
|
||
|
|
||
|
flagSet *flag.FlagSet
|
||
|
}
|
||
|
|
||
|
// Call NewGinkgoFlagSet to create GinkgoFlagSet that creates and binds to it's own *flag.FlagSet
|
||
|
func NewGinkgoFlagSet(flags GinkgoFlags, bindings interface{}, sections GinkgoFlagSections) (GinkgoFlagSet, error) {
|
||
|
return bindFlagSet(GinkgoFlagSet{
|
||
|
flags: flags,
|
||
|
bindings: bindings,
|
||
|
sections: sections,
|
||
|
}, nil)
|
||
|
}
|
||
|
|
||
|
// Call NewGinkgoFlagSet to create GinkgoFlagSet that extends an existing *flag.FlagSet
|
||
|
func NewAttachedGinkgoFlagSet(flagSet *flag.FlagSet, flags GinkgoFlags, bindings interface{}, sections GinkgoFlagSections, extraGoFlagsSection GinkgoFlagSection) (GinkgoFlagSet, error) {
|
||
|
return bindFlagSet(GinkgoFlagSet{
|
||
|
flags: flags,
|
||
|
bindings: bindings,
|
||
|
sections: sections,
|
||
|
extraGoFlagsSection: extraGoFlagsSection,
|
||
|
}, flagSet)
|
||
|
}
|
||
|
|
||
|
func bindFlagSet(f GinkgoFlagSet, flagSet *flag.FlagSet) (GinkgoFlagSet, error) {
|
||
|
if flagSet == nil {
|
||
|
f.flagSet = flag.NewFlagSet("", flag.ContinueOnError)
|
||
|
//suppress all output as Ginkgo is responsible for formatting usage
|
||
|
f.flagSet.SetOutput(io.Discard)
|
||
|
} else {
|
||
|
f.flagSet = flagSet
|
||
|
//we're piggybacking on an existing flagset (typically go test) so we have limited control
|
||
|
//on user feedback
|
||
|
f.flagSet.Usage = f.substituteUsage
|
||
|
}
|
||
|
|
||
|
for _, flag := range f.flags {
|
||
|
name := flag.Name
|
||
|
|
||
|
deprecatedUsage := "[DEPRECATED]"
|
||
|
deprecatedName := flag.DeprecatedName
|
||
|
if name != "" {
|
||
|
deprecatedUsage = fmt.Sprintf("[DEPRECATED] use --%s instead", name)
|
||
|
} else if flag.Usage != "" {
|
||
|
deprecatedUsage += " " + flag.Usage
|
||
|
}
|
||
|
|
||
|
value, ok := valueAtKeyPath(f.bindings, flag.KeyPath)
|
||
|
if !ok {
|
||
|
return GinkgoFlagSet{}, fmt.Errorf("could not load KeyPath: %s", flag.KeyPath)
|
||
|
}
|
||
|
|
||
|
iface, addr := value.Interface(), value.Addr().Interface()
|
||
|
|
||
|
switch value.Type() {
|
||
|
case reflect.TypeOf(string("")):
|
||
|
if name != "" {
|
||
|
f.flagSet.StringVar(addr.(*string), name, iface.(string), flag.Usage)
|
||
|
}
|
||
|
if deprecatedName != "" {
|
||
|
f.flagSet.StringVar(addr.(*string), deprecatedName, iface.(string), deprecatedUsage)
|
||
|
}
|
||
|
case reflect.TypeOf(int64(0)):
|
||
|
if name != "" {
|
||
|
f.flagSet.Int64Var(addr.(*int64), name, iface.(int64), flag.Usage)
|
||
|
}
|
||
|
if deprecatedName != "" {
|
||
|
f.flagSet.Int64Var(addr.(*int64), deprecatedName, iface.(int64), deprecatedUsage)
|
||
|
}
|
||
|
case reflect.TypeOf(float64(0)):
|
||
|
if name != "" {
|
||
|
f.flagSet.Float64Var(addr.(*float64), name, iface.(float64), flag.Usage)
|
||
|
}
|
||
|
if deprecatedName != "" {
|
||
|
f.flagSet.Float64Var(addr.(*float64), deprecatedName, iface.(float64), deprecatedUsage)
|
||
|
}
|
||
|
case reflect.TypeOf(int(0)):
|
||
|
if name != "" {
|
||
|
f.flagSet.IntVar(addr.(*int), name, iface.(int), flag.Usage)
|
||
|
}
|
||
|
if deprecatedName != "" {
|
||
|
f.flagSet.IntVar(addr.(*int), deprecatedName, iface.(int), deprecatedUsage)
|
||
|
}
|
||
|
case reflect.TypeOf(bool(true)):
|
||
|
if name != "" {
|
||
|
f.flagSet.BoolVar(addr.(*bool), name, iface.(bool), flag.Usage)
|
||
|
}
|
||
|
if deprecatedName != "" {
|
||
|
f.flagSet.BoolVar(addr.(*bool), deprecatedName, iface.(bool), deprecatedUsage)
|
||
|
}
|
||
|
case reflect.TypeOf(time.Duration(0)):
|
||
|
if name != "" {
|
||
|
f.flagSet.DurationVar(addr.(*time.Duration), name, iface.(time.Duration), flag.Usage)
|
||
|
}
|
||
|
if deprecatedName != "" {
|
||
|
f.flagSet.DurationVar(addr.(*time.Duration), deprecatedName, iface.(time.Duration), deprecatedUsage)
|
||
|
}
|
||
|
|
||
|
case reflect.TypeOf([]string{}):
|
||
|
if name != "" {
|
||
|
f.flagSet.Var(stringSliceVar{value}, name, flag.Usage)
|
||
|
}
|
||
|
if deprecatedName != "" {
|
||
|
f.flagSet.Var(stringSliceVar{value}, deprecatedName, deprecatedUsage)
|
||
|
}
|
||
|
default:
|
||
|
return GinkgoFlagSet{}, fmt.Errorf("unsupported type %T", iface)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return f, nil
|
||
|
}
|
||
|
|
||
|
func (f GinkgoFlagSet) IsZero() bool {
|
||
|
return f.flagSet == nil
|
||
|
}
|
||
|
|
||
|
func (f GinkgoFlagSet) WasSet(name string) bool {
|
||
|
found := false
|
||
|
f.flagSet.Visit(func(f *flag.Flag) {
|
||
|
if f.Name == name {
|
||
|
found = true
|
||
|
}
|
||
|
})
|
||
|
|
||
|
return found
|
||
|
}
|
||
|
|
||
|
func (f GinkgoFlagSet) Lookup(name string) *flag.Flag {
|
||
|
return f.flagSet.Lookup(name)
|
||
|
}
|
||
|
|
||
|
func (f GinkgoFlagSet) Parse(args []string) ([]string, error) {
|
||
|
if f.IsZero() {
|
||
|
return args, nil
|
||
|
}
|
||
|
err := f.flagSet.Parse(args)
|
||
|
if err != nil {
|
||
|
return []string{}, err
|
||
|
}
|
||
|
return f.flagSet.Args(), nil
|
||
|
}
|
||
|
|
||
|
func (f GinkgoFlagSet) ValidateDeprecations(deprecationTracker *DeprecationTracker) {
|
||
|
if f.IsZero() {
|
||
|
return
|
||
|
}
|
||
|
f.flagSet.Visit(func(flag *flag.Flag) {
|
||
|
for _, ginkgoFlag := range f.flags {
|
||
|
if ginkgoFlag.DeprecatedName != "" && strings.HasSuffix(flag.Name, ginkgoFlag.DeprecatedName) {
|
||
|
message := fmt.Sprintf("--%s is deprecated", ginkgoFlag.DeprecatedName)
|
||
|
if ginkgoFlag.Name != "" {
|
||
|
message = fmt.Sprintf("--%s is deprecated, use --%s instead", ginkgoFlag.DeprecatedName, ginkgoFlag.Name)
|
||
|
} else if ginkgoFlag.Usage != "" {
|
||
|
message += " " + ginkgoFlag.Usage
|
||
|
}
|
||
|
|
||
|
deprecationTracker.TrackDeprecation(Deprecation{
|
||
|
Message: message,
|
||
|
DocLink: ginkgoFlag.DeprecatedDocLink,
|
||
|
Version: ginkgoFlag.DeprecatedVersion,
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (f GinkgoFlagSet) Usage() string {
|
||
|
if f.IsZero() {
|
||
|
return ""
|
||
|
}
|
||
|
groupedFlags := map[GinkgoFlagSection]GinkgoFlags{}
|
||
|
ungroupedFlags := GinkgoFlags{}
|
||
|
managedFlags := map[string]bool{}
|
||
|
extraGoFlags := []*flag.Flag{}
|
||
|
|
||
|
for _, flag := range f.flags {
|
||
|
managedFlags[flag.Name] = true
|
||
|
managedFlags[flag.DeprecatedName] = true
|
||
|
|
||
|
if flag.Name == "" {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
section, ok := f.sections.Lookup(flag.SectionKey)
|
||
|
if ok {
|
||
|
groupedFlags[section] = append(groupedFlags[section], flag)
|
||
|
} else {
|
||
|
ungroupedFlags = append(ungroupedFlags, flag)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
f.flagSet.VisitAll(func(flag *flag.Flag) {
|
||
|
if !managedFlags[flag.Name] {
|
||
|
extraGoFlags = append(extraGoFlags, flag)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
out := ""
|
||
|
for _, section := range f.sections {
|
||
|
flags := groupedFlags[section]
|
||
|
if len(flags) == 0 {
|
||
|
continue
|
||
|
}
|
||
|
out += f.usageForSection(section)
|
||
|
if section.Succinct {
|
||
|
succinctFlags := []string{}
|
||
|
for _, flag := range flags {
|
||
|
if flag.Name != "" {
|
||
|
succinctFlags = append(succinctFlags, fmt.Sprintf("--%s", flag.Name))
|
||
|
}
|
||
|
}
|
||
|
out += formatter.Fiw(1, formatter.COLS, section.Style+strings.Join(succinctFlags, ", ")+"{{/}}\n")
|
||
|
} else {
|
||
|
for _, flag := range flags {
|
||
|
out += f.usageForFlag(flag, section.Style)
|
||
|
}
|
||
|
}
|
||
|
out += "\n"
|
||
|
}
|
||
|
if len(ungroupedFlags) > 0 {
|
||
|
for _, flag := range ungroupedFlags {
|
||
|
out += f.usageForFlag(flag, "")
|
||
|
}
|
||
|
out += "\n"
|
||
|
}
|
||
|
if len(extraGoFlags) > 0 {
|
||
|
out += f.usageForSection(f.extraGoFlagsSection)
|
||
|
for _, goFlag := range extraGoFlags {
|
||
|
out += f.usageForGoFlag(goFlag)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return out
|
||
|
}
|
||
|
|
||
|
func (f GinkgoFlagSet) substituteUsage() {
|
||
|
fmt.Fprintln(f.flagSet.Output(), f.Usage())
|
||
|
}
|
||
|
|
||
|
func valueAtKeyPath(root interface{}, keyPath string) (reflect.Value, bool) {
|
||
|
if len(keyPath) == 0 {
|
||
|
return reflect.Value{}, false
|
||
|
}
|
||
|
|
||
|
val := reflect.ValueOf(root)
|
||
|
components := strings.Split(keyPath, ".")
|
||
|
for _, component := range components {
|
||
|
val = reflect.Indirect(val)
|
||
|
switch val.Kind() {
|
||
|
case reflect.Map:
|
||
|
val = val.MapIndex(reflect.ValueOf(component))
|
||
|
if val.Kind() == reflect.Interface {
|
||
|
val = reflect.ValueOf(val.Interface())
|
||
|
}
|
||
|
case reflect.Struct:
|
||
|
val = val.FieldByName(component)
|
||
|
default:
|
||
|
return reflect.Value{}, false
|
||
|
}
|
||
|
if (val == reflect.Value{}) {
|
||
|
return reflect.Value{}, false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return val, true
|
||
|
}
|
||
|
|
||
|
func (f GinkgoFlagSet) usageForSection(section GinkgoFlagSection) string {
|
||
|
out := formatter.F(section.Style + "{{bold}}{{underline}}" + section.Heading + "{{/}}\n")
|
||
|
if section.Description != "" {
|
||
|
out += formatter.Fiw(0, formatter.COLS, section.Description+"\n")
|
||
|
}
|
||
|
return out
|
||
|
}
|
||
|
|
||
|
func (f GinkgoFlagSet) usageForFlag(flag GinkgoFlag, style string) string {
|
||
|
argument := flag.UsageArgument
|
||
|
defValue := flag.UsageDefaultValue
|
||
|
if argument == "" {
|
||
|
value, _ := valueAtKeyPath(f.bindings, flag.KeyPath)
|
||
|
switch value.Type() {
|
||
|
case reflect.TypeOf(string("")):
|
||
|
argument = "string"
|
||
|
case reflect.TypeOf(int64(0)), reflect.TypeOf(int(0)):
|
||
|
argument = "int"
|
||
|
case reflect.TypeOf(time.Duration(0)):
|
||
|
argument = "duration"
|
||
|
case reflect.TypeOf(float64(0)):
|
||
|
argument = "float"
|
||
|
case reflect.TypeOf([]string{}):
|
||
|
argument = "string"
|
||
|
}
|
||
|
}
|
||
|
if argument != "" {
|
||
|
argument = "[" + argument + "] "
|
||
|
}
|
||
|
if defValue != "" {
|
||
|
defValue = fmt.Sprintf("(default: %s)", defValue)
|
||
|
}
|
||
|
hyphens := "--"
|
||
|
if len(flag.Name) == 1 {
|
||
|
hyphens = "-"
|
||
|
}
|
||
|
|
||
|
out := formatter.Fi(1, style+"%s%s{{/}} %s{{gray}}%s{{/}}\n", hyphens, flag.Name, argument, defValue)
|
||
|
out += formatter.Fiw(2, formatter.COLS, "{{light-gray}}%s{{/}}\n", flag.Usage)
|
||
|
return out
|
||
|
}
|
||
|
|
||
|
func (f GinkgoFlagSet) usageForGoFlag(goFlag *flag.Flag) string {
|
||
|
//Taken directly from the flag package
|
||
|
out := fmt.Sprintf(" -%s", goFlag.Name)
|
||
|
name, usage := flag.UnquoteUsage(goFlag)
|
||
|
if len(name) > 0 {
|
||
|
out += " " + name
|
||
|
}
|
||
|
if len(out) <= 4 {
|
||
|
out += "\t"
|
||
|
} else {
|
||
|
out += "\n \t"
|
||
|
}
|
||
|
out += strings.ReplaceAll(usage, "\n", "\n \t")
|
||
|
out += "\n"
|
||
|
return out
|
||
|
}
|
||
|
|
||
|
type stringSliceVar struct {
|
||
|
slice reflect.Value
|
||
|
}
|
||
|
|
||
|
func (ssv stringSliceVar) String() string { return "" }
|
||
|
func (ssv stringSliceVar) Set(s string) error {
|
||
|
ssv.slice.Set(reflect.AppendSlice(ssv.slice, reflect.ValueOf([]string{s})))
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
//given a set of GinkgoFlags and bindings, generate flag arguments suitable to be passed to an application with that set of flags configured.
|
||
|
func GenerateFlagArgs(flags GinkgoFlags, bindings interface{}) ([]string, error) {
|
||
|
result := []string{}
|
||
|
for _, flag := range flags {
|
||
|
name := flag.ExportAs
|
||
|
if name == "" {
|
||
|
name = flag.Name
|
||
|
}
|
||
|
if name == "" {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
value, ok := valueAtKeyPath(bindings, flag.KeyPath)
|
||
|
if !ok {
|
||
|
return []string{}, fmt.Errorf("could not load KeyPath: %s", flag.KeyPath)
|
||
|
}
|
||
|
|
||
|
iface := value.Interface()
|
||
|
switch value.Type() {
|
||
|
case reflect.TypeOf(string("")):
|
||
|
if iface.(string) != "" {
|
||
|
result = append(result, fmt.Sprintf("--%s=%s", name, iface))
|
||
|
}
|
||
|
case reflect.TypeOf(int64(0)):
|
||
|
if iface.(int64) != 0 {
|
||
|
result = append(result, fmt.Sprintf("--%s=%d", name, iface))
|
||
|
}
|
||
|
case reflect.TypeOf(float64(0)):
|
||
|
if iface.(float64) != 0 {
|
||
|
result = append(result, fmt.Sprintf("--%s=%f", name, iface))
|
||
|
}
|
||
|
case reflect.TypeOf(int(0)):
|
||
|
if iface.(int) != 0 {
|
||
|
result = append(result, fmt.Sprintf("--%s=%d", name, iface))
|
||
|
}
|
||
|
case reflect.TypeOf(bool(true)):
|
||
|
if iface.(bool) {
|
||
|
result = append(result, fmt.Sprintf("--%s", name))
|
||
|
}
|
||
|
case reflect.TypeOf(time.Duration(0)):
|
||
|
if iface.(time.Duration) != time.Duration(0) {
|
||
|
result = append(result, fmt.Sprintf("--%s=%s", name, iface))
|
||
|
}
|
||
|
|
||
|
case reflect.TypeOf([]string{}):
|
||
|
strings := iface.([]string)
|
||
|
for _, s := range strings {
|
||
|
result = append(result, fmt.Sprintf("--%s=%s", name, s))
|
||
|
}
|
||
|
default:
|
||
|
return []string{}, fmt.Errorf("unsupported type %T", iface)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return result, nil
|
||
|
}
|