package sqlmock import ( "fmt" "regexp" "strings" ) var re = regexp.MustCompile("\\s+") // strip out new lines and trim spaces func stripQuery(q string) (s string) { return strings.TrimSpace(re.ReplaceAllString(q, " ")) } // QueryMatcher is an SQL query string matcher interface, // which can be used to customize validation of SQL query strings. // As an example, external library could be used to build // and validate SQL ast, columns selected. // // sqlmock can be customized to implement a different QueryMatcher // configured through an option when sqlmock.New or sqlmock.NewWithDSN // is called, default QueryMatcher is QueryMatcherRegexp. type QueryMatcher interface { // Match expected SQL query string without whitespace to // actual SQL. Match(expectedSQL, actualSQL string) error } // QueryMatcherFunc type is an adapter to allow the use of // ordinary functions as QueryMatcher. If f is a function // with the appropriate signature, QueryMatcherFunc(f) is a // QueryMatcher that calls f. type QueryMatcherFunc func(expectedSQL, actualSQL string) error // Match implements the QueryMatcher func (f QueryMatcherFunc) Match(expectedSQL, actualSQL string) error { return f(expectedSQL, actualSQL) } // QueryMatcherRegexp is the default SQL query matcher // used by sqlmock. It parses expectedSQL to a regular // expression and attempts to match actualSQL. var QueryMatcherRegexp QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error { expect := stripQuery(expectedSQL) actual := stripQuery(actualSQL) re, err := regexp.Compile(expect) if err != nil { return err } if !re.MatchString(actual) { return fmt.Errorf(`could not match actual sql: "%s" with expected regexp "%s"`, actual, re.String()) } return nil }) // QueryMatcherEqual is the SQL query matcher // which simply tries a case sensitive match of // expected and actual SQL strings without whitespace. var QueryMatcherEqual QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error { expect := stripQuery(expectedSQL) actual := stripQuery(actualSQL) if actual != expect { return fmt.Errorf(`actual sql: "%s" does not equal to expected "%s"`, actual, expect) } return nil })