From a9f27702e9ab8b76bfd2a4e70db67d7e52e1fc41 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Wed, 9 Oct 2019 11:55:50 +0200 Subject: [PATCH 01/44] fix missing module in node integration test Signed-off-by: Miguel Molina --- _integration/javascript/package-lock.json | 27 +++++++++++++---------- _integration/javascript/package.json | 2 +- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/_integration/javascript/package-lock.json b/_integration/javascript/package-lock.json index 41c153513..b37fc5abe 100644 --- a/_integration/javascript/package-lock.json +++ b/_integration/javascript/package-lock.json @@ -658,11 +658,6 @@ "integrity": "sha1-ibTRmasr7kneFk6gK4nORi1xt2c=", "dev": true }, - "bignumber.js": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/bignumber.js/-/bignumber.js-4.1.0.tgz", - "integrity": "sha512-eJzYkFYy9L4JzXsbymsFn3p54D+llV27oTQ+ziJG7WFRheJcNZilgVXMG0LoZtlQSKBsJdWtLFqOD0u+U0jZKA==" - }, "binary-extensions": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.0.0.tgz", @@ -2266,13 +2261,21 @@ "dev": true }, "mysql": { - "version": "github:mysqljs/mysql#cf5d1e396a343ffa0fba23b3791d2dd5da20f30e", - "from": "github:mysqljs/mysql", + "version": "2.17.1", + "resolved": "https://registry.npmjs.org/mysql/-/mysql-2.17.1.tgz", + "integrity": "sha512-7vMqHQ673SAk5C8fOzTG2LpPcf3bNt0oL3sFpxPEEFp1mdlDcrLK0On7z8ZYKaaHrHwNcQ/MTUz7/oobZ2OyyA==", "requires": { - "bignumber.js": "9.0.0", + "bignumber.js": "7.2.1", "readable-stream": "2.3.6", "safe-buffer": "5.1.2", "sqlstring": "2.3.1" + }, + "dependencies": { + "bignumber.js": { + "version": "7.2.1", + "resolved": "https://registry.npmjs.org/bignumber.js/-/bignumber.js-7.2.1.tgz", + "integrity": "sha512-S4XzBk5sMB+Rcb/LNcpzXr57VRTxgAvaAEDAl1AwRx27j00hT84O6OkteE7u8UB3NuaaygCRrEpqox4uDOrbdQ==" + } } }, "normalize-package-data": { @@ -2674,9 +2677,9 @@ } }, "process-nextick-args": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.0.tgz", - "integrity": "sha512-MtEC1TqN0EU5nephaJ4rAtThHtC86dNN9qCuEhtshvpVBkAW5ZO7BASN9REnF9eoXGcRub+pFuKEpOHE+HbEMw==" + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.1.tgz", + "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==" }, "pseudomap": { "version": "1.0.2", @@ -2803,7 +2806,7 @@ }, "readable-stream": { "version": "2.3.6", - "resolved": "http://registry.npmjs.org/readable-stream/-/readable-stream-2.3.6.tgz", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.6.tgz", "integrity": "sha512-tQtKA9WIAhBF3+VLAseyMqZeBjW0AHJoxOtYqSUZNJxauErmLbVm2FW1y+J/YA9dUrAC39ITejlZWhVIwawkKw==", "requires": { "core-util-is": "~1.0.0", diff --git a/_integration/javascript/package.json b/_integration/javascript/package.json index eff42abd3..da6670c88 100644 --- a/_integration/javascript/package.json +++ b/_integration/javascript/package.json @@ -7,7 +7,7 @@ "test": "./node_modules/.bin/ava" }, "dependencies": { - "mysql": "github:mysqljs/mysql" + "mysql": "2.17.1" }, "devDependencies": { "ava": "2.3.0" From e1a798803618f6db48af7a18a52bff65cef5b7e0 Mon Sep 17 00:00:00 2001 From: Antonio Navarro Perez Date: Fri, 11 Oct 2019 11:09:51 +0200 Subject: [PATCH 02/44] Modify MAINTAINERS and added CODEWONERS Signed-off-by: Antonio Navarro Perez --- CODEOWNERS | 1 + MAINTAINERS | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 CODEOWNERS diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 000000000..f48cffb3a --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @src-d/data-processing diff --git a/MAINTAINERS b/MAINTAINERS index dc6817245..8d8de2618 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -1 +1,2 @@ -Antonio Navarro Perez (@ajnavarro) +Miguel Molina (@erizocosmico) +Juanjo Álvarez Martinez (@juanjux) From 258a7352979bebb3f624cdc594f3d27ecf37614e Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Tue, 8 Oct 2019 17:03:00 +0200 Subject: [PATCH 03/44] *: implement subquery expressions Signed-off-by: Miguel Molina --- SUPPORTED.md | 6 +- engine.go | 2 +- engine_test.go | 30 +- log.go | 2 +- server/server.go | 2 +- sql/analyzer/analyzer.go | 2 +- sql/analyzer/assign_indexes.go | 14 +- sql/analyzer/resolve_subqueries.go | 24 +- sql/analyzer/validation_rules.go | 27 ++ sql/analyzer/validation_rules_test.go | 31 ++ sql/core.go | 2 +- sql/expression/comparison.go | 58 +++- sql/expression/comparison_test.go | 312 +++++++++++++++++---- sql/expression/doc.go | 1 - sql/expression/function/aggregation/avg.go | 2 +- sql/expression/function/arraylength.go | 2 +- sql/expression/subquery.go | 125 +++++++++ sql/expression/subquery_test.go | 69 +++++ sql/expression/transform.go | 4 +- sql/information_schema.go | 2 +- sql/parse/indexes.go | 4 +- sql/parse/indexes_test.go | 2 +- sql/parse/parse.go | 186 ++++++------ sql/parse/parse_test.go | 11 +- sql/parse/util.go | 4 +- sql/plan/common.go | 2 +- 26 files changed, 744 insertions(+), 182 deletions(-) delete mode 100644 sql/expression/doc.go create mode 100644 sql/expression/subquery.go create mode 100644 sql/expression/subquery_test.go diff --git a/SUPPORTED.md b/SUPPORTED.md index a316dffe1..6b89188c0 100644 --- a/SUPPORTED.md +++ b/SUPPORTED.md @@ -80,9 +80,6 @@ - div - % -## Subqueries -- supported only as tables, not as expressions. - ## Functions - ARRAY_LENGTH - CEIL @@ -133,3 +130,6 @@ - WEEKDAY - YEAR - YEARWEEK + +## Subqueries +Supported both as a table and as expressions but they can't access the parent query scope. diff --git a/engine.go b/engine.go index ae11ef701..4cb90c3ad 100644 --- a/engine.go +++ b/engine.go @@ -1,4 +1,4 @@ -package sqle // import "github.com/src-d/go-mysql-server" +package sqle import ( "time" diff --git a/engine_test.go b/engine_test.go index dce9066d6..3ce2c8f81 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1548,6 +1548,28 @@ var queries = []struct { {int64(5), "there is some text in here"}, }, }, + { + `SELECT i FROM mytable WHERE i = (SELECT 1)`, + []sql.Row{{int64(1)}}, + }, + { + `SELECT i FROM mytable WHERE i IN (SELECT i FROM mytable)`, + []sql.Row{ + {int64(1)}, + {int64(2)}, + {int64(3)}, + }, + }, + { + `SELECT i FROM mytable WHERE i NOT IN (SELECT i FROM mytable ORDER BY i ASC LIMIT 2)`, + []sql.Row{ + {int64(3)}, + }, + }, + { + `SELECT (SELECT i FROM mytable ORDER BY i ASC LIMIT 1) AS x`, + []sql.Row{{int64(1)}}, + }, } func TestQueries(t *testing.T) { @@ -1901,7 +1923,7 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1), + int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), int64(0), int64(0), int64(0), int64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), @@ -1919,7 +1941,7 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1), + int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), int64(0), int64(0), int64(0), int64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), @@ -2101,7 +2123,7 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1), + int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), int64(0), int64(0), int64(0), int64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), @@ -2119,7 +2141,7 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1), + int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), int64(0), int64(0), int64(0), int64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), diff --git a/log.go b/log.go index 4880acebf..346b545f7 100644 --- a/log.go +++ b/log.go @@ -1,4 +1,4 @@ -package sqle // import "github.com/src-d/go-mysql-server" +package sqle import ( "github.com/golang/glog" diff --git a/server/server.go b/server/server.go index b2bc8924c..15fdb45c1 100644 --- a/server/server.go +++ b/server/server.go @@ -1,4 +1,4 @@ -package server // import "github.com/src-d/go-mysql-server/server" +package server import ( "time" diff --git a/sql/analyzer/analyzer.go b/sql/analyzer/analyzer.go index 1066d6138..f7bc456d5 100644 --- a/sql/analyzer/analyzer.go +++ b/sql/analyzer/analyzer.go @@ -1,4 +1,4 @@ -package analyzer // import "github.com/src-d/go-mysql-server/sql/analyzer" +package analyzer import ( "os" diff --git a/sql/analyzer/assign_indexes.go b/sql/analyzer/assign_indexes.go index 4462cb915..b6bcb1327 100644 --- a/sql/analyzer/assign_indexes.go +++ b/sql/analyzer/assign_indexes.go @@ -759,8 +759,20 @@ func containsColumns(e sql.Expression) bool { return result } +func containsSubquery(e sql.Expression) bool { + var result bool + expression.Inspect(e, func(e sql.Expression) bool { + if _, ok := e.(*expression.Subquery); ok { + result = true + return false + } + return true + }) + return result +} + func isEvaluable(e sql.Expression) bool { - return !containsColumns(e) + return !containsColumns(e) && !containsSubquery(e) } func canMergeIndexes(a, b sql.IndexLookup) bool { diff --git a/sql/analyzer/resolve_subqueries.go b/sql/analyzer/resolve_subqueries.go index 5015253c1..20b97df3f 100644 --- a/sql/analyzer/resolve_subqueries.go +++ b/sql/analyzer/resolve_subqueries.go @@ -2,6 +2,7 @@ package analyzer import ( "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" ) @@ -10,7 +11,7 @@ func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err defer span.Finish() a.Log("resolving subqueries") - return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + n, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.SubqueryAlias: a.Log("found subquery %q with child of type %T", n.Name(), n.Child) @@ -24,4 +25,25 @@ func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err return n, nil } }) + if err != nil { + return nil, err + } + + return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) { + s, ok := e.(*expression.Subquery) + if !ok || s.Resolved() { + return e, nil + } + + q, err := a.Analyze(ctx, s.Query) + if err != nil { + return nil, err + } + + if qp, ok := q.(*plan.QueryProcess); ok { + q = qp.Child + } + + return s.WithQuery(q), nil + }) } diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index e7448d9e2..7d35ee6be 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -20,6 +20,7 @@ const ( validateCaseResultTypesRule = "validate_case_result_types" validateIntervalUsageRule = "validate_interval_usage" validateExplodeUsageRule = "validate_explode_usage" + validateSubqueryColumnsRule = "validate_subquery_columns" ) var ( @@ -57,6 +58,12 @@ var ( ErrExplodeInvalidUse = errors.NewKind( "using EXPLODE is not supported outside a Project node", ) + + // ErrSubqueryColumns is returned when an expression subquery returns + // more than a single column. + ErrSubqueryColumns = errors.NewKind( + "subquery expressions can only return a single column", + ) ) // DefaultValidationRules to apply while analyzing nodes. @@ -70,6 +77,7 @@ var DefaultValidationRules = []Rule{ {validateCaseResultTypesRule, validateCaseResultTypes}, {validateIntervalUsageRule, validateIntervalUsage}, {validateExplodeUsageRule, validateExplodeUsage}, + {validateSubqueryColumnsRule, validateSubqueryColumns}, } func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { @@ -322,6 +330,25 @@ func validateExplodeUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, return n, nil } +func validateSubqueryColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + valid := true + plan.InspectExpressions(n, func(e sql.Expression) bool { + s, ok := e.(*expression.Subquery) + if ok && len(s.Query.Schema()) != 1 { + valid = false + return false + } + + return true + }) + + if !valid { + return nil, ErrSubqueryColumns.New() + } + + return n, nil +} + func stringContains(strs []string, target string) bool { for _, s := range strs { if s == target { diff --git a/sql/analyzer/validation_rules_test.go b/sql/analyzer/validation_rules_test.go index 1af5ad53a..543fbc688 100644 --- a/sql/analyzer/validation_rules_test.go +++ b/sql/analyzer/validation_rules_test.go @@ -674,6 +674,37 @@ func TestValidateExplodeUsage(t *testing.T) { } } +func TestValidateSubqueryColumns(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + node := plan.NewProject([]sql.Expression{ + expression.NewSubquery(plan.NewProject( + []sql.Expression{ + lit(1), + lit(2), + }, + dummyNode{true}, + )), + }, dummyNode{true}) + + _, err := validateSubqueryColumns(ctx, nil, node) + require.Error(err) + require.True(ErrSubqueryColumns.Is(err)) + + node = plan.NewProject([]sql.Expression{ + expression.NewSubquery(plan.NewProject( + []sql.Expression{ + lit(1), + }, + dummyNode{true}, + )), + }, dummyNode{true}) + + _, err = validateSubqueryColumns(ctx, nil, node) + require.NoError(err) +} + type dummyNode struct{ resolved bool } func (n dummyNode) String() string { return "dummynode" } diff --git a/sql/core.go b/sql/core.go index a45a1d69f..e91ce0b67 100644 --- a/sql/core.go +++ b/sql/core.go @@ -1,4 +1,4 @@ -package sql // import "github.com/src-d/go-mysql-server/sql" +package sql import ( "fmt" diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 8491f9cef..87601579b 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -466,7 +466,6 @@ func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - // TODO: support subqueries switch right := in.Right().(type) { case Tuple: for _, el := range right { @@ -496,6 +495,34 @@ func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } + return false, nil + case *Subquery: + if leftElems > 1 { + return nil, ErrInvalidOperandColumns.New(leftElems, 1) + } + + typ := right.Type() + values, err := right.EvalMultiple(ctx) + if err != nil { + return nil, err + } + + for _, val := range values { + val, err = typ.Convert(val) + if err != nil { + return nil, err + } + + cmp, err := typ.Compare(left, val) + if err != nil { + return nil, err + } + + if cmp == 0 { + return true, nil + } + } + return false, nil default: return nil, ErrUnsupportedInOperand.New(right) @@ -547,7 +574,6 @@ func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - // TODO: support subqueries switch right := in.Right().(type) { case Tuple: for _, el := range right { @@ -577,6 +603,34 @@ func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } + return true, nil + case *Subquery: + if leftElems > 1 { + return nil, ErrInvalidOperandColumns.New(leftElems, 1) + } + + typ := right.Type() + values, err := right.EvalMultiple(ctx) + if err != nil { + return nil, err + } + + for _, val := range values { + val, err = typ.Convert(val) + if err != nil { + return nil, err + } + + cmp, err := typ.Compare(left, val) + if err != nil { + return nil, err + } + + if cmp == 0 { + return false, nil + } + } + return true, nil default: return nil, ErrUnsupportedInOperand.New(right) diff --git a/sql/expression/comparison_test.go b/sql/expression/comparison_test.go index c69d1fdc5..802c3676a 100644 --- a/sql/expression/comparison_test.go +++ b/sql/expression/comparison_test.go @@ -1,10 +1,13 @@ -package expression +package expression_test import ( "testing" "github.com/src-d/go-mysql-server/internal/regex" + "github.com/src-d/go-mysql-server/memory" "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" errors "gopkg.in/src-d/go-errors.v1" "github.com/stretchr/testify/require" @@ -95,11 +98,11 @@ var likeComparisonCases = map[sql.Type]map[int][][]interface{}{ func TestEquals(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := NewGetField(0, resultType, "col1", true) + get0 := expression.NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := NewGetField(1, resultType, "col2", true) + get1 := expression.NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := NewEquals(get0, get1) + eq := expression.NewEquals(get0, get1) require.NotNil(eq) require.Equal(sql.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -122,11 +125,11 @@ func TestEquals(t *testing.T) { func TestLessThan(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := NewGetField(0, resultType, "col1", true) + get0 := expression.NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := NewGetField(1, resultType, "col2", true) + get1 := expression.NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := NewLessThan(get0, get1) + eq := expression.NewLessThan(get0, get1) require.NotNil(eq) require.Equal(sql.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -149,11 +152,11 @@ func TestLessThan(t *testing.T) { func TestGreaterThan(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := NewGetField(0, resultType, "col1", true) + get0 := expression.NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := NewGetField(1, resultType, "col2", true) + get1 := expression.NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := NewGreaterThan(get0, get1) + eq := expression.NewGreaterThan(get0, get1) require.NotNil(eq) require.Equal(sql.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -185,13 +188,13 @@ func testRegexpCases(t *testing.T) { require := require.New(t) for resultType, cmpCase := range likeComparisonCases { - get0 := NewGetField(0, resultType, "col1", true) + get0 := expression.NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := NewGetField(1, resultType, "col2", true) + get1 := expression.NewGetField(1, resultType, "col2", true) require.NotNil(get1) for cmpResult, cases := range cmpCase { for _, pair := range cases { - eq := NewRegexp(get0, get1) + eq := expression.NewRegexp(get0, get1) require.NotNil(eq) require.Equal(sql.Boolean, eq.Type()) @@ -214,9 +217,9 @@ func TestInvalidRegexp(t *testing.T) { t.Helper() require := require.New(t) - col1 := NewGetField(0, sql.Text, "col1", true) - invalid := NewLiteral("*col1", sql.Text) - r := NewRegexp(col1, invalid) + col1 := expression.NewGetField(0, sql.Text, "col1", true) + invalid := expression.NewLiteral("*col1", sql.Text) + r := expression.NewRegexp(col1, invalid) row := sql.NewRow("col1") _, err := r.Eval(sql.NewEmptyContext(), row) @@ -234,10 +237,10 @@ func TestIn(t *testing.T) { }{ { "left is nil", - NewLiteral(nil, sql.Null), - NewTuple( - NewLiteral(int64(1), sql.Int64), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(nil, sql.Null), + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), ), nil, nil, @@ -245,32 +248,32 @@ func TestIn(t *testing.T) { }, { "left and right don't have the same cols", - NewLiteral(1, sql.Int64), - NewTuple( - NewTuple( - NewLiteral(int64(1), sql.Int64), - NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(1, sql.Int64), + expression.NewTuple( + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), ), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), ), nil, nil, - ErrInvalidOperandColumns, + expression.ErrInvalidOperandColumns, }, { "right is an unsupported operand", - NewLiteral(1, sql.Int64), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(1, sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), nil, nil, - ErrUnsupportedInOperand, + expression.ErrUnsupportedInOperand, }, { "left is in right", - NewGetField(0, sql.Int64, "foo", false), - NewTuple( - NewGetField(0, sql.Int64, "foo", false), - NewLiteral(int64(2), sql.Int64), + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewTuple( + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewLiteral(int64(2), sql.Int64), ), sql.NewRow(int64(1)), true, @@ -278,10 +281,10 @@ func TestIn(t *testing.T) { }, { "left is not in right", - NewGetField(0, sql.Int64, "foo", false), - NewTuple( - NewGetField(1, sql.Int64, "bar", false), - NewLiteral(int64(2), sql.Int64), + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewTuple( + expression.NewGetField(1, sql.Int64, "bar", false), + expression.NewLiteral(int64(2), sql.Int64), ), sql.NewRow(int64(1), int64(3)), false, @@ -293,7 +296,96 @@ func TestIn(t *testing.T) { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - result, err := NewIn(tt.left, tt.right).Eval(sql.NewEmptyContext(), tt.row) + result, err := expression.NewIn(tt.left, tt.right). + Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.result, result) + } + }) + } +} + +func TestInSubquery(t *testing.T) { + ctx := sql.NewEmptyContext() + table := memory.NewTable("foo", sql.Schema{ + {Name: "t", Source: "foo", Type: sql.Text}, + }) + + require.NoError(t, table.Insert(ctx, sql.Row{"one"})) + require.NoError(t, table.Insert(ctx, sql.Row{"two"})) + require.NoError(t, table.Insert(ctx, sql.Row{"three"})) + + project := func(expr sql.Expression) sql.Node { + return plan.NewProject([]sql.Expression{ + expr, + }, plan.NewResolvedTable(table)) + } + + testCases := []struct { + name string + left sql.Expression + right sql.Node + row sql.Row + result interface{} + err *errors.Kind + }{ + { + "left is nil", + expression.NewLiteral(nil, sql.Null), + project( + expression.NewLiteral(int64(1), sql.Int64), + ), + nil, + nil, + nil, + }, + { + "left and right don't have the same cols", + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), + ), + project( + expression.NewLiteral(int64(2), sql.Int64), + ), + nil, + nil, + expression.ErrInvalidOperandColumns, + }, + { + "left is in right", + expression.NewGetField(0, sql.Text, "foo", false), + project( + expression.NewGetField(0, sql.Text, "foo", false), + ), + sql.NewRow("two"), + true, + nil, + }, + { + "left is not in right", + expression.NewGetField(0, sql.Text, "foo", false), + project( + expression.NewGetField(0, sql.Text, "foo", false), + ), + sql.NewRow("four"), + false, + nil, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + result, err := expression.NewIn( + tt.left, + expression.NewSubquery(tt.right), + ).Eval(sql.NewEmptyContext(), tt.row) if tt.err != nil { require.Error(err) require.True(tt.err.Is(err)) @@ -316,10 +408,10 @@ func TestNotIn(t *testing.T) { }{ { "left is nil", - NewLiteral(nil, sql.Null), - NewTuple( - NewLiteral(int64(1), sql.Int64), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(nil, sql.Null), + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), ), nil, nil, @@ -327,32 +419,32 @@ func TestNotIn(t *testing.T) { }, { "left and right don't have the same cols", - NewLiteral(1, sql.Int64), - NewTuple( - NewTuple( - NewLiteral(int64(1), sql.Int64), - NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(1, sql.Int64), + expression.NewTuple( + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), ), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), ), nil, nil, - ErrInvalidOperandColumns, + expression.ErrInvalidOperandColumns, }, { "right is an unsupported operand", - NewLiteral(1, sql.Int64), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(1, sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), nil, nil, - ErrUnsupportedInOperand, + expression.ErrUnsupportedInOperand, }, { "left is in right", - NewGetField(0, sql.Int64, "foo", false), - NewTuple( - NewGetField(0, sql.Int64, "foo", false), - NewLiteral(int64(2), sql.Int64), + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewTuple( + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewLiteral(int64(2), sql.Int64), ), sql.NewRow(int64(1)), false, @@ -360,10 +452,10 @@ func TestNotIn(t *testing.T) { }, { "left is not in right", - NewGetField(0, sql.Int64, "foo", false), - NewTuple( - NewGetField(1, sql.Int64, "bar", false), - NewLiteral(int64(2), sql.Int64), + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewTuple( + expression.NewGetField(1, sql.Int64, "bar", false), + expression.NewLiteral(int64(2), sql.Int64), ), sql.NewRow(int64(1), int64(3)), true, @@ -375,7 +467,8 @@ func TestNotIn(t *testing.T) { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - result, err := NewNotIn(tt.left, tt.right).Eval(sql.NewEmptyContext(), tt.row) + result, err := expression.NewNotIn(tt.left, tt.right). + Eval(sql.NewEmptyContext(), tt.row) if tt.err != nil { require.Error(err) require.True(tt.err.Is(err)) @@ -386,3 +479,98 @@ func TestNotIn(t *testing.T) { }) } } + +func TestNotInSubquery(t *testing.T) { + ctx := sql.NewEmptyContext() + table := memory.NewTable("foo", sql.Schema{ + {Name: "t", Source: "foo", Type: sql.Text}, + }) + + require.NoError(t, table.Insert(ctx, sql.Row{"one"})) + require.NoError(t, table.Insert(ctx, sql.Row{"two"})) + require.NoError(t, table.Insert(ctx, sql.Row{"three"})) + + project := func(expr sql.Expression) sql.Node { + return plan.NewProject([]sql.Expression{ + expr, + }, plan.NewResolvedTable(table)) + } + + testCases := []struct { + name string + left sql.Expression + right sql.Node + row sql.Row + result interface{} + err *errors.Kind + }{ + { + "left is nil", + expression.NewLiteral(nil, sql.Null), + project( + expression.NewLiteral(int64(1), sql.Int64), + ), + nil, + nil, + nil, + }, + { + "left and right don't have the same cols", + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), + ), + project( + expression.NewLiteral(int64(2), sql.Int64), + ), + nil, + nil, + expression.ErrInvalidOperandColumns, + }, + { + "left is in right", + expression.NewGetField(0, sql.Text, "foo", false), + project( + expression.NewGetField(0, sql.Text, "foo", false), + ), + sql.NewRow("two"), + false, + nil, + }, + { + "left is not in right", + expression.NewGetField(0, sql.Text, "foo", false), + project( + expression.NewGetField(0, sql.Text, "foo", false), + ), + sql.NewRow("four"), + true, + nil, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + result, err := expression.NewNotIn( + tt.left, + expression.NewSubquery(tt.right), + ).Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.result, result) + } + }) + } +} + +func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { + t.Helper() + v, err := e.Eval(sql.NewEmptyContext(), row) + require.NoError(t, err) + return v +} diff --git a/sql/expression/doc.go b/sql/expression/doc.go deleted file mode 100644 index 820316fe6..000000000 --- a/sql/expression/doc.go +++ /dev/null @@ -1 +0,0 @@ -package expression // import "github.com/src-d/go-mysql-server/sql/expression" diff --git a/sql/expression/function/aggregation/avg.go b/sql/expression/function/aggregation/avg.go index 5e6b2ff49..eae1c8c0e 100644 --- a/sql/expression/function/aggregation/avg.go +++ b/sql/expression/function/aggregation/avg.go @@ -1,4 +1,4 @@ -package aggregation // import "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" +package aggregation import ( "fmt" diff --git a/sql/expression/function/arraylength.go b/sql/expression/function/arraylength.go index 00d10cfd2..61a902cd9 100644 --- a/sql/expression/function/arraylength.go +++ b/sql/expression/function/arraylength.go @@ -1,4 +1,4 @@ -package function // import "github.com/src-d/go-mysql-server/sql/expression/function" +package function import ( "fmt" diff --git a/sql/expression/subquery.go b/sql/expression/subquery.go new file mode 100644 index 000000000..faae15aa1 --- /dev/null +++ b/sql/expression/subquery.go @@ -0,0 +1,125 @@ +package expression + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" +) + +var errExpectedSingleRow = errors.NewKind("the subquery returned more than 1 row") + +// Subquery that is executed as an expression. +type Subquery struct { + Query sql.Node + value interface{} +} + +// NewSubquery returns a new subquery node. +func NewSubquery(node sql.Node) *Subquery { + return &Subquery{node, nil} +} + +// Eval implements the Expression interface. +func (s *Subquery) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) { + if s.value != nil { + if elems, ok := s.value.([]interface{}); ok { + if len(elems) > 1 { + return nil, errExpectedSingleRow.New() + } + return elems[0], nil + } + return s.value, nil + } + + iter, err := s.Query.RowIter(ctx) + if err != nil { + return nil, err + } + + rows, err := sql.RowIterToRows(iter) + if err != nil { + return nil, err + } + + if len(rows) == 0 { + s.value = nil + return nil, nil + } + + if len(rows) > 1 { + return nil, errExpectedSingleRow.New() + } + + s.value = rows[0][0] + return s.value, nil +} + +// EvalMultiple returns all rows returned by a subquery. +func (s *Subquery) EvalMultiple(ctx *sql.Context) ([]interface{}, error) { + if s.value != nil { + return s.value.([]interface{}), nil + } + + iter, err := s.Query.RowIter(ctx) + if err != nil { + return nil, err + } + + rows, err := sql.RowIterToRows(iter) + if err != nil { + return nil, err + } + + if len(rows) == 0 { + s.value = []interface{}{} + return nil, nil + } + + var result = make([]interface{}, len(rows)) + for i, row := range rows { + result[i] = row[0] + } + s.value = result + + return result, nil +} + +// IsNullable implements the Expression interface. +func (s *Subquery) IsNullable() bool { + return s.Query.Schema()[0].Nullable +} + +func (s *Subquery) String() string { + return fmt.Sprintf("(%s)", s.Query) +} + +// Resolved implements the Expression interface. +func (s *Subquery) Resolved() bool { + return s.Query.Resolved() +} + +// Type implements the Expression interface. +func (s *Subquery) Type() sql.Type { + return s.Query.Schema()[0].Type +} + +// WithChildren implements the Expression interface. +func (s *Subquery) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } + return s, nil +} + +// Children implements the Expression interface. +func (s *Subquery) Children() []sql.Expression { + return nil +} + +// WithQuery returns the subquery with the query node changed. +func (s *Subquery) WithQuery(node sql.Node) *Subquery { + ns := *s + ns.Query = node + return &ns +} diff --git a/sql/expression/subquery_test.go b/sql/expression/subquery_test.go new file mode 100644 index 000000000..0d3f353be --- /dev/null +++ b/sql/expression/subquery_test.go @@ -0,0 +1,69 @@ +package expression_test + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestSubquery(t *testing.T) { + require := require.New(t) + table := memory.NewTable("", nil) + require.NoError(table.Insert(sql.NewEmptyContext(), nil)) + + subquery := expression.NewSubquery(plan.NewProject( + []sql.Expression{ + expression.NewLiteral("one", sql.Text), + }, + plan.NewResolvedTable(table), + )) + + value, err := subquery.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal(value, "one") +} + +func TestSubqueryTooManyRows(t *testing.T) { + require := require.New(t) + table := memory.NewTable("", nil) + require.NoError(table.Insert(sql.NewEmptyContext(), nil)) + require.NoError(table.Insert(sql.NewEmptyContext(), nil)) + + subquery := expression.NewSubquery(plan.NewProject( + []sql.Expression{ + expression.NewLiteral("one", sql.Text), + }, + plan.NewResolvedTable(table), + )) + + _, err := subquery.Eval(sql.NewEmptyContext(), nil) + require.Error(err) +} + +func TestSubqueryMultipleRows(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + table := memory.NewTable("foo", sql.Schema{ + {Name: "t", Source: "foo", Type: sql.Text}, + }) + + require.NoError(table.Insert(ctx, sql.Row{"one"})) + require.NoError(table.Insert(ctx, sql.Row{"two"})) + require.NoError(table.Insert(ctx, sql.Row{"three"})) + + subquery := expression.NewSubquery(plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Text, "t", false), + }, + plan.NewResolvedTable(table), + )) + + values, err := subquery.EvalMultiple(ctx) + require.NoError(err) + require.Equal(values, []interface{}{"one", "two", "three"}) +} diff --git a/sql/expression/transform.go b/sql/expression/transform.go index ccf3d4276..05195c7fd 100644 --- a/sql/expression/transform.go +++ b/sql/expression/transform.go @@ -1,6 +1,8 @@ package expression -import "github.com/src-d/go-mysql-server/sql" +import ( + "github.com/src-d/go-mysql-server/sql" +) // TransformUp applies a transformation function to the given expression from the // bottom up. diff --git a/sql/information_schema.go b/sql/information_schema.go index 0a304786a..d88804751 100644 --- a/sql/information_schema.go +++ b/sql/information_schema.go @@ -1,4 +1,4 @@ -package sql // import "github.com/src-d/go-mysql-server/sql" +package sql import ( "bytes" diff --git a/sql/parse/indexes.go b/sql/parse/indexes.go index 4a0ac4e0e..a435f7888 100644 --- a/sql/parse/indexes.go +++ b/sql/parse/indexes.go @@ -39,7 +39,7 @@ func parseShowIndex(s string) (sql.Node, error) { ), nil } -func parseCreateIndex(s string) (sql.Node, error) { +func parseCreateIndex(ctx *sql.Context, s string) (sql.Node, error) { r := bufio.NewReader(strings.NewReader(s)) var name, table, driver string @@ -78,7 +78,7 @@ func parseCreateIndex(s string) (sql.Node, error) { var indexExprs = make([]sql.Expression, len(exprs)) for i, e := range exprs { var err error - indexExprs[i], err = parseExpr(e) + indexExprs[i], err = parseExpr(ctx, e) if err != nil { return nil, err } diff --git a/sql/parse/indexes_test.go b/sql/parse/indexes_test.go index 4072b7a5d..a9d7df5a4 100644 --- a/sql/parse/indexes_test.go +++ b/sql/parse/indexes_test.go @@ -159,7 +159,7 @@ func TestParseCreateIndex(t *testing.T) { t.Run(tt.query, func(t *testing.T) { require := require.New(t) - result, err := parseCreateIndex(strings.ToLower(tt.query)) + result, err := parseCreateIndex(sql.NewEmptyContext(), strings.ToLower(tt.query)) if tt.err != nil { require.Error(err) require.True(tt.err.Is(err)) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 72a125ecd..14d5e3045 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -1,4 +1,4 @@ -package parse // import "github.com/src-d/go-mysql-server/sql/parse" +package parse import ( "bufio" @@ -26,9 +26,6 @@ var ( // ErrUnsupportedFeature is thrown when a feature is not already supported ErrUnsupportedFeature = errors.NewKind("unsupported feature: %s") - // ErrUnsupportedSubqueryExpression is thrown because subqueries are not supported, yet. - ErrUnsupportedSubqueryExpression = errors.NewKind("unsupported subquery expression") - // ErrInvalidSQLValType is returned when a SQLVal type is not valid. ErrInvalidSQLValType = errors.NewKind("invalid SQLVal of type: %d") @@ -73,7 +70,7 @@ func Parse(ctx *sql.Context, query string) (sql.Node, error) { case describeTablesRegex.MatchString(lowerQuery): return parseDescribeTables(lowerQuery) case createIndexRegex.MatchString(lowerQuery): - return parseCreateIndex(s) + return parseCreateIndex(ctx, s) case dropIndexRegex.MatchString(lowerQuery): return parseDropIndex(s) case showIndexRegex.MatchString(lowerQuery): @@ -85,7 +82,7 @@ func Parse(ctx *sql.Context, query string) (sql.Node, error) { case showWarningsRegex.MatchString(lowerQuery): return parseShowWarnings(ctx, s) case showCollationRegex.MatchString(lowerQuery): - return parseShowCollation(s) + return parseShowCollation(ctx, s) case describeRegex.MatchString(lowerQuery): return parseDescribeQuery(ctx, s) case fullProcessListRegex.MatchString(lowerQuery): @@ -135,7 +132,13 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node default: return nil, ErrUnsupportedSyntax.New(n) case *sqlparser.Show: - return convertShow(n, query) + // When a query is empty it means it comes from a subquery, as we don't + // have the query itself in a subquery. Hence, a SHOW could not be + // parsed. + if query == "" { + return nil, ErrUnsupportedFeature.New("SHOW in subquery") + } + return convertShow(ctx, n, query) case *sqlparser.Select: return convertSelect(ctx, n) case *sqlparser.Insert: @@ -165,7 +168,7 @@ func convertSet(ctx *sql.Context, n *sqlparser.Set) (sql.Node, error) { var variables = make([]plan.SetVariable, len(n.Exprs)) for i, e := range n.Exprs { - expr, err := exprToExpression(e.Expr) + expr, err := exprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -215,7 +218,7 @@ func convertSet(ctx *sql.Context, n *sqlparser.Set) (sql.Node, error) { return plan.NewSet(variables...), nil } -func convertShow(s *sqlparser.Show, query string) (sql.Node, error) { +func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, error) { switch s.Type { case sqlparser.KeywordString(sqlparser.TABLES): var dbName string @@ -228,7 +231,7 @@ func convertShow(s *sqlparser.Show, query string) (sql.Node, error) { if s.ShowTablesOpt.Filter != nil { if s.ShowTablesOpt.Filter.Filter != nil { var err error - filter, err = exprToExpression(s.ShowTablesOpt.Filter.Filter) + filter, err = exprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) if err != nil { return nil, err } @@ -270,7 +273,7 @@ func convertShow(s *sqlparser.Show, query string) (sql.Node, error) { } if s.ShowTablesOpt.Filter.Filter != nil { - filter, err := exprToExpression(s.ShowTablesOpt.Filter.Filter) + filter, err := exprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) if err != nil { return nil, err } @@ -281,7 +284,7 @@ func convertShow(s *sqlparser.Show, query string) (sql.Node, error) { return node, nil case sqlparser.KeywordString(sqlparser.TABLE): - return parseShowTableStatus(query) + return parseShowTableStatus(ctx, query) default: unsupportedShow := fmt.Sprintf("SHOW %s", s.Type) return nil, ErrUnsupportedFeature.New(unsupportedShow) @@ -295,19 +298,19 @@ func convertSelect(ctx *sql.Context, s *sqlparser.Select) (sql.Node, error) { } if s.Where != nil { - node, err = whereToFilter(s.Where, node) + node, err = whereToFilter(ctx, s.Where, node) if err != nil { return nil, err } } - node, err = selectToProjectOrGroupBy(s.SelectExprs, s.GroupBy, node) + node, err = selectToProjectOrGroupBy(ctx, s.SelectExprs, s.GroupBy, node) if err != nil { return nil, err } if s.Having != nil { - node, err = havingToHaving(s.Having, node) + node, err = havingToHaving(ctx, s.Having, node) if err != nil { return nil, err } @@ -318,7 +321,7 @@ func convertSelect(ctx *sql.Context, s *sqlparser.Select) (sql.Node, error) { } if len(s.OrderBy) != 0 { - node, err = orderByToSort(s.OrderBy, node) + node, err = orderByToSort(ctx, s.OrderBy, node) if err != nil { return nil, err } @@ -395,14 +398,14 @@ func convertDelete(ctx *sql.Context, d *sqlparser.Delete) (sql.Node, error) { } if d.Where != nil { - node, err = whereToFilter(d.Where, node) + node, err = whereToFilter(ctx, d.Where, node) if err != nil { return nil, err } } if len(d.OrderBy) != 0 { - node, err = orderByToSort(d.OrderBy, node) + node, err = orderByToSort(ctx, d.OrderBy, node) if err != nil { return nil, err } @@ -463,19 +466,19 @@ func insertRowsToNode(ctx *sql.Context, ir sqlparser.InsertRows) (sql.Node, erro case *sqlparser.Union: return nil, ErrUnsupportedFeature.New("UNION") case sqlparser.Values: - return valuesToValues(v) + return valuesToValues(ctx, v) default: return nil, ErrUnsupportedSyntax.New(ir) } } -func valuesToValues(v sqlparser.Values) (sql.Node, error) { +func valuesToValues(ctx *sql.Context, v sqlparser.Values) (sql.Node, error) { exprTuples := make([][]sql.Expression, len(v)) for i, vt := range v { exprs := make([]sql.Expression, len(vt)) exprTuples[i] = exprs for j, e := range vt { - expr, err := exprToExpression(e) + expr, err := exprToExpression(ctx, e) if err != nil { return nil, err } @@ -573,7 +576,7 @@ func tableExprToTable( return nil, ErrUnsupportedSyntax.New("missed ON clause for JOIN statement") } - cond, err := exprToExpression(t.Condition.On) + cond, err := exprToExpression(ctx, t.Condition.On) if err != nil { return nil, err } @@ -591,8 +594,8 @@ func tableExprToTable( } } -func whereToFilter(w *sqlparser.Where, child sql.Node) (*plan.Filter, error) { - c, err := exprToExpression(w.Expr) +func whereToFilter(ctx *sql.Context, w *sqlparser.Where, child sql.Node) (*plan.Filter, error) { + c, err := exprToExpression(ctx, w.Expr) if err != nil { return nil, err } @@ -600,10 +603,10 @@ func whereToFilter(w *sqlparser.Where, child sql.Node) (*plan.Filter, error) { return plan.NewFilter(c, child), nil } -func orderByToSort(ob sqlparser.OrderBy, child sql.Node) (*plan.Sort, error) { +func orderByToSort(ctx *sql.Context, ob sqlparser.OrderBy, child sql.Node) (*plan.Sort, error) { var sortFields []plan.SortField for _, o := range ob { - e, err := exprToExpression(o.Expr) + e, err := exprToExpression(ctx, o.Expr) if err != nil { return nil, err } @@ -642,8 +645,8 @@ func limitToLimit( return plan.NewLimit(rowCount, child), nil } -func havingToHaving(having *sqlparser.Where, node sql.Node) (sql.Node, error) { - cond, err := exprToExpression(having.Expr) +func havingToHaving(ctx *sql.Context, having *sqlparser.Where, node sql.Node) (sql.Node, error) { + cond, err := exprToExpression(ctx, having.Expr) if err != nil { return nil, err } @@ -670,8 +673,8 @@ func offsetToOffset( // getInt64Literal returns an int64 *expression.Literal for the value given, or an unsupported error with the string // given if the expression doesn't represent an integer literal. -func getInt64Literal(expr sqlparser.Expr, errStr string) (*expression.Literal, error) { - e, err := exprToExpression(expr) +func getInt64Literal(ctx *sql.Context, expr sqlparser.Expr, errStr string) (*expression.Literal, error) { + e, err := exprToExpression(ctx, expr) if err != nil { return nil, err } @@ -679,15 +682,15 @@ func getInt64Literal(expr sqlparser.Expr, errStr string) (*expression.Literal, e nl, ok := e.(*expression.Literal) if !ok || nl.Type() != sql.Int64 { return nil, ErrUnsupportedFeature.New(errStr) - } else { - return nl, nil } + + return nl, nil } // getInt64Value returns the int64 literal value in the expression given, or an error with the errStr given if it // cannot. func getInt64Value(ctx *sql.Context, expr sqlparser.Expr, errStr string) (int64, error) { - ie, err := getInt64Literal(expr, errStr) + ie, err := getInt64Literal(ctx, expr, errStr) if err != nil { return 0, err } @@ -715,8 +718,13 @@ func isAggregate(e sql.Expression) bool { return isAgg } -func selectToProjectOrGroupBy(se sqlparser.SelectExprs, g sqlparser.GroupBy, child sql.Node) (sql.Node, error) { - selectExprs, err := selectExprsToExpressions(se) +func selectToProjectOrGroupBy( + ctx *sql.Context, + se sqlparser.SelectExprs, + g sqlparser.GroupBy, + child sql.Node, +) (sql.Node, error) { + selectExprs, err := selectExprsToExpressions(ctx, se) if err != nil { return nil, err } @@ -732,7 +740,7 @@ func selectToProjectOrGroupBy(se sqlparser.SelectExprs, g sqlparser.GroupBy, chi } if isAgg { - groupingExprs, err := groupByToExpressions(g) + groupingExprs, err := groupByToExpressions(ctx, g) if err != nil { return nil, err } @@ -757,10 +765,10 @@ func selectToProjectOrGroupBy(se sqlparser.SelectExprs, g sqlparser.GroupBy, chi return plan.NewProject(selectExprs, child), nil } -func selectExprsToExpressions(se sqlparser.SelectExprs) ([]sql.Expression, error) { +func selectExprsToExpressions(ctx *sql.Context, se sqlparser.SelectExprs) ([]sql.Expression, error) { var exprs []sql.Expression for _, e := range se { - pe, err := selectExprToExpression(e) + pe, err := selectExprToExpression(ctx, e) if err != nil { return nil, err } @@ -771,7 +779,7 @@ func selectExprsToExpressions(se sqlparser.SelectExprs) ([]sql.Expression, error return exprs, nil } -func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { +func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error) { switch v := e.(type) { default: return nil, ErrUnsupportedSyntax.New(e) @@ -783,14 +791,14 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { err error ) if v.Name != nil { - name, err = exprToExpression(v.Name) + name, err = exprToExpression(ctx, v.Name) } else { - name, err = exprToExpression(v.StrVal) + name, err = exprToExpression(ctx, v.StrVal) } if err != nil { return nil, err } - from, err := exprToExpression(v.From) + from, err := exprToExpression(ctx, v.From) if err != nil { return nil, err } @@ -798,17 +806,17 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { if v.To == nil { return function.NewSubstring(name, from) } - to, err := exprToExpression(v.To) + to, err := exprToExpression(ctx, v.To) if err != nil { return nil, err } return function.NewSubstring(name, from, to) case *sqlparser.ComparisonExpr: - return comparisonExprToExpression(v) + return comparisonExprToExpression(ctx, v) case *sqlparser.IsExpr: - return isExprToExpression(v) + return isExprToExpression(ctx, v) case *sqlparser.NotExpr: - c, err := exprToExpression(v.Expr) + c, err := exprToExpression(ctx, v.Expr) if err != nil { return nil, err } @@ -829,7 +837,7 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { } return expression.NewUnresolvedColumn(v.Name.String()), nil case *sqlparser.FuncExpr: - exprs, err := selectExprsToExpressions(v.Exprs) + exprs, err := selectExprsToExpressions(ctx, v.Exprs) if err != nil { return nil, err } @@ -849,50 +857,50 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { return expression.NewUnresolvedFunction(v.Name.Lowered(), isAggregateFunc(v), exprs...), nil case *sqlparser.ParenExpr: - return exprToExpression(v.Expr) + return exprToExpression(ctx, v.Expr) case *sqlparser.AndExpr: - lhs, err := exprToExpression(v.Left) + lhs, err := exprToExpression(ctx, v.Left) if err != nil { return nil, err } - rhs, err := exprToExpression(v.Right) + rhs, err := exprToExpression(ctx, v.Right) if err != nil { return nil, err } return expression.NewAnd(lhs, rhs), nil case *sqlparser.OrExpr: - lhs, err := exprToExpression(v.Left) + lhs, err := exprToExpression(ctx, v.Left) if err != nil { return nil, err } - rhs, err := exprToExpression(v.Right) + rhs, err := exprToExpression(ctx, v.Right) if err != nil { return nil, err } return expression.NewOr(lhs, rhs), nil case *sqlparser.ConvertExpr: - expr, err := exprToExpression(v.Expr) + expr, err := exprToExpression(ctx, v.Expr) if err != nil { return nil, err } return expression.NewConvert(expr, v.Type.Type), nil case *sqlparser.RangeCond: - val, err := exprToExpression(v.Left) + val, err := exprToExpression(ctx, v.Left) if err != nil { return nil, err } - lower, err := exprToExpression(v.From) + lower, err := exprToExpression(ctx, v.From) if err != nil { return nil, err } - upper, err := exprToExpression(v.To) + upper, err := exprToExpression(ctx, v.To) if err != nil { return nil, err } @@ -908,7 +916,7 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { case sqlparser.ValTuple: var exprs = make([]sql.Expression, len(v)) for i, e := range v { - expr, err := exprToExpression(e) + expr, err := exprToExpression(ctx, e) if err != nil { return nil, err } @@ -917,15 +925,19 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { return expression.NewTuple(exprs...), nil case *sqlparser.BinaryExpr: - return binaryExprToExpression(v) + return binaryExprToExpression(ctx, v) case *sqlparser.UnaryExpr: - return unaryExprToExpression(v) + return unaryExprToExpression(ctx, v) case *sqlparser.Subquery: - return nil, ErrUnsupportedSubqueryExpression.New() + node, err := convert(ctx, v.Select, "") + if err != nil { + return nil, err + } + return expression.NewSubquery(node), nil case *sqlparser.CaseExpr: - return caseExprToExpression(v) + return caseExprToExpression(ctx, v) case *sqlparser.IntervalExpr: - return intervalExprToExpression(v) + return intervalExprToExpression(ctx, v) } } @@ -988,8 +1000,8 @@ func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { return nil, ErrInvalidSQLValType.New(v.Type) } -func isExprToExpression(c *sqlparser.IsExpr) (sql.Expression, error) { - e, err := exprToExpression(c.Expr) +func isExprToExpression(ctx *sql.Context, c *sqlparser.IsExpr) (sql.Expression, error) { + e, err := exprToExpression(ctx, c.Expr) if err != nil { return nil, err } @@ -1012,13 +1024,13 @@ func isExprToExpression(c *sqlparser.IsExpr) (sql.Expression, error) { } } -func comparisonExprToExpression(c *sqlparser.ComparisonExpr) (sql.Expression, error) { - left, err := exprToExpression(c.Left) +func comparisonExprToExpression(ctx *sql.Context, c *sqlparser.ComparisonExpr) (sql.Expression, error) { + left, err := exprToExpression(ctx, c.Left) if err != nil { return nil, err } - right, err := exprToExpression(c.Right) + right, err := exprToExpression(ctx, c.Right) if err != nil { return nil, err } @@ -1055,10 +1067,10 @@ func comparisonExprToExpression(c *sqlparser.ComparisonExpr) (sql.Expression, er } } -func groupByToExpressions(g sqlparser.GroupBy) ([]sql.Expression, error) { +func groupByToExpressions(ctx *sql.Context, g sqlparser.GroupBy) ([]sql.Expression, error) { es := make([]sql.Expression, len(g)) for i, ve := range g { - e, err := exprToExpression(ve) + e, err := exprToExpression(ctx, ve) if err != nil { return nil, err } @@ -1069,7 +1081,7 @@ func groupByToExpressions(g sqlparser.GroupBy) ([]sql.Expression, error) { return es, nil } -func selectExprToExpression(se sqlparser.SelectExpr) (sql.Expression, error) { +func selectExprToExpression(ctx *sql.Context, se sqlparser.SelectExpr) (sql.Expression, error) { switch e := se.(type) { default: return nil, ErrUnsupportedSyntax.New(e) @@ -1079,7 +1091,7 @@ func selectExprToExpression(se sqlparser.SelectExpr) (sql.Expression, error) { } return expression.NewQualifiedStar(e.TableName.Name.String()), nil case *sqlparser.AliasedExpr: - expr, err := exprToExpression(e.Expr) + expr, err := exprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -1093,10 +1105,10 @@ func selectExprToExpression(se sqlparser.SelectExpr) (sql.Expression, error) { } } -func unaryExprToExpression(e *sqlparser.UnaryExpr) (sql.Expression, error) { +func unaryExprToExpression(ctx *sql.Context, e *sqlparser.UnaryExpr) (sql.Expression, error) { switch e.Operator { case sqlparser.MinusStr: - expr, err := exprToExpression(e.Expr) + expr, err := exprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -1108,7 +1120,7 @@ func unaryExprToExpression(e *sqlparser.UnaryExpr) (sql.Expression, error) { } } -func binaryExprToExpression(be *sqlparser.BinaryExpr) (sql.Expression, error) { +func binaryExprToExpression(ctx *sql.Context, be *sqlparser.BinaryExpr) (sql.Expression, error) { switch be.Operator { case sqlparser.PlusStr, @@ -1123,12 +1135,12 @@ func binaryExprToExpression(be *sqlparser.BinaryExpr) (sql.Expression, error) { sqlparser.IntDivStr, sqlparser.ModStr: - l, err := exprToExpression(be.Left) + l, err := exprToExpression(ctx, be.Left) if err != nil { return nil, err } - r, err := exprToExpression(be.Right) + r, err := exprToExpression(ctx, be.Right) if err != nil { return nil, err } @@ -1150,12 +1162,12 @@ func binaryExprToExpression(be *sqlparser.BinaryExpr) (sql.Expression, error) { } } -func caseExprToExpression(e *sqlparser.CaseExpr) (sql.Expression, error) { +func caseExprToExpression(ctx *sql.Context, e *sqlparser.CaseExpr) (sql.Expression, error) { var expr sql.Expression var err error if e.Expr != nil { - expr, err = exprToExpression(e.Expr) + expr, err = exprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -1164,13 +1176,13 @@ func caseExprToExpression(e *sqlparser.CaseExpr) (sql.Expression, error) { var branches []expression.CaseBranch for _, w := range e.Whens { var cond sql.Expression - cond, err = exprToExpression(w.Cond) + cond, err = exprToExpression(ctx, w.Cond) if err != nil { return nil, err } var val sql.Expression - val, err = exprToExpression(w.Val) + val, err = exprToExpression(ctx, w.Val) if err != nil { return nil, err } @@ -1183,7 +1195,7 @@ func caseExprToExpression(e *sqlparser.CaseExpr) (sql.Expression, error) { var elseExpr sql.Expression if e.Else != nil { - elseExpr, err = exprToExpression(e.Else) + elseExpr, err = exprToExpression(ctx, e.Else) if err != nil { return nil, err } @@ -1192,8 +1204,8 @@ func caseExprToExpression(e *sqlparser.CaseExpr) (sql.Expression, error) { return expression.NewCase(expr, branches, elseExpr), nil } -func intervalExprToExpression(e *sqlparser.IntervalExpr) (sql.Expression, error) { - expr, err := exprToExpression(e.Expr) +func intervalExprToExpression(ctx *sql.Context, e *sqlparser.IntervalExpr) (sql.Expression, error) { + expr, err := exprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -1300,7 +1312,7 @@ func readString(r *bufio.Reader, single bool) []rune { return result } -func parseShowTableStatus(query string) (sql.Node, error) { +func parseShowTableStatus(ctx *sql.Context, query string) (sql.Node, error) { buf := bufio.NewReader(strings.NewReader(query)) err := parseFuncs{ expect("show"), @@ -1342,7 +1354,7 @@ func parseShowTableStatus(query string) (sql.Node, error) { return nil, err } - expr, err := parseExpr(string(bs)) + expr, err := parseExpr(ctx, string(bs)) if err != nil { return nil, err } @@ -1366,7 +1378,7 @@ func parseShowTableStatus(query string) (sql.Node, error) { } } -func parseShowCollation(query string) (sql.Node, error) { +func parseShowCollation(ctx *sql.Context, query string) (sql.Node, error) { buf := bufio.NewReader(strings.NewReader(query)) err := parseFuncs{ expect("show"), @@ -1399,7 +1411,7 @@ func parseShowCollation(query string) (sql.Node, error) { return nil, err } - expr, err := parseExpr(string(bs)) + expr, err := parseExpr(ctx, string(bs)) if err != nil { return nil, err } diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 6e74dd243..be176a9c6 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -1208,12 +1208,11 @@ func TestParse(t *testing.T) { } var fixturesErrors = map[string]*errors.Kind{ - `SHOW METHEMONEY`: ErrUnsupportedFeature, - `LOCK TABLES foo AS READ`: errUnexpectedSyntax, - `LOCK TABLES foo LOW_PRIORITY READ`: errUnexpectedSyntax, - `SELECT * FROM mytable WHERE i IN (SELECT i FROM foo)`: ErrUnsupportedSubqueryExpression, - `SELECT * FROM mytable LIMIT -100`: ErrUnsupportedSyntax, - `SELECT * FROM mytable LIMIT 100 OFFSET -1`: ErrUnsupportedSyntax, + `SHOW METHEMONEY`: ErrUnsupportedFeature, + `LOCK TABLES foo AS READ`: errUnexpectedSyntax, + `LOCK TABLES foo LOW_PRIORITY READ`: errUnexpectedSyntax, + `SELECT * FROM mytable LIMIT -100`: ErrUnsupportedSyntax, + `SELECT * FROM mytable LIMIT 100 OFFSET -1`: ErrUnsupportedSyntax, `SELECT * FROM files JOIN commit_files JOIN refs diff --git a/sql/parse/util.go b/sql/parse/util.go index a99e7a76c..bfb358b1d 100644 --- a/sql/parse/util.go +++ b/sql/parse/util.go @@ -251,7 +251,7 @@ func readRemaining(val *string) parseFunc { } } -func parseExpr(str string) (sql.Expression, error) { +func parseExpr(ctx *sql.Context, str string) (sql.Expression, error) { stmt, err := sqlparser.Parse("SELECT " + str) if err != nil { return nil, err @@ -271,7 +271,7 @@ func parseExpr(str string) (sql.Expression, error) { return nil, errInvalidIndexExpression.New(str) } - return exprToExpression(selectExpr.Expr) + return exprToExpression(ctx, selectExpr.Expr) } func readQuotableIdent(ident *string) parseFunc { diff --git a/sql/plan/common.go b/sql/plan/common.go index beec46177..28e0dd935 100644 --- a/sql/plan/common.go +++ b/sql/plan/common.go @@ -1,4 +1,4 @@ -package plan // import "github.com/src-d/go-mysql-server/sql/plan" +package plan import "github.com/src-d/go-mysql-server/sql" From 87546547b55b58dd7b31d66e5c985e8061882fd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Fri, 27 Sep 2019 15:37:11 +0200 Subject: [PATCH 04/44] Modify convertInt so it returns the smallest representation allowed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- sql/parse/parse.go | 57 +++++++++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 14d5e3045..b03627219 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -950,22 +950,51 @@ func isAggregateFunc(v *sqlparser.FuncExpr) bool { return v.IsAggregate() } +// Convert an integer, represented by the specified string in the specified +// base, to its smallest representation possible, out of: +// int8, uint8, int16, uint16, int32, uint32, int64 and uint64 +func convertInt(value string, base int) (sql.Expression, error) { + i8, err := strconv.ParseInt(value, base, 8) + if err != nil { + ui8, err := strconv.ParseUint(value, base, 8) + if err != nil { + i16, err := strconv.ParseInt(value, base, 16) + if err != nil { + ui16, err := strconv.ParseUint(value, base, 16) + if err != nil { + i32, err := strconv.ParseInt(value, base, 32) + if err != nil { + ui32, err := strconv.ParseUint(value, base, 32) + if err != nil { + i64, err := strconv.ParseInt(value, base, 64) + if err != nil { + ui64, err := strconv.ParseUint(value, base, 64) + if err != nil { + return nil, err + } + return expression.NewLiteral(uint64(ui64), sql.Uint64), nil + } + return expression.NewLiteral(int64(i64), sql.Int64), nil + } + return expression.NewLiteral(uint32(ui32), sql.Uint32), nil + } + return expression.NewLiteral(int32(i32), sql.Int32), nil + } + return expression.NewLiteral(uint16(ui16), sql.Uint16), nil + } + return expression.NewLiteral(int16(i16), sql.Int16), nil + } + return expression.NewLiteral(uint8(ui8), sql.Uint16), nil + } + return expression.NewLiteral(int8(i8), sql.Int8), nil +} + func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { switch v.Type { case sqlparser.StrVal: return expression.NewLiteral(string(v.Val), sql.Text), nil case sqlparser.IntVal: - //TODO: Use smallest integer representation and widen later. - val, err := strconv.ParseInt(string(v.Val), 10, 64) - if err != nil { - // Might be a uint64 value that is greater than int64 max - val, checkErr := strconv.ParseUint(string(v.Val), 10, 64) - if checkErr != nil { - return nil, err - } - return expression.NewLiteral(val, sql.Uint64), nil - } - return expression.NewLiteral(val, sql.Int64), nil + return convertInt(string(v.Val), 10) case sqlparser.FloatVal: val, err := strconv.ParseFloat(string(v.Val), 64) if err != nil { @@ -980,11 +1009,7 @@ func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { v = strings.Trim(v[1:], "'") } - val, err := strconv.ParseInt(v, 16, 64) - if err != nil { - return nil, err - } - return expression.NewLiteral(val, sql.Int64), nil + return convertInt(v, 16) case sqlparser.HexVal: val, err := v.HexDecode() if err != nil { From 09e535d74ce477dc23c87fbb497ce4175e22b669 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Wed, 9 Oct 2019 12:29:56 +0200 Subject: [PATCH 05/44] Add a rule to convert integer literals in INSERT MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - The integer literals in INSERT expressions are now converted to the type of their corresponding column in the schema. - sql.IsInteger has been fixed so that it recognizes smaller types. Signed-off-by: Alejandro García Montoro --- sql/analyzer/resolve_insert_literals.go | 82 +++++++++++++++++++++++++ sql/analyzer/rules.go | 1 + sql/type.go | 4 +- 3 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 sql/analyzer/resolve_insert_literals.go diff --git a/sql/analyzer/resolve_insert_literals.go b/sql/analyzer/resolve_insert_literals.go new file mode 100644 index 000000000..108980e63 --- /dev/null +++ b/sql/analyzer/resolve_insert_literals.go @@ -0,0 +1,82 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + errors "gopkg.in/src-d/go-errors.v1" + +) + +var errWrongNumberOfValues = errors.NewKind("the number of values to insert differ from the expected columns") + +func convertIntegerLiteralsInsert(ctx *sql.Context, analyzer *Analyzer, originalNode sql.Node) (sql.Node, error) { + span, _ := ctx.Span("resolve_insert_literals") + defer span.Finish() + + return plan.TransformUp(originalNode, func(node sql.Node) (sql.Node, error) { + + if node, ok := node.(*plan.InsertInto); ok { + resolvedTable, ok := node.BinaryNode.Left.(*plan.ResolvedTable) + if !ok { + return node, nil + } + + values := node.BinaryNode.Right.(*plan.Values) + if !ok { + return node, nil + } + + analyzer.Log("Transforming integer literals in INSERT node") + + schema := resolvedTable.Table.Schema() + + // If the InsertInto node does not have any explicit columns, + // we assume the values are in the same order as in the table schema + if len(node.Columns) == 0 { + node.Columns = make([]string, len(schema)) + for i, column := range schema { + node.Columns[i] = column.Name + } + } + + // Check that all tuples contain as many values as needed + numColumns := len(node.Columns) + for _, tuple := range values.ExpressionTuples { + if len(tuple) != numColumns { + return nil, errWrongNumberOfValues.New() + } + } + + // Get the columns that should be converted: only those that are in + // node.Columns and whose corresponding type in the schema is an integer + columnsToConvert := make(map[int]sql.Type) + for _, schemaColumn := range schema { + colType := schemaColumn.Type + if sql.IsInteger(colType) { + for nodeIdx, insertColumn := range node.Columns { + if schemaColumn.Name == insertColumn { + columnsToConvert[nodeIdx] = colType + } + } + } + } + + // Replace the values in the node with the converted ones + for _, valuesTuple := range values.ExpressionTuples { + for colIdx, newType := range columnsToConvert { + oldValue := valuesTuple[colIdx].(*expression.Literal).Value() + // Do not convert nil values, Convert() may make them zero + if oldValue != nil { + newValue, err := newType.Convert(oldValue) + if err != nil { + return nil, err + } + valuesTuple[colIdx] = expression.NewLiteral(newValue, newType) + } + } + } + } + return node, nil + }) +} diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index c2b8daf00..40e62e708 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -12,6 +12,7 @@ var DefaultRules = []Rule{ {"resolve_grouping_columns", resolveGroupingColumns}, {"qualify_columns", qualifyColumns}, {"resolve_columns", resolveColumns}, + {"resolve_insert_literals", convertIntegerLiteralsInsert}, {"resolve_database", resolveDatabase}, {"resolve_star", resolveStar}, {"resolve_functions", resolveFunctions}, diff --git a/sql/type.go b/sql/type.go index caadd9254..669919838 100644 --- a/sql/type.go +++ b/sql/type.go @@ -1155,12 +1155,12 @@ func IsNumber(t Type) bool { // IsSigned checks if t is a signed type. func IsSigned(t Type) bool { - return t == Int32 || t == Int64 + return t == Int8 || t == Int16 || t == Int32 || t == Int64 } // IsUnsigned checks if t is an unsigned type. func IsUnsigned(t Type) bool { - return t == Uint64 || t == Uint32 + return t == Uint8 || t == Uint16 || t == Uint32 || t == Uint64 } // IsInteger checks if t is a (U)Int32/64 type. From 25ce62fabd27f25fac7b62cd4789564182bae35a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Fri, 11 Oct 2019 11:32:08 +0200 Subject: [PATCH 06/44] Test the new rule to convert integers in INSERT nodes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- sql/analyzer/resolve_insert_literals_test.go | 165 +++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 sql/analyzer/resolve_insert_literals_test.go diff --git a/sql/analyzer/resolve_insert_literals_test.go b/sql/analyzer/resolve_insert_literals_test.go new file mode 100644 index 000000000..c19ae1642 --- /dev/null +++ b/sql/analyzer/resolve_insert_literals_test.go @@ -0,0 +1,165 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +// Common data for most of the tests below +var ( + schema = sql.Schema{ + {Name: "i8", Type: sql.Int8, Source: "table"}, + {Name: "i16", Type: sql.Int16, Source: "table"}, + {Name: "i32", Type: sql.Int32, Source: "table"}, + {Name: "i64", Type: sql.Int64, Source: "table"}, + {Name: "ui8", Type: sql.Uint8, Source: "table"}, + {Name: "ui16", Type: sql.Uint16, Source: "table"}, + {Name: "ui32", Type: sql.Uint32, Source: "table"}, + {Name: "ui64", Type: sql.Uint64, Source: "table"}, + } + + orderedColumns = []string{"i8", "i16", "i32", "i64", "ui8", "ui16", "ui32", "ui64"} + + inputValues = [][]sql.Expression{{ + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Uint8), + expression.NewLiteral(int8(1), sql.Uint8), + expression.NewLiteral(int8(1), sql.Uint8), + expression.NewLiteral(int8(1), sql.Uint8), + }} +) + +// Test the correct conversion of integer literals in INSERT nodes when no +// columns are explicitely specified by the plan +func TestInsertLiteralsWithoutColumns(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("table", schema) + + // An INSERT node with an empty columns field: the expected columns should be + // the ones in the schema + node := plan.NewInsertInto( + plan.NewResolvedTable(table), + plan.NewValues(inputValues), + false, + []string{}, + ) + + rule := getRule("resolve_insert_literals") + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + // The expected result should have the integers converted to the types + // specified by the schema, as well as the columns field populated with the + // schema columns in order + expected := plan.NewInsertInto( + plan.NewResolvedTable(table), + plan.NewValues([][]sql.Expression{{ + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int16(1), sql.Int16), + expression.NewLiteral(int32(1), sql.Int32), + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(uint8(1), sql.Uint8), + expression.NewLiteral(uint16(1), sql.Uint16), + expression.NewLiteral(uint32(1), sql.Uint32), + expression.NewLiteral(uint64(1), sql.Uint64), + }}), + false, + orderedColumns, + ) + + require.Equal(expected, result) +} + +// Test the correct conversion of integer literals in INSERT nodes when the +// node has a explicit order of columns, different than the one in the schema +func TestInsertLiteralsWithColumns(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("table", schema) + + // First unsigned, then signed + unorderedColumns := []string{"ui8", "ui16", "ui32", "ui64", "i8", "i16", "i32", "i64"} + + // An INSERT node with an explicit columns field, unordered with respect to + // the schema + node := plan.NewInsertInto( + plan.NewResolvedTable(table), + plan.NewValues(inputValues), + false, + unorderedColumns, + ) + + rule := getRule("resolve_insert_literals") + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + // The expected result should have the integers converted to the types + // specified by the schema, in the order specified by the columns field of + // the INSERT node + expected := plan.NewInsertInto( + plan.NewResolvedTable(table), + plan.NewValues([][]sql.Expression{{ + expression.NewLiteral(uint8(1), sql.Uint8), + expression.NewLiteral(uint16(1), sql.Uint16), + expression.NewLiteral(uint32(1), sql.Uint32), + expression.NewLiteral(uint64(1), sql.Uint64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int16(1), sql.Int16), + expression.NewLiteral(int32(1), sql.Int32), + expression.NewLiteral(int64(1), sql.Int64), + }}), + false, + unorderedColumns, + ) + + require.Equal(expected, result) +} + +// Test that non-integer literals are unchanged after applying the conversion +// of integers in INSERT nodes +func TestInsertLiteralsUnchanged(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("table", sql.Schema{ + {Name: "f32", Type: sql.Float32, Source: "typestable", Nullable: true}, + {Name: "f64", Type: sql.Float64, Source: "typestable", Nullable: true}, + {Name: "time", Type: sql.Timestamp, Source: "typestable", Nullable: true}, + {Name: "date", Type: sql.Date, Source: "typestable", Nullable: true}, + {Name: "text", Type: sql.Text, Source: "typestable", Nullable: true}, + {Name: "bool", Type: sql.Boolean, Source: "typestable", Nullable: true}, + {Name: "json", Type: sql.JSON, Source: "typestable", Nullable: true}, + {Name: "blob", Type: sql.Blob, Source: "typestable", Nullable: true}, + }) + + node := plan.NewInsertInto( + plan.NewResolvedTable(table), + plan.NewValues([][]sql.Expression{{ + expression.NewLiteral(float64(1.0), sql.Float32), + expression.NewLiteral(float64(5.0), sql.Float64), + expression.NewLiteral("1234-05-06 07:08:09", sql.Timestamp), + expression.NewLiteral("1234-05-06", sql.Date), + expression.NewLiteral("there be dragons", sql.Text), + expression.NewLiteral(false, sql.Boolean), + expression.NewLiteral(`{"key":"value"}`, sql.JSON), + expression.NewLiteral("blipblop", sql.Blob), + }}), + false, + []string{"f32", "f64", "time", "date", "text", "bool", "json", "blob"}, + ) + + rule := getRule("resolve_insert_literals") + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + // The original node should be unchanged, as there are no integers + require.Equal(node, result) +} From 412960b03d16c750eb9bd3f82d3c634fd91b82c7 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Mon, 14 Oct 2019 11:33:23 +0200 Subject: [PATCH 07/44] sql/index/pilosa: better error messages Signed-off-by: Miguel Molina --- sql/index/config.go | 1 - sql/index/pilosa/driver.go | 25 +++++++++++++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/sql/index/config.go b/sql/index/config.go index dcf846216..7d5b7c9fd 100644 --- a/sql/index/config.go +++ b/sql/index/config.go @@ -24,7 +24,6 @@ func NewConfig( driverID string, driverConfig map[string]string, ) *Config { - cfg := &Config{ DB: db, Table: table, diff --git a/sql/index/pilosa/driver.go b/sql/index/pilosa/driver.go index a9b30cf04..4f9146b21 100644 --- a/sql/index/pilosa/driver.go +++ b/sql/index/pilosa/driver.go @@ -110,6 +110,8 @@ func (*Driver) ID() string { return DriverID } +var errWriteConfigFile = errors.NewKind("unable to write indexes configuration file") + // Create a new index. func (d *Driver) Create( db, table, id string, @@ -133,7 +135,7 @@ func (d *Driver) Create( cfg := index.NewConfig(db, table, id, exprs, d.ID(), config) err = index.WriteConfigFile(d.configFilePath(db, table, id), cfg) if err != nil { - return nil, err + return nil, errWriteConfigFile.Wrap(err) } idx, err := d.newPilosaIndex(db, table) @@ -146,12 +148,14 @@ func (d *Driver) Create( processingFile, []byte{processingFileOnCreate}, ); err != nil { - return nil, err + return nil, errWriteConfigFile.Wrap(err) } return newPilosaIndex(idx, cfg), nil } +var errReadIndexes = errors.NewKind("error loading all indexes for table %s of database %s: %s") + // LoadAll loads all indexes for given db and table func (d *Driver) LoadAll(db, table string) ([]sql.Index, error) { var ( @@ -165,7 +169,7 @@ func (d *Driver) LoadAll(db, table string) ([]sql.Index, error) { if os.IsNotExist(err) { return indexes, nil } - return nil, err + return nil, errReadIndexes.New(table, db, err) } for _, info := range dirs { if info.IsDir() && !strings.HasPrefix(info.Name(), ".") { @@ -187,6 +191,11 @@ func (d *Driver) LoadAll(db, table string) ([]sql.Index, error) { return indexes, nil } +var ( + errLoadingIndexConfig = errors.NewKind("unable to load index configuration") + errReadIndexConfig = errors.NewKind("unable to read index configuration") +) + func (d *Driver) loadIndex(db, table, id string) (*pilosaIndex, error) { idx, err := d.newPilosaIndex(db, table) if err != nil { @@ -206,7 +215,7 @@ func (d *Driver) loadIndex(db, table, id string) (*pilosaIndex, error) { processing := d.processingFilePath(db, table, id) ok, err := index.ExistsProcessingFile(processing) if err != nil { - return nil, err + return nil, errLoadingIndexConfig.Wrap(err) } if ok { log := logrus.WithFields(logrus.Fields{ @@ -226,7 +235,7 @@ func (d *Driver) loadIndex(db, table, id string) (*pilosaIndex, error) { cfg, err := index.ReadConfigFile(config) if err != nil { - return nil, err + return nil, errReadIndexConfig.Wrap(err) } cfgDriver := cfg.Driver(DriverID) if cfgDriver == nil { @@ -382,13 +391,13 @@ func (d *Driver) Save( []byte{processingFileOnSave}, ) if err != nil { - return err + return errWriteConfigFile.Wrap(err) } cfgPath := d.configFilePath(i.Database(), i.Table(), i.ID()) cfg, err := index.ReadConfigFile(cfgPath) if err != nil { - return err + return errReadIndexConfig.Wrap(err) } driverCfg := cfg.Driver(DriverID) @@ -462,7 +471,7 @@ func (d *Driver) Save( return errors[0] } if err = index.WriteConfigFile(cfgPath, cfg); err != nil { - return err + return errWriteConfigFile.Wrap(err) } observeIndex(time.Since(start), timePilosa, timeMapping, rows) From be23624663c952c285e06880771cc2552e650245 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Tue, 15 Oct 2019 08:50:06 +0200 Subject: [PATCH 08/44] Modify getInt64Literal to try to convert smaller representations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- sql/parse/parse.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index b03627219..3f214d5c5 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -680,8 +680,14 @@ func getInt64Literal(ctx *sql.Context, expr sqlparser.Expr, errStr string) (*exp } nl, ok := e.(*expression.Literal) - if !ok || nl.Type() != sql.Int64 { + if !ok || !sql.IsInteger(nl.Type()) { return nil, ErrUnsupportedFeature.New(errStr) + } else { + i64, err := sql.Int64.Convert(nl.Value()) + if err != nil { + return nil, ErrUnsupportedFeature.New(errStr) + } + return expression.NewLiteral(i64, sql.Int64) , nil } return nl, nil From 506c5153f86f7f34825626a2a3de840c9d57b7ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Wed, 9 Oct 2019 17:06:17 +0200 Subject: [PATCH 09/44] Convert literals to specific int64 values in GROUP BY parsing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- sql/parse/parse.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 3f214d5c5..22a5aa09e 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -755,12 +755,14 @@ func selectToProjectOrGroupBy( for i, ge := range groupingExprs { // if GROUP BY index if l, ok := ge.(*expression.Literal); ok && sql.IsNumber(l.Type()) { - if idx, ok := l.Value().(int64); ok && idx > 0 && idx <= agglen { - aggexpr := selectExprs[idx-1] - if alias, ok := aggexpr.(*expression.Alias); ok { - aggexpr = expression.NewUnresolvedColumn(alias.Name()) + if i64, err := sql.Int64.Convert(l.Value()); err == nil { + if idx, ok := i64.(int64); ok && idx > 0 && idx <= agglen { + aggexpr := selectExprs[idx-1] + if alias, ok := aggexpr.(*expression.Alias); ok { + aggexpr = expression.NewUnresolvedColumn(alias.Name()) + } + groupingExprs[i] = aggexpr } - groupingExprs[i] = aggexpr } } } From 463c5e47af67d648d8ad7e2111fcdd5554613de3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Fri, 11 Oct 2019 12:07:58 +0200 Subject: [PATCH 10/44] Modify ROUND and YEARWEEK functions to manage all integer types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- sql/expression/function/ceil_round_floor.go | 24 +++++++++++++++++++++ sql/expression/function/time.go | 6 ++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sql/expression/function/ceil_round_floor.go b/sql/expression/function/ceil_round_floor.go index 4c8c1d757..c056f82da 100644 --- a/sql/expression/function/ceil_round_floor.go +++ b/sql/expression/function/ceil_round_floor.go @@ -203,6 +203,18 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { dVal = float64(dNum) case int32: dVal = float64(dNum) + case int16: + dVal = float64(dNum) + case int8: + dVal = float64(dNum) + case uint64: + dVal = float64(dNum) + case uint32: + dVal = float64(dNum) + case uint16: + dVal = float64(dNum) + case uint8: + dVal = float64(dNum) case int: dVal = float64(dNum) default: @@ -233,6 +245,18 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return int64(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil case int32: return int32(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case int16: + return int16(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case int8: + return int8(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint64: + return uint64(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint32: + return uint32(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint16: + return uint16(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint8: + return uint8(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil case int: return int(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil default: diff --git a/sql/expression/function/time.go b/sql/expression/function/time.go index 2385aec0d..c0dcaf3d4 100644 --- a/sql/expression/function/time.go +++ b/sql/expression/function/time.go @@ -363,8 +363,10 @@ func (d *YearWeek) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } if val != nil { - if mode, ok = val.(int64); ok { - mode %= 8 // mode in [0, 7] + if i64, err := sql.Int64.Convert(val); err == nil { + if mode, ok = i64.(int64); ok { + mode %= 8 // mode in [0, 7] + } } } yyyy, week := calcWeek(yyyy, mm, dd, weekMode(mode)|weekBehaviourYear) From 244439c1819502b6b3f15219d50e8377309301eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Wed, 9 Oct 2019 17:06:56 +0200 Subject: [PATCH 11/44] Adapt all tests to new integer parsing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The tests now expect the correct types specified in the tables schema Signed-off-by: Alejandro García Montoro --- engine_test.go | 52 +++++++++--------- sql/parse/parse_test.go | 116 ++++++++++++++++++++-------------------- 2 files changed, 84 insertions(+), 84 deletions(-) diff --git a/engine_test.go b/engine_test.go index 3ce2c8f81..270f47884 100644 --- a/engine_test.go +++ b/engine_test.go @@ -582,7 +582,7 @@ var queries = []struct { { `SELECT COALESCE(NULL, NULL, NULL, COALESCE(NULL, 1234567890))`, []sql.Row{ - {int64(1234567890)}, + {int32(1234567890)}, }, }, { @@ -981,7 +981,7 @@ var queries = []struct { }, { `SELECT -1`, - []sql.Row{{int64(-1)}}, + []sql.Row{{int8(-1)}}, }, { ` @@ -1043,13 +1043,13 @@ var queries = []struct { { `SELECT nullif(123, 321)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { `SELECT ifnull(123, NULL)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { @@ -1061,19 +1061,19 @@ var queries = []struct { { `SELECT ifnull(NULL, 123)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { `SELECT ifnull(123, 123)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { `SELECT ifnull(123, 321)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { @@ -1085,7 +1085,7 @@ var queries = []struct { { `SELECT round(15, 1)`, []sql.Row{ - {int64(15)}, + {int8(15)}, }, }, { @@ -1452,7 +1452,7 @@ var queries = []struct { }, { `SELECT 1 FROM mytable GROUP BY i HAVING i > 1`, - []sql.Row{{int64(1)}, {int64(1)}}, + []sql.Row{{int8(1)}, {int8(1)}}, }, { `SELECT avg(i) FROM mytable GROUP BY i HAVING avg(i) > 1`, @@ -1887,8 +1887,8 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(math.MaxInt8), int64(math.MaxInt16), int64(math.MaxInt32), int64(math.MaxInt64), - int64(math.MaxUint8), int64(math.MaxUint16), int64(math.MaxUint32), uint64(math.MaxUint64), + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), float64(math.MaxFloat32), float64(math.MaxFloat64), timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), "random text", true, `{"key":"value"}`, "blobdata", @@ -1905,8 +1905,8 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(math.MaxInt8), int64(math.MaxInt16), int64(math.MaxInt32), int64(math.MaxInt64), - int64(math.MaxUint8), int64(math.MaxUint16), int64(math.MaxUint32), uint64(math.MaxUint64), + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), float64(math.MaxFloat32), float64(math.MaxFloat64), timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), "random text", true, `{"key":"value"}`, "blobdata", @@ -1923,8 +1923,8 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), - int64(0), int64(0), int64(0), int64(0), + int64(999), int8(-math.MaxInt8-1), int16(-math.MaxInt16-1), int32(-math.MaxInt32-1), int64(-math.MaxInt64-1), + uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), "", false, ``, "", @@ -1941,8 +1941,8 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), - int64(0), int64(0), int64(0), int64(0), + int64(999), int8(-math.MaxInt8-1), int16(-math.MaxInt16-1), int32(-math.MaxInt32-1), int64(-math.MaxInt64-1), + uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), "", false, ``, "", @@ -2087,8 +2087,8 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(math.MaxInt8), int64(math.MaxInt16), int64(math.MaxInt32), int64(math.MaxInt64), - int64(math.MaxUint8), int64(math.MaxUint16), int64(math.MaxUint32), uint64(math.MaxUint64), + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), float64(math.MaxFloat32), float64(math.MaxFloat64), timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), "random text", true, `{"key":"value"}`, "blobdata", @@ -2105,8 +2105,8 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(math.MaxInt8), int64(math.MaxInt16), int64(math.MaxInt32), int64(math.MaxInt64), - int64(math.MaxUint8), int64(math.MaxUint16), int64(math.MaxUint32), uint64(math.MaxUint64), + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), float64(math.MaxFloat32), float64(math.MaxFloat64), timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), "random text", true, `{"key":"value"}`, "blobdata", @@ -2123,8 +2123,8 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), - int64(0), int64(0), int64(0), int64(0), + int64(999), int8(-math.MaxInt8-1), int16(-math.MaxInt16-1), int32(-math.MaxInt32-1), int64(-math.MaxInt64-1), + uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), "", false, ``, "", @@ -2141,8 +2141,8 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), - int64(0), int64(0), int64(0), int64(0), + int64(999), int8(-math.MaxInt8-1), int16(-math.MaxInt16-1), int32(-math.MaxInt32-1), int64(-math.MaxInt64-1), + uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), "", false, ``, "", @@ -2947,7 +2947,7 @@ func TestSessionVariables(t *testing.T) { rows, err := sql.RowIterToRows(iter) require.NoError(err) - require.Equal([]sql.Row{{int64(1), ",STRICT_TRANS_TABLES"}}, rows) + require.Equal([]sql.Row{{int8(1), ",STRICT_TRANS_TABLES"}}, rows) } func TestSessionVariablesONOFF(t *testing.T) { diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index be176a9c6..ab5d7acfc 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -181,7 +181,7 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewEquals( expression.NewUnresolvedColumn("qux"), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), ), plan.NewUnresolvedTable("foo", ""), ), @@ -258,7 +258,7 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("t1", ""), plan.NewValues([][]sql.Expression{{ expression.NewLiteral("a", sql.Text), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), }}), false, []string{"col1", "col2"}, @@ -267,7 +267,7 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("t1", ""), plan.NewValues([][]sql.Expression{{ expression.NewLiteral("a", sql.Text), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), }}), true, []string{"col1", "col2"}, @@ -360,7 +360,7 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewEquals( expression.NewUnresolvedColumn("a"), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), ), plan.NewUnresolvedTable("foo", ""), ), @@ -432,9 +432,9 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewNot( expression.NewBetween( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(2), sql.Int64), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(2), sql.Int8), + expression.NewLiteral(int8(5), sql.Int8), ), ), plan.NewUnresolvedTable("foo", ""), @@ -444,16 +444,16 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewStar()}, plan.NewFilter( expression.NewBetween( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(2), sql.Int64), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(2), sql.Int8), + expression.NewLiteral(int8(5), sql.Int8), ), plan.NewUnresolvedTable("foo", ""), ), ), `SELECT 0x01AF`: plan.NewProject( []sql.Expression{ - expression.NewLiteral(int64(431), sql.Int64), + expression.NewLiteral(int16(431), sql.Int16), }, plan.NewUnresolvedTable("dual", ""), ), @@ -470,12 +470,12 @@ var fixtures = map[string]sql.Node{ "somefunc", false, expression.NewTuple( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(2), sql.Int8), ), expression.NewTuple( - expression.NewLiteral(int64(3), sql.Int64), - expression.NewLiteral(int64(4), sql.Int64), + expression.NewLiteral(int8(3), sql.Int8), + expression.NewLiteral(int8(4), sql.Int8), ), ), plan.NewUnresolvedTable("b", ""), @@ -486,7 +486,7 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewEquals( expression.NewLiteral(":foo_id", sql.Text), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), plan.NewUnresolvedTable("foo", ""), ), @@ -510,13 +510,13 @@ var fixtures = map[string]sql.Node{ ), `SELECT CAST(-3 AS UNSIGNED) FROM foo`: plan.NewProject( []sql.Expression{ - expression.NewConvert(expression.NewLiteral(int64(-3), sql.Int64), expression.ConvertToUnsigned), + expression.NewConvert(expression.NewLiteral(int8(-3), sql.Int8), expression.ConvertToUnsigned), }, plan.NewUnresolvedTable("foo", ""), ), `SELECT 2 = 2 FROM foo`: plan.NewProject( []sql.Expression{ - expression.NewEquals(expression.NewLiteral(int64(2), sql.Int64), expression.NewLiteral(int64(2), sql.Int64)), + expression.NewEquals(expression.NewLiteral(int8(2), sql.Int8), expression.NewLiteral(int8(2), sql.Int8)), }, plan.NewUnresolvedTable("foo", ""), ), @@ -560,10 +560,10 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewStar()}, plan.NewFilter( expression.NewIn( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), expression.NewTuple( expression.NewLiteral("1", sql.Text), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), ), plan.NewUnresolvedTable("foo", ""), @@ -573,10 +573,10 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewStar()}, plan.NewFilter( expression.NewNotIn( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), expression.NewTuple( expression.NewLiteral("1", sql.Text), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), ), plan.NewUnresolvedTable("foo", ""), @@ -585,12 +585,12 @@ var fixtures = map[string]sql.Node{ `SELECT a, b FROM t ORDER BY 2, 1`: plan.NewSort( []plan.SortField{ { - Column: expression.NewLiteral(int64(2), sql.Int64), + Column: expression.NewLiteral(int8(2), sql.Int8), Order: plan.Ascending, NullOrdering: plan.NullsFirst, }, { - Column: expression.NewLiteral(int64(1), sql.Int64), + Column: expression.NewLiteral(int8(1), sql.Int8), Order: plan.Ascending, NullOrdering: plan.NullsFirst, }, @@ -605,22 +605,22 @@ var fixtures = map[string]sql.Node{ ), `SELECT 1 + 1;`: plan.NewProject( []sql.Expression{ - expression.NewPlus(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(1), sql.Int64)), + expression.NewPlus(expression.NewLiteral(int8(1), sql.Int8), expression.NewLiteral(int8(1), sql.Int8)), }, plan.NewUnresolvedTable("dual", ""), ), `SELECT 1 * (2 + 1);`: plan.NewProject( []sql.Expression{ - expression.NewMult(expression.NewLiteral(int64(1), sql.Int64), - expression.NewPlus(expression.NewLiteral(int64(2), sql.Int64), expression.NewLiteral(int64(1), sql.Int64))), + expression.NewMult(expression.NewLiteral(int8(1), sql.Int8), + expression.NewPlus(expression.NewLiteral(int8(2), sql.Int8), expression.NewLiteral(int8(1), sql.Int8))), }, plan.NewUnresolvedTable("dual", ""), ), `SELECT (0 - 1) * (1 | 1);`: plan.NewProject( []sql.Expression{ expression.NewMult( - expression.NewMinus(expression.NewLiteral(int64(0), sql.Int64), expression.NewLiteral(int64(1), sql.Int64)), - expression.NewBitOr(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(1), sql.Int64)), + expression.NewMinus(expression.NewLiteral(int8(0), sql.Int8), expression.NewLiteral(int8(1), sql.Int8)), + expression.NewBitOr(expression.NewLiteral(int8(1), sql.Int8), expression.NewLiteral(int8(1), sql.Int8)), ), }, plan.NewUnresolvedTable("dual", ""), @@ -628,8 +628,8 @@ var fixtures = map[string]sql.Node{ `SELECT (1 << 3) % (2 div 1);`: plan.NewProject( []sql.Expression{ expression.NewMod( - expression.NewShiftLeft(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(3), sql.Int64)), - expression.NewIntDiv(expression.NewLiteral(int64(2), sql.Int64), expression.NewLiteral(int64(1), sql.Int64))), + expression.NewShiftLeft(expression.NewLiteral(int8(1), sql.Int8), expression.NewLiteral(int8(3), sql.Int8)), + expression.NewIntDiv(expression.NewLiteral(int8(2), sql.Int8), expression.NewLiteral(int8(1), sql.Int8))), }, plan.NewUnresolvedTable("dual", ""), ), @@ -645,7 +645,7 @@ var fixtures = map[string]sql.Node{ `SELECT '1.0' + 2;`: plan.NewProject( []sql.Expression{ expression.NewPlus( - expression.NewLiteral("1.0", sql.Text), expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral("1.0", sql.Text), expression.NewLiteral(int8(2), sql.Int8), ), }, plan.NewUnresolvedTable("dual", ""), @@ -714,7 +714,7 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedFunction( "max", true, expression.NewUnresolvedColumn("i"), ), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), "/", ), }, @@ -742,7 +742,7 @@ var fixtures = map[string]sql.Node{ `SET autocommit=1, foo="bar"`: plan.NewSet( plan.SetVariable{ Name: "autocommit", - Value: expression.NewLiteral(int64(1), sql.Int64), + Value: expression.NewLiteral(int8(1), sql.Int8), }, plan.SetVariable{ Name: "foo", @@ -752,7 +752,7 @@ var fixtures = map[string]sql.Node{ `SET @@session.autocommit=1, foo="bar"`: plan.NewSet( plan.SetVariable{ Name: "@@session.autocommit", - Value: expression.NewLiteral(int64(1), sql.Int64), + Value: expression.NewLiteral(int8(1), sql.Int8), }, plan.SetVariable{ Name: "foo", @@ -818,11 +818,11 @@ var fixtures = map[string]sql.Node{ `SET SESSION NET_READ_TIMEOUT= 700, SESSION NET_WRITE_TIMEOUT= 700`: plan.NewSet( plan.SetVariable{ Name: "@@session.net_read_timeout", - Value: expression.NewLiteral(int64(700), sql.Int64), + Value: expression.NewLiteral(int16(700), sql.Int16), }, plan.SetVariable{ Name: "@@session.net_write_timeout", - Value: expression.NewLiteral(int64(700), sql.Int64), + Value: expression.NewLiteral(int16(700), sql.Int16), }, ), `SET gtid_mode=DEFAULT`: plan.NewSet( @@ -975,11 +975,11 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("foo"), []expression.CaseBranch{ { - Cond: expression.NewLiteral(int64(1), sql.Int64), + Cond: expression.NewLiteral(int8(1), sql.Int8), Value: expression.NewLiteral("foo", sql.Text), }, { - Cond: expression.NewLiteral(int64(2), sql.Int64), + Cond: expression.NewLiteral(int8(2), sql.Int8), Value: expression.NewLiteral("bar", sql.Text), }, }, @@ -992,11 +992,11 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("foo"), []expression.CaseBranch{ { - Cond: expression.NewLiteral(int64(1), sql.Int64), + Cond: expression.NewLiteral(int8(1), sql.Int8), Value: expression.NewLiteral("foo", sql.Text), }, { - Cond: expression.NewLiteral(int64(2), sql.Int64), + Cond: expression.NewLiteral(int8(2), sql.Int8), Value: expression.NewLiteral("bar", sql.Text), }, }, @@ -1011,14 +1011,14 @@ var fixtures = map[string]sql.Node{ { Cond: expression.NewEquals( expression.NewUnresolvedColumn("foo"), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), ), Value: expression.NewLiteral("foo", sql.Text), }, { Cond: expression.NewEquals( expression.NewUnresolvedColumn("foo"), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), Value: expression.NewLiteral("bar", sql.Text), }, @@ -1055,7 +1055,7 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewArithmetic( expression.NewLiteral("2018-05-01", sql.Text), expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), "+", @@ -1066,7 +1066,7 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewArithmetic( expression.NewLiteral("2018-05-01", sql.Text), expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), "-", @@ -1076,7 +1076,7 @@ var fixtures = map[string]sql.Node{ `SELECT INTERVAL 1 DAY + '2018-05-01'`: plan.NewProject( []sql.Expression{expression.NewArithmetic( expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), expression.NewLiteral("2018-05-01", sql.Text), @@ -1089,13 +1089,13 @@ var fixtures = map[string]sql.Node{ expression.NewArithmetic( expression.NewLiteral("2018-05-01", sql.Text), expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), "+", ), expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), "+", @@ -1105,7 +1105,7 @@ var fixtures = map[string]sql.Node{ `SELECT COUNT(*) FROM foo GROUP BY a HAVING COUNT(*) > 5`: plan.NewHaving( expression.NewGreaterThan( expression.NewUnresolvedFunction("count", true, expression.NewStar()), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(5), sql.Int8), ), plan.NewGroupBy( []sql.Expression{expression.NewUnresolvedFunction("count", true, expression.NewStar())}, @@ -1117,7 +1117,7 @@ var fixtures = map[string]sql.Node{ plan.NewHaving( expression.NewGreaterThan( expression.NewUnresolvedFunction("count", true, expression.NewStar()), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(5), sql.Int8), ), plan.NewGroupBy( []sql.Expression{expression.NewUnresolvedFunction("count", true, expression.NewStar())}, @@ -1132,8 +1132,8 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewUnresolvedTable("bar", ""), expression.NewEquals( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), ), ), ), @@ -1143,8 +1143,8 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewUnresolvedTable("bar", ""), expression.NewEquals( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), ), ), ), @@ -1154,8 +1154,8 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewUnresolvedTable("bar", ""), expression.NewEquals( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), ), ), ), @@ -1165,8 +1165,8 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewUnresolvedTable("bar", ""), expression.NewEquals( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), ), ), ), From 5616382938efab2fcb3da84776067b05396e1400 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Thu, 10 Oct 2019 09:42:15 +0200 Subject: [PATCH 12/44] Add missing cases to pilosa decodeGob and compare functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- sql/index/pilosa/lookup.go | 76 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/sql/index/pilosa/lookup.go b/sql/index/pilosa/lookup.go index e50aeeabd..29173921b 100644 --- a/sql/index/pilosa/lookup.go +++ b/sql/index/pilosa/lookup.go @@ -621,6 +621,14 @@ func decodeGob(k []byte, value interface{}) (interface{}, error) { var v string err := decoder.Decode(&v) return v, err + case int8: + var v int8 + err := decoder.Decode(&v) + return v, err + case int16: + var v int16 + err := decoder.Decode(&v) + return v, err case int32: var v int32 err := decoder.Decode(&v) @@ -629,6 +637,14 @@ func decodeGob(k []byte, value interface{}) (interface{}, error) { var v int64 err := decoder.Decode(&v) return v, err + case uint8: + var v uint8 + err := decoder.Decode(&v) + return v, err + case uint16: + var v uint16 + err := decoder.Decode(&v) + return v, err case uint32: var v uint32 err := decoder.Decode(&v) @@ -688,6 +704,36 @@ func compare(a, b interface{}) (int, error) { } return strings.Compare(a, v), nil + case int8: + v, ok := b.(int8) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case int16: + v, ok := b.(int16) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil case int32: v, ok := b.(int32) if !ok { @@ -717,6 +763,36 @@ func compare(a, b interface{}) (int, error) { return -1, nil } + return 1, nil + case uint8: + v, ok := b.(uint8) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case uint16: + v, ok := b.(uint16) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + return 1, nil case uint32: v, ok := b.(uint32) From 95124bbb0a142a00f3ac7a41248595b650ca03b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Thu, 10 Oct 2019 15:41:48 +0200 Subject: [PATCH 13/44] Add smaller integer types to the conversion to sqltypes.Value MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- sql/type.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/type.go b/sql/type.go index 669919838..fa50a19d5 100644 --- a/sql/type.go +++ b/sql/type.go @@ -335,10 +335,18 @@ func (t numberT) SQL(v interface{}) (sqltypes.Value, error) { } switch t.t { + case sqltypes.Int8: + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil + case sqltypes.Int16: + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil case sqltypes.Int32: return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil case sqltypes.Int64: return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil + case sqltypes.Uint8: + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil + case sqltypes.Uint16: + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil case sqltypes.Uint32: return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil case sqltypes.Uint64: From 2395078170ccd6e1a3417e4115f3296469cce5e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Fri, 11 Oct 2019 12:50:06 +0200 Subject: [PATCH 14/44] Fix formatting of convertInt. HT @erizocosmico MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- sql/parse/parse.go | 57 +++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 22a5aa09e..0e17206a0 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -962,39 +962,34 @@ func isAggregateFunc(v *sqlparser.FuncExpr) bool { // base, to its smallest representation possible, out of: // int8, uint8, int16, uint16, int32, uint32, int64 and uint64 func convertInt(value string, base int) (sql.Expression, error) { - i8, err := strconv.ParseInt(value, base, 8) + if i8, err := strconv.ParseInt(value, base, 8); err == nil { + return expression.NewLiteral(int8(i8), sql.Int8), nil + } + if ui8, err := strconv.ParseUint(value, base, 8); err == nil { + return expression.NewLiteral(uint8(ui8), sql.Uint8), nil + } + if i16, err := strconv.ParseInt(value, base, 16); err == nil { + return expression.NewLiteral(int16(i16), sql.Int16), nil + } + if ui16, err := strconv.ParseUint(value, base, 16); err == nil { + return expression.NewLiteral(uint16(ui16), sql.Uint16), nil + } + if i32, err := strconv.ParseInt(value, base, 32); err == nil { + return expression.NewLiteral(int32(i32), sql.Int32), nil + } + if ui32, err := strconv.ParseUint(value, base, 32); err == nil { + return expression.NewLiteral(uint32(ui32), sql.Uint32), nil + } + if i64, err := strconv.ParseInt(value, base, 64); err == nil { + return expression.NewLiteral(int64(i64), sql.Int64), nil + } + + ui64, err := strconv.ParseUint(value, base, 64); if err != nil { - ui8, err := strconv.ParseUint(value, base, 8) - if err != nil { - i16, err := strconv.ParseInt(value, base, 16) - if err != nil { - ui16, err := strconv.ParseUint(value, base, 16) - if err != nil { - i32, err := strconv.ParseInt(value, base, 32) - if err != nil { - ui32, err := strconv.ParseUint(value, base, 32) - if err != nil { - i64, err := strconv.ParseInt(value, base, 64) - if err != nil { - ui64, err := strconv.ParseUint(value, base, 64) - if err != nil { - return nil, err - } - return expression.NewLiteral(uint64(ui64), sql.Uint64), nil - } - return expression.NewLiteral(int64(i64), sql.Int64), nil - } - return expression.NewLiteral(uint32(ui32), sql.Uint32), nil - } - return expression.NewLiteral(int32(i32), sql.Int32), nil - } - return expression.NewLiteral(uint16(ui16), sql.Uint16), nil - } - return expression.NewLiteral(int16(i16), sql.Int16), nil - } - return expression.NewLiteral(uint8(ui8), sql.Uint16), nil + return nil, err } - return expression.NewLiteral(int8(i8), sql.Int8), nil + + return expression.NewLiteral(uint64(ui64), sql.Uint64), nil } func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { From daf82c59b6a66655deb0bca7db835f59de5b2bde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Fri, 11 Oct 2019 14:17:56 +0200 Subject: [PATCH 15/44] Run gofmt over the whole codebase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- engine_pilosa_test.go | 1 - engine_test.go | 8 +- internal/sockstate/netstat_linux.go | 4 +- sql/analyzer/resolve_insert_literals.go | 124 ++++++++++++------------ sql/parse/parse.go | 4 +- sql/plan/insert.go | 9 +- sql/type.go | 1 - 7 files changed, 72 insertions(+), 79 deletions(-) diff --git a/engine_pilosa_test.go b/engine_pilosa_test.go index 00380a1b9..a6da4c5a9 100644 --- a/engine_pilosa_test.go +++ b/engine_pilosa_test.go @@ -207,4 +207,3 @@ func TestCreateIndex(t *testing.T) { require.NoError(os.RemoveAll(tmpDir)) }() } - diff --git a/engine_test.go b/engine_test.go index 270f47884..1ba6ef32a 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1923,7 +1923,7 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int8(-math.MaxInt8-1), int16(-math.MaxInt16-1), int32(-math.MaxInt32-1), int64(-math.MaxInt64-1), + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), @@ -1941,7 +1941,7 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int8(-math.MaxInt8-1), int16(-math.MaxInt16-1), int32(-math.MaxInt32-1), int64(-math.MaxInt64-1), + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), @@ -2123,7 +2123,7 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int8(-math.MaxInt8-1), int16(-math.MaxInt16-1), int32(-math.MaxInt32-1), int64(-math.MaxInt64-1), + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), @@ -2141,7 +2141,7 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int8(-math.MaxInt8-1), int16(-math.MaxInt16-1), int32(-math.MaxInt32-1), int64(-math.MaxInt64-1), + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), diff --git a/internal/sockstate/netstat_linux.go b/internal/sockstate/netstat_linux.go index 435c7deaa..a7eb6ff62 100644 --- a/internal/sockstate/netstat_linux.go +++ b/internal/sockstate/netstat_linux.go @@ -22,8 +22,8 @@ import ( const ( pathTCP4Tab = "/proc/net/tcp" pathTCP6Tab = "/proc/net/tcp6" - ipv4StrLen = 8 - ipv6StrLen = 32 + ipv4StrLen = 8 + ipv6StrLen = 32 ) type procFd struct { diff --git a/sql/analyzer/resolve_insert_literals.go b/sql/analyzer/resolve_insert_literals.go index 108980e63..f9c641c5f 100644 --- a/sql/analyzer/resolve_insert_literals.go +++ b/sql/analyzer/resolve_insert_literals.go @@ -2,81 +2,79 @@ package analyzer import ( "github.com/src-d/go-mysql-server/sql" - "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" - errors "gopkg.in/src-d/go-errors.v1" - + errors "gopkg.in/src-d/go-errors.v1" ) var errWrongNumberOfValues = errors.NewKind("the number of values to insert differ from the expected columns") func convertIntegerLiteralsInsert(ctx *sql.Context, analyzer *Analyzer, originalNode sql.Node) (sql.Node, error) { - span, _ := ctx.Span("resolve_insert_literals") - defer span.Finish() - - return plan.TransformUp(originalNode, func(node sql.Node) (sql.Node, error) { + span, _ := ctx.Span("resolve_insert_literals") + defer span.Finish() - if node, ok := node.(*plan.InsertInto); ok { - resolvedTable, ok := node.BinaryNode.Left.(*plan.ResolvedTable) - if !ok { - return node, nil - } + return plan.TransformUp(originalNode, func(node sql.Node) (sql.Node, error) { + if node, ok := node.(*plan.InsertInto); ok { + resolvedTable, ok := node.BinaryNode.Left.(*plan.ResolvedTable) + if !ok { + return node, nil + } - values := node.BinaryNode.Right.(*plan.Values) - if !ok { - return node, nil - } + values := node.BinaryNode.Right.(*plan.Values) + if !ok { + return node, nil + } - analyzer.Log("Transforming integer literals in INSERT node") + analyzer.Log("Transforming integer literals in INSERT node") - schema := resolvedTable.Table.Schema() + schema := resolvedTable.Table.Schema() - // If the InsertInto node does not have any explicit columns, - // we assume the values are in the same order as in the table schema - if len(node.Columns) == 0 { - node.Columns = make([]string, len(schema)) - for i, column := range schema { - node.Columns[i] = column.Name - } - } + // If the InsertInto node does not have any explicit columns, + // we assume the values are in the same order as in the table schema + if len(node.Columns) == 0 { + node.Columns = make([]string, len(schema)) + for i, column := range schema { + node.Columns[i] = column.Name + } + } - // Check that all tuples contain as many values as needed - numColumns := len(node.Columns) - for _, tuple := range values.ExpressionTuples { - if len(tuple) != numColumns { - return nil, errWrongNumberOfValues.New() - } - } + // Check that all tuples contain as many values as needed + numColumns := len(node.Columns) + for _, tuple := range values.ExpressionTuples { + if len(tuple) != numColumns { + return nil, errWrongNumberOfValues.New() + } + } - // Get the columns that should be converted: only those that are in - // node.Columns and whose corresponding type in the schema is an integer - columnsToConvert := make(map[int]sql.Type) - for _, schemaColumn := range schema { - colType := schemaColumn.Type - if sql.IsInteger(colType) { - for nodeIdx, insertColumn := range node.Columns { - if schemaColumn.Name == insertColumn { - columnsToConvert[nodeIdx] = colType - } - } - } - } + // Get the columns that should be converted: only those that are in + // node.Columns and whose corresponding type in the schema is an integer + columnsToConvert := make(map[int]sql.Type) + for _, schemaColumn := range schema { + colType := schemaColumn.Type + if sql.IsInteger(colType) { + for nodeIdx, insertColumn := range node.Columns { + if schemaColumn.Name == insertColumn { + columnsToConvert[nodeIdx] = colType + } + } + } + } - // Replace the values in the node with the converted ones - for _, valuesTuple := range values.ExpressionTuples { - for colIdx, newType := range columnsToConvert { - oldValue := valuesTuple[colIdx].(*expression.Literal).Value() - // Do not convert nil values, Convert() may make them zero - if oldValue != nil { - newValue, err := newType.Convert(oldValue) - if err != nil { - return nil, err - } - valuesTuple[colIdx] = expression.NewLiteral(newValue, newType) - } - } - } - } - return node, nil - }) + // Replace the values in the node with the converted ones + for _, valuesTuple := range values.ExpressionTuples { + for colIdx, newType := range columnsToConvert { + oldValue := valuesTuple[colIdx].(*expression.Literal).Value() + // Do not convert nil values, Convert() may make them zero + if oldValue != nil { + newValue, err := newType.Convert(oldValue) + if err != nil { + return nil, err + } + valuesTuple[colIdx] = expression.NewLiteral(newValue, newType) + } + } + } + } + return node, nil + }) } diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 0e17206a0..e77819f2a 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -687,7 +687,7 @@ func getInt64Literal(ctx *sql.Context, expr sqlparser.Expr, errStr string) (*exp if err != nil { return nil, ErrUnsupportedFeature.New(errStr) } - return expression.NewLiteral(i64, sql.Int64) , nil + return expression.NewLiteral(i64, sql.Int64), nil } return nl, nil @@ -984,7 +984,7 @@ func convertInt(value string, base int) (sql.Expression, error) { return expression.NewLiteral(int64(i64), sql.Int64), nil } - ui64, err := strconv.ParseUint(value, base, 64); + ui64, err := strconv.ParseUint(value, base, 64) if err != nil { return nil, err } diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 807191f4b..46f6f4df0 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -11,15 +11,12 @@ import ( // ErrInsertIntoNotSupported is thrown when a table doesn't support inserts var ErrInsertIntoNotSupported = errors.NewKind("table doesn't support INSERT INTO") var ErrReplaceIntoNotSupported = errors.NewKind("table doesn't support REPLACE INTO") -var ErrInsertIntoMismatchValueCount = - errors.NewKind("number of values does not match number of columns provided") +var ErrInsertIntoMismatchValueCount = errors.NewKind("number of values does not match number of columns provided") var ErrInsertIntoUnsupportedValues = errors.NewKind("%T is unsupported for inserts") var ErrInsertIntoDuplicateColumn = errors.NewKind("duplicate column name %v") var ErrInsertIntoNonexistentColumn = errors.NewKind("invalid column name %v") -var ErrInsertIntoNonNullableDefaultNullColumn = - errors.NewKind("column name '%v' is non-nullable but attempted to set default value of null") -var ErrInsertIntoNonNullableProvidedNull = - errors.NewKind("column name '%v' is non-nullable but attempted to set a value of null") +var ErrInsertIntoNonNullableDefaultNullColumn = errors.NewKind("column name '%v' is non-nullable but attempted to set default value of null") +var ErrInsertIntoNonNullableProvidedNull = errors.NewKind("column name '%v' is non-nullable but attempted to set a value of null") // InsertInto is a node describing the insertion into some table. type InsertInto struct { diff --git a/sql/type.go b/sql/type.go index fa50a19d5..94e825328 100644 --- a/sql/type.go +++ b/sql/type.go @@ -743,7 +743,6 @@ func (t charT) Compare(a interface{}, b interface{}) (int, error) { return strings.Compare(a.(string), b.(string)), nil } - type varCharT struct { length int } From f38b019f5cadae0a32fe12fe997cab629501ab80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Fri, 11 Oct 2019 16:26:14 +0200 Subject: [PATCH 16/44] Move integer conversion from ad-hoc rule to Insert.Execute MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- sql/analyzer/resolve_insert_literals.go | 80 --------- sql/analyzer/resolve_insert_literals_test.go | 165 ------------------- sql/analyzer/rules.go | 1 - sql/plan/insert.go | 14 ++ 4 files changed, 14 insertions(+), 246 deletions(-) delete mode 100644 sql/analyzer/resolve_insert_literals.go delete mode 100644 sql/analyzer/resolve_insert_literals_test.go diff --git a/sql/analyzer/resolve_insert_literals.go b/sql/analyzer/resolve_insert_literals.go deleted file mode 100644 index f9c641c5f..000000000 --- a/sql/analyzer/resolve_insert_literals.go +++ /dev/null @@ -1,80 +0,0 @@ -package analyzer - -import ( - "github.com/src-d/go-mysql-server/sql" - "github.com/src-d/go-mysql-server/sql/expression" - "github.com/src-d/go-mysql-server/sql/plan" - errors "gopkg.in/src-d/go-errors.v1" -) - -var errWrongNumberOfValues = errors.NewKind("the number of values to insert differ from the expected columns") - -func convertIntegerLiteralsInsert(ctx *sql.Context, analyzer *Analyzer, originalNode sql.Node) (sql.Node, error) { - span, _ := ctx.Span("resolve_insert_literals") - defer span.Finish() - - return plan.TransformUp(originalNode, func(node sql.Node) (sql.Node, error) { - if node, ok := node.(*plan.InsertInto); ok { - resolvedTable, ok := node.BinaryNode.Left.(*plan.ResolvedTable) - if !ok { - return node, nil - } - - values := node.BinaryNode.Right.(*plan.Values) - if !ok { - return node, nil - } - - analyzer.Log("Transforming integer literals in INSERT node") - - schema := resolvedTable.Table.Schema() - - // If the InsertInto node does not have any explicit columns, - // we assume the values are in the same order as in the table schema - if len(node.Columns) == 0 { - node.Columns = make([]string, len(schema)) - for i, column := range schema { - node.Columns[i] = column.Name - } - } - - // Check that all tuples contain as many values as needed - numColumns := len(node.Columns) - for _, tuple := range values.ExpressionTuples { - if len(tuple) != numColumns { - return nil, errWrongNumberOfValues.New() - } - } - - // Get the columns that should be converted: only those that are in - // node.Columns and whose corresponding type in the schema is an integer - columnsToConvert := make(map[int]sql.Type) - for _, schemaColumn := range schema { - colType := schemaColumn.Type - if sql.IsInteger(colType) { - for nodeIdx, insertColumn := range node.Columns { - if schemaColumn.Name == insertColumn { - columnsToConvert[nodeIdx] = colType - } - } - } - } - - // Replace the values in the node with the converted ones - for _, valuesTuple := range values.ExpressionTuples { - for colIdx, newType := range columnsToConvert { - oldValue := valuesTuple[colIdx].(*expression.Literal).Value() - // Do not convert nil values, Convert() may make them zero - if oldValue != nil { - newValue, err := newType.Convert(oldValue) - if err != nil { - return nil, err - } - valuesTuple[colIdx] = expression.NewLiteral(newValue, newType) - } - } - } - } - return node, nil - }) -} diff --git a/sql/analyzer/resolve_insert_literals_test.go b/sql/analyzer/resolve_insert_literals_test.go deleted file mode 100644 index c19ae1642..000000000 --- a/sql/analyzer/resolve_insert_literals_test.go +++ /dev/null @@ -1,165 +0,0 @@ -package analyzer - -import ( - "testing" - - "github.com/src-d/go-mysql-server/memory" - "github.com/src-d/go-mysql-server/sql" - "github.com/src-d/go-mysql-server/sql/expression" - "github.com/src-d/go-mysql-server/sql/plan" - "github.com/stretchr/testify/require" -) - -// Common data for most of the tests below -var ( - schema = sql.Schema{ - {Name: "i8", Type: sql.Int8, Source: "table"}, - {Name: "i16", Type: sql.Int16, Source: "table"}, - {Name: "i32", Type: sql.Int32, Source: "table"}, - {Name: "i64", Type: sql.Int64, Source: "table"}, - {Name: "ui8", Type: sql.Uint8, Source: "table"}, - {Name: "ui16", Type: sql.Uint16, Source: "table"}, - {Name: "ui32", Type: sql.Uint32, Source: "table"}, - {Name: "ui64", Type: sql.Uint64, Source: "table"}, - } - - orderedColumns = []string{"i8", "i16", "i32", "i64", "ui8", "ui16", "ui32", "ui64"} - - inputValues = [][]sql.Expression{{ - expression.NewLiteral(int8(1), sql.Int8), - expression.NewLiteral(int8(1), sql.Int8), - expression.NewLiteral(int8(1), sql.Int8), - expression.NewLiteral(int8(1), sql.Int8), - expression.NewLiteral(int8(1), sql.Uint8), - expression.NewLiteral(int8(1), sql.Uint8), - expression.NewLiteral(int8(1), sql.Uint8), - expression.NewLiteral(int8(1), sql.Uint8), - }} -) - -// Test the correct conversion of integer literals in INSERT nodes when no -// columns are explicitely specified by the plan -func TestInsertLiteralsWithoutColumns(t *testing.T) { - require := require.New(t) - - table := memory.NewTable("table", schema) - - // An INSERT node with an empty columns field: the expected columns should be - // the ones in the schema - node := plan.NewInsertInto( - plan.NewResolvedTable(table), - plan.NewValues(inputValues), - false, - []string{}, - ) - - rule := getRule("resolve_insert_literals") - result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) - require.NoError(err) - - // The expected result should have the integers converted to the types - // specified by the schema, as well as the columns field populated with the - // schema columns in order - expected := plan.NewInsertInto( - plan.NewResolvedTable(table), - plan.NewValues([][]sql.Expression{{ - expression.NewLiteral(int8(1), sql.Int8), - expression.NewLiteral(int16(1), sql.Int16), - expression.NewLiteral(int32(1), sql.Int32), - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(uint8(1), sql.Uint8), - expression.NewLiteral(uint16(1), sql.Uint16), - expression.NewLiteral(uint32(1), sql.Uint32), - expression.NewLiteral(uint64(1), sql.Uint64), - }}), - false, - orderedColumns, - ) - - require.Equal(expected, result) -} - -// Test the correct conversion of integer literals in INSERT nodes when the -// node has a explicit order of columns, different than the one in the schema -func TestInsertLiteralsWithColumns(t *testing.T) { - require := require.New(t) - - table := memory.NewTable("table", schema) - - // First unsigned, then signed - unorderedColumns := []string{"ui8", "ui16", "ui32", "ui64", "i8", "i16", "i32", "i64"} - - // An INSERT node with an explicit columns field, unordered with respect to - // the schema - node := plan.NewInsertInto( - plan.NewResolvedTable(table), - plan.NewValues(inputValues), - false, - unorderedColumns, - ) - - rule := getRule("resolve_insert_literals") - result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) - require.NoError(err) - - // The expected result should have the integers converted to the types - // specified by the schema, in the order specified by the columns field of - // the INSERT node - expected := plan.NewInsertInto( - plan.NewResolvedTable(table), - plan.NewValues([][]sql.Expression{{ - expression.NewLiteral(uint8(1), sql.Uint8), - expression.NewLiteral(uint16(1), sql.Uint16), - expression.NewLiteral(uint32(1), sql.Uint32), - expression.NewLiteral(uint64(1), sql.Uint64), - expression.NewLiteral(int8(1), sql.Int8), - expression.NewLiteral(int16(1), sql.Int16), - expression.NewLiteral(int32(1), sql.Int32), - expression.NewLiteral(int64(1), sql.Int64), - }}), - false, - unorderedColumns, - ) - - require.Equal(expected, result) -} - -// Test that non-integer literals are unchanged after applying the conversion -// of integers in INSERT nodes -func TestInsertLiteralsUnchanged(t *testing.T) { - require := require.New(t) - - table := memory.NewTable("table", sql.Schema{ - {Name: "f32", Type: sql.Float32, Source: "typestable", Nullable: true}, - {Name: "f64", Type: sql.Float64, Source: "typestable", Nullable: true}, - {Name: "time", Type: sql.Timestamp, Source: "typestable", Nullable: true}, - {Name: "date", Type: sql.Date, Source: "typestable", Nullable: true}, - {Name: "text", Type: sql.Text, Source: "typestable", Nullable: true}, - {Name: "bool", Type: sql.Boolean, Source: "typestable", Nullable: true}, - {Name: "json", Type: sql.JSON, Source: "typestable", Nullable: true}, - {Name: "blob", Type: sql.Blob, Source: "typestable", Nullable: true}, - }) - - node := plan.NewInsertInto( - plan.NewResolvedTable(table), - plan.NewValues([][]sql.Expression{{ - expression.NewLiteral(float64(1.0), sql.Float32), - expression.NewLiteral(float64(5.0), sql.Float64), - expression.NewLiteral("1234-05-06 07:08:09", sql.Timestamp), - expression.NewLiteral("1234-05-06", sql.Date), - expression.NewLiteral("there be dragons", sql.Text), - expression.NewLiteral(false, sql.Boolean), - expression.NewLiteral(`{"key":"value"}`, sql.JSON), - expression.NewLiteral("blipblop", sql.Blob), - }}), - false, - []string{"f32", "f64", "time", "date", "text", "bool", "json", "blob"}, - ) - - rule := getRule("resolve_insert_literals") - result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) - require.NoError(err) - - // The original node should be unchanged, as there are no integers - require.Equal(node, result) -} diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index 40e62e708..c2b8daf00 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -12,7 +12,6 @@ var DefaultRules = []Rule{ {"resolve_grouping_columns", resolveGroupingColumns}, {"qualify_columns", qualifyColumns}, {"resolve_columns", resolveColumns}, - {"resolve_insert_literals", convertIntegerLiteralsInsert}, {"resolve_database", resolveDatabase}, {"resolve_star", resolveStar}, {"resolve_functions", resolveFunctions}, diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 46f6f4df0..06e638d26 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -145,6 +145,20 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) { return i, err } + // Convert integer values in row to specified type in schema + for colIdx, oldValue := range row { + dstColType := projExprs[colIdx].Type() + + if sql.IsInteger(dstColType) && oldValue != nil { + newValue, err := dstColType.Convert(oldValue) + if err != nil { + return i, err + } + + row[colIdx] = newValue + } + } + if replaceable != nil { if err = replaceable.Delete(ctx, row); err != nil { if err != sql.ErrDeleteRowNotFound { From dcf26123d88456861d72361a3f9fbf7f4c2a0211 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Mon, 14 Oct 2019 10:30:17 +0200 Subject: [PATCH 17/44] Test missing cases in ROUND function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- .../function/ceil_round_floor_test.go | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/sql/expression/function/ceil_round_floor_test.go b/sql/expression/function/ceil_round_floor_test.go index 2ad014c42..4af2456ef 100644 --- a/sql/expression/function/ceil_round_floor_test.go +++ b/sql/expression/function/ceil_round_floor_test.go @@ -156,6 +156,50 @@ func TestRound(t *testing.T) { {"int32 with float d", sql.Int32, sql.Float64, sql.NewRow(int32(5), float32(2.123)), int32(5), nil}, {"int32 with float negative d", sql.Int32, sql.Float64, sql.NewRow(int32(52), float32(-1)), int32(50), nil}, {"int32 with blob d", sql.Int32, sql.Blob, sql.NewRow(int32(5), []byte{1, 2, 3}), int32(5), nil}, + {"int16 is nil", sql.Int16, sql.Int16, sql.NewRow(nil, nil), nil, nil}, + {"int16 without d", sql.Int16, sql.Int16, sql.NewRow(int16(5), nil), int16(5), nil}, + {"int16 with d", sql.Int16, sql.Int16, sql.NewRow(int16(5), 2), int16(5), nil}, + {"int16 with negative d", sql.Int16, sql.Int16, sql.NewRow(int16(52), -1), int16(50), nil}, + {"int16 with float d", sql.Int16, sql.Float64, sql.NewRow(int16(5), float32(2.123)), int16(5), nil}, + {"int16 with float negative d", sql.Int16, sql.Float64, sql.NewRow(int16(52), float32(-1)), int16(50), nil}, + {"int16 with blob d", sql.Int16, sql.Blob, sql.NewRow(int16(5), []byte{1, 2, 3}), int16(5), nil}, + {"int8 is nil", sql.Int8, sql.Int8, sql.NewRow(nil, nil), nil, nil}, + {"int8 without d", sql.Int8, sql.Int8, sql.NewRow(int8(5), nil), int8(5), nil}, + {"int8 with d", sql.Int8, sql.Int8, sql.NewRow(int8(5), 2), int8(5), nil}, + {"int8 with negative d", sql.Int8, sql.Int8, sql.NewRow(int8(52), -1), int8(50), nil}, + {"int8 with float d", sql.Int8, sql.Float64, sql.NewRow(int8(5), float32(2.123)), int8(5), nil}, + {"int8 with float negative d", sql.Int8, sql.Float64, sql.NewRow(int8(52), float32(-1)), int8(50), nil}, + {"int8 with blob d", sql.Int8, sql.Blob, sql.NewRow(int8(5), []byte{1, 2, 3}), int8(5), nil}, + {"uint64 is nil", sql.Uint64, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"uint64 without d", sql.Uint64, sql.Int32, sql.NewRow(uint64(5), nil), uint64(5), nil}, + {"uint64 with d", sql.Uint64, sql.Int32, sql.NewRow(uint64(5), 2), uint64(5), nil}, + {"uint64 with negative d", sql.Uint64, sql.Int32, sql.NewRow(uint64(52), -1), uint64(50), nil}, + {"uint64 with float d", sql.Uint64, sql.Float64, sql.NewRow(uint64(5), float32(2.123)), uint64(5), nil}, + {"uint64 with float negative d", sql.Uint64, sql.Float64, sql.NewRow(uint64(52), float32(-1)), uint64(50), nil}, + {"uint32 with blob d", sql.Uint32, sql.Blob, sql.NewRow(uint32(5), []byte{1, 2, 3}), uint32(5), nil}, + {"uint32 is nil", sql.Uint32, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"uint32 without d", sql.Uint32, sql.Int32, sql.NewRow(uint32(5), nil), uint32(5), nil}, + {"uint32 with d", sql.Uint32, sql.Int32, sql.NewRow(uint32(5), 2), uint32(5), nil}, + {"uint32 with negative d", sql.Uint32, sql.Int32, sql.NewRow(uint32(52), -1), uint32(50), nil}, + {"uint32 with float d", sql.Uint32, sql.Float64, sql.NewRow(uint32(5), float32(2.123)), uint32(5), nil}, + {"uint32 with float negative d", sql.Uint32, sql.Float64, sql.NewRow(uint32(52), float32(-1)), uint32(50), nil}, + {"uint32 with blob d", sql.Uint32, sql.Blob, sql.NewRow(uint32(5), []byte{1, 2, 3}), uint32(5), nil}, + {"uint16 with blob d", sql.Uint16, sql.Blob, sql.NewRow(uint16(5), []byte{1, 2, 3}), uint16(5), nil}, + {"uint16 is nil", sql.Uint16, sql.Int16, sql.NewRow(nil, nil), nil, nil}, + {"uint16 without d", sql.Uint16, sql.Int16, sql.NewRow(uint16(5), nil), uint16(5), nil}, + {"uint16 with d", sql.Uint16, sql.Int16, sql.NewRow(uint16(5), 2), uint16(5), nil}, + {"uint16 with negative d", sql.Uint16, sql.Int16, sql.NewRow(uint16(52), -1), uint16(50), nil}, + {"uint16 with float d", sql.Uint16, sql.Float64, sql.NewRow(uint16(5), float32(2.123)), uint16(5), nil}, + {"uint16 with float negative d", sql.Uint16, sql.Float64, sql.NewRow(uint16(52), float32(-1)), uint16(50), nil}, + {"uint16 with blob d", sql.Uint16, sql.Blob, sql.NewRow(uint16(5), []byte{1, 2, 3}), uint16(5), nil}, + {"uint8 with blob d", sql.Uint8, sql.Blob, sql.NewRow(uint8(5), []byte{1, 2, 3}), uint8(5), nil}, + {"uint8 is nil", sql.Uint8, sql.Int8, sql.NewRow(nil, nil), nil, nil}, + {"uint8 without d", sql.Uint8, sql.Int8, sql.NewRow(uint8(5), nil), uint8(5), nil}, + {"uint8 with d", sql.Uint8, sql.Int8, sql.NewRow(uint8(5), 2), uint8(5), nil}, + {"uint8 with negative d", sql.Uint8, sql.Int8, sql.NewRow(uint8(52), -1), uint8(50), nil}, + {"uint8 with float d", sql.Uint8, sql.Float64, sql.NewRow(uint8(5), float32(2.123)), uint8(5), nil}, + {"uint8 with float negative d", sql.Uint8, sql.Float64, sql.NewRow(uint8(52), float32(-1)), uint8(50), nil}, + {"uint8 with blob d", sql.Uint8, sql.Blob, sql.NewRow(uint8(5), []byte{1, 2, 3}), uint8(5), nil}, {"blob is nil", sql.Blob, sql.Int32, sql.NewRow(nil, nil), nil, nil}, {"blob is ok", sql.Blob, sql.Int32, sql.NewRow([]byte{1, 2, 3}, nil), int32(0), nil}, {"text int without d", sql.Text, sql.Int32, sql.NewRow("5", nil), int32(5), nil}, From 2233b4d5051440cb3f35e628d69264dfde121b02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Mon, 14 Oct 2019 10:30:40 +0200 Subject: [PATCH 18/44] Refactor type tests to cover all missing cases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- sql/type_test.go | 130 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 110 insertions(+), 20 deletions(-) diff --git a/sql/type_test.go b/sql/type_test.go index a73bf08be..5ca9ac9e5 100644 --- a/sql/type_test.go +++ b/sql/type_test.go @@ -46,16 +46,118 @@ func TestBoolean(t *testing.T) { eq(t, Boolean, false, false) } +// Test conversion of all types of numbers to the specified signed integer type +// in typ, where minusOne, zero and one are the expected values with the +// same type as typ +func testSignedInt(t *testing.T, typ Type, minusOne, zero, one interface{}) { + t.Helper() + + convert(t, typ, -1, minusOne) + convert(t, typ, int8(-1), minusOne) + convert(t, typ, int16(-1), minusOne) + convert(t, typ, int32(-1), minusOne) + convert(t, typ, int64(-1), minusOne) + convert(t, typ, 0, zero) + convert(t, typ, int8(0), zero) + convert(t, typ, int16(0), zero) + convert(t, typ, int32(0), zero) + convert(t, typ, int64(0), zero) + convert(t, typ, uint8(0), zero) + convert(t, typ, uint16(0), zero) + convert(t, typ, uint32(0), zero) + convert(t, typ, uint64(0), zero) + convert(t, typ, 1, one) + convert(t, typ, int8(1), one) + convert(t, typ, int16(1), one) + convert(t, typ, int32(1), one) + convert(t, typ, int64(1), one) + convert(t, typ, uint8(1), one) + convert(t, typ, uint16(1), one) + convert(t, typ, uint32(1), one) + convert(t, typ, uint64(1), one) + convert(t, typ, "-1", minusOne) + convert(t, typ, "0", zero) + convert(t, typ, "1", one) + convertErr(t, typ, "") + + lt(t, Int8, minusOne, one) + eq(t, Int8, zero, zero) + eq(t, Int8, minusOne, minusOne) + eq(t, Int8, one, one) + gt(t, Int8, one, minusOne) +} + +// Test conversion of all types of numbers to the specified unsigned integer +// type in typ, where zero and one are the expected values with the same type +// as typ. The expected errors when converting from negative numbers are also +// tested +func testUnsignedInt(t *testing.T, typ Type, zero, one interface{}) { + t.Helper() + + convertErr(t, typ, -1) + convertErr(t, typ, int8(-1)) + convertErr(t, typ, int16(-1)) + convertErr(t, typ, int32(-1)) + convertErr(t, typ, int64(-1)) + convert(t, typ, 0, zero) + convert(t, typ, int8(0), zero) + convert(t, typ, int16(0), zero) + convert(t, typ, int32(0), zero) + convert(t, typ, int64(0), zero) + convert(t, typ, uint8(0), zero) + convert(t, typ, uint16(0), zero) + convert(t, typ, uint32(0), zero) + convert(t, typ, uint64(0), zero) + convert(t, typ, 1, one) + convert(t, typ, int8(1), one) + convert(t, typ, int16(1), one) + convert(t, typ, int32(1), one) + convert(t, typ, int64(1), one) + convert(t, typ, uint8(1), one) + convert(t, typ, uint16(1), one) + convert(t, typ, uint32(1), one) + convert(t, typ, uint64(1), one) + convertErr(t, typ, "-1") + convert(t, typ, "0", zero) + convert(t, typ, "1", one) + convertErr(t, typ, "") + + lt(t, Int8, zero, one) + eq(t, Int8, zero, zero) + eq(t, Int8, one, one) + gt(t, Int8, one, zero) +} + +func TestInt8(t *testing.T) { + testSignedInt(t, Int8, int8(-1), int8(0), int8(1)) +} + +func TestInt16(t *testing.T) { + testSignedInt(t, Int16, int16(-1), int16(0), int16(1)) +} + func TestInt32(t *testing.T) { - convert(t, Int32, int32(1), int32(1)) - convert(t, Int32, 1, int32(1)) - convert(t, Int32, int64(1), int32(1)) - convert(t, Int32, "5", int32(5)) - convertErr(t, Int32, "") + testSignedInt(t, Int32, int32(-1), int32(0), int32(1)) +} - lt(t, Int32, int32(1), int32(2)) - eq(t, Int32, int32(1), int32(1)) - gt(t, Int32, int32(3), int32(2)) +func TestInt64(t *testing.T) { + testSignedInt(t, Int64, int64(-1), int64(0), int64(1)) +} + +func TestUint8(t *testing.T) { + testUnsignedInt(t, Uint8, uint8(0), uint8(1)) +} + +func TestUint16(t *testing.T) { + testUnsignedInt(t, Uint16, uint16(0), uint16(1)) +} + +func TestUint32(t *testing.T) { + testUnsignedInt(t, Uint32, uint32(0), uint32(1)) +} + +func TestUint64(t *testing.T) { + testUnsignedInt(t, Uint64, uint64(0), uint64(1)) } func TestNumberComparison(t *testing.T) { @@ -140,18 +242,6 @@ func TestNumberComparison(t *testing.T) { } } -func TestInt64(t *testing.T) { - convert(t, Int64, int32(1), int64(1)) - convert(t, Int64, 1, int64(1)) - convert(t, Int64, int64(1), int64(1)) - convertErr(t, Int64, "") - convert(t, Int64, "5", int64(5)) - - lt(t, Int64, int64(1), int64(2)) - eq(t, Int64, int64(1), int64(1)) - gt(t, Int64, int64(3), int64(2)) -} - func TestFloat64(t *testing.T) { require := require.New(t) From d436a962db5154799a28352a39f9a33b5d8049b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Mon, 14 Oct 2019 11:10:43 +0200 Subject: [PATCH 19/44] Add missing cases to pilosa decodeGob test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- sql/index/pilosa/lookup_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/index/pilosa/lookup_test.go b/sql/index/pilosa/lookup_test.go index e93a3e587..b93da3ca2 100644 --- a/sql/index/pilosa/lookup_test.go +++ b/sql/index/pilosa/lookup_test.go @@ -98,8 +98,12 @@ func TestCompare(t *testing.T) { func TestDecodeGob(t *testing.T) { testCases := []interface{}{ "foo", + int8(1), + int16(1), int32(1), int64(1), + uint8(1), + uint16(1), uint32(1), uint64(1), float64(1), From c0899fbb58302ab97a0ccc4e0f800aa7e194122f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Mon, 14 Oct 2019 11:12:52 +0200 Subject: [PATCH 20/44] Test parsing of integer literals MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alejandro García Montoro --- sql/parse/parse_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index ab5d7acfc..ad8bfc62d 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -1,6 +1,7 @@ package parse import ( + "math" "testing" "github.com/src-d/go-mysql-server/sql/expression" @@ -1191,6 +1192,23 @@ var fixtures = map[string]sql.Node{ []sql.Expression{}, plan.NewUnresolvedTable("foo", ""), ), + `SELECT -128, 127, 255, -32768, 32767, 65535, -2147483648, 2147483647, 4294967295, -9223372036854775808, 9223372036854775807, 18446744073709551615`: plan.NewProject( + []sql.Expression{ + expression.NewLiteral(int8(math.MinInt8), sql.Int8), + expression.NewLiteral(int8(math.MaxInt8), sql.Int8), + expression.NewLiteral(uint8(math.MaxUint8), sql.Uint8), + expression.NewLiteral(int16(math.MinInt16), sql.Int16), + expression.NewLiteral(int16(math.MaxInt16), sql.Int16), + expression.NewLiteral(uint16(math.MaxUint16), sql.Uint16), + expression.NewLiteral(int32(math.MinInt32), sql.Int32), + expression.NewLiteral(int32(math.MaxInt32), sql.Int32), + expression.NewLiteral(uint32(math.MaxUint32), sql.Uint32), + expression.NewLiteral(int64(math.MinInt64), sql.Int64), + expression.NewLiteral(int64(math.MaxInt64), sql.Int64), + expression.NewLiteral(uint64(math.MaxUint64), sql.Uint64), + }, + plan.NewUnresolvedTable("dual", ""), + ), } func TestParse(t *testing.T) { From 2a6ea8de78a8694a78b379a1eae17f81b0f0ccbe Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Wed, 16 Oct 2019 12:32:50 +0200 Subject: [PATCH 21/44] analyzer: only optimize distinct for sorts where first column is in schema Signed-off-by: Miguel Molina --- engine_test.go | 14 +++++++ sql/analyzer/optimization_rules.go | 10 +++-- sql/analyzer/optimization_rules_test.go | 56 +++++++++++++++++++------ 3 files changed, 63 insertions(+), 17 deletions(-) diff --git a/engine_test.go b/engine_test.go index 1ba6ef32a..536e02074 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1570,6 +1570,20 @@ var queries = []struct { `SELECT (SELECT i FROM mytable ORDER BY i ASC LIMIT 1) AS x`, []sql.Row{{int64(1)}}, }, + { + `SELECT DISTINCT n FROM bigtable ORDER BY t`, + []sql.Row{ + {int64(1)}, + {int64(9)}, + {int64(7)}, + {int64(3)}, + {int64(2)}, + {int64(8)}, + {int64(6)}, + {int64(5)}, + {int64(4)}, + }, + }, } func TestQueries(t *testing.T) { diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index 0037d9b0d..88283cf9f 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -34,17 +34,19 @@ func optimizeDistinct(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, e a.Log("optimize distinct, node of type: %T", node) if n, ok := node.(*plan.Distinct); ok { - var isSorted bool + var sortField *expression.GetField plan.Inspect(n, func(node sql.Node) bool { a.Log("checking for optimization in node of type: %T", node) - if _, ok := node.(*plan.Sort); ok { - isSorted = true + if sort, ok := node.(*plan.Sort); ok && sortField == nil { + if col, ok := sort.SortFields[0].Column.(*expression.GetField); ok { + sortField = col + } return false } return true }) - if isSorted { + if sortField != nil && n.Schema().Contains(sortField.Name(), sortField.Table()) { a.Log("distinct optimized for ordered output") return plan.NewOrderedDistinct(n.Child), nil } diff --git a/sql/analyzer/optimization_rules_test.go b/sql/analyzer/optimization_rules_test.go index 285117251..60dee3b0f 100644 --- a/sql/analyzer/optimization_rules_test.go +++ b/sql/analyzer/optimization_rules_test.go @@ -186,24 +186,54 @@ func TestEraseProjection(t *testing.T) { } func TestOptimizeDistinct(t *testing.T) { - require := require.New(t) - - t1 := memory.NewTable("foo", nil) - t2 := memory.NewTable("foo", nil) + t1 := memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo"}, + {Name: "b", Source: "foo"}, + }) - notSorted := plan.NewDistinct(plan.NewResolvedTable(t1)) - sorted := plan.NewDistinct(plan.NewSort(nil, plan.NewResolvedTable(t2))) + testCases := []struct { + name string + child sql.Node + optimized bool + }{ + { + "without sort", + plan.NewResolvedTable(t1), + false, + }, + { + "sort but column not projected", + plan.NewSort( + []plan.SortField{ + {Column: gf(0, "foo", "c")}, + }, + plan.NewResolvedTable(t1), + ), + false, + }, + { + "sort and column projected", + plan.NewSort( + []plan.SortField{ + {Column: gf(0, "foo", "a")}, + }, + plan.NewResolvedTable(t1), + ), + true, + }, + } rule := getRule("optimize_distinct") - analyzedNotSorted, err := rule.Apply(sql.NewEmptyContext(), nil, notSorted) - require.NoError(err) - - analyzedSorted, err := rule.Apply(sql.NewEmptyContext(), nil, sorted) - require.NoError(err) + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + node, err := rule.Apply(sql.NewEmptyContext(), nil, plan.NewDistinct(tt.child)) + require.NoError(t, err) - require.Equal(notSorted, analyzedNotSorted) - require.Equal(plan.NewOrderedDistinct(sorted.Child), analyzedSorted) + _, ok := node.(*plan.OrderedDistinct) + require.Equal(t, tt.optimized, ok) + }) + } } func TestMoveJoinConditionsToFilter(t *testing.T) { From 2c4b01ed05c4cfcc4f13ed62176e6bfdb9239315 Mon Sep 17 00:00:00 2001 From: Daylon Wilkins Date: Tue, 1 Oct 2019 12:14:50 -0700 Subject: [PATCH 22/44] Implemented UPDATE Signed-off-by: Daylon Wilkins --- engine.go | 2 +- engine_test.go | 149 ++++++++++++++++++++++++++++-- memory/table.go | 31 +++++++ sql/analyzer/pushdown.go | 2 +- sql/core.go | 6 ++ sql/expression/set.go | 61 +++++++++++++ sql/parse/parse.go | 61 +++++++++++++ sql/plan/update.go | 189 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 493 insertions(+), 8 deletions(-) create mode 100644 sql/expression/set.go create mode 100644 sql/plan/update.go diff --git a/engine.go b/engine.go index 4cb90c3ad..6be2be3c2 100644 --- a/engine.go +++ b/engine.go @@ -120,7 +120,7 @@ func (e *Engine) Query( case *plan.CreateIndex: typ = sql.CreateIndexProcess perm = auth.ReadPerm | auth.WritePerm - case *plan.InsertInto, *plan.DeleteFrom, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables: + case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables: perm = auth.ReadPerm | auth.WritePerm } diff --git a/engine_test.go b/engine_test.go index 536e02074..e1274e2ab 100644 --- a/engine_test.go +++ b/engine_test.go @@ -18,6 +18,7 @@ import ( "github.com/src-d/go-mysql-server/sql/plan" "github.com/src-d/go-mysql-server/test" + "github.com/stretchr/testify/require" ) @@ -2245,6 +2246,142 @@ func TestReplaceIntoErrors(t *testing.T) { } } +func TestUpdate(t *testing.T) { + var updates = []struct { + updateQuery string + expectedUpdate []sql.Row + selectQuery string + expectedSelect []sql.Row + }{ + { + "UPDATE mytable SET s = 'updated';", + []sql.Row{{int64(3), int64(3)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}}, + }, + { + "UPDATE mytable SET s = 'updated' WHERE i > 9999;", + []sql.Row{{int64(0), int64(0)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}}, + }, + { + "UPDATE mytable SET s = 'updated' WHERE i = 1;", + []sql.Row{{int64(1), int64(1)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "second row"}, {int64(3), "third row"}}, + }, + { + "UPDATE mytable SET s = 'updated' WHERE i <> 9999;", + []sql.Row{{int64(3), int64(3)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"},{int64(2), "updated"},{int64(3), "updated"}}, + }, + { + "UPDATE floattable SET f32 = f32 + f32, f64 = f32 * f64 WHERE i = 2;", + []sql.Row{{int64(1), int64(1)}}, + "SELECT * FROM floattable WHERE i = 2;", + []sql.Row{{int64(2), float32(3.0), float64(4.5)}}, + }, + { + "UPDATE floattable SET f32 = 5, f32 = 4 WHERE i = 1;", + []sql.Row{{int64(1), int64(1)}}, + "SELECT f32 FROM floattable WHERE i = 1;", + []sql.Row{{float32(4.0)}}, + }, + { + "UPDATE mytable SET s = 'first row' WHERE i = 1;", + []sql.Row{{int64(1), int64(0)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}}, + }, + { + "UPDATE niltable SET b = NULL WHERE f IS NULL;", + []sql.Row{{int64(2), int64(1)}}, + "SELECT * FROM niltable WHERE f IS NULL;", + []sql.Row{{int64(4), nil, nil}, {nil, nil, nil}}, + }, + { + "UPDATE mytable SET s = 'updated' ORDER BY i ASC LIMIT 2;", + []sql.Row{{int64(2), int64(2)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "third row"}}, + }, + { + "UPDATE mytable SET s = 'updated' ORDER BY i DESC LIMIT 2;", + []sql.Row{{int64(2), int64(2)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "updated"}, {int64(3), "updated"}}, + }, + { + "UPDATE mytable SET s = 'updated' ORDER BY i LIMIT 1 OFFSET 1;", + []sql.Row{{int64(1), int64(1)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "updated"}, {int64(3), "third row"}}, + }, + { + "UPDATE mytable SET s = 'updated';", + []sql.Row{{int64(3), int64(3)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}}, + }, + } + + for _, update := range updates { + e := newEngine(t) + ctx := newCtx() + testQueryWithContext(ctx, t, e, update.updateQuery, update.expectedUpdate) + testQueryWithContext(ctx, t, e, update.selectQuery, update.expectedSelect) + } +} + +func TestUpdateErrors(t *testing.T) { + var expectedFailures = []struct { + name string + query string + }{ + { + "invalid table", + "UPDATE doesnotexist SET i = 0;", + }, + { + "invalid column set", + "UPDATE mytable SET z = 0;", + }, + { + "invalid column set value", + "UPDATE mytable SET i = z;", + }, + { + "invalid column where", + "UPDATE mytable SET s = 'hi' WHERE z = 1;", + }, + { + "invalid column order by", + "UPDATE mytable SET s = 'hi' ORDER BY z;", + }, + { + "negative limit", + "UPDATE mytable SET s = 'hi' LIMIT -1;", + }, + { + "negative offset", + "UPDATE mytable SET s = 'hi' LIMIT 1 OFFSET -1;", + }, + { + "set null on non-nullable", + "UPDATE mytable SET s = NULL;", + }, + } + + for _, expectedFailure := range expectedFailures { + t.Run(expectedFailure.name, func(t *testing.T) { + _, _, err := newEngine(t).Query(newCtx(), expectedFailure.query) + require.Error(t, err) + }) + } +} + const testNumPartitions = 5 func TestAmbiguousColumnResolution(t *testing.T) { @@ -2670,12 +2807,12 @@ func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine { insertRows( t, floatTable, - sql.NewRow(1, float32(1.0), float64(1.0)), - sql.NewRow(2, float32(1.5), float64(1.5)), - sql.NewRow(3, float32(2.0), float64(2.0)), - sql.NewRow(4, float32(2.5), float64(2.5)), - sql.NewRow(-1, float32(-1.0), float64(-1.0)), - sql.NewRow(-2, float32(-1.5), float64(-1.5)), + sql.NewRow(int64(1), float32(1.0), float64(1.0)), + sql.NewRow(int64(2), float32(1.5), float64(1.5)), + sql.NewRow(int64(3), float32(2.0), float64(2.0)), + sql.NewRow(int64(4), float32(2.5), float64(2.5)), + sql.NewRow(int64(-1), float32(-1.0), float64(-1.0)), + sql.NewRow(int64(-2), float32(-1.5), float64(-1.5)), ) nilTable := memory.NewPartitionedTable("niltable", sql.Schema{ diff --git a/memory/table.go b/memory/table.go index a73ecedb8..088956316 100644 --- a/memory/table.go +++ b/memory/table.go @@ -290,6 +290,37 @@ func (t *Table) Delete(ctx *sql.Context, row sql.Row) error { return nil } +func (t *Table) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error { + if err := checkRow(t.schema, oldRow); err != nil { + return err + } + if err := checkRow(t.schema, newRow); err != nil { + return err + } + + matches := false + for partitionIndex, partition := range t.partitions { + for partitionRowIndex, partitionRow := range partition { + matches = true + for rIndex, val := range oldRow { + if val != partitionRow[rIndex] { + matches = false + break + } + } + if matches { + t.partitions[partitionIndex][partitionRowIndex] = newRow + break + } + } + if matches { + break + } + } + + return nil +} + func checkRow(schema sql.Schema, row sql.Row) error { if len(row) != len(schema) { return sql.ErrUnexpectedRowLength.New(len(schema), len(row)) diff --git a/sql/analyzer/pushdown.go b/sql/analyzer/pushdown.go index 34312b1ac..07049c413 100644 --- a/sql/analyzer/pushdown.go +++ b/sql/analyzer/pushdown.go @@ -20,7 +20,7 @@ func pushdown(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { // don't do pushdown on certain queries switch n.(type) { - case *plan.InsertInto, *plan.DeleteFrom, *plan.CreateIndex: + case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.CreateIndex: return n, nil } diff --git a/sql/core.go b/sql/core.go index e91ce0b67..7a8875062 100644 --- a/sql/core.go +++ b/sql/core.go @@ -217,6 +217,12 @@ type Replacer interface { Inserter } +// Updater allows rows to be updated. +type Updater interface { + // Update the given row. Provides both the old and new rows. + Update(ctx *Context, old Row, new Row) error +} + // Database represents the database. type Database interface { Nameable diff --git a/sql/expression/set.go b/sql/expression/set.go new file mode 100644 index 000000000..d18bde374 --- /dev/null +++ b/sql/expression/set.go @@ -0,0 +1,61 @@ +package expression + +import ( + "fmt" + "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" +) + +var errCannotSetField = errors.NewKind("Expected GetField expression on left but got %T") + +// SetField updates the value of a field from a row. +type SetField struct { + BinaryExpression +} + +// NewSetField creates a new SetField expression. +func NewSetField(colName, expr sql.Expression) sql.Expression { + return &SetField{BinaryExpression{Left: colName, Right: expr}} +} + +func (s *SetField) String() string { + return fmt.Sprintf("SETFIELD %s = %s", s.Left, s.Right) +} + +// Type implements the Expression interface. +func (s *SetField) Type() sql.Type { + return s.Left.Type() +} + +// Eval implements the Expression interface. +// Returns a copy of the given row with an updated value. +func (s *SetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + getField, ok := s.Left.(*GetField) + if !ok { + return nil, errCannotSetField.New(s.Left) + } + if getField.fieldIndex < 0 || getField.fieldIndex >= len(row) { + return nil, ErrIndexOutOfBounds.New(getField.fieldIndex, len(row)) + } + val, err := s.Right.Eval(ctx, row) + if err != nil { + return nil, err + } + if val != nil { + val, err = getField.fieldType.Convert(val) + if err != nil { + return nil, err + } + } + updatedRow := row.Copy() + updatedRow[getField.fieldIndex] = val + return updatedRow, nil +} + +// WithChildren implements the Expression interface. +func (s *SetField) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 2) + } + return NewSetField(children[0], children[1]), nil +} \ No newline at end of file diff --git a/sql/parse/parse.go b/sql/parse/parse.go index e77819f2a..a7f065654 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -153,6 +153,8 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node return plan.NewRollback(), nil case *sqlparser.Delete: return convertDelete(ctx, n) + case *sqlparser.Update: + return convertUpdate(ctx, n) } } @@ -429,6 +431,49 @@ func convertDelete(ctx *sql.Context, d *sqlparser.Delete) (sql.Node, error) { return plan.NewDeleteFrom(node), nil } +func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) { + node, err := tableExprsToTable(ctx, d.TableExprs) + if err != nil { + return nil, err + } + + updateExprs, err := updateExprsToExpressions(d.Exprs) + if err != nil { + return nil, err + } + + if d.Where != nil { + node, err = whereToFilter(d.Where, node) + if err != nil { + return nil, err + } + } + + if len(d.OrderBy) != 0 { + node, err = orderByToSort(d.OrderBy, node) + if err != nil { + return nil, err + } + } + + // Limit must wrap offset, and not vice-versa, so that skipped rows don't count toward the returned row count. + if d.Limit != nil && d.Limit.Offset != nil { + node, err = offsetToOffset(ctx, d.Limit.Offset, node) + if err != nil { + return nil, err + } + } + + if d.Limit != nil { + node, err = limitToLimit(ctx, d.Limit.Rowcount, node) + if err != nil { + return nil, err + } + } + + return plan.NewUpdate(node, updateExprs), nil +} + func columnDefinitionToSchema(colDef []*sqlparser.ColumnDefinition) (sql.Schema, error) { var schema sql.Schema for _, cd := range colDef { @@ -1241,6 +1286,22 @@ func intervalExprToExpression(ctx *sql.Context, e *sqlparser.IntervalExpr) (sql. return expression.NewInterval(expr, e.Unit), nil } +func updateExprsToExpressions(e sqlparser.UpdateExprs) ([]sql.Expression, error) { + res := make([]sql.Expression, len(e)) + for i, updateExpr := range e { + colName, err := exprToExpression(updateExpr.Name) + if err != nil { + return nil, err + } + innerExpr, err := exprToExpression(updateExpr.Expr) + if err != nil { + return nil, err + } + res[i] = expression.NewSetField(colName, innerExpr) + } + return res, nil +} + func removeComments(s string) string { r := bufio.NewReader(strings.NewReader(s)) var result []rune diff --git a/sql/plan/update.go b/sql/plan/update.go new file mode 100644 index 000000000..a29fb7392 --- /dev/null +++ b/sql/plan/update.go @@ -0,0 +1,189 @@ +package plan + +import ( + "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" + "io" +) + +var ErrUpdateNotSupported = errors.NewKind("table doesn't support UPDATE") +var ErrUpdateUnexpectedSetResult = errors.NewKind("attempted to set field but expression returned %T") + +// Update is a node for updating rows on tables. +type Update struct { + sql.Node + UpdateExprs []sql.Expression +} + +// NewUpdate creates an Update node. +func NewUpdate(n sql.Node, updateExprs []sql.Expression) *Update { + return &Update{n, updateExprs} +} + +// Expressions implements the Expressioner interface. +func (p *Update) Expressions() []sql.Expression { + return p.UpdateExprs +} + +// Schema implements the Node interface. +func (p *Update) Schema() sql.Schema { + return sql.Schema{ + { + Name: "matched", + Type: sql.Int64, + Default: int64(0), + Nullable: false, + }, + { + Name: "updated", + Type: sql.Int64, + Default: int64(0), + Nullable: false, + }, + } +} + +// Resolved implements the Resolvable interface. +func (p *Update) Resolved() bool { + if !p.Node.Resolved() { + return false + } + for _, updateExpr := range p.UpdateExprs { + if !updateExpr.Resolved() { + return false + } + } + return true +} + +func (p *Update) Children() []sql.Node { + return []sql.Node{p.Node} +} + +func getUpdatable(node sql.Node) (sql.Updater, error) { + switch node := node.(type) { + case sql.Updater: + return node, nil + case *ResolvedTable: + return getUpdatableTable(node.Table) + } + for _, child := range node.Children() { + updater, _ := getUpdatable(child) + if updater != nil { + return updater, nil + } + } + return nil, ErrUpdateNotSupported.New() +} + +func getUpdatableTable(t sql.Table) (sql.Updater, error) { + switch t := t.(type) { + case sql.Updater: + return t, nil + case sql.TableWrapper: + return getUpdatableTable(t.Underlying()) + default: + return nil, ErrUpdateNotSupported.New() + } +} + +// Execute inserts the rows in the database. +func (p *Update) Execute(ctx *sql.Context) (int, int, error) { + updatable, err := getUpdatable(p.Node) + if err != nil { + return 0, 0, err + } + schema := p.Node.Schema() + + iter, err := p.Node.RowIter(ctx) + if err != nil { + return 0, 0, err + } + + rowsMatched := 0 + rowsUpdated := 0 + for { + oldRow, err := iter.Next() + if err == io.EOF { + break + } + if err != nil { + _ = iter.Close() + return rowsMatched, rowsUpdated, err + } + rowsMatched++ + + newRow, err := p.applyUpdates(ctx, oldRow) + if err != nil { + _ = iter.Close() + return rowsMatched, rowsUpdated, err + } + if equals, err := oldRow.Equals(newRow, schema); err == nil { + if !equals { + err = updatable.Update(ctx, oldRow, newRow) + if err != nil { + _ = iter.Close() + return rowsMatched, rowsUpdated, err + } + rowsUpdated++ + } + } else { + _ = iter.Close() + return rowsMatched, rowsUpdated, err + } + } + + return rowsMatched, rowsUpdated, nil +} + +// RowIter implements the Node interface. +func (p *Update) RowIter(ctx *sql.Context) (sql.RowIter, error) { + matched, updated, err := p.Execute(ctx) + if err != nil { + return nil, err + } + + return sql.RowsToRowIter(sql.NewRow(int64(matched), int64(updated))), nil +} + +// WithChildren implements the Node interface. +func (p *Update) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + return NewUpdate(children[0], p.UpdateExprs), nil +} + +// WithExpressions implements the Expressioner interface. +func (p *Update) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) { + if len(newExprs) != len(p.UpdateExprs) { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(p.UpdateExprs), 1) + } + return NewUpdate(p.Node, newExprs), nil +} + +func (p Update) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("Update") + _ = pr.WriteChildren(p.Node.String()) + for _, updateExpr := range p.UpdateExprs { + _ = pr.WriteChildren(updateExpr.String()) + } + return pr.String() +} + +func (p *Update) applyUpdates(ctx *sql.Context, row sql.Row) (sql.Row, error) { + var ok bool + prev := row + for _, updateExpr := range p.UpdateExprs { + val, err := updateExpr.Eval(ctx, prev) + if err != nil { + return nil, err + } + prev, ok = val.(sql.Row) + if !ok { + return nil, ErrUpdateUnexpectedSetResult.New(val) + } + } + return prev, nil +} From e847998dd283d6af3afa4f71ab98eef7afcf5bd2 Mon Sep 17 00:00:00 2001 From: Daylon Wilkins Date: Wed, 16 Oct 2019 10:56:09 -0700 Subject: [PATCH 23/44] Passed sql.Context into functions where it was missing Signed-off-by: Daylon Wilkins --- sql/parse/parse.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index a7f065654..5b6b9ca87 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -437,20 +437,20 @@ func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) { return nil, err } - updateExprs, err := updateExprsToExpressions(d.Exprs) + updateExprs, err := updateExprsToExpressions(ctx, d.Exprs) if err != nil { return nil, err } if d.Where != nil { - node, err = whereToFilter(d.Where, node) + node, err = whereToFilter(ctx, d.Where, node) if err != nil { return nil, err } } if len(d.OrderBy) != 0 { - node, err = orderByToSort(d.OrderBy, node) + node, err = orderByToSort(ctx, d.OrderBy, node) if err != nil { return nil, err } @@ -1286,14 +1286,14 @@ func intervalExprToExpression(ctx *sql.Context, e *sqlparser.IntervalExpr) (sql. return expression.NewInterval(expr, e.Unit), nil } -func updateExprsToExpressions(e sqlparser.UpdateExprs) ([]sql.Expression, error) { +func updateExprsToExpressions(ctx *sql.Context, e sqlparser.UpdateExprs) ([]sql.Expression, error) { res := make([]sql.Expression, len(e)) for i, updateExpr := range e { - colName, err := exprToExpression(updateExpr.Name) + colName, err := exprToExpression(ctx, updateExpr.Name) if err != nil { return nil, err } - innerExpr, err := exprToExpression(updateExpr.Expr) + innerExpr, err := exprToExpression(ctx, updateExpr.Expr) if err != nil { return nil, err } From 029818162b16be6c81a85f00d4bca6a82fdc7644 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 27 Sep 2019 15:08:26 -0700 Subject: [PATCH 24/44] First pass at drop table Signed-off-by: Zach Musgrave --- memory/database.go | 11 ++++++ sql/core.go | 5 +++ sql/parse/parse.go | 10 ++++++ sql/parse/parse_test.go | 9 +++++ sql/plan/ddl.go | 78 +++++++++++++++++++++++++++++++++++++++++ sql/plan/ddl_test.go | 59 ++++++++++++++++++++++++++----- 6 files changed, 163 insertions(+), 9 deletions(-) diff --git a/memory/database.go b/memory/database.go index b5b87a741..fcfda3fbc 100644 --- a/memory/database.go +++ b/memory/database.go @@ -43,3 +43,14 @@ func (d *Database) Create(name string, schema sql.Schema) error { d.tables[name] = NewTable(name, schema) return nil } + +func (d *Database) DropTable(name string, ifExists bool) error { + _, ok := d.tables[name] + if !ok && !ifExists { + return sql.ErrTableNotFound.New(name) + } + + delete(d.tables, name) + return nil +} + diff --git a/sql/core.go b/sql/core.go index 7a8875062..55424022f 100644 --- a/sql/core.go +++ b/sql/core.go @@ -235,6 +235,11 @@ type Alterable interface { Create(name string, schema Schema) error } +// Droppable should be implemented by databases that can drop tables. +type Droppable interface { + DropTable(name string, ifExists bool) error +} + // Lockable should be implemented by tables that can be locked and unlocked. type Lockable interface { Nameable diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 5b6b9ca87..eaa42d1b8 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -354,11 +354,21 @@ func convertDDL(c *sqlparser.DDL) (sql.Node, error) { switch c.Action { case sqlparser.CreateStr: return convertCreateTable(c) + case sqlparser.DropStr: + return convertDropTable(c) default: return nil, ErrUnsupportedSyntax.New(c) } } +func convertDropTable(c *sqlparser.DDL) (sql.Node, error) { + tableNames := make([]string, len(c.FromTables)) + for i, t := range c.FromTables { + tableNames[i] = t.Name.String() + } + return plan.NewDropTable(sql.UnresolvedDatabase(""), c.IfExists, tableNames...), nil +} + func convertCreateTable(c *sqlparser.DDL) (sql.Node, error) { schema, err := columnDefinitionToSchema(c.TableSpec.Columns) if err != nil { diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index ad8bfc62d..0f27794b4 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -51,6 +51,15 @@ var fixtures = map[string]sql.Node{ Nullable: true, }}, ), + `DROP TABLE foo;`: plan.NewDropTable( + sql.UnresolvedDatabase(""), false, "foo", + ), + `DROP TABLE IF EXISTS foo;`: plan.NewDropTable( + sql.UnresolvedDatabase(""), true, "foo", + ), + `DROP TABLE IF EXISTS foo, bar, baz;`: plan.NewDropTable( + sql.UnresolvedDatabase(""), true, "foo", "bar", "baz", + ), `DESCRIBE TABLE foo;`: plan.NewDescribe( plan.NewUnresolvedTable("foo", ""), ), diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index 3eb97331e..8e0af21e1 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -1,12 +1,14 @@ package plan import ( + "fmt" "github.com/src-d/go-mysql-server/sql" "gopkg.in/src-d/go-errors.v1" ) // ErrCreateTable is thrown when the database doesn't support table creation var ErrCreateTable = errors.NewKind("tables cannot be created on database %s") +var ErrDropTableNotSupported = errors.NewKind("tables cannot be dropped on database %s") // CreateTable is a node describing the creation of some table. type CreateTable struct { @@ -15,6 +17,13 @@ type CreateTable struct { schema sql.Schema } +// DropTable is a node describing dropping a table +type DropTable struct { + db sql.Database + names []string + ifExists bool +} + // NewCreateTable creates a new CreateTable node func NewCreateTable(db sql.Database, name string, schema sql.Schema) *CreateTable { for _, s := range schema { @@ -75,3 +84,72 @@ func (c *CreateTable) WithChildren(children ...sql.Node) (sql.Node, error) { func (c *CreateTable) String() string { return "CreateTable" } + +// NewDropTable creates a new DropTable node +func NewDropTable(db sql.Database, ifExists bool, tableNames ...string) *DropTable { + return &DropTable{ + db: db, + names: tableNames, + ifExists: ifExists, + } +} + +var _ sql.Databaser = (*DropTable)(nil) + +// Database implements the sql.Databaser interface. +func (d *DropTable) Database() sql.Database { + return d.db +} + +// WithDatabase implements the sql.Databaser interface. +func (d *DropTable) WithDatabase(db sql.Database) (sql.Node, error) { + nc := *d + nc.db = db + return &nc, nil +} + +// Resolved implements the Resolvable interface. +func (d *DropTable) Resolved() bool { + _, ok := d.db.(sql.UnresolvedDatabase) + return !ok +} + +// RowIter implements the Node interface. +func (d *DropTable) RowIter(s *sql.Context) (sql.RowIter, error) { + droppable, ok := d.db.(sql.Droppable) + if !ok { + return nil, ErrDropTableNotSupported.New(d.db.Name()) + } + + var err error + for _, tableName := range d.names { + err = droppable.DropTable(tableName, d.ifExists) + if err != nil { + break + } + } + + return sql.RowsToRowIter(), err +} + +// Schema implements the Node interface. +func (d *DropTable) Schema() sql.Schema { return nil } + +// Children implements the Node interface. +func (d *DropTable) Children() []sql.Node { return nil } + +// WithChildren implements the Node interface. +func (d *DropTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 0) + } + return d, nil +} + +func (d *DropTable) String() string { + ifExists := "" + if d.ifExists { + ifExists = "if exists " + } + return fmt.Sprintf("Drop table %s%s", ifExists, d.names) +} diff --git a/sql/plan/ddl_test.go b/sql/plan/ddl_test.go index a644857e9..63328efbf 100644 --- a/sql/plan/ddl_test.go +++ b/sql/plan/ddl_test.go @@ -22,15 +22,7 @@ func TestCreateTable(t *testing.T) { {Name: "c2", Type: sql.Int32}, } - c := NewCreateTable(db, "testTable", s) - - rows, err := c.RowIter(sql.NewEmptyContext()) - - require.NoError(err) - - r, err := rows.Next() - require.Equal(err, io.EOF) - require.Nil(r) + createTable(t, db, "testTable", s) tables = db.Tables() @@ -43,3 +35,52 @@ func TestCreateTable(t *testing.T) { require.Equal("testTable", s.Source) } } + +func TestDropTable(t *testing.T) { + require := require.New(t) + + db := memory.NewDatabase("test") + + s := sql.Schema{ + {Name: "c1", Type: sql.Text}, + {Name: "c2", Type: sql.Int32}, + } + + createTable(t, db, "testTable1", s) + createTable(t, db, "testTable2", s) + createTable(t, db, "testTable3", s) + + d := NewDropTable(db, false, "testTable1", "testTable2") + rows, err := d.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + r, err := rows.Next() + require.Equal(err, io.EOF) + require.Nil(r) + + _, ok := db.Tables()["testTable1"] + require.False(ok) + _, ok = db.Tables()["testTable2"] + require.False(ok) + _, ok = db.Tables()["testTable3"] + require.True(ok) + + d = NewDropTable(db, false, "testTable1") + _, err = d.RowIter(sql.NewEmptyContext()) + require.Error(err) + + d = NewDropTable(db, true, "testTable1") + _, err = d.RowIter(sql.NewEmptyContext()) + require.NoError(err) +} + +func createTable(t *testing.T, db sql.Database, name string, schema sql.Schema) { + c := NewCreateTable(db, name, schema) + + rows, err := c.RowIter(sql.NewEmptyContext()) + require.NoError(t, err) + + r, err := rows.Next() + require.Equal(t, err, io.EOF) + require.Nil(t, r) +} \ No newline at end of file From a514766021a0714126a3aaa8b413c301aecc73a8 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 27 Sep 2019 15:29:24 -0700 Subject: [PATCH 25/44] Changed the interface for Droppable to not consider the if exists flag Signed-off-by: Zach Musgrave --- memory/database.go | 4 ++-- sql/core.go | 2 +- sql/plan/ddl.go | 9 ++++++++- sql/plan/ddl_test.go | 7 +++++++ 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/memory/database.go b/memory/database.go index fcfda3fbc..8b0eeb2a5 100644 --- a/memory/database.go +++ b/memory/database.go @@ -44,9 +44,9 @@ func (d *Database) Create(name string, schema sql.Schema) error { return nil } -func (d *Database) DropTable(name string, ifExists bool) error { +func (d *Database) DropTable(name string) error { _, ok := d.tables[name] - if !ok && !ifExists { + if !ok { return sql.ErrTableNotFound.New(name) } diff --git a/sql/core.go b/sql/core.go index 55424022f..2fa51260a 100644 --- a/sql/core.go +++ b/sql/core.go @@ -237,7 +237,7 @@ type Alterable interface { // Droppable should be implemented by databases that can drop tables. type Droppable interface { - DropTable(name string, ifExists bool) error + DropTable(name string) error } // Lockable should be implemented by tables that can be locked and unlocked. diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index 8e0af21e1..308fec21f 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -123,7 +123,14 @@ func (d *DropTable) RowIter(s *sql.Context) (sql.RowIter, error) { var err error for _, tableName := range d.names { - err = droppable.DropTable(tableName, d.ifExists) + _, ok := d.db.Tables()[tableName] + if !ok { + if d.ifExists { + continue + } + return nil, sql.ErrTableNotFound.New(tableName) + } + err = droppable.DropTable(tableName) if err != nil { break } diff --git a/sql/plan/ddl_test.go b/sql/plan/ddl_test.go index 63328efbf..a226a631c 100644 --- a/sql/plan/ddl_test.go +++ b/sql/plan/ddl_test.go @@ -72,6 +72,13 @@ func TestDropTable(t *testing.T) { d = NewDropTable(db, true, "testTable1") _, err = d.RowIter(sql.NewEmptyContext()) require.NoError(err) + + d = NewDropTable(db, true, "testTable1", "testTable2", "testTable3") + _, err = d.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + _, ok = db.Tables()["testTable3"] + require.False(ok) } func createTable(t *testing.T, db sql.Database, name string, schema sql.Schema) { From 9af52bffde54f3f9570365c717ccd624397ae2b7 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 27 Sep 2019 15:59:19 -0700 Subject: [PATCH 26/44] Added a context to DropTable and CreateTable, and declared a new interface for create table statements, deprecating Alterable (which only supports Create) Signed-off-by: Zach Musgrave --- memory/database.go | 13 ++++++++++++- sql/core.go | 12 +++++++++--- sql/plan/ddl.go | 10 ++++++++-- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/memory/database.go b/memory/database.go index 8b0eeb2a5..0bacbe573 100644 --- a/memory/database.go +++ b/memory/database.go @@ -44,7 +44,18 @@ func (d *Database) Create(name string, schema sql.Schema) error { return nil } -func (d *Database) DropTable(name string) error { +// Create creates a table with the given name and schema +func (d *Database) CreateTable(ctx *sql.Context, name string, schema sql.Schema) error { + _, ok := d.tables[name] + if ok { + return sql.ErrTableAlreadyExists.New(name) + } + + d.tables[name] = NewTable(name, schema) + return nil +} + +func (d *Database) DropTable(ctx *sql.Context, name string) error { _, ok := d.tables[name] if !ok { return sql.ErrTableNotFound.New(name) diff --git a/sql/core.go b/sql/core.go index 2fa51260a..40fe0c702 100644 --- a/sql/core.go +++ b/sql/core.go @@ -230,14 +230,20 @@ type Database interface { Tables() map[string]Table } +// DEPRECATED. Use TableCreator and TableDropper. // Alterable should be implemented by databases that can handle DDL statements type Alterable interface { Create(name string, schema Schema) error } -// Droppable should be implemented by databases that can drop tables. -type Droppable interface { - DropTable(name string) error +// TableCreator should be implemented by databases that can create new tables. +type TableCreator interface { + CreateTable(ctx *Context, name string, schema Schema) error +} + +// TableDropper should be implemented by databases that can drop tables. +type TableDropper interface { + DropTable(ctx *Context, name string) error } // Lockable should be implemented by tables that can be locked and unlocked. diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index 308fec21f..2206fd0fd 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -59,6 +59,12 @@ func (c *CreateTable) Resolved() bool { // RowIter implements the Node interface. func (c *CreateTable) RowIter(s *sql.Context) (sql.RowIter, error) { + creatable, ok := c.db.(sql.TableCreator) + if ok { + return sql.RowsToRowIter(), creatable.CreateTable(s, c.name, c.schema) + } + + // TODO: phase out this interface d, ok := c.db.(sql.Alterable) if !ok { return nil, ErrCreateTable.New(c.db.Name()) @@ -116,7 +122,7 @@ func (d *DropTable) Resolved() bool { // RowIter implements the Node interface. func (d *DropTable) RowIter(s *sql.Context) (sql.RowIter, error) { - droppable, ok := d.db.(sql.Droppable) + droppable, ok := d.db.(sql.TableDropper) if !ok { return nil, ErrDropTableNotSupported.New(d.db.Name()) } @@ -130,7 +136,7 @@ func (d *DropTable) RowIter(s *sql.Context) (sql.RowIter, error) { } return nil, sql.ErrTableNotFound.New(tableName) } - err = droppable.DropTable(tableName) + err = droppable.DropTable(s, tableName) if err != nil { break } From 3fe1190fa13b36c9f70d257a7ae56b041f513dc2 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Tue, 1 Oct 2019 15:58:07 -0700 Subject: [PATCH 27/44] Support for primary keys in column definitions Signed-off-by: Zach Musgrave --- sql/parse/parse.go | 71 +++++++++++++++++++++++++++++++++-------- sql/parse/parse_test.go | 45 ++++++++++++++++++++++++++ sql/type.go | 2 ++ 3 files changed, 105 insertions(+), 13 deletions(-) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index eaa42d1b8..50d242227 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -49,6 +49,16 @@ var ( setRegex = regexp.MustCompile(`^set\s+`) ) +// These constants aren't exported from vitess for some reason. This could be removed if we changed this. +const ( + colKeyNone sqlparser.ColumnKeyOption = iota + colKeyPrimary + colKeySpatialKey + colKeyUnique + colKeyUniqueKey + colKey +) + // Parse parses the given SQL sentence and returns the corresponding node. func Parse(ctx *sql.Context, query string) (sql.Node, error) { span, ctx := ctx.Span("parse", opentracing.Tag{Key: "query", Value: query}) @@ -144,7 +154,12 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node case *sqlparser.Insert: return convertInsert(ctx, n) case *sqlparser.DDL: - return convertDDL(n) + // unlike other statements, DDL statements have loose parsing by default + ddl, err := sqlparser.ParseStrictDDL(query) + if err != nil { + return nil, err + } + return convertDDL(ddl.(*sqlparser.DDL)) case *sqlparser.Set: return convertSet(ctx, n) case *sqlparser.Use: @@ -370,7 +385,7 @@ func convertDropTable(c *sqlparser.DDL) (sql.Node, error) { } func convertCreateTable(c *sqlparser.DDL) (sql.Node, error) { - schema, err := columnDefinitionToSchema(c.TableSpec.Columns) + schema, err := columnDefinitionToSchema(c.TableSpec) if err != nil { return nil, err } @@ -472,6 +487,7 @@ func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) { if err != nil { return nil, err } + } if d.Limit != nil { @@ -484,27 +500,56 @@ func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) { return plan.NewUpdate(node, updateExprs), nil } -func columnDefinitionToSchema(colDef []*sqlparser.ColumnDefinition) (sql.Schema, error) { +func columnDefinitionToSchema(tableSpec *sqlparser.TableSpec) (sql.Schema, error) { var schema sql.Schema - for _, cd := range colDef { - typ := cd.Type - internalTyp, err := sql.MysqlTypeToType(typ.SQLType()) + for _, cd := range tableSpec.Columns { + column, err := getColumn(cd, tableSpec.Indexes) if err != nil { return nil, err } - schema = append(schema, &sql.Column{ - Nullable: !bool(typ.NotNull), - Type: internalTyp, - Name: cd.Name.String(), - // TODO - Default: nil, - }) + schema = append(schema, column) } return schema, nil } +// getColumn returns the sql.Column for the column definition given, as part of a create table statement. +func getColumn(cd *sqlparser.ColumnDefinition, indexes []*sqlparser.IndexDefinition) (*sql.Column, error) { + typ := cd.Type + internalTyp, err := sql.MysqlTypeToType(typ.SQLType()) + if err != nil { + return nil, err + } + + // Primary key info can either be specified in the column's type info (for in-line declarations), or in a slice of + // indexes attached to the table def. We have to check both places to find if a column is part of the primary key + isPkey := cd.Type.KeyOpt == colKeyPrimary + + if !isPkey { + OuterLoop: + for _, index := range indexes { + if index.Info.Primary { + for _, indexCol := range index.Columns { + if indexCol.Column.Equal(cd.Name) { + isPkey = true + break OuterLoop + } + } + } + } + } + + return &sql.Column{ + Nullable: !bool(typ.NotNull), + Type: internalTyp, + Name: cd.Name.String(), + PrimaryKey: isPkey, + // TODO + Default: nil, + }, nil +} + func columnsToStrings(cols sqlparser.Columns) []string { res := make([]string, len(cols)) for i, c := range cols { diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 0f27794b4..664b713fd 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -51,6 +51,51 @@ var fixtures = map[string]sql.Node{ Nullable: true, }}, ), + `CREATE TABLE t1(a INTEGER NOT NULL PRIMARY KEY, b TEXT)`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: false, + }}, + ), + `CREATE TABLE t1(a INTEGER, b TEXT, PRIMARY KEY (a))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: true, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: false, + }}, + ), + `CREATE TABLE t1(a INTEGER, b TEXT, PRIMARY KEY (a, b))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: true, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: true, + }}, + ), `DROP TABLE foo;`: plan.NewDropTable( sql.UnresolvedDatabase(""), false, "foo", ), diff --git a/sql/type.go b/sql/type.go index 94e825328..bb683b53a 100644 --- a/sql/type.go +++ b/sql/type.go @@ -124,6 +124,8 @@ type Column struct { Nullable bool // Source is the name of the table this column came from. Source string + // PrimaryKey is true if the column is part of the primary key for its table. + PrimaryKey bool } // Check ensures the value is correct for this column. From 994fc517fc2b505bbc3fd77ea164c5bf5f020d7e Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Tue, 1 Oct 2019 17:23:25 -0700 Subject: [PATCH 28/44] Added support for 24-bit integers (MySQL's MEDIUMINT) Signed-off-by: Zach Musgrave --- sql/type.go | 12 ++++++++++-- sql/type_test.go | 2 ++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/type.go b/sql/type.go index bb683b53a..63272445d 100644 --- a/sql/type.go +++ b/sql/type.go @@ -186,6 +186,10 @@ var ( Int16 = numberT{t: sqltypes.Int16} // Uint16 is an unsigned integer of 16 bits Uint16 = numberT{t: sqltypes.Uint16} + // Int24 is an integer of 24 bits. + Int24 = numberT{t: sqltypes.Int24} + // Uint24 is an unsigned integer of 24 bits. + Uint24 = numberT{t: sqltypes.Uint24} // Int32 is an integer of 32 bits. Int32 = numberT{t: sqltypes.Int32} // Uint32 is an unsigned integer of 32 bits. @@ -248,12 +252,16 @@ func MysqlTypeToType(sql query.Type) (Type, error) { return Int16, nil case sqltypes.Uint16: return Uint16, nil + case sqltypes.Int24: + return Int24, nil + case sqltypes.Uint24: + return Uint24, nil case sqltypes.Int32: return Int32, nil - case sqltypes.Int64: - return Int64, nil case sqltypes.Uint32: return Uint32, nil + case sqltypes.Int64: + return Int64, nil case sqltypes.Uint64: return Uint64, nil case sqltypes.Float32: diff --git a/sql/type_test.go b/sql/type_test.go index 5ca9ac9e5..0dab60260 100644 --- a/sql/type_test.go +++ b/sql/type_test.go @@ -192,6 +192,8 @@ func TestNumberComparison(t *testing.T) { {Uint8, uint8(42)}, {Int16, int16(42)}, {Uint16, uint16(42)}, + {Int24, int32(42)}, + {Uint24, uint32(42)}, {Int32, int32(42)}, {Uint32, uint32(42)}, {Int64, int64(42)}, From a7bfc0a5336d9788c657ef271df35ab9f1cbeda0 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Wed, 16 Oct 2019 11:22:39 -0700 Subject: [PATCH 29/44] Throw an error on CREATE VIEW statements, which were parsing as CREATE TABLE with an empty table spec prior to this change. Signed-off-by: Zach Musgrave --- sql/parse/parse.go | 4 ++++ sql/parse/parse_test.go | 1 + 2 files changed, 5 insertions(+) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 50d242227..2f3d7d732 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -47,6 +47,7 @@ var ( unlockTablesRegex = regexp.MustCompile(`^unlock\s+tables$`) lockTablesRegex = regexp.MustCompile(`^lock\s+tables\s`) setRegex = regexp.MustCompile(`^set\s+`) + createViewRegex = regexp.MustCompile(`^create\s+view\s+`) ) // These constants aren't exported from vitess for some reason. This could be removed if we changed this. @@ -103,6 +104,9 @@ func Parse(ctx *sql.Context, query string) (sql.Node, error) { return parseLockTables(ctx, s) case setRegex.MatchString(lowerQuery): s = fixSetQuery(s) + case createViewRegex.MatchString(lowerQuery): + // CREATE VIEW parses as a CREATE DDL statement with an empty table spec + return nil, ErrUnsupportedFeature.New("CREATE VIEW") } stmt, err := sqlparser.Parse(s) diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 664b713fd..b66dca943 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -1296,6 +1296,7 @@ var fixturesErrors = map[string]*errors.Kind{ `SELECT INTERVAL 1 DAY + INTERVAL 1 DAY`: ErrUnsupportedSyntax, `SELECT '2018-05-01' + (INTERVAL 1 DAY + INTERVAL 1 DAY)`: ErrUnsupportedSyntax, `SELECT AVG(DISTINCT foo) FROM b`: ErrUnsupportedSyntax, + `CREATE VIEW view1 AS SELECT x FROM t1 WHERE x>0`: ErrUnsupportedFeature, } func TestParseErrors(t *testing.T) { From 808d3ef879c4ed646ead3a43915b8d0ea248429c Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Thu, 17 Oct 2019 11:32:48 -0700 Subject: [PATCH 30/44] Renamed columnDefinitiontoSchema to tableSpecToSchema Signed-off-by: Zach Musgrave --- sql/parse/parse.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 2f3d7d732..f4dfe6f0a 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -389,7 +389,7 @@ func convertDropTable(c *sqlparser.DDL) (sql.Node, error) { } func convertCreateTable(c *sqlparser.DDL) (sql.Node, error) { - schema, err := columnDefinitionToSchema(c.TableSpec) + schema, err := tableSpecToSchema(c.TableSpec) if err != nil { return nil, err } @@ -504,7 +504,7 @@ func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) { return plan.NewUpdate(node, updateExprs), nil } -func columnDefinitionToSchema(tableSpec *sqlparser.TableSpec) (sql.Schema, error) { +func tableSpecToSchema(tableSpec *sqlparser.TableSpec) (sql.Schema, error) { var schema sql.Schema for _, cd := range tableSpec.Columns { column, err := getColumn(cd, tableSpec.Indexes) From 0ddfce4dbfbabf16ea80e9c83e04a2cb2f1022dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mart=C3=ADn?= Date: Fri, 18 Oct 2019 11:38:36 +0100 Subject: [PATCH 31/44] sql: Add length to VARCHAR MySQLTypeName string MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Carlos Martín --- sql/plan/show_create_table_test.go | 5 ++++- sql/type.go | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/plan/show_create_table_test.go b/sql/plan/show_create_table_test.go index ff597682c..16b150ea4 100644 --- a/sql/plan/show_create_table_test.go +++ b/sql/plan/show_create_table_test.go @@ -19,6 +19,7 @@ func TestShowCreateTable(t *testing.T) { &sql.Column{Name: "baz", Type: sql.Text, Default: "", Nullable: false}, &sql.Column{Name: "zab", Type: sql.Int32, Default: int32(0), Nullable: true}, &sql.Column{Name: "bza", Type: sql.Uint64, Default: uint64(0), Nullable: true}, + &sql.Column{Name: "foo", Type: sql.VarChar(123), Default: "", Nullable: true}, }) db.AddTable(table.Name(), table) @@ -39,7 +40,9 @@ func TestShowCreateTable(t *testing.T) { table.Name(), "CREATE TABLE `test-table` (\n `baz` text NOT NULL,\n"+ " `zab` integer DEFAULT 0,\n"+ - " `bza` bigint unsigned DEFAULT 0\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", + " `bza` bigint unsigned DEFAULT 0,\n"+ + " `foo` varchar(123)\n"+ + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", ) require.Equal(expected, row) diff --git a/sql/type.go b/sql/type.go index 94e825328..46d43299d 100644 --- a/sql/type.go +++ b/sql/type.go @@ -1258,7 +1258,7 @@ func MySQLTypeName(t Type) string { case sqltypes.Char: return "CHAR" case sqltypes.VarChar: - return "VARCHAR" + return fmt.Sprintf("VARCHAR(%v)", t.(varCharT).Capacity()) case sqltypes.Text: return "TEXT" case sqltypes.Bit: From 058210282e146cf186bd3c97e19f3e4a4e8c4d6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mart=C3=ADn?= Date: Fri, 18 Oct 2019 11:51:20 +0100 Subject: [PATCH 32/44] docs: Remove DESCRIBE [table name] from supported expressions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Carlos Martín --- SUPPORTED.md | 1 - 1 file changed, 1 deletion(-) diff --git a/SUPPORTED.md b/SUPPORTED.md index 6b89188c0..87841ac72 100644 --- a/SUPPORTED.md +++ b/SUPPORTED.md @@ -27,7 +27,6 @@ - ALIAS (AS) - CAST/CONVERT - CREATE TABLE -- DESCRIBE/DESC/EXPLAIN [table name] - DESCRIBE/DESC/EXPLAIN FORMAT=TREE [query] - DISTINCT - FILTER (WHERE) From d63b44ecdb8818ee0aedb530641468196f58c9e7 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 18 Oct 2019 16:54:25 -0700 Subject: [PATCH 33/44] PR feedback: added engine tests, removed Alterable interface, added comments. Signed-off-by: Zach Musgrave --- engine_test.go | 79 +++++++++++++++++++++++++++++++++++++++++++++- memory/database.go | 14 ++------ sql/core.go | 6 ---- sql/plan/ddl.go | 24 ++++++-------- 4 files changed, 89 insertions(+), 34 deletions(-) diff --git a/engine_test.go b/engine_test.go index e1274e2ab..5070589ee 100644 --- a/engine_test.go +++ b/engine_test.go @@ -2443,7 +2443,7 @@ func TestAmbiguousColumnResolution(t *testing.T) { require.Equal(expected, rs) } -func TestDDL(t *testing.T) { +func TestCreateTable(t *testing.T) { require := require.New(t) e := newEngine(t) @@ -2474,6 +2474,83 @@ func TestDDL(t *testing.T) { } require.Equal(s, testTable.Schema()) + + testQuery(t, e, + "CREATE TABLE t2 (a INTEGER NOT NULL PRIMARY KEY, "+ + "b VARCHAR(10) NOT NULL)", + []sql.Row(nil), + ) + + db, err = e.Catalog.Database("mydb") + require.NoError(err) + + testTable, ok = db.Tables()["t2"] + require.True(ok) + + s = sql.Schema{ + {Name: "a", Type: sql.Int32, Nullable: false, PrimaryKey: true, Source: "t2"}, + {Name: "b", Type: sql.Text, Nullable: false, Source: "t2"}, + } + + require.Equal(s, testTable.Schema()) + + testQuery(t, e, + "CREATE TABLE t3(a INTEGER NOT NULL,"+ + "b TEXT NOT NULL,"+ + "c bool, primary key (a,b))", + []sql.Row(nil), + ) + + db, err = e.Catalog.Database("mydb") + require.NoError(err) + + testTable, ok = db.Tables()["t3"] + require.True(ok) + + s = sql.Schema{ + {Name: "a", Type: sql.Int32, Nullable: false, PrimaryKey: true, Source: "t3"}, + {Name: "b", Type: sql.Text, Nullable: false, PrimaryKey: true, Source: "t3"}, + {Name: "c", Type: sql.Uint8, Nullable: true, Source: "t3"}, + } + + require.Equal(s, testTable.Schema()) +} + +func TestDropTable(t *testing.T) { + require := require.New(t) + + e := newEngine(t) + db, err := e.Catalog.Database("mydb") + require.NoError(err) + + _, ok := db.Tables()["mytable"] + require.True(ok) + + testQuery(t, e, + "DROP TABLE IF EXISTS mytable, not_exist", + []sql.Row(nil), + ) + + _, ok = db.Tables()["mytable"] + require.False(ok) + + _, ok = db.Tables()["othertable"] + require.True(ok) + _, ok = db.Tables()["tabletest"] + require.True(ok) + + testQuery(t, e, + "DROP TABLE IF EXISTS othertable, tabletest", + []sql.Row(nil), + ) + + _, ok = db.Tables()["othertable"] + require.False(ok) + _, ok = db.Tables()["tabletest"] + require.False(ok) + + _, _, err = e.Query(newCtx(), "DROP TABLE not_exist") + require.Error(err) } func TestNaturalJoin(t *testing.T) { diff --git a/memory/database.go b/memory/database.go index 0bacbe573..133c4f5de 100644 --- a/memory/database.go +++ b/memory/database.go @@ -33,18 +33,7 @@ func (d *Database) AddTable(name string, t sql.Table) { d.tables[name] = t } -// Create creates a table with the given name and schema -func (d *Database) Create(name string, schema sql.Schema) error { - _, ok := d.tables[name] - if ok { - return sql.ErrTableAlreadyExists.New(name) - } - - d.tables[name] = NewTable(name, schema) - return nil -} - -// Create creates a table with the given name and schema +// CreateTable creates a table with the given name and schema func (d *Database) CreateTable(ctx *sql.Context, name string, schema sql.Schema) error { _, ok := d.tables[name] if ok { @@ -55,6 +44,7 @@ func (d *Database) CreateTable(ctx *sql.Context, name string, schema sql.Schema) return nil } +// DropTable drops the table with the given name func (d *Database) DropTable(ctx *sql.Context, name string) error { _, ok := d.tables[name] if !ok { diff --git a/sql/core.go b/sql/core.go index 40fe0c702..2d49ece19 100644 --- a/sql/core.go +++ b/sql/core.go @@ -230,12 +230,6 @@ type Database interface { Tables() map[string]Table } -// DEPRECATED. Use TableCreator and TableDropper. -// Alterable should be implemented by databases that can handle DDL statements -type Alterable interface { - Create(name string, schema Schema) error -} - // TableCreator should be implemented by databases that can create new tables. type TableCreator interface { CreateTable(ctx *Context, name string, schema Schema) error diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index 2206fd0fd..d9acf8a09 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -7,7 +7,7 @@ import ( ) // ErrCreateTable is thrown when the database doesn't support table creation -var ErrCreateTable = errors.NewKind("tables cannot be created on database %s") +var ErrCreateTableNotSupported = errors.NewKind("tables cannot be created on database %s") var ErrDropTableNotSupported = errors.NewKind("tables cannot be dropped on database %s") // CreateTable is a node describing the creation of some table. @@ -17,13 +17,6 @@ type CreateTable struct { schema sql.Schema } -// DropTable is a node describing dropping a table -type DropTable struct { - db sql.Database - names []string - ifExists bool -} - // NewCreateTable creates a new CreateTable node func NewCreateTable(db sql.Database, name string, schema sql.Schema) *CreateTable { for _, s := range schema { @@ -64,13 +57,7 @@ func (c *CreateTable) RowIter(s *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(), creatable.CreateTable(s, c.name, c.schema) } - // TODO: phase out this interface - d, ok := c.db.(sql.Alterable) - if !ok { - return nil, ErrCreateTable.New(c.db.Name()) - } - - return sql.RowsToRowIter(), d.Create(c.name, c.schema) + return nil, ErrCreateTableNotSupported.New(c.db.Name()) } // Schema implements the Node interface. @@ -91,6 +78,13 @@ func (c *CreateTable) String() string { return "CreateTable" } +// DropTable is a node describing dropping one or more tables +type DropTable struct { + db sql.Database + names []string + ifExists bool +} + // NewDropTable creates a new DropTable node func NewDropTable(db sql.Database, ifExists bool, tableNames ...string) *DropTable { return &DropTable{ From 939ca769e0f0fcaf6bcac02c6a411a35754e00a2 Mon Sep 17 00:00:00 2001 From: Juanjo Alvarez Date: Mon, 21 Oct 2019 10:54:59 +0200 Subject: [PATCH 34/44] Add length to Char.String Signed-off-by: Juanjo Alvarez --- sql/plan/show_create_table_test.go | 4 +++- sql/type.go | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/plan/show_create_table_test.go b/sql/plan/show_create_table_test.go index 16b150ea4..0da6418bf 100644 --- a/sql/plan/show_create_table_test.go +++ b/sql/plan/show_create_table_test.go @@ -20,6 +20,7 @@ func TestShowCreateTable(t *testing.T) { &sql.Column{Name: "zab", Type: sql.Int32, Default: int32(0), Nullable: true}, &sql.Column{Name: "bza", Type: sql.Uint64, Default: uint64(0), Nullable: true}, &sql.Column{Name: "foo", Type: sql.VarChar(123), Default: "", Nullable: true}, + &sql.Column{Name: "pok", Type: sql.Char(123), Default: "", Nullable: true}, }) db.AddTable(table.Name(), table) @@ -41,7 +42,8 @@ func TestShowCreateTable(t *testing.T) { "CREATE TABLE `test-table` (\n `baz` text NOT NULL,\n"+ " `zab` integer DEFAULT 0,\n"+ " `bza` bigint unsigned DEFAULT 0,\n"+ - " `foo` varchar(123)\n"+ + " `foo` varchar(123),\n"+ + " `pok` char(123)\n"+ ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", ) diff --git a/sql/type.go b/sql/type.go index 46d43299d..48ee57f8d 100644 --- a/sql/type.go +++ b/sql/type.go @@ -1256,7 +1256,7 @@ func MySQLTypeName(t Type) string { case sqltypes.Date: return "DATE" case sqltypes.Char: - return "CHAR" + return fmt.Sprintf("CHAR(%v)", t.(charT).Capacity()) case sqltypes.VarChar: return fmt.Sprintf("VARCHAR(%v)", t.(varCharT).Capacity()) case sqltypes.Text: From 2f20a50a803dbf292d8897dac023f1411442001b Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Mon, 21 Oct 2019 09:34:06 -0700 Subject: [PATCH 35/44] Fixed test error Signed-off-by: Zach Musgrave --- memory/database_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/memory/database_test.go b/memory/database_test.go index e41c9e289..c6a2f9182 100644 --- a/memory/database_test.go +++ b/memory/database_test.go @@ -19,8 +19,7 @@ func TestDatabase_AddTable(t *testing.T) { tables := db.Tables() require.Equal(0, len(tables)) - var altDb sql.Alterable = db - err := altDb.Create("test_table", nil) + err := db.CreateTable(sql.NewEmptyContext(), "test_table", nil) require.NoError(err) tables = db.Tables() @@ -29,6 +28,6 @@ func TestDatabase_AddTable(t *testing.T) { require.True(ok) require.NotNil(tt) - err = altDb.Create("test_table", nil) + err = db.CreateTable(sql.NewEmptyContext(), "test_table", nil) require.Error(err) } From 256db1a5500c1d71591e0cc147a1f9fe54865d0f Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Tue, 22 Oct 2019 15:01:08 +0200 Subject: [PATCH 36/44] sql: information schema column types should be lowercase Fixes #849 Signed-off-by: Miguel Molina --- engine_test.go | 17 ++++++++-------- sql/information_schema.go | 43 ++++++++++++++++++++------------------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/engine_test.go b/engine_test.go index 5070589ee..0c3f4807f 100644 --- a/engine_test.go +++ b/engine_test.go @@ -18,7 +18,6 @@ import ( "github.com/src-d/go-mysql-server/sql/plan" "github.com/src-d/go-mysql-server/test" - "github.com/stretchr/testify/require" ) @@ -867,8 +866,8 @@ var queries = []struct { WHERE TABLE_SCHEMA='mydb' AND TABLE_NAME='mytable' `, []sql.Row{ - {"s", "TEXT"}, - {"i", "BIGINT"}, + {"s", "text"}, + {"i", "bigint"}, }, }, { @@ -2250,8 +2249,8 @@ func TestUpdate(t *testing.T) { var updates = []struct { updateQuery string expectedUpdate []sql.Row - selectQuery string - expectedSelect []sql.Row + selectQuery string + expectedSelect []sql.Row }{ { "UPDATE mytable SET s = 'updated';", @@ -2275,7 +2274,7 @@ func TestUpdate(t *testing.T) { "UPDATE mytable SET s = 'updated' WHERE i <> 9999;", []sql.Row{{int64(3), int64(3)}}, "SELECT * FROM mytable;", - []sql.Row{{int64(1), "updated"},{int64(2), "updated"},{int64(3), "updated"}}, + []sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}}, }, { "UPDATE floattable SET f32 = f32 + f32, f64 = f32 * f64 WHERE i = 2;", @@ -2477,7 +2476,7 @@ func TestCreateTable(t *testing.T) { testQuery(t, e, "CREATE TABLE t2 (a INTEGER NOT NULL PRIMARY KEY, "+ - "b VARCHAR(10) NOT NULL)", + "b VARCHAR(10) NOT NULL)", []sql.Row(nil), ) @@ -2496,8 +2495,8 @@ func TestCreateTable(t *testing.T) { testQuery(t, e, "CREATE TABLE t3(a INTEGER NOT NULL,"+ - "b TEXT NOT NULL,"+ - "c bool, primary key (a,b))", + "b TEXT NOT NULL,"+ + "c bool, primary key (a,b))", []sql.Row(nil), ) diff --git a/sql/information_schema.go b/sql/information_schema.go index d88804751..48147c7d5 100644 --- a/sql/information_schema.go +++ b/sql/information_schema.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "io" + "strings" ) const ( @@ -214,27 +215,27 @@ func columnsRowIter(cat *Catalog) RowIter { collName = "utf8_bin" } rows = append(rows, Row{ - "def", // table_catalog - db.Name(), // table_schema - t.Name(), // table_name - c.Name, // column_name - uint64(i), // ordinal_position - c.Default, // column_default - nullable, // is_nullable - MySQLTypeName(c.Type), // data_type - nil, // character_maximum_length - nil, // character_octet_length - nil, // numeric_precision - nil, // numeric_scale - nil, // datetime_precision - charName, // character_set_name - collName, // collation_name - MySQLTypeName(c.Type), // column_type - "", // column_key - "", // extra - "select", // privileges - "", // column_comment - "", // generation_expression + "def", // table_catalog + db.Name(), // table_schema + t.Name(), // table_name + c.Name, // column_name + uint64(i), // ordinal_position + c.Default, // column_default + nullable, // is_nullable + strings.ToLower(MySQLTypeName(c.Type)), // data_type + nil, // character_maximum_length + nil, // character_octet_length + nil, // numeric_precision + nil, // numeric_scale + nil, // datetime_precision + charName, // character_set_name + collName, // collation_name + strings.ToLower(MySQLTypeName(c.Type)), // column_type + "", // column_key + "", // extra + "select", // privileges + "", // column_comment + "", // generation_expression }) } } From d3ef2bcabd6d93cf084495f0f2158b9cf061dad0 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Tue, 22 Oct 2019 15:04:44 +0200 Subject: [PATCH 37/44] ci: add python-mysql and python-sqlalchemy to matrix Signed-off-by: Miguel Molina --- .gitignore | 1 + .travis.yml | 14 ++++++++++++++ _integration/python-sqlalchemy/requirements.txt | 3 ++- _integration/python-sqlalchemy/test.py | 5 ++--- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 251b391c1..cc1c664a6 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ Makefile.main .ci/ _example/main _example/*.exe +test-server diff --git a/.travis.yml b/.travis.yml index 9ebb6ac49..64eecc9a1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -38,6 +38,20 @@ jobs: script: - make TEST=python-pymysql integration + - language: python + python: '3.6' + before_install: + - eval "$(gimme 1.12.4)" + script: + - make TEST=python-mysql integration + + - language: python + python: '3.6' + before_install: + - eval "$(gimme 1.12.4)" + script: + - make TEST=python-sqlalchemy integration + - language: php php: '7.1' before_install: diff --git a/_integration/python-sqlalchemy/requirements.txt b/_integration/python-sqlalchemy/requirements.txt index a0e94bb24..4a2392f06 100644 --- a/_integration/python-sqlalchemy/requirements.txt +++ b/_integration/python-sqlalchemy/requirements.txt @@ -1,2 +1,3 @@ pandas -sqlalchemy \ No newline at end of file +sqlalchemy +mysqlclient diff --git a/_integration/python-sqlalchemy/test.py b/_integration/python-sqlalchemy/test.py index 6c8767d7e..e1baa052d 100644 --- a/_integration/python-sqlalchemy/test.py +++ b/_integration/python-sqlalchemy/test.py @@ -6,14 +6,13 @@ class TestMySQL(unittest.TestCase): def test_connect(self): - engine = sqlalchemy.create_engine('mysql+pymysql://root:@127.0.0.1:3306/mydb') + engine = sqlalchemy.create_engine('mysql+mysqldb://root:@127.0.0.1:3306/mydb') with engine.connect() as conn: expected = { "name": {0: 'John Doe', 1: 'John Doe', 2: 'Jane Doe', 3: 'Evil Bob'}, "email": {0: 'john@doe.com', 1: 'johnalt@doe.com', 2: 'jane@doe.com', 3: 'evilbob@gmail.com'}, - "phone_numbers": {0: '["555-555-555"]', 1: '[]', 2: '[]', 3: '["555-666-555","666-666-666"]'}, + "phone_numbers": {0: ['555-555-555'], 1: [], 2: [], 3: ['555-666-555', '666-666-666']}, } - repo_df = pd.read_sql_table("mytable", con=conn) d = repo_df.to_dict() del d["created_at"] From b6ada0efbc3c738d9ff98a9e860d46a025e9e384 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Tue, 22 Oct 2019 16:46:16 +0200 Subject: [PATCH 38/44] sql: finish root span of the query Fixes #853 Signed-off-by: Miguel Molina --- engine_test.go | 28 ++++++++++++++++++++++++++++ server/context.go | 6 ++---- sql/analyzer/process.go | 7 ++++++- sql/session.go | 27 ++++++++++++++++++++------- 4 files changed, 56 insertions(+), 12 deletions(-) diff --git a/engine_test.go b/engine_test.go index 0c3f4807f..b022a3e6b 100644 --- a/engine_test.go +++ b/engine_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "github.com/opentracing/opentracing-go" + sqle "github.com/src-d/go-mysql-server" "github.com/src-d/go-mysql-server/auth" "github.com/src-d/go-mysql-server/memory" @@ -3420,6 +3422,32 @@ func TestDeleteFromErrors(t *testing.T) { } } +type mockSpan struct { + opentracing.Span + finished bool +} + +func (m *mockSpan) Finish() { + m.finished = true +} + +func TestRootSpanFinish(t *testing.T) { + e := newEngine(t) + fakeSpan := &mockSpan{Span: opentracing.NoopTracer{}.StartSpan("")} + ctx := sql.NewContext( + context.Background(), + sql.WithRootSpan(fakeSpan), + ) + + _, iter, err := e.Query(ctx, "SELECT 1") + require.NoError(t, err) + + _, err = sql.RowIterToRows(iter) + require.NoError(t, err) + + require.True(t, fakeSpan.finished) +} + var generatorQueries = []struct { query string expected []sql.Row diff --git a/server/context.go b/server/context.go index ca2d7ce7b..6ee2ee508 100644 --- a/server/context.go +++ b/server/context.go @@ -92,15 +92,13 @@ func (s *SessionManager) NewContextWithQuery( s.mu.Unlock() context := sql.NewContext( - opentracing.ContextWithSpan( - context.Background(), - s.tracer.StartSpan("query"), - ), + context.Background(), sql.WithSession(sess), sql.WithTracer(s.tracer), sql.WithPid(s.nextPid()), sql.WithQuery(query), sql.WithMemoryManager(s.memory), + sql.WithRootSpan(s.tracer.StartSpan("query")), ) return context diff --git a/sql/analyzer/process.go b/sql/analyzer/process.go index 926d52c83..24170e615 100644 --- a/sql/analyzer/process.go +++ b/sql/analyzer/process.go @@ -85,5 +85,10 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { return nil, err } - return plan.NewQueryProcess(node, func() { processList.Done(ctx.Pid()) }), nil + return plan.NewQueryProcess(node, func() { + processList.Done(ctx.Pid()) + if span := ctx.RootSpan(); span != nil { + span.Finish() + } + }), nil } diff --git a/sql/session.go b/sql/session.go index 8123ed51e..251467ba0 100644 --- a/sql/session.go +++ b/sql/session.go @@ -211,10 +211,11 @@ func NewBaseSession() Session { type Context struct { context.Context Session - Memory *MemoryManager - pid uint64 - query string - tracer opentracing.Tracer + Memory *MemoryManager + pid uint64 + query string + tracer opentracing.Tracer + rootSpan opentracing.Span } // ContextOption is a function to configure the context. @@ -255,6 +256,13 @@ func WithMemoryManager(m *MemoryManager) ContextOption { } } +// WithRootSpan sets the root span of the context. +func WithRootSpan(s opentracing.Span) ContextOption { + return func(ctx *Context) { + ctx.rootSpan = s + } +} + // NewContext creates a new query context. Options can be passed to configure // the context. If some aspect of the context is not configure, the default // value will be used. @@ -264,7 +272,7 @@ func NewContext( ctx context.Context, opts ...ContextOption, ) *Context { - c := &Context{ctx, NewBaseSession(), nil, 0, "", opentracing.NoopTracer{}} + c := &Context{ctx, NewBaseSession(), nil, 0, "", opentracing.NoopTracer{}, nil} for _, opt := range opts { opt(c) } @@ -298,12 +306,17 @@ func (c *Context) Span( span := c.tracer.StartSpan(opName, opts...) ctx := opentracing.ContextWithSpan(c.Context, span) - return span, &Context{ctx, c.Session, c.Memory, c.Pid(), c.Query(), c.tracer} + return span, &Context{ctx, c.Session, c.Memory, c.Pid(), c.Query(), c.tracer, c.rootSpan} } // WithContext returns a new context with the given underlying context. func (c *Context) WithContext(ctx context.Context) *Context { - return &Context{ctx, c.Session, c.Memory, c.Pid(), c.Query(), c.tracer} + return &Context{ctx, c.Session, c.Memory, c.Pid(), c.Query(), c.tracer, c.rootSpan} +} + +// RootSpan returns the root span, if any. +func (c *Context) RootSpan() opentracing.Span { + return c.rootSpan } // Error adds an error as warning to the session. From 2f000c0c9a0f2bb98f1b06b62f418b4484ac16d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mart=C3=ADn?= Date: Wed, 23 Oct 2019 17:01:09 +0100 Subject: [PATCH 39/44] Add progress for each partition in SHOW PROCESSLIST MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Carlos Martín --- sql/analyzer/process.go | 15 ++++- sql/plan/process.go | 117 ++++++++++++++++++++++++++++++++------- sql/plan/process_test.go | 36 +++++++++--- sql/processlist.go | 14 +++++ 4 files changed, 151 insertions(+), 31 deletions(-) diff --git a/sql/analyzer/process.go b/sql/analyzer/process.go index 24170e615..09fbae4e0 100644 --- a/sql/analyzer/process.go +++ b/sql/analyzer/process.go @@ -44,16 +44,25 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { seen[name] = struct{}{} - notify := func() { + onPartitionDone := func(partitionName string) { processList.UpdateProgress(ctx.Pid(), name, 1) + processList.RemoveProgressItem(ctx.Pid(), partitionName) + } + + onPartitionStart := func(partitionName string) { + processList.AddProgressItem(ctx.Pid(), partitionName, -1) + } + + onRowNext := func(partitionName string) { + processList.UpdateProgress(ctx.Pid(), partitionName, 1) } var t sql.Table switch table := n.Table.(type) { case sql.IndexableTable: - t = plan.NewProcessIndexableTable(table, notify) + t = plan.NewProcessIndexableTable(table, onPartitionDone, onPartitionStart, onRowNext) default: - t = plan.NewProcessTable(table, notify) + t = plan.NewProcessTable(table, onPartitionDone, onPartitionStart, onRowNext) } return plan.NewResolvedTable(t), nil diff --git a/sql/plan/process.go b/sql/plan/process.go index 32a8452d6..52d2933a7 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -37,7 +37,7 @@ func (p *QueryProcess) RowIter(ctx *sql.Context) (sql.RowIter, error) { return nil, err } - return &trackedRowIter{iter, p.Notify}, nil + return &trackedRowIter{iter: iter, onDone: p.Notify}, nil } func (p *QueryProcess) String() string { return p.Child.String() } @@ -48,12 +48,14 @@ func (p *QueryProcess) String() string { return p.Child.String() } // partition is processed. type ProcessIndexableTable struct { sql.IndexableTable - Notify NotifyFunc + OnPartitionDone NamedNotifyFunc + OnPartitionStart NamedNotifyFunc + OnRowNext NamedNotifyFunc } // NewProcessIndexableTable returns a new ProcessIndexableTable. -func NewProcessIndexableTable(t sql.IndexableTable, notify NotifyFunc) *ProcessIndexableTable { - return &ProcessIndexableTable{t, notify} +func NewProcessIndexableTable(t sql.IndexableTable, onPartitionDone, onPartitionStart, OnRowNext NamedNotifyFunc) *ProcessIndexableTable { + return &ProcessIndexableTable{t, onPartitionDone, onPartitionStart, OnRowNext} } // Underlying implements sql.TableWrapper interface. @@ -71,7 +73,7 @@ func (t *ProcessIndexableTable) IndexKeyValues( return nil, err } - return &trackedPartitionIndexKeyValueIter{iter, t.Notify}, nil + return &trackedPartitionIndexKeyValueIter{iter, t.OnPartitionDone, t.OnPartitionStart, t.OnRowNext}, nil } // PartitionRows implements the sql.Table interface. @@ -81,22 +83,46 @@ func (t *ProcessIndexableTable) PartitionRows(ctx *sql.Context, p sql.Partition) return nil, err } - return &trackedRowIter{iter, t.Notify}, nil + partitionName := string(p.Key()) + if t.OnPartitionStart != nil { + t.OnPartitionStart(partitionName) + } + + var onDone NotifyFunc + if t.OnPartitionDone != nil { + onDone = func() { + t.OnPartitionDone(partitionName) + } + } + + var onNext NotifyFunc + if t.OnRowNext != nil { + onNext = func() { + t.OnRowNext(partitionName) + } + } + + return &trackedRowIter{iter: iter, onNext: onNext, onDone: onDone}, nil } var _ sql.IndexableTable = (*ProcessIndexableTable)(nil) +// NamedNotifyFunc is a function to notify about some event with a string argument. +type NamedNotifyFunc func(name string) + // ProcessTable is a wrapper for sql.Tables inside a query process. It // notifies the process manager about the status of a query when a partition // is processed. type ProcessTable struct { sql.Table - Notify NotifyFunc + OnPartitionDone NamedNotifyFunc + OnPartitionStart NamedNotifyFunc + OnRowNext NamedNotifyFunc } // NewProcessTable returns a new ProcessTable. -func NewProcessTable(t sql.Table, notify NotifyFunc) *ProcessTable { - return &ProcessTable{t, notify} +func NewProcessTable(t sql.Table, onPartitionDone, onPartitionStart, OnRowNext NamedNotifyFunc) *ProcessTable { + return &ProcessTable{t, onPartitionDone, onPartitionStart, OnRowNext} } // Underlying implements sql.TableWrapper interface. @@ -111,18 +137,38 @@ func (t *ProcessTable) PartitionRows(ctx *sql.Context, p sql.Partition) (sql.Row return nil, err } - return &trackedRowIter{iter, t.Notify}, nil + partitionName := string(p.Key()) + if t.OnPartitionStart != nil { + t.OnPartitionStart(partitionName) + } + + var onDone NotifyFunc + if t.OnPartitionDone != nil { + onDone = func() { + t.OnPartitionDone(partitionName) + } + } + + var onNext NotifyFunc + if t.OnRowNext != nil { + onNext = func() { + t.OnRowNext(partitionName) + } + } + + return &trackedRowIter{iter: iter, onNext: onNext, onDone: onDone}, nil } type trackedRowIter struct { iter sql.RowIter - notify NotifyFunc + onDone NotifyFunc + onNext NotifyFunc } func (i *trackedRowIter) done() { - if i.notify != nil { - i.notify() - i.notify = nil + if i.onDone != nil { + i.onDone() + i.onDone = nil } } @@ -134,6 +180,11 @@ func (i *trackedRowIter) Next() (sql.Row, error) { } return nil, err } + + if i.onNext != nil { + i.onNext() + } + return row, nil } @@ -144,7 +195,9 @@ func (i *trackedRowIter) Close() error { type trackedPartitionIndexKeyValueIter struct { sql.PartitionIndexKeyValueIter - notify NotifyFunc + OnPartitionDone NamedNotifyFunc + OnPartitionStart NamedNotifyFunc + OnRowNext NamedNotifyFunc } func (i *trackedPartitionIndexKeyValueIter) Next() (sql.Partition, sql.IndexKeyValueIter, error) { @@ -153,18 +206,38 @@ func (i *trackedPartitionIndexKeyValueIter) Next() (sql.Partition, sql.IndexKeyV return nil, nil, err } - return p, &trackedIndexKeyValueIter{iter, i.notify}, nil + partitionName := string(p.Key()) + if i.OnPartitionStart != nil { + i.OnPartitionStart(partitionName) + } + + var onDone NotifyFunc + if i.OnPartitionDone != nil { + onDone = func() { + i.OnPartitionDone(partitionName) + } + } + + var onNext NotifyFunc + if i.OnRowNext != nil { + onNext = func() { + i.OnRowNext(partitionName) + } + } + + return p, &trackedIndexKeyValueIter{iter, onDone, onNext}, nil } type trackedIndexKeyValueIter struct { iter sql.IndexKeyValueIter - notify NotifyFunc + onDone NotifyFunc + onNext NotifyFunc } func (i *trackedIndexKeyValueIter) done() { - if i.notify != nil { - i.notify() - i.notify = nil + if i.onDone != nil { + i.onDone() + i.onDone = nil } } @@ -185,5 +258,9 @@ func (i *trackedIndexKeyValueIter) Next() ([]interface{}, []byte, error) { return nil, nil, err } + if i.onNext != nil { + i.onNext() + } + return v, k, nil } diff --git a/sql/plan/process_test.go b/sql/plan/process_test.go index 2d4e65763..de819edb0 100644 --- a/sql/plan/process_test.go +++ b/sql/plan/process_test.go @@ -61,7 +61,9 @@ func TestProcessTable(t *testing.T) { table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(3))) table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(4))) - var notifications int + var partitionDoneNotifications int + var partitionStartNotifications int + var rowNextNotifications int node := NewProject( []sql.Expression{ @@ -70,8 +72,14 @@ func TestProcessTable(t *testing.T) { NewResolvedTable( NewProcessTable( table, - func() { - notifications++ + func(partitionName string) { + partitionDoneNotifications++ + }, + func(partitionName string) { + partitionStartNotifications++ + }, + func(partitionName string) { + rowNextNotifications++ }, ), ), @@ -91,7 +99,9 @@ func TestProcessTable(t *testing.T) { } require.ElementsMatch(expected, rows) - require.Equal(2, notifications) + require.Equal(2, partitionDoneNotifications) + require.Equal(2, partitionStartNotifications) + require.Equal(4, rowNextNotifications) } func TestProcessIndexableTable(t *testing.T) { @@ -106,12 +116,20 @@ func TestProcessIndexableTable(t *testing.T) { table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(3))) table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(4))) - var notifications int + var partitionDoneNotifications int + var partitionStartNotifications int + var rowNextNotifications int pt := NewProcessIndexableTable( table, - func() { - notifications++ + func(partitionName string) { + partitionDoneNotifications++ + }, + func(partitionName string) { + partitionStartNotifications++ + }, + func(partitionName string) { + rowNextNotifications++ }, ) @@ -144,5 +162,7 @@ func TestProcessIndexableTable(t *testing.T) { } require.ElementsMatch(expectedValues, values) - require.Equal(2, notifications) + require.Equal(2, partitionDoneNotifications) + require.Equal(2, partitionStartNotifications) + require.Equal(4, rowNextNotifications) } diff --git a/sql/processlist.go b/sql/processlist.go index 3a17f50d0..580323238 100644 --- a/sql/processlist.go +++ b/sql/processlist.go @@ -156,6 +156,20 @@ func (pl *ProcessList) AddProgressItem(pid uint64, name string, total int64) { } } +// RemoveProgressItem removes an existing item tracking progress from the +// process with the given pid, if it exists. +func (pl *ProcessList) RemoveProgressItem(pid uint64, name string) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + delete(p.Progress, name) +} + // Kill terminates all queries for a given connection id. func (pl *ProcessList) Kill(connID uint32) { pl.mu.Lock() From 64b899d2c31e2f32fe3ddecbfc504f9dc076e0f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mart=C3=ADn?= Date: Wed, 23 Oct 2019 17:28:35 +0100 Subject: [PATCH 40/44] Add names to travis tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Carlos Martín --- .travis.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.travis.yml b/.travis.yml index 64eecc9a1..e06f94818 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,10 +24,13 @@ script: jobs: include: - go: 1.11.x + name: 'Unit tests Go 1.11' - go: 1.12.x + name: 'Unit tests Go 1.12' # Integration test builds for 3rd party clients - go: 1.12.x + name: 'Integration test go' script: - make TEST=go integration @@ -35,6 +38,7 @@ jobs: python: '3.6' before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test python-pymysql' script: - make TEST=python-pymysql integration @@ -42,6 +46,7 @@ jobs: python: '3.6' before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test python-mysql' script: - make TEST=python-mysql integration @@ -49,6 +54,7 @@ jobs: python: '3.6' before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test python-sqlalchemy' script: - make TEST=python-sqlalchemy integration @@ -56,6 +62,7 @@ jobs: php: '7.1' before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test php' script: - make TEST=php integration @@ -63,6 +70,7 @@ jobs: ruby: '2.3' before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test ruby' script: - make TEST=ruby integration @@ -70,6 +78,7 @@ jobs: jdk: openjdk8 before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test jdbc-mariadb' script: - make TEST=jdbc-mariadb integration @@ -77,6 +86,7 @@ jobs: node_js: '12' before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test javascript' script: - make TEST=javascript integration @@ -85,6 +95,7 @@ jobs: dotnet: '2.1' before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test dotnet' script: - make TEST=dotnet integration @@ -92,5 +103,6 @@ jobs: compiler: clang before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test c' script: - make TEST=c integration From a385e34fedfea7d97817302bbf90bcdaa02c67c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mart=C3=ADn?= Date: Thu, 24 Oct 2019 11:31:44 +0100 Subject: [PATCH 41/44] Check if partition is Nameable in SHOW PROCESSLIST output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Carlos Martín --- sql/plan/process.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sql/plan/process.go b/sql/plan/process.go index 52d2933a7..2e0bec439 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -83,7 +83,7 @@ func (t *ProcessIndexableTable) PartitionRows(ctx *sql.Context, p sql.Partition) return nil, err } - partitionName := string(p.Key()) + partitionName := partitionName(p) if t.OnPartitionStart != nil { t.OnPartitionStart(partitionName) } @@ -137,7 +137,7 @@ func (t *ProcessTable) PartitionRows(ctx *sql.Context, p sql.Partition) (sql.Row return nil, err } - partitionName := string(p.Key()) + partitionName := partitionName(p) if t.OnPartitionStart != nil { t.OnPartitionStart(partitionName) } @@ -206,7 +206,7 @@ func (i *trackedPartitionIndexKeyValueIter) Next() (sql.Partition, sql.IndexKeyV return nil, nil, err } - partitionName := string(p.Key()) + partitionName := partitionName(p) if i.OnPartitionStart != nil { i.OnPartitionStart(partitionName) } @@ -264,3 +264,10 @@ func (i *trackedIndexKeyValueIter) Next() ([]interface{}, []byte, error) { return v, k, nil } + +func partitionName(p sql.Partition) string { + if n, ok := p.(sql.Nameable); ok { + return n.Name() + } + return string(p.Key()) +} From 6059987c050728ff74cb734d8d4e4a3d8c2d8dd0 Mon Sep 17 00:00:00 2001 From: Carlos Date: Fri, 25 Oct 2019 16:31:04 +0100 Subject: [PATCH 42/44] Do not cancel context for async queries on success Signed-off-by: Carlos --- engine.go | 12 ++++++++++++ server/handler.go | 10 +++++++--- sql/core.go | 6 ++++++ sql/plan/create_index.go | 5 +++++ 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/engine.go b/engine.go index 6be2be3c2..f09c167c5 100644 --- a/engine.go +++ b/engine.go @@ -153,6 +153,18 @@ func (e *Engine) Query( return analyzed.Schema(), iter, nil } +// Async returns true if the query is async. If there are any errors with the +// query it returns false +func (e *Engine) Async(ctx *sql.Context, query string) bool { + parsed, err := parse.Parse(ctx, query) + if err != nil { + return false + } + + asyncNode, ok := parsed.(sql.AsyncNode) + return ok && asyncNode.IsAsync() +} + // AddDatabase adds the given database to the catalog. func (e *Engine) AddDatabase(db sql.Database) { e.Catalog.AddDatabase(db) diff --git a/server/handler.go b/server/handler.go index c76597eee..ed71a5f18 100644 --- a/server/handler.go +++ b/server/handler.go @@ -113,9 +113,13 @@ func (h *Handler) ComQuery( callback func(*sqltypes.Result) error, ) (err error) { ctx := h.sm.NewContextWithQuery(c, query) - newCtx, cancel := context.WithCancel(ctx) - defer cancel() - ctx = ctx.WithContext(newCtx) + + if !h.e.Async(ctx, query) { + newCtx, cancel := context.WithCancel(ctx) + ctx = ctx.WithContext(newCtx) + + defer cancel() + } handled, err := h.handleKill(c, query) if err != nil { diff --git a/sql/core.go b/sql/core.go index 2d49ece19..16ef6fa8d 100644 --- a/sql/core.go +++ b/sql/core.go @@ -120,6 +120,12 @@ type OpaqueNode interface { Opaque() bool } +// AsyncNode is a node that can be executed asynchronously. +type AsyncNode interface { + // IsAsync reports whether the node is async or not. + IsAsync() bool +} + // Expressioner is a node that contains expressions. type Expressioner interface { // Expressions returns the list of expressions contained by the node. diff --git a/sql/plan/create_index.go b/sql/plan/create_index.go index f45de2f7e..00455a127 100644 --- a/sql/plan/create_index.go +++ b/sql/plan/create_index.go @@ -282,6 +282,11 @@ func (c *CreateIndex) WithChildren(children ...sql.Node) (sql.Node, error) { return &nc, nil } +// IsAsync implements the AsyncNode interface. +func (c *CreateIndex) IsAsync() bool { + return c.Async +} + // getColumnsAndPrepareExpressions extracts the unique columns required by all // those expressions and fixes the indexes of the GetFields in the expressions // to match a row with only the returned columns in that same order. From e81f029b62114afaf9b2288d7cd30acabb0c32e5 Mon Sep 17 00:00:00 2001 From: Carlos Date: Tue, 29 Oct 2019 12:57:53 +0000 Subject: [PATCH 43/44] Format SHOW PROCESSLIST progress as a tree Signed-off-by: Carlos --- sql/analyzer/process.go | 10 +-- sql/analyzer/process_test.go | 14 ++-- sql/plan/processlist.go | 26 ++++++-- sql/plan/processlist_test.go | 27 +++++--- sql/processlist.go | 124 +++++++++++++++++++++++++++++++---- sql/processlist_test.go | 37 ++++++++--- 6 files changed, 191 insertions(+), 47 deletions(-) diff --git a/sql/analyzer/process.go b/sql/analyzer/process.go index 09fbae4e0..aa0d9d0c7 100644 --- a/sql/analyzer/process.go +++ b/sql/analyzer/process.go @@ -40,21 +40,21 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { } total = count } - processList.AddProgressItem(ctx.Pid(), name, total) + processList.AddTableProgress(ctx.Pid(), name, total) seen[name] = struct{}{} onPartitionDone := func(partitionName string) { - processList.UpdateProgress(ctx.Pid(), name, 1) - processList.RemoveProgressItem(ctx.Pid(), partitionName) + processList.UpdateTableProgress(ctx.Pid(), name, 1) + processList.RemovePartitionProgress(ctx.Pid(), name, partitionName) } onPartitionStart := func(partitionName string) { - processList.AddProgressItem(ctx.Pid(), partitionName, -1) + processList.AddPartitionProgress(ctx.Pid(), name, partitionName, -1) } onRowNext := func(partitionName string) { - processList.UpdateProgress(ctx.Pid(), partitionName, 1) + processList.UpdatePartitionProgress(ctx.Pid(), name, partitionName, 1) } var t sql.Table diff --git a/sql/analyzer/process_test.go b/sql/analyzer/process_test.go index 32a719500..91271af4c 100644 --- a/sql/analyzer/process_test.go +++ b/sql/analyzer/process_test.go @@ -35,10 +35,16 @@ func TestTrackProcess(t *testing.T) { require.Len(processes, 1) require.Equal("SELECT foo", processes[0].Query) require.Equal(sql.QueryProcess, processes[0].Type) - require.Equal(map[string]sql.Progress{ - "foo": sql.Progress{Total: 2}, - "bar": sql.Progress{Total: 4}, - }, processes[0].Progress) + require.Equal( + map[string]sql.TableProgress{ + "foo": sql.TableProgress{ + Progress: sql.Progress{Name: "foo", Done: 0, Total: 2}, + PartitionsProgress: map[string]sql.PartitionProgress{}}, + "bar": sql.TableProgress{ + Progress: sql.Progress{Name: "bar", Done: 0, Total: 4}, + PartitionsProgress: map[string]sql.PartitionProgress{}}, + }, + processes[0].Progress) proc, ok := result.(*plan.QueryProcess) require.True(ok) diff --git a/sql/plan/processlist.go b/sql/plan/processlist.go index bc9f4d18c..a447fb79d 100644 --- a/sql/plan/processlist.go +++ b/sql/plan/processlist.go @@ -1,7 +1,6 @@ package plan import ( - "fmt" "sort" "strings" @@ -77,21 +76,36 @@ func (p *ShowProcessList) RowIter(ctx *sql.Context) (sql.RowIter, error) { for i, proc := range processes { var status []string - for name, progress := range proc.Progress { - status = append(status, fmt.Sprintf("%s(%s)", name, progress)) + var names []string + for name := range proc.Progress { + names = append(names, name) + } + sort.Strings(names) + + for _, name := range names { + progress := proc.Progress[name] + + printer := sql.NewTreePrinter() + _ = printer.WriteNode("\n" + progress.String()) + children := []string{} + for _, partitionProgress := range progress.PartitionsProgress { + children = append(children, partitionProgress.String()) + } + sort.Strings(children) + _ = printer.WriteChildren(children...) + + status = append(status, printer.String()) } if len(status) == 0 { status = []string{"running"} } - sort.Strings(status) - rows[i] = process{ id: int64(proc.Connection), user: proc.User, time: int64(proc.Seconds()), - state: strings.Join(status, ", "), + state: strings.Join(status, ""), command: proc.Type.String(), host: ctx.Session.Client().Address, info: proc.Query, diff --git a/sql/plan/processlist_test.go b/sql/plan/processlist_test.go index 401661ba2..12ee98b9a 100644 --- a/sql/plan/processlist_test.go +++ b/sql/plan/processlist_test.go @@ -21,19 +21,21 @@ func TestShowProcessList(t *testing.T) { ctx, err := p.AddProcess(ctx, sql.QueryProcess, "SELECT foo") require.NoError(err) - p.AddProgressItem(ctx.Pid(), "a", 5) - p.AddProgressItem(ctx.Pid(), "b", 6) + p.AddTableProgress(ctx.Pid(), "a", 5) + p.AddTableProgress(ctx.Pid(), "b", 6) ctx = sql.NewContext(context.Background(), sql.WithPid(2), sql.WithSession(sess)) ctx, err = p.AddProcess(ctx, sql.CreateIndexProcess, "SELECT bar") require.NoError(err) - p.AddProgressItem(ctx.Pid(), "foo", 2) + p.AddTableProgress(ctx.Pid(), "foo", 2) - p.UpdateProgress(1, "a", 3) - p.UpdateProgress(1, "a", 1) - p.UpdateProgress(1, "b", 2) - p.UpdateProgress(2, "foo", 1) + p.UpdateTableProgress(1, "a", 3) + p.UpdateTableProgress(1, "a", 1) + p.UpdatePartitionProgress(1, "a", "a-1", 7) + p.UpdatePartitionProgress(1, "a", "a-2", 9) + p.UpdateTableProgress(1, "b", 2) + p.UpdateTableProgress(2, "foo", 1) n.ProcessList = p n.Database = "foo" @@ -44,8 +46,15 @@ func TestShowProcessList(t *testing.T) { require.NoError(err) expected := []sql.Row{ - {int64(1), "foo", addr, "foo", "query", int64(0), "a(4/5), b(2/6)", "SELECT foo"}, - {int64(1), "foo", addr, "foo", "create_index", int64(0), "foo(1/2)", "SELECT bar"}, + {int64(1), "foo", addr, "foo", "query", int64(0), + ` +a (4/5 partitions) + ├─ a-1 (7/? rows) + └─ a-2 (9/? rows) + +b (2/6 partitions) +`, "SELECT foo"}, + {int64(1), "foo", addr, "foo", "create_index", int64(0), "\nfoo (1/2 partitions)\n", "SELECT bar"}, } require.ElementsMatch(expected, rows) diff --git a/sql/processlist.go b/sql/processlist.go index 580323238..5bc1fecf6 100644 --- a/sql/processlist.go +++ b/sql/processlist.go @@ -12,17 +12,46 @@ import ( // Progress between done items and total items. type Progress struct { + Name string Done int64 Total int64 } -func (p Progress) String() string { +func (p Progress) totalString() string { var total = "?" if p.Total > 0 { total = fmt.Sprint(p.Total) } + return total +} + +// TableProgress keeps track of a table progress, and for each of its partitions +type TableProgress struct { + Progress + PartitionsProgress map[string]PartitionProgress +} + +func NewTableProgress(name string, total int64) TableProgress { + return TableProgress{ + Progress: Progress{ + Name: name, + Total: total, + }, + PartitionsProgress: make(map[string]PartitionProgress), + } +} + +func (p TableProgress) String() string { + return fmt.Sprintf("%s (%d/%s partitions)", p.Name, p.Done, p.totalString()) +} - return fmt.Sprintf("%d/%s", p.Done, total) +// PartitionProgress keeps track of a partition progress +type PartitionProgress struct { + Progress +} + +func (p PartitionProgress) String() string { + return fmt.Sprintf("%s (%d/%s rows)", p.Name, p.Done, p.totalString()) } // ProcessType is the type of process. @@ -53,7 +82,7 @@ type Process struct { User string Type ProcessType Query string - Progress map[string]Progress + Progress map[string]TableProgress StartedAt time.Time Kill context.CancelFunc } @@ -108,7 +137,7 @@ func (pl *ProcessList) AddProcess( Connection: ctx.ID(), Type: typ, Query: query, - Progress: make(map[string]Progress), + Progress: make(map[string]TableProgress), User: ctx.Session.Client().User, StartedAt: time.Now(), Kill: cancel, @@ -117,9 +146,9 @@ func (pl *ProcessList) AddProcess( return ctx, nil } -// UpdateProgress updates the progress of the item with the given name for the +// UpdateTableProgress updates the progress of the table with the given name for the // process with the given pid. -func (pl *ProcessList) UpdateProgress(pid uint64, name string, delta int64) { +func (pl *ProcessList) UpdateTableProgress(pid uint64, name string, delta int64) { pl.mu.Lock() defer pl.mu.Unlock() @@ -130,16 +159,41 @@ func (pl *ProcessList) UpdateProgress(pid uint64, name string, delta int64) { progress, ok := p.Progress[name] if !ok { - progress = Progress{Total: -1} + progress = NewTableProgress(name, -1) } progress.Done += delta p.Progress[name] = progress } -// AddProgressItem adds a new item to track progress from to the process with +// UpdatePartitionProgress updates the progress of the table partition with the +// given name for the process with the given pid. +func (pl *ProcessList) UpdatePartitionProgress(pid uint64, tableName, partitionName string, delta int64) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + tablePg, ok := p.Progress[tableName] + if !ok { + return + } + + partitionPg, ok := tablePg.PartitionsProgress[partitionName] + if !ok { + partitionPg = PartitionProgress{Progress: Progress{Name: partitionName, Total: -1}} + } + + partitionPg.Done += delta + tablePg.PartitionsProgress[partitionName] = partitionPg +} + +// AddTableProgress adds a new item to track progress from to the process with // the given pid. If the pid does not exist, it will do nothing. -func (pl *ProcessList) AddProgressItem(pid uint64, name string, total int64) { +func (pl *ProcessList) AddTableProgress(pid uint64, name string, total int64) { pl.mu.Lock() defer pl.mu.Unlock() @@ -152,13 +206,38 @@ func (pl *ProcessList) AddProgressItem(pid uint64, name string, total int64) { pg.Total = total p.Progress[name] = pg } else { - p.Progress[name] = Progress{Total: total} + p.Progress[name] = NewTableProgress(name, total) } } -// RemoveProgressItem removes an existing item tracking progress from the +// AddPartitionProgress adds a new item to track progress from to the process with +// the given pid. If the pid or the table does not exist, it will do nothing. +func (pl *ProcessList) AddPartitionProgress(pid uint64, tableName, partitionName string, total int64) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + tablePg, ok := p.Progress[tableName] + if !ok { + return + } + + if pg, ok := tablePg.PartitionsProgress[partitionName]; ok { + pg.Total = total + tablePg.PartitionsProgress[partitionName] = pg + } else { + tablePg.PartitionsProgress[partitionName] = + PartitionProgress{Progress: Progress{Name: partitionName, Total: total}} + } +} + +// RemoveTableProgress removes an existing item tracking progress from the // process with the given pid, if it exists. -func (pl *ProcessList) RemoveProgressItem(pid uint64, name string) { +func (pl *ProcessList) RemoveTableProgress(pid uint64, name string) { pl.mu.Lock() defer pl.mu.Unlock() @@ -170,6 +249,25 @@ func (pl *ProcessList) RemoveProgressItem(pid uint64, name string) { delete(p.Progress, name) } +// RemovePartitionProgress removes an existing item tracking progress from the +// process with the given pid, if it exists. +func (pl *ProcessList) RemovePartitionProgress(pid uint64, tableName, partitionName string) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + tablePg, ok := p.Progress[tableName] + if !ok { + return + } + + delete(tablePg.PartitionsProgress, partitionName) +} + // Kill terminates all queries for a given connection id. func (pl *ProcessList) Kill(connID uint32) { pl.mu.Lock() @@ -220,7 +318,7 @@ func (pl *ProcessList) Processes() []Process { for _, proc := range pl.procs { p := *proc - var progress = make(map[string]Progress, len(p.Progress)) + var progress = make(map[string]TableProgress, len(p.Progress)) for n, p := range p.Progress { progress[n] = p } diff --git a/sql/processlist_test.go b/sql/processlist_test.go index a6fb3a44f..198f40b12 100644 --- a/sql/processlist_test.go +++ b/sql/processlist_test.go @@ -20,16 +20,16 @@ func TestProcessList(t *testing.T) { require.Equal(uint64(1), ctx.Pid()) require.Len(p.procs, 1) - p.AddProgressItem(ctx.Pid(), "a", 5) - p.AddProgressItem(ctx.Pid(), "b", 6) + p.AddTableProgress(ctx.Pid(), "a", 5) + p.AddTableProgress(ctx.Pid(), "b", 6) expectedProcess := &Process{ Pid: 1, Connection: 1, Type: QueryProcess, - Progress: map[string]Progress{ - "a": Progress{0, 5}, - "b": Progress{0, 6}, + Progress: map[string]TableProgress{ + "a": {Progress{Name: "a", Done: 0, Total: 5}, map[string]PartitionProgress{}}, + "b": {Progress{Name: "b", Done: 0, Total: 6}, map[string]PartitionProgress{}}, }, User: "foo", Query: "SELECT foo", @@ -39,19 +39,36 @@ func TestProcessList(t *testing.T) { p.procs[ctx.Pid()].Kill = nil require.Equal(expectedProcess, p.procs[ctx.Pid()]) + p.AddPartitionProgress(ctx.Pid(), "b", "b-1", -1) + p.AddPartitionProgress(ctx.Pid(), "b", "b-2", -1) + p.AddPartitionProgress(ctx.Pid(), "b", "b-3", -1) + + p.UpdatePartitionProgress(ctx.Pid(), "b", "b-2", 1) + + p.RemovePartitionProgress(ctx.Pid(), "b", "b-3") + + expectedProgress := map[string]TableProgress{ + "a": {Progress{Name: "a", Total: 5}, map[string]PartitionProgress{}}, + "b": {Progress{Name: "b", Total: 6}, map[string]PartitionProgress{ + "b-1": {Progress{Name: "b-1", Done: 0, Total: -1}}, + "b-2": {Progress{Name: "b-2", Done: 1, Total: -1}}, + }}, + } + require.Equal(expectedProgress, p.procs[ctx.Pid()].Progress) + ctx = NewContext(context.Background(), WithPid(2), WithSession(sess)) ctx, err = p.AddProcess(ctx, CreateIndexProcess, "SELECT bar") require.NoError(err) - p.AddProgressItem(ctx.Pid(), "foo", 2) + p.AddTableProgress(ctx.Pid(), "foo", 2) require.Equal(uint64(2), ctx.Pid()) require.Len(p.procs, 2) - p.UpdateProgress(1, "a", 3) - p.UpdateProgress(1, "a", 1) - p.UpdateProgress(1, "b", 2) - p.UpdateProgress(2, "foo", 1) + p.UpdateTableProgress(1, "a", 3) + p.UpdateTableProgress(1, "a", 1) + p.UpdateTableProgress(1, "b", 2) + p.UpdateTableProgress(2, "foo", 1) require.Equal(int64(4), p.procs[1].Progress["a"].Done) require.Equal(int64(2), p.procs[1].Progress["b"].Done) From af7466b4e4dc47c09b56b4be84b0f375c2e6f091 Mon Sep 17 00:00:00 2001 From: Marcelo Novaes Date: Wed, 27 Jan 2021 16:40:42 -0300 Subject: [PATCH 44/44] Archive repository and add notice Close #873 --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 55271b663..1ab7e9ee5 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +_**Notice: This repository is no longer actively maintained, and no further updates will be done, nor issues/PRs will be answered or attended. An alternative actively maintained can be found at https://github.com/dolthub/go-mysql-server repository.**_ + # go-mysql-server [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) Build Status pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy