diff --git a/logger/create.go b/logger/create.go index 315435fc..040a4ec8 100644 --- a/logger/create.go +++ b/logger/create.go @@ -3,6 +3,7 @@ package logger import ( "fmt" "os" + "path/filepath" "strings" "time" @@ -106,7 +107,7 @@ func New(opts ...Option) (Service, error) { l := NewOutputWriter(SharedWriteManager) if config.logFileDirectory != "" { - l.Add(NewFileRollingWriter(config.logFileDirectory, + l.Add(NewFileRollingWriter(SanitizeLogPath(config.logFileDirectory), "cloudflared", int64(config.maxFileSize), config.maxFileCount), @@ -139,3 +140,13 @@ func ParseLevelString(lvl string) ([]Level, error) { } return []Level{}, fmt.Errorf("not a valid log level: %q", lvl) } + +// SanitizeLogPath checks that the logger log path +func SanitizeLogPath(path string) string { + newPath := strings.TrimSpace(path) + // make sure it has a log file extension and is not a directory + if filepath.Ext(newPath) != ".log" && !(isDirectory(newPath) || strings.HasSuffix(newPath, "/")) { + newPath = newPath + ".log" + } + return newPath +} diff --git a/logger/create_test.go b/logger/create_test.go new file mode 100644 index 00000000..a0617351 --- /dev/null +++ b/logger/create_test.go @@ -0,0 +1,46 @@ +package logger + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLogLevelParse(t *testing.T) { + lvls, err := ParseLevelString("fatal") + assert.NoError(t, err) + assert.Equal(t, []Level{FatalLevel}, lvls) + + lvls, err = ParseLevelString("error") + assert.NoError(t, err) + assert.Equal(t, []Level{FatalLevel, ErrorLevel}, lvls) + + lvls, err = ParseLevelString("info") + assert.NoError(t, err) + assert.Equal(t, []Level{FatalLevel, ErrorLevel, InfoLevel}, lvls) + + lvls, err = ParseLevelString("info") + assert.NoError(t, err) + assert.Equal(t, []Level{FatalLevel, ErrorLevel, InfoLevel}, lvls) + + lvls, err = ParseLevelString("warn") + assert.NoError(t, err) + assert.Equal(t, []Level{FatalLevel, ErrorLevel, InfoLevel}, lvls) + + lvls, err = ParseLevelString("debug") + assert.NoError(t, err) + assert.Equal(t, []Level{FatalLevel, ErrorLevel, InfoLevel, DebugLevel}, lvls) + + _, err = ParseLevelString("blah") + assert.Error(t, err) + + _, err = ParseLevelString("") + assert.Error(t, err) +} + +func TestPathSanitizer(t *testing.T) { + assert.Equal(t, "somebad/path/log.bat.log", SanitizeLogPath("\t somebad/path/log.bat\n\n")) + assert.Equal(t, "proper/path/cloudflared.log", SanitizeLogPath("proper/path/cloudflared.log")) + assert.Equal(t, "proper/path/", SanitizeLogPath("proper/path/")) + assert.Equal(t, "proper/path/cloudflared.log", SanitizeLogPath("\tproper/path/cloudflared\n\n")) +}