diff --git a/engine_test.go b/engine_test.go index 7864ed426..e507f0d6d 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1230,6 +1230,10 @@ var queries = []struct { `SELECT (NULL+1)`, []sql.Row{{nil}}, }, + { + `SELECT * FROM mytable WHERE NULL AND i = 3`, + []sql.Row{}, + }, } func TestQueries(t *testing.T) { diff --git a/sql/core.go b/sql/core.go index 57e579764..5d7fe82f2 100644 --- a/sql/core.go +++ b/sql/core.go @@ -3,6 +3,8 @@ package sql // import "github.com/src-d/go-mysql-server/sql" import ( "fmt" "io" + "math" + "time" "gopkg.in/src-d/go-errors.v1" ) @@ -218,3 +220,27 @@ type Lockable interface { // available. Unlock(ctx *Context, id uint32) error } + +// EvaluateCondition evaluates a condition, which is an expression whose value +// will be coerced to boolean. +func EvaluateCondition(ctx *Context, cond Expression, row Row) (bool, error) { + v, err := cond.Eval(ctx, row) + if err != nil { + return false, err + } + + switch b := v.(type) { + case bool: + return b, nil + case int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8: + return b != 0, nil + case time.Duration: + return int64(b) != 0, nil + case time.Time: + return b.UnixNano() != 0, nil + case float32, float64: + return int(math.Round(v.(float64))) != 0, nil + default: + return false, nil + } +} diff --git a/sql/plan/filter.go b/sql/plan/filter.go index ac2a4772a..a86c51eab 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -107,12 +107,12 @@ func (i *FilterIter) Next() (sql.Row, error) { return nil, err } - result, err := i.cond.Eval(i.ctx, row) + ok, err := sql.EvaluateCondition(i.ctx, i.cond, row) if err != nil { return nil, err } - if result == true { + if ok { return row, nil } } diff --git a/sql/type.go b/sql/type.go index e082abd17..37c483785 100644 --- a/sql/type.go +++ b/sql/type.go @@ -616,25 +616,13 @@ func (t booleanT) Convert(v interface{}) (interface{}, error) { case bool: return b, nil case int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8: - if b != 0 { - return true, nil - } - return false, nil + return b != 0, nil case time.Duration: - if int64(b) != 0 { - return true, nil - } - return false, nil + return int64(b) != 0, nil case time.Time: - if b.UnixNano() != 0 { - return true, nil - } - return false, nil + return b.UnixNano() != 0, nil case float32, float64: - if int(math.Round(v.(float64))) != 0 { - return true, nil - } - return false, nil + return int(math.Round(v.(float64))) != 0, nil case string: return false, fmt.Errorf("unable to cast string to bool")
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: