Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

sql/(analyzer,parse,expression): implement case expression #576

Merged
merged 1 commit into from
Jan 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,30 @@ var queries = []struct {
{int64(15)},
},
},
{
`SELECT CASE i WHEN 1 THEN 'one' WHEN 2 THEN 'two' ELSE 'other' END FROM mytable`,
[]sql.Row{
{"one"},
{"two"},
{"other"},
},
},
{
`SELECT CASE WHEN i > 2 THEN 'more than two' WHEN i < 2 THEN 'less than two' ELSE 'two' END FROM mytable`,
[]sql.Row{
{"less than two"},
{"two"},
{"more than two"},
},
},
{
`SELECT CASE i WHEN 1 THEN 'one' WHEN 2 THEN 'two' END FROM mytable`,
[]sql.Row{
{"one"},
{"two"},
{nil},
},
},
}

func TestQueries(t *testing.T) {
Expand Down
56 changes: 50 additions & 6 deletions sql/analyzer/validation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import (
)

const (
validateResolvedRule = "validate_resolved"
validateOrderByRule = "validate_order_by"
validateGroupByRule = "validate_group_by"
validateSchemaSourceRule = "validate_schema_source"
validateProjectTuplesRule = "validate_project_tuples"
validateIndexCreationRule = "validate_index_creation"
validateResolvedRule = "validate_resolved"
validateOrderByRule = "validate_order_by"
validateGroupByRule = "validate_group_by"
validateSchemaSourceRule = "validate_schema_source"
validateProjectTuplesRule = "validate_project_tuples"
validateIndexCreationRule = "validate_index_creation"
validateCaseResultTypesRule = "validate_case_result_types"
)

var (
Expand All @@ -36,6 +37,12 @@ var (
// ErrUnknownIndexColumns is returned when there are columns in the expr
// to index that are unknown in the table.
ErrUnknownIndexColumns = errors.NewKind("unknown columns to index for table %q: %s")
// ErrCaseResultType is returned when one or more of the types of the values in
// a case expression don't match.
ErrCaseResultType = errors.NewKind(
"expecting all case branches to return values of type %s, " +
"but found value %q of type %s on %s",
)
)

// DefaultValidationRules to apply while analyzing nodes.
Expand All @@ -46,6 +53,7 @@ var DefaultValidationRules = []Rule{
{validateSchemaSourceRule, validateSchemaSource},
{validateProjectTuplesRule, validateProjectTuples},
{validateIndexCreationRule, validateIndexCreation},
{validateCaseResultTypesRule, validateCaseResultTypes},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just open question - doesn't make sense as a optimization to squash some rules?
For instance instead of 3 times looping, we may have one loop with some cases - loud thinking.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this query phase is really fast compared with the rest. The idea of having separated rules is the synergy that we can obtain executing them again and again, interacting with other rules.

}

func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
Expand Down Expand Up @@ -199,6 +207,42 @@ func validateProjectTuples(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node,
return n, nil
}

func validateCaseResultTypes(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
span, ctx := ctx.Span("validate_case_result_types")
defer span.Finish()

var err error
plan.InspectExpressions(n, func(e sql.Expression) bool {
switch e := e.(type) {
case *expression.Case:
typ := e.Type()
for _, b := range e.Branches {
if b.Value.Type() != typ {
err = ErrCaseResultType.New(typ, b.Value, b.Value.Type(), e)
return false
}
}

if e.Else != nil {
if e.Else.Type() != typ {
err = ErrCaseResultType.New(typ, e.Else, e.Else.Type(), e)
return false
}
}

return false
default:
return true
}
})

if err != nil {
return nil, err
}

return n, nil
}

func stringContains(strs []string, target string) bool {
for _, s := range strs {
if s == target {
Expand Down
82 changes: 82 additions & 0 deletions sql/analyzer/validation_rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,88 @@ func TestValidateIndexCreation(t *testing.T) {
}
}

func TestValidateCaseResultTypes(t *testing.T) {
rule := getValidationRule(validateCaseResultTypesRule)

testCases := []struct {
name string
expr *expression.Case
ok bool
}{
{
"one of the branches does not match",
expression.NewCase(
expression.NewGetField(0, sql.Int64, "foo", false),
[]expression.CaseBranch{
{
Cond: expression.NewLiteral(int64(1), sql.Int64),
Value: expression.NewLiteral("foo", sql.Text),
},
{
Cond: expression.NewLiteral(int64(2), sql.Int64),
Value: expression.NewLiteral(int64(1), sql.Int64),
},
},
expression.NewLiteral("foo", sql.Text),
),
false,
},
{
"else does not match",
expression.NewCase(
expression.NewGetField(0, sql.Int64, "foo", false),
[]expression.CaseBranch{
{
Cond: expression.NewLiteral(int64(1), sql.Int64),
Value: expression.NewLiteral("foo", sql.Text),
},
{
Cond: expression.NewLiteral(int64(2), sql.Int64),
Value: expression.NewLiteral("bar", sql.Text),
},
},
expression.NewLiteral(int64(1), sql.Int64),
),
false,
},
{
"all ok",
expression.NewCase(
expression.NewGetField(0, sql.Int64, "foo", false),
[]expression.CaseBranch{
{
Cond: expression.NewLiteral(int64(1), sql.Int64),
Value: expression.NewLiteral("foo", sql.Text),
},
{
Cond: expression.NewLiteral(int64(2), sql.Int64),
Value: expression.NewLiteral("bar", sql.Text),
},
},
expression.NewLiteral("baz", sql.Text),
),
true,
},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
_, err := rule.Apply(sql.NewEmptyContext(), nil, plan.NewProject(
[]sql.Expression{tt.expr},
plan.NewResolvedTable(dualTable),
))

if tt.ok {
require.NoError(err)
} else {
require.Error(err)
require.True(ErrCaseResultType.Is(err))
}
})
}
}

type dummyNode struct{ resolved bool }

func (n dummyNode) String() string { return "dummynode" }
Expand Down
188 changes: 188 additions & 0 deletions sql/expression/case.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
package expression

import (
"bytes"

"gopkg.in/src-d/go-mysql-server.v0/sql"
)

// CaseBranch is a single branch of a case expression.
type CaseBranch struct {
Cond sql.Expression
Value sql.Expression
}

// Case is an expression that returns the value of one of its branches when a
// condition is met.
type Case struct {
Expr sql.Expression
Branches []CaseBranch
Else sql.Expression
}

// NewCase returns an new Case expression.
func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression) *Case {
return &Case{expr, branches, elseExpr}
}

// Type implements the sql.Expression interface.
func (c *Case) Type() sql.Type {
for _, b := range c.Branches {
return b.Value.Type()
}
return c.Else.Type()
}

// IsNullable implements the sql.Expression interface.
func (c *Case) IsNullable() bool {
for _, b := range c.Branches {
if b.Value.IsNullable() {
return true
}
}

return c.Else == nil || c.Else.IsNullable()
}

// Resolved implements the sql.Expression interface.
func (c *Case) Resolved() bool {
if (c.Expr != nil && !c.Expr.Resolved()) ||
(c.Else != nil && !c.Else.Resolved()) {
return false
}

for _, b := range c.Branches {
if !b.Cond.Resolved() || !b.Value.Resolved() {
return false
}
}

return true
}

// Children implements the sql.Expression interface.
func (c *Case) Children() []sql.Expression {
var children []sql.Expression

if c.Expr != nil {
children = append(children, c.Expr)
}

for _, b := range c.Branches {
children = append(children, b.Cond, b.Value)
}

if c.Else != nil {
children = append(children, c.Else)
}

return children
}

// Eval implements the sql.Expression interface.
func (c *Case) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
span, ctx := ctx.Span("expression.Case")
defer span.Finish()

var expr interface{}
var err error
if c.Expr != nil {
expr, err = c.Expr.Eval(ctx, row)
if err != nil {
return nil, err
}
}

for _, b := range c.Branches {
var cond sql.Expression
if expr != nil {
cond = NewEquals(NewLiteral(expr, c.Expr.Type()), b.Cond)
} else {
cond = b.Cond
}

v, err := cond.Eval(ctx, row)
if err != nil {
return nil, err
}

v, err = sql.Boolean.Convert(v)
if err != nil {
return nil, err
}

if v == true {
return b.Value.Eval(ctx, row)
}
}

if c.Else != nil {
return c.Else.Eval(ctx, row)
}

return nil, nil
}

// TransformUp implements the sql.Expression interface.
func (c *Case) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
var expr sql.Expression
var err error

if c.Expr != nil {
expr, err = c.Expr.TransformUp(f)
if err != nil {
return nil, err
}
}

var branches []CaseBranch
for _, b := range c.Branches {
var nb CaseBranch

nb.Cond, err = b.Cond.TransformUp(f)
if err != nil {
return nil, err
}

nb.Value, err = b.Value.TransformUp(f)
if err != nil {
return nil, err
}

branches = append(branches, nb)
}

var elseExpr sql.Expression
if c.Else != nil {
elseExpr, err = c.Else.TransformUp(f)
if err != nil {
return nil, err
}
}

return f(NewCase(expr, branches, elseExpr))
}

func (c *Case) String() string {
var buf bytes.Buffer

buf.WriteString("CASE ")
if c.Expr != nil {
buf.WriteString(c.Expr.String())
}

for _, b := range c.Branches {
buf.WriteString(" WHEN ")
buf.WriteString(b.Cond.String())
buf.WriteString(" THEN ")
buf.WriteString(b.Value.String())
}

if c.Else != nil {
buf.WriteString(" ELSE ")
buf.WriteString(c.Else.String())
}

buf.WriteString(" END")
return buf.String()
}
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy