Skip to content

Commit 599699b

Browse files
authored
fix: truly allow overridding default string array (#6874)
1 parent 96ff400 commit 599699b

File tree

7 files changed

+148
-112
lines changed

7 files changed

+148
-112
lines changed

cli/clibase/cmd.go

Lines changed: 63 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ type Invocation struct {
172172

173173
// WithOS returns the invocation as a main package, filling in the invocation's unset
174174
// fields with OS defaults.
175-
func (i *Invocation) WithOS() *Invocation {
176-
return i.with(func(i *Invocation) {
175+
func (inv *Invocation) WithOS() *Invocation {
176+
return inv.with(func(i *Invocation) {
177177
i.Stdout = os.Stdout
178178
i.Stderr = os.Stderr
179179
i.Stdin = os.Stdin
@@ -182,18 +182,18 @@ func (i *Invocation) WithOS() *Invocation {
182182
})
183183
}
184184

185-
func (i *Invocation) Context() context.Context {
186-
if i.ctx == nil {
185+
func (inv *Invocation) Context() context.Context {
186+
if inv.ctx == nil {
187187
return context.Background()
188188
}
189-
return i.ctx
189+
return inv.ctx
190190
}
191191

192-
func (i *Invocation) ParsedFlags() *pflag.FlagSet {
193-
if i.parsedFlags == nil {
192+
func (inv *Invocation) ParsedFlags() *pflag.FlagSet {
193+
if inv.parsedFlags == nil {
194194
panic("flags not parsed, has Run() been called?")
195195
}
196-
return i.parsedFlags
196+
return inv.parsedFlags
197197
}
198198

199199
type runState struct {
@@ -218,39 +218,17 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
218218
// run recursively executes the command and its children.
219219
// allArgs is wired through the stack so that global flags can be accepted
220220
// anywhere in the command invocation.
221-
func (i *Invocation) run(state *runState) error {
222-
err := i.Command.Options.SetDefaults()
223-
if err != nil {
224-
return xerrors.Errorf("setting defaults: %w", err)
225-
}
226-
227-
// If we set the Default of an array but later see a flag for it, we
228-
// don't want to append, we want to replace. So, we need to keep the state
229-
// of defaulted array options.
230-
defaultedArrays := make(map[string]int)
231-
for _, opt := range i.Command.Options {
232-
sv, ok := opt.Value.(pflag.SliceValue)
233-
if !ok {
234-
continue
235-
}
236-
237-
if opt.Flag == "" {
238-
continue
239-
}
240-
241-
defaultedArrays[opt.Flag] = len(sv.GetSlice())
242-
}
243-
244-
err = i.Command.Options.ParseEnv(i.Environ)
221+
func (inv *Invocation) run(state *runState) error {
222+
err := inv.Command.Options.ParseEnv(inv.Environ)
245223
if err != nil {
246224
return xerrors.Errorf("parsing env: %w", err)
247225
}
248226

249227
// Now the fun part, argument parsing!
250228

251229
children := make(map[string]*Cmd)
252-
for _, child := range i.Command.Children {
253-
child.Parent = i.Command
230+
for _, child := range inv.Command.Children {
231+
child.Parent = inv.Command
254232
for _, name := range append(child.Aliases, child.Name()) {
255233
if _, ok := children[name]; ok {
256234
return xerrors.Errorf("duplicate command name: %s", name)
@@ -259,57 +237,44 @@ func (i *Invocation) run(state *runState) error {
259237
}
260238
}
261239

262-
if i.parsedFlags == nil {
263-
i.parsedFlags = pflag.NewFlagSet(i.Command.Name(), pflag.ContinueOnError)
240+
if inv.parsedFlags == nil {
241+
inv.parsedFlags = pflag.NewFlagSet(inv.Command.Name(), pflag.ContinueOnError)
264242
// We handle Usage ourselves.
265-
i.parsedFlags.Usage = func() {}
243+
inv.parsedFlags.Usage = func() {}
266244
}
267245

268246
// If we find a duplicate flag, we want the deeper command's flag to override
269247
// the shallow one. Unfortunately, pflag has no way to remove a flag, so we
270248
// have to create a copy of the flagset without a value.
271-
i.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
272-
if i.parsedFlags.Lookup(f.Name) != nil {
273-
i.parsedFlags = copyFlagSetWithout(i.parsedFlags, f.Name)
249+
inv.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
250+
if inv.parsedFlags.Lookup(f.Name) != nil {
251+
inv.parsedFlags = copyFlagSetWithout(inv.parsedFlags, f.Name)
274252
}
275-
i.parsedFlags.AddFlag(f)
253+
inv.parsedFlags.AddFlag(f)
276254
})
277255

278256
var parsedArgs []string
279257

280-
if !i.Command.RawArgs {
258+
if !inv.Command.RawArgs {
281259
// Flag parsing will fail on intermediate commands in the command tree,
282260
// so we check the error after looking for a child command.
283-
state.flagParseErr = i.parsedFlags.Parse(state.allArgs)
284-
parsedArgs = i.parsedFlags.Args()
285-
286-
i.parsedFlags.VisitAll(func(f *pflag.Flag) {
287-
i, ok := defaultedArrays[f.Name]
288-
if !ok {
289-
return
290-
}
291-
292-
if !f.Changed {
293-
return
294-
}
261+
state.flagParseErr = inv.parsedFlags.Parse(state.allArgs)
262+
parsedArgs = inv.parsedFlags.Args()
263+
}
295264

296-
// If flag was changed, we need to remove the default values.
297-
sv, ok := f.Value.(pflag.SliceValue)
298-
if !ok {
299-
panic("defaulted array option is not a slice value")
300-
}
301-
ss := sv.GetSlice()
302-
if len(ss) == 0 {
303-
// Slice likely zeroed by a flag.
304-
// E.g. "--fruit" may default to "apples,oranges" but the user
305-
// provided "--fruit=""".
306-
return
307-
}
308-
err := sv.Replace(ss[i:])
309-
if err != nil {
310-
panic(err)
311-
}
312-
})
265+
// Set defaults for flags that weren't set by the user.
266+
skipDefaults := make(map[int]struct{}, len(inv.Command.Options))
267+
for i, opt := range inv.Command.Options {
268+
if fl := inv.parsedFlags.Lookup(opt.Flag); fl != nil && fl.Changed {
269+
skipDefaults[i] = struct{}{}
270+
}
271+
if opt.envChanged {
272+
skipDefaults[i] = struct{}{}
273+
}
274+
}
275+
err = inv.Command.Options.SetDefaults(skipDefaults)
276+
if err != nil {
277+
return xerrors.Errorf("setting defaults: %w", err)
313278
}
314279

315280
// Run child command if found (next child only)
@@ -318,64 +283,64 @@ func (i *Invocation) run(state *runState) error {
318283
if len(parsedArgs) > state.commandDepth {
319284
nextArg := parsedArgs[state.commandDepth]
320285
if child, ok := children[nextArg]; ok {
321-
child.Parent = i.Command
322-
i.Command = child
286+
child.Parent = inv.Command
287+
inv.Command = child
323288
state.commandDepth++
324-
return i.run(state)
289+
return inv.run(state)
325290
}
326291
}
327292

328293
// Flag parse errors are irrelevant for raw args commands.
329-
if !i.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
294+
if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
330295
return xerrors.Errorf(
331296
"parsing flags (%v) for %q: %w",
332297
state.allArgs,
333-
i.Command.FullName(), state.flagParseErr,
298+
inv.Command.FullName(), state.flagParseErr,
334299
)
335300
}
336301

337-
if i.Command.RawArgs {
302+
if inv.Command.RawArgs {
338303
// If we're at the root command, then the name is omitted
339304
// from the arguments, so we can just use the entire slice.
340305
if state.commandDepth == 0 {
341-
i.Args = state.allArgs
306+
inv.Args = state.allArgs
342307
} else {
343-
argPos, err := findArg(i.Command.Name(), state.allArgs, i.parsedFlags)
308+
argPos, err := findArg(inv.Command.Name(), state.allArgs, inv.parsedFlags)
344309
if err != nil {
345310
panic(err)
346311
}
347-
i.Args = state.allArgs[argPos+1:]
312+
inv.Args = state.allArgs[argPos+1:]
348313
}
349314
} else {
350315
// In non-raw-arg mode, we want to skip over flags.
351-
i.Args = parsedArgs[state.commandDepth:]
316+
inv.Args = parsedArgs[state.commandDepth:]
352317
}
353318

354-
mw := i.Command.Middleware
319+
mw := inv.Command.Middleware
355320
if mw == nil {
356321
mw = Chain()
357322
}
358323

359-
ctx := i.ctx
324+
ctx := inv.ctx
360325
if ctx == nil {
361326
ctx = context.Background()
362327
}
363328

364329
ctx, cancel := context.WithCancel(ctx)
365330
defer cancel()
366-
i = i.WithContext(ctx)
331+
inv = inv.WithContext(ctx)
367332

368-
if i.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
369-
if i.Command.HelpHandler == nil {
370-
return xerrors.Errorf("no handler or help for command %s", i.Command.FullName())
333+
if inv.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
334+
if inv.Command.HelpHandler == nil {
335+
return xerrors.Errorf("no handler or help for command %s", inv.Command.FullName())
371336
}
372-
return i.Command.HelpHandler(i)
337+
return inv.Command.HelpHandler(inv)
373338
}
374339

375-
err = mw(i.Command.Handler)(i)
340+
err = mw(inv.Command.Handler)(inv)
376341
if err != nil {
377342
return &RunCommandError{
378-
Cmd: i.Command,
343+
Cmd: inv.Command,
379344
Err: err,
380345
}
381346
}
@@ -438,33 +403,33 @@ func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
438403
// If two command share a flag name, the first command wins.
439404
//
440405
//nolint:revive
441-
func (i *Invocation) Run() (err error) {
406+
func (inv *Invocation) Run() (err error) {
442407
defer func() {
443408
// Pflag is panicky, so additional context is helpful in tests.
444409
if flag.Lookup("test.v") == nil {
445410
return
446411
}
447412
if r := recover(); r != nil {
448-
err = xerrors.Errorf("panic recovered for %s: %v", i.Command.FullName(), r)
413+
err = xerrors.Errorf("panic recovered for %s: %v", inv.Command.FullName(), r)
449414
panic(err)
450415
}
451416
}()
452-
err = i.run(&runState{
453-
allArgs: i.Args,
417+
err = inv.run(&runState{
418+
allArgs: inv.Args,
454419
})
455420
return err
456421
}
457422

458423
// WithContext returns a copy of the Invocation with the given context.
459-
func (i *Invocation) WithContext(ctx context.Context) *Invocation {
460-
return i.with(func(i *Invocation) {
424+
func (inv *Invocation) WithContext(ctx context.Context) *Invocation {
425+
return inv.with(func(i *Invocation) {
461426
i.ctx = ctx
462427
})
463428
}
464429

465430
// with returns a copy of the Invocation with the given function applied.
466-
func (i *Invocation) with(fn func(*Invocation)) *Invocation {
467-
i2 := *i
431+
func (inv *Invocation) with(fn func(*Invocation)) *Invocation {
432+
i2 := *inv
468433
fn(&i2)
469434
return &i2
470435
}

cli/clibase/cmd_test.go

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package clibase_test
33
import (
44
"bytes"
55
"context"
6+
"fmt"
67
"strings"
78
"testing"
89

@@ -247,6 +248,7 @@ func TestCommand_FlagOverride(t *testing.T) {
247248
Use: "1",
248249
Options: clibase.OptionSet{
249250
{
251+
Name: "flag",
250252
Flag: "f",
251253
Value: clibase.DiscardValue,
252254
},
@@ -256,6 +258,7 @@ func TestCommand_FlagOverride(t *testing.T) {
256258
Use: "2",
257259
Options: clibase.OptionSet{
258260
{
261+
Name: "flag",
259262
Flag: "f",
260263
Value: clibase.StringOf(&flag),
261264
},
@@ -515,7 +518,7 @@ func TestCommand_EmptySlice(t *testing.T) {
515518
{
516519
Name: "arr",
517520
Flag: "arr",
518-
Default: "bad,bad,bad",
521+
Default: "def,def,def",
519522
Env: "ARR",
520523
Value: clibase.StringArrayOf(&got),
521524
},
@@ -527,11 +530,67 @@ func TestCommand_EmptySlice(t *testing.T) {
527530
}
528531
}
529532

530-
// Base-case
531-
err := cmd("bad", "bad", "bad").Invoke().Run()
533+
// Base-case, uses default.
534+
err := cmd("def", "def", "def").Invoke().Run()
535+
require.NoError(t, err)
536+
537+
// Empty-env uses default, too.
538+
inv := cmd("def", "def", "def").Invoke()
539+
inv.Environ.Set("ARR", "")
540+
require.NoError(t, err)
541+
542+
// Reset to nothing at all via flag.
543+
inv = cmd().Invoke("--arr", "")
544+
inv.Environ.Set("ARR", "cant see")
545+
err = inv.Run()
546+
require.NoError(t, err)
547+
548+
// Reset to a specific value with flag.
549+
inv = cmd("great").Invoke("--arr", "great")
550+
inv.Environ.Set("ARR", "")
551+
err = inv.Run()
552+
require.NoError(t, err)
553+
}
554+
555+
func TestCommand_DefaultsOverride(t *testing.T) {
556+
t.Parallel()
557+
558+
var got string
559+
cmd := &clibase.Cmd{
560+
Options: clibase.OptionSet{
561+
{
562+
Name: "url",
563+
Flag: "url",
564+
Default: "def.com",
565+
Env: "URL",
566+
Value: clibase.StringOf(&got),
567+
},
568+
},
569+
Handler: (func(i *clibase.Invocation) error {
570+
_, _ = fmt.Fprintf(i.Stdout, "%s", got)
571+
return nil
572+
}),
573+
}
574+
575+
// Base case
576+
inv := cmd.Invoke()
577+
stdio := fakeIO(inv)
578+
err := inv.Run()
579+
require.NoError(t, err)
580+
require.Equal(t, "def.com", stdio.Stdout.String())
581+
582+
// Flag overrides
583+
inv = cmd.Invoke("--url", "good.com")
584+
stdio = fakeIO(inv)
585+
err = inv.Run()
532586
require.NoError(t, err)
587+
require.Equal(t, "good.com", stdio.Stdout.String())
533588

534-
inv := cmd().Invoke("--arr", "")
589+
// Env overrides
590+
inv = cmd.Invoke()
591+
inv.Environ.Set("URL", "good.com")
592+
stdio = fakeIO(inv)
535593
err = inv.Run()
536594
require.NoError(t, err)
595+
require.Equal(t, "good.com", stdio.Stdout.String())
537596
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy