youtubebeat/vendor/github.com/elastic/beats/libbeat/common/flags_test.go

195 lines
5.1 KiB
Go

// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package common
import (
"bytes"
"flag"
"fmt"
"io"
"os"
"strings"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)
func TestStringArrFlag(t *testing.T) {
tests := []struct {
init []string
def string
in []string
expected []string
}{
{nil, "test", nil, []string{"test"}},
{nil, "test", []string{"new"}, []string{"new"}},
{nil, "test", []string{"a", "b"}, []string{"a", "b"}},
{[]string{"default"}, "newdefault", nil, []string{"newdefault"}},
{[]string{"default"}, "newdefault", []string{"arg"}, []string{"arg"}},
{[]string{"default"}, "newdefault", []string{"a", "b"}, []string{"a", "b"}},
{[]string{"default"}, "newdefault", []string{"a", "b", "a", "b"}, []string{"a", "b"}},
}
for _, test := range tests {
test := test
name := fmt.Sprintf("init=%v,default=%v,in=%v,out=%v", test.init, test.def, test.in, test.expected)
t.Run(name, func(t *testing.T) {
init := make([]string, len(test.init))
copy(init, test.init)
fs := flag.NewFlagSet("test", flag.ContinueOnError)
flag := StringArrVarFlag(fs, &init, "a", "add")
if test.def != "" {
flag.SetDefault(test.def)
}
defaultValue := flag.String()
goflagUsage, _ := withStderr(fs.PrintDefaults)
goflagExpectedUsage := fmt.Sprintf(" -a value\n \tadd (default %v)\n", defaultValue)
cmd := cobra.Command{}
cmd.PersistentFlags().AddGoFlag(fs.Lookup("a"))
cobraUsage := cmd.LocalFlags().FlagUsages()
cobraExpectedUsage := fmt.Sprintf(" -a, --a string add (default \"%v\")\n", defaultValue)
for _, v := range test.in {
err := flag.Set(v)
if err != nil {
t.Error(err)
}
}
assert.Equal(t, goflagExpectedUsage, goflagUsage)
assert.Equal(t, cobraExpectedUsage, cobraUsage)
assert.Equal(t, test.expected, init)
assert.Equal(t, test.expected, flag.List())
})
}
}
func TestSettingsFlag(t *testing.T) {
tests := []struct {
in []string
expected map[string]interface{}
}{
{nil, nil},
{[]string{"a=1"}, map[string]interface{}{"a": uint64(1)}},
{[]string{"a=1", "b=false"}, map[string]interface{}{"a": uint64(1), "b": false}},
{[]string{"a=1", "b"}, map[string]interface{}{"a": uint64(1), "b": true}},
{[]string{"a=1", "c=${a}"}, map[string]interface{}{"a": uint64(1), "c": uint64(1)}},
}
for _, test := range tests {
test := test
name := strings.Join(test.in, ",")
t.Run(name, func(t *testing.T) {
config := NewConfig()
f := NewSettingsFlag(config)
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.Var(f, "s", "message")
goflagUsage, _ := withStderr(fs.PrintDefaults)
goflagExpectedUsage := " -s value\n \tmessage\n"
cmd := cobra.Command{}
cmd.PersistentFlags().AddGoFlag(fs.Lookup("s"))
cobraUsage := cmd.LocalFlags().FlagUsages()
cobraExpectedUsage := " -s, --s setting=value message\n"
for _, in := range test.in {
err := f.Set(in)
if err != nil {
t.Error(err)
}
}
var result map[string]interface{}
err := config.Unpack(&result)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, goflagExpectedUsage, goflagUsage)
assert.Equal(t, cobraExpectedUsage, cobraUsage)
assert.Equal(t, test.expected, result)
})
}
}
func TestOverwriteFlag(t *testing.T) {
config, err := NewConfigFrom(map[string]interface{}{
"a": "test",
})
if err != nil {
panic(err)
}
fs := flag.NewFlagSet("test", flag.ContinueOnError)
ConfigOverwriteFlag(fs, config, "a", "a", "", "message")
goflagUsage, _ := withStderr(fs.PrintDefaults)
goflagExpectedUsage := " -a value\n \tmessage\n"
assert.Equal(t, goflagExpectedUsage, goflagUsage)
cmd := cobra.Command{}
cmd.PersistentFlags().AddGoFlag(fs.Lookup("a"))
cobraUsage := cmd.LocalFlags().FlagUsages()
cobraExpectedUsage := " -a, --a string message\n"
assert.Equal(t, cobraExpectedUsage, cobraUsage)
fs.Set("a", "overwrite")
final, err := config.String("a", -1)
assert.NoError(t, err)
assert.Equal(t, "overwrite", final)
}
// capture stderr and return captured string
func withStderr(fn func()) (string, error) {
stderr := os.Stderr
r, w, err := os.Pipe()
if err != nil {
return "", err
}
os.Stderr = w
defer func() {
os.Stderr = stderr
}()
outC := make(chan string)
go func() {
// capture all output
var buf bytes.Buffer
_, err = io.Copy(&buf, r)
r.Close()
outC <- buf.String()
}()
fn()
w.Close()
result := <-outC
return result, err
}