Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 1071055

Browse files
committed
sql/(analyzer,parse,expression): implement case expression
Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
1 parent 92c65a5 commit 1071055

File tree

7 files changed

+572
-6
lines changed

7 files changed

+572
-6
lines changed

engine_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,30 @@ var queries = []struct {
776776
{int64(15)},
777777
},
778778
},
779+
{
780+
`SELECT CASE i WHEN 1 THEN 'one' WHEN 2 THEN 'two' ELSE 'other' END FROM mytable`,
781+
[]sql.Row{
782+
{"one"},
783+
{"two"},
784+
{"other"},
785+
},
786+
},
787+
{
788+
`SELECT CASE WHEN i > 2 THEN 'more than two' WHEN i < 2 THEN 'less than two' ELSE 'two' END FROM mytable`,
789+
[]sql.Row{
790+
{"less than two"},
791+
{"two"},
792+
{"more than two"},
793+
},
794+
},
795+
{
796+
`SELECT CASE i WHEN 1 THEN 'one' WHEN 2 THEN 'two' END FROM mytable`,
797+
[]sql.Row{
798+
{"one"},
799+
{"two"},
800+
{nil},
801+
},
802+
},
779803
}
780804

781805
func TestQueries(t *testing.T) {

sql/analyzer/validation_rules.go

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ import (
1010
)
1111

1212
const (
13-
validateResolvedRule = "validate_resolved"
14-
validateOrderByRule = "validate_order_by"
15-
validateGroupByRule = "validate_group_by"
16-
validateSchemaSourceRule = "validate_schema_source"
17-
validateProjectTuplesRule = "validate_project_tuples"
18-
validateIndexCreationRule = "validate_index_creation"
13+
validateResolvedRule = "validate_resolved"
14+
validateOrderByRule = "validate_order_by"
15+
validateGroupByRule = "validate_group_by"
16+
validateSchemaSourceRule = "validate_schema_source"
17+
validateProjectTuplesRule = "validate_project_tuples"
18+
validateIndexCreationRule = "validate_index_creation"
19+
validateCaseResultTypesRule = "validate_case_result_types"
1920
)
2021

2122
var (
@@ -36,6 +37,12 @@ var (
3637
// ErrUnknownIndexColumns is returned when there are columns in the expr
3738
// to index that are unknown in the table.
3839
ErrUnknownIndexColumns = errors.NewKind("unknown columns to index for table %q: %s")
40+
// ErrCaseResultType is returned when one or more of the types of the values in
41+
// a case expression don't match.
42+
ErrCaseResultType = errors.NewKind(
43+
"expecting all case branches to return values of type %s, " +
44+
"but found value %q of type %s on %s",
45+
)
3946
)
4047

4148
// DefaultValidationRules to apply while analyzing nodes.
@@ -46,6 +53,7 @@ var DefaultValidationRules = []Rule{
4653
{validateSchemaSourceRule, validateSchemaSource},
4754
{validateProjectTuplesRule, validateProjectTuples},
4855
{validateIndexCreationRule, validateIndexCreation},
56+
{validateCaseResultTypesRule, validateCaseResultTypes},
4957
}
5058

5159
func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
@@ -199,6 +207,42 @@ func validateProjectTuples(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node,
199207
return n, nil
200208
}
201209

210+
func validateCaseResultTypes(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
211+
span, ctx := ctx.Span("validate_case_result_types")
212+
defer span.Finish()
213+
214+
var err error
215+
plan.InspectExpressions(n, func(e sql.Expression) bool {
216+
switch e := e.(type) {
217+
case *expression.Case:
218+
typ := e.Type()
219+
for _, b := range e.Branches {
220+
if b.Value.Type() != typ {
221+
err = ErrCaseResultType.New(typ, b.Value, b.Value.Type(), e)
222+
return false
223+
}
224+
}
225+
226+
if e.Else != nil {
227+
if e.Else.Type() != typ {
228+
err = ErrCaseResultType.New(typ, e.Else, e.Else.Type(), e)
229+
return false
230+
}
231+
}
232+
233+
return false
234+
default:
235+
return true
236+
}
237+
})
238+
239+
if err != nil {
240+
return nil, err
241+
}
242+
243+
return n, nil
244+
}
245+
202246
func stringContains(strs []string, target string) bool {
203247
for _, s := range strs {
204248
if s == target {

sql/analyzer/validation_rules_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,88 @@ func TestValidateIndexCreation(t *testing.T) {
349349
}
350350
}
351351

352+
func TestValidateCaseResultTypes(t *testing.T) {
353+
rule := getValidationRule(validateCaseResultTypesRule)
354+
355+
testCases := []struct {
356+
name string
357+
expr *expression.Case
358+
ok bool
359+
}{
360+
{
361+
"one of the branches does not match",
362+
expression.NewCase(
363+
expression.NewGetField(0, sql.Int64, "foo", false),
364+
[]expression.CaseBranch{
365+
{
366+
Cond: expression.NewLiteral(int64(1), sql.Int64),
367+
Value: expression.NewLiteral("foo", sql.Text),
368+
},
369+
{
370+
Cond: expression.NewLiteral(int64(2), sql.Int64),
371+
Value: expression.NewLiteral(int64(1), sql.Int64),
372+
},
373+
},
374+
expression.NewLiteral("foo", sql.Text),
375+
),
376+
false,
377+
},
378+
{
379+
"else does not match",
380+
expression.NewCase(
381+
expression.NewGetField(0, sql.Int64, "foo", false),
382+
[]expression.CaseBranch{
383+
{
384+
Cond: expression.NewLiteral(int64(1), sql.Int64),
385+
Value: expression.NewLiteral("foo", sql.Text),
386+
},
387+
{
388+
Cond: expression.NewLiteral(int64(2), sql.Int64),
389+
Value: expression.NewLiteral("bar", sql.Text),
390+
},
391+
},
392+
expression.NewLiteral(int64(1), sql.Int64),
393+
),
394+
false,
395+
},
396+
{
397+
"all ok",
398+
expression.NewCase(
399+
expression.NewGetField(0, sql.Int64, "foo", false),
400+
[]expression.CaseBranch{
401+
{
402+
Cond: expression.NewLiteral(int64(1), sql.Int64),
403+
Value: expression.NewLiteral("foo", sql.Text),
404+
},
405+
{
406+
Cond: expression.NewLiteral(int64(2), sql.Int64),
407+
Value: expression.NewLiteral("bar", sql.Text),
408+
},
409+
},
410+
expression.NewLiteral("baz", sql.Text),
411+
),
412+
true,
413+
},
414+
}
415+
416+
for _, tt := range testCases {
417+
t.Run(tt.name, func(t *testing.T) {
418+
require := require.New(t)
419+
_, err := rule.Apply(sql.NewEmptyContext(), nil, plan.NewProject(
420+
[]sql.Expression{tt.expr},
421+
plan.NewResolvedTable(dualTable),
422+
))
423+
424+
if tt.ok {
425+
require.NoError(err)
426+
} else {
427+
require.Error(err)
428+
require.True(ErrCaseResultType.Is(err))
429+
}
430+
})
431+
}
432+
}
433+
352434
type dummyNode struct{ resolved bool }
353435

354436
func (n dummyNode) String() string { return "dummynode" }

sql/expression/case.go

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
package expression
2+
3+
import (
4+
"bytes"
5+
6+
"gopkg.in/src-d/go-mysql-server.v0/sql"
7+
)
8+
9+
// CaseBranch is a single branch of a case expression.
10+
type CaseBranch struct {
11+
Cond sql.Expression
12+
Value sql.Expression
13+
}
14+
15+
// Case is an expression that returns the value of one of its branches when a
16+
// condition is met.
17+
type Case struct {
18+
Expr sql.Expression
19+
Branches []CaseBranch
20+
Else sql.Expression
21+
}
22+
23+
// NewCase returns an new Case expression.
24+
func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression) *Case {
25+
return &Case{expr, branches, elseExpr}
26+
}
27+
28+
// Type implements the sql.Expression interface.
29+
func (c *Case) Type() sql.Type {
30+
for _, b := range c.Branches {
31+
return b.Value.Type()
32+
}
33+
return c.Else.Type()
34+
}
35+
36+
// IsNullable implements the sql.Expression interface.
37+
func (c *Case) IsNullable() bool {
38+
for _, b := range c.Branches {
39+
if b.Value.IsNullable() {
40+
return true
41+
}
42+
}
43+
44+
return c.Else == nil || c.Else.IsNullable()
45+
}
46+
47+
// Resolved implements the sql.Expression interface.
48+
func (c *Case) Resolved() bool {
49+
if (c.Expr != nil && !c.Expr.Resolved()) ||
50+
(c.Else != nil && !c.Else.Resolved()) {
51+
return false
52+
}
53+
54+
for _, b := range c.Branches {
55+
if !b.Cond.Resolved() || !b.Value.Resolved() {
56+
return false
57+
}
58+
}
59+
60+
return true
61+
}
62+
63+
// Children implements the sql.Expression interface.
64+
func (c *Case) Children() []sql.Expression {
65+
var children []sql.Expression
66+
67+
if c.Expr != nil {
68+
children = append(children, c.Expr)
69+
}
70+
71+
for _, b := range c.Branches {
72+
children = append(children, b.Cond, b.Value)
73+
}
74+
75+
if c.Else != nil {
76+
children = append(children, c.Else)
77+
}
78+
79+
return children
80+
}
81+
82+
// Eval implements the sql.Expression interface.
83+
func (c *Case) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
84+
span, ctx := ctx.Span("expression.Case")
85+
defer span.Finish()
86+
87+
var expr interface{}
88+
var err error
89+
if c.Expr != nil {
90+
expr, err = c.Expr.Eval(ctx, row)
91+
if err != nil {
92+
return nil, err
93+
}
94+
}
95+
96+
for _, b := range c.Branches {
97+
var cond sql.Expression
98+
if expr != nil {
99+
cond = NewEquals(NewLiteral(expr, c.Expr.Type()), b.Cond)
100+
} else {
101+
cond = b.Cond
102+
}
103+
104+
v, err := cond.Eval(ctx, row)
105+
if err != nil {
106+
return nil, err
107+
}
108+
109+
v, err = sql.Boolean.Convert(v)
110+
if err != nil {
111+
return nil, err
112+
}
113+
114+
if v == true {
115+
return b.Value.Eval(ctx, row)
116+
}
117+
}
118+
119+
if c.Else != nil {
120+
return c.Else.Eval(ctx, row)
121+
}
122+
123+
return nil, nil
124+
}
125+
126+
// TransformUp implements the sql.Expression interface.
127+
func (c *Case) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
128+
var expr sql.Expression
129+
var err error
130+
131+
if c.Expr != nil {
132+
expr, err = c.Expr.TransformUp(f)
133+
if err != nil {
134+
return nil, err
135+
}
136+
}
137+
138+
var branches []CaseBranch
139+
for _, b := range c.Branches {
140+
var nb CaseBranch
141+
142+
nb.Cond, err = b.Cond.TransformUp(f)
143+
if err != nil {
144+
return nil, err
145+
}
146+
147+
nb.Value, err = b.Value.TransformUp(f)
148+
if err != nil {
149+
return nil, err
150+
}
151+
152+
branches = append(branches, nb)
153+
}
154+
155+
var elseExpr sql.Expression
156+
if c.Else != nil {
157+
elseExpr, err = c.Else.TransformUp(f)
158+
if err != nil {
159+
return nil, err
160+
}
161+
}
162+
163+
return f(NewCase(expr, branches, elseExpr))
164+
}
165+
166+
func (c *Case) String() string {
167+
var buf bytes.Buffer
168+
169+
buf.WriteString("CASE ")
170+
if c.Expr != nil {
171+
buf.WriteString(c.Expr.String())
172+
}
173+
174+
for _, b := range c.Branches {
175+
buf.WriteString(" WHEN ")
176+
buf.WriteString(b.Cond.String())
177+
buf.WriteString(" THEN ")
178+
buf.WriteString(b.Value.String())
179+
}
180+
181+
if c.Else != nil {
182+
buf.WriteString(" ELSE ")
183+
buf.WriteString(c.Else.String())
184+
}
185+
186+
buf.WriteString(" END")
187+
return buf.String()
188+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy