Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 9ee3f1d

Browse files
authored
Merge pull request #555 from kuba--/fix-546/null
Implement ifnull and nullif functions.
2 parents 5a108b2 + 88d2d71 commit 9ee3f1d

File tree

9 files changed

+366
-4
lines changed

9 files changed

+366
-4
lines changed

README.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,21 @@ go get gopkg.in/src-d/go-mysql-server.v0
5050

5151
We are continuously adding more functionality to go-mysql-server. We support a subset of what is supported in MySQL, to see what is currently included check the [SUPPORTED](./SUPPORTED.md) file.
5252

53-
# Third-party clients
53+
## Third-party clients
5454

5555
We support and actively test against certain third-party clients to ensure compatibility between them and go-mysql-server. You can check out the list of supported third party clients in the [SUPPORTED_CLIENTS](./SUPPORTED_CLIENTS.md) file along with some examples on how to connect to go-mysql-server using them.
5656

5757
## Custom functions
5858

59+
- `COUNT(expr)`: Returns a count of the number of non-NULL values of expr in the rows retrieved by a SELECT statement.
60+
- `MIN(expr)`: Returns the minimum value of expr.
61+
- `MAX(expr)`: Returns the maximum value of expr.
62+
- `AVG(expr)`: Returns the average value of expr.
63+
- `SUM(expr)`: Returns the sum of expr.
5964
- `IS_BINARY(blob)`: Returns whether a BLOB is a binary file or not.
60-
- `SUBSTRING(str, pos)`, `SUBSTRING(str, pos, len)`: Return a substring from the provided string.
65+
- `SUBSTRING(str, pos)`, `SUBSTRING(str, pos, len)` : Return a substring from the provided string.
66+
- `SUBSTR(str, pos)`, `SUBSTR(str, pos, len)` : Return a substring from the provided string.
67+
- `MID(str, pos)`, `MID(str, pos, len)` : Return a substring from the provided string.
6168
- Date and Timestamp functions: `YEAR(date)`, `MONTH(date)`, `DAY(date)`, `WEEKDAY(date)`, `HOUR(date)`, `MINUTE(date)`, `SECOND(date)`, `DAYOFWEEK(date)`, `DAYOFYEAR(date)`.
6269
- `ARRAY_LENGTH(json)`: If the json representation is an array, this function returns its size.
6370
- `SPLIT(str,sep)`: Receives a string and a separator and returns the parts of the string split by the separator as a JSON array of strings.
@@ -70,6 +77,23 @@ We support and actively test against certain third-party clients to ensure compa
7077
- `ROUND(number, decimals)`: Round the `number` to `decimals` decimal places.
7178
- `CONNECTION_ID()`: Return the current connection ID.
7279
- `SOUNDEX(str)`: Returns the soundex of a string.
80+
- `JSON_EXTRACT(json_doc, path, ...)`: Extracts data from a json document using json paths.
81+
- `LN(X)`: Return the natural logarithm of X.
82+
- `LOG2(X)`: Returns the base-2 logarithm of X.
83+
- `LOG10(X)`: Returns the base-10 logarithm of X.
84+
- `LOG(X), LOG(B, X)`: If called with one parameter, this function returns the natural logarithm of X. If called with two parameters, this function returns the logarithm of X to the base B. If X is less than or equal to 0, or if B is less than or equal to 1, then NULL is returned.
85+
- `RPAD(str, len, padstr)`: Returns the string str, right-padded with the string padstr to a length of len characters.
86+
- `LPAD(str, len, padstr)`: Return the string argument, left-padded with the specified string.
87+
- `SQRT(X)`: Returns the square root of a nonnegative number X.
88+
- `POW(X, Y)`, `POWER(X, Y)`: Returns the value of X raised to the power of Y.
89+
- `TRIM(str)`: Returns the string str with all spaces removed.
90+
- `LTRIM(str)`: Returns the string str with leading space characters removed.
91+
- `RTRIM(str)`: Returns the string str with trailing space characters removed.
92+
- `REVERSE(str)`: Returns the string str with the order of the characters reversed.
93+
- `REPEAT(str, count)`: Returns a string consisting of the string str repeated count times.
94+
- `REPLACE(str,from_str,to_str)`: Returns the string str with all occurrences of the string from_str replaced by the string to_str.
95+
- `IFNULL(expr1, expr2)`: If expr1 is not NULL, IFNULL() returns expr1; otherwise it returns expr2.
96+
- `NULLIF(expr1, expr2)`: Returns NULL if expr1 = expr2 is true, otherwise returns expr1.
7397

7498
## Example
7599

engine_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,72 @@ var queries = []struct {
700700
{"tabletest", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil},
701701
},
702702
},
703+
{
704+
`SELECT NULL`,
705+
[]sql.Row{
706+
{nil},
707+
},
708+
},
709+
{
710+
`SELECT nullif('abc', NULL)`,
711+
[]sql.Row{
712+
{"abc"},
713+
},
714+
},
715+
{
716+
`SELECT nullif(NULL, NULL)`,
717+
[]sql.Row{
718+
{sql.Null},
719+
},
720+
},
721+
{
722+
`SELECT nullif(NULL, 123)`,
723+
[]sql.Row{
724+
{nil},
725+
},
726+
},
727+
{
728+
`SELECT nullif(123, 123)`,
729+
[]sql.Row{
730+
{sql.Null},
731+
},
732+
},
733+
{
734+
`SELECT nullif(123, 321)`,
735+
[]sql.Row{
736+
{int64(123)},
737+
},
738+
},
739+
{
740+
`SELECT ifnull(123, NULL)`,
741+
[]sql.Row{
742+
{int64(123)},
743+
},
744+
},
745+
{
746+
`SELECT ifnull(NULL, NULL)`,
747+
[]sql.Row{
748+
{nil},
749+
},
750+
},
751+
{
752+
`SELECT ifnull(NULL, 123)`,
753+
[]sql.Row{
754+
{int64(123)},
755+
},
756+
},
757+
{
758+
`SELECT ifnull(123, 123)`,
759+
[]sql.Row{
760+
{int64(123)},
761+
},
762+
},
763+
{
764+
`SELECT ifnull(123, 321)`,
765+
[]sql.Row{
766+
{int64(123)},
767+
},
768+
},
703769
}
704770

705771
func TestQueries(t *testing.T) {

sql/expression/function/ifnull.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package function
2+
3+
import (
4+
"fmt"
5+
6+
"gopkg.in/src-d/go-mysql-server.v0/sql"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
8+
)
9+
10+
// IfNull function returns the specified value IF the expression is NULL, otherwise return the expression.
11+
type IfNull struct {
12+
expression.BinaryExpression
13+
}
14+
15+
// NewIfNull returns a new IFNULL UDF
16+
func NewIfNull(ex, value sql.Expression) sql.Expression {
17+
return &IfNull{
18+
expression.BinaryExpression{
19+
Left: ex,
20+
Right: value,
21+
},
22+
}
23+
}
24+
25+
// Eval implements the Expression interface.
26+
func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
27+
left, err := f.Left.Eval(ctx, row)
28+
if err != nil {
29+
return nil, err
30+
}
31+
if left != nil {
32+
return left, nil
33+
}
34+
35+
right, err := f.Right.Eval(ctx, row)
36+
if err != nil {
37+
return nil, err
38+
}
39+
return right, nil
40+
}
41+
42+
// Type implements the Expression interface.
43+
func (f *IfNull) Type() sql.Type {
44+
if sql.IsNull(f.Left) {
45+
if sql.IsNull(f.Right) {
46+
return sql.Null
47+
}
48+
return f.Right.Type()
49+
}
50+
return f.Left.Type()
51+
}
52+
53+
// IsNullable implements the Expression interface.
54+
func (f *IfNull) IsNullable() bool {
55+
if sql.IsNull(f.Left) {
56+
if sql.IsNull(f.Right) {
57+
return true
58+
}
59+
return f.Right.IsNullable()
60+
}
61+
return f.Left.IsNullable()
62+
}
63+
64+
func (f *IfNull) String() string {
65+
return fmt.Sprintf("ifnull(%s, %s)", f.Left, f.Right)
66+
}
67+
68+
// TransformUp implements the Expression interface.
69+
func (f *IfNull) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
70+
left, err := f.Left.TransformUp(fn)
71+
if err != nil {
72+
return nil, err
73+
}
74+
75+
right, err := f.Right.TransformUp(fn)
76+
if err != nil {
77+
return nil, err
78+
}
79+
80+
return fn(NewIfNull(left, right))
81+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package function
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
9+
)
10+
11+
func TestIfNull(t *testing.T) {
12+
testCases := []struct {
13+
expression interface{}
14+
value interface{}
15+
expected interface{}
16+
}{
17+
{"foo", "bar", "foo"},
18+
{"foo", "foo", "foo"},
19+
{nil, "foo", "foo"},
20+
{"foo", nil, "foo"},
21+
{nil, nil, nil},
22+
{"", nil, ""},
23+
}
24+
25+
f := NewIfNull(
26+
expression.NewGetField(0, sql.Text, "expression", true),
27+
expression.NewGetField(1, sql.Text, "value", true),
28+
)
29+
require.Equal(t, sql.Text, f.Type())
30+
31+
for _, tc := range testCases {
32+
v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.expression, tc.value))
33+
require.NoError(t, err)
34+
require.Equal(t, tc.expected, v)
35+
}
36+
}

sql/expression/function/nullif.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package function
2+
3+
import (
4+
"fmt"
5+
6+
"gopkg.in/src-d/go-mysql-server.v0/sql"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
8+
)
9+
10+
// NullIf function compares two expressions and returns NULL if they are equal. Otherwise, the first expression is returned.
11+
type NullIf struct {
12+
expression.BinaryExpression
13+
}
14+
15+
// NewNullIf returns a new NULLIF UDF
16+
func NewNullIf(ex1, ex2 sql.Expression) sql.Expression {
17+
return &NullIf{
18+
expression.BinaryExpression{
19+
Left: ex1,
20+
Right: ex2,
21+
},
22+
}
23+
}
24+
25+
// Eval implements the Expression interface.
26+
func (f *NullIf) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
27+
if sql.IsNull(f.Left) && sql.IsNull(f.Right) {
28+
return sql.Null, nil
29+
}
30+
31+
val, err := expression.NewEquals(f.Left, f.Right).Eval(ctx, row)
32+
if err != nil {
33+
return nil, err
34+
}
35+
if b, ok := val.(bool); ok && b {
36+
return sql.Null, nil
37+
}
38+
39+
return f.Left.Eval(ctx, row)
40+
}
41+
42+
// Type implements the Expression interface.
43+
func (f *NullIf) Type() sql.Type {
44+
if sql.IsNull(f.Left) {
45+
return sql.Null
46+
}
47+
48+
return f.Left.Type()
49+
}
50+
51+
// IsNullable implements the Expression interface.
52+
func (f *NullIf) IsNullable() bool {
53+
return true
54+
}
55+
56+
func (f *NullIf) String() string {
57+
return fmt.Sprintf("nullif(%s, %s)", f.Left, f.Right)
58+
}
59+
60+
// TransformUp implements the Expression interface.
61+
func (f *NullIf) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
62+
left, err := f.Left.TransformUp(fn)
63+
if err != nil {
64+
return nil, err
65+
}
66+
67+
right, err := f.Right.TransformUp(fn)
68+
if err != nil {
69+
return nil, err
70+
}
71+
72+
return fn(NewNullIf(left, right))
73+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package function
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
9+
)
10+
11+
func TestNullIf(t *testing.T) {
12+
testCases := []struct {
13+
ex1 interface{}
14+
ex2 interface{}
15+
expected interface{}
16+
}{
17+
{"foo", "bar", "foo"},
18+
{"foo", "foo", sql.Null},
19+
{nil, "foo", nil},
20+
{"foo", nil, "foo"},
21+
{nil, nil, nil},
22+
{"", nil, ""},
23+
}
24+
25+
f := NewNullIf(
26+
expression.NewGetField(0, sql.Text, "ex1", true),
27+
expression.NewGetField(1, sql.Text, "ex2", true),
28+
)
29+
require.Equal(t, sql.Text, f.Type())
30+
31+
for _, tc := range testCases {
32+
v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.ex1, tc.ex2))
33+
require.NoError(t, err)
34+
require.Equal(t, tc.expected, v)
35+
}
36+
}

sql/expression/function/registry.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,16 @@ var Defaults = sql.Functions{
4141
"split": sql.Function2(NewSplit),
4242
"concat": sql.FunctionN(NewConcat),
4343
"concat_ws": sql.FunctionN(NewConcatWithSeparator),
44+
"coalesce": sql.FunctionN(NewCoalesce),
4445
"lower": sql.Function1(NewLower),
4546
"upper": sql.Function1(NewUpper),
4647
"ceiling": sql.Function1(NewCeil),
4748
"ceil": sql.Function1(NewCeil),
4849
"floor": sql.Function1(NewFloor),
4950
"round": sql.FunctionN(NewRound),
50-
"coalesce": sql.FunctionN(NewCoalesce),
51-
"json_extract": sql.FunctionN(NewJSONExtract),
5251
"connection_id": sql.Function0(NewConnectionID),
5352
"soundex": sql.Function1(NewSoundex),
53+
"json_extract": sql.FunctionN(NewJSONExtract),
5454
"ln": sql.Function1(NewLogBaseFunc(float64(math.E))),
5555
"log2": sql.Function1(NewLogBaseFunc(float64(2))),
5656
"log10": sql.Function1(NewLogBaseFunc(float64(10))),
@@ -66,4 +66,6 @@ var Defaults = sql.Functions{
6666
"reverse": sql.Function1(NewReverse),
6767
"repeat": sql.Function2(NewRepeat),
6868
"replace": sql.Function3(NewReplace),
69+
"ifnull": sql.Function2(NewIfNull),
70+
"nullif": sql.Function2(NewNullIf),
6971
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy