diff --git a/.gitignore b/.gitignore index 251b391c1..cc1c664a6 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ Makefile.main .ci/ _example/main _example/*.exe +test-server diff --git a/.travis.yml b/.travis.yml index 9ebb6ac49..e06f94818 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,10 +24,13 @@ script: jobs: include: - go: 1.11.x + name: 'Unit tests Go 1.11' - go: 1.12.x + name: 'Unit tests Go 1.12' # Integration test builds for 3rd party clients - go: 1.12.x + name: 'Integration test go' script: - make TEST=go integration @@ -35,13 +38,31 @@ jobs: python: '3.6' before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test python-pymysql' script: - make TEST=python-pymysql integration + - language: python + python: '3.6' + before_install: + - eval "$(gimme 1.12.4)" + name: 'Integration test python-mysql' + script: + - make TEST=python-mysql integration + + - language: python + python: '3.6' + before_install: + - eval "$(gimme 1.12.4)" + name: 'Integration test python-sqlalchemy' + script: + - make TEST=python-sqlalchemy integration + - language: php php: '7.1' before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test php' script: - make TEST=php integration @@ -49,6 +70,7 @@ jobs: ruby: '2.3' before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test ruby' script: - make TEST=ruby integration @@ -56,6 +78,7 @@ jobs: jdk: openjdk8 before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test jdbc-mariadb' script: - make TEST=jdbc-mariadb integration @@ -63,6 +86,7 @@ jobs: node_js: '12' before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test javascript' script: - make TEST=javascript integration @@ -71,6 +95,7 @@ jobs: dotnet: '2.1' before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test dotnet' script: - make TEST=dotnet integration @@ -78,5 +103,6 @@ jobs: compiler: clang before_install: - eval "$(gimme 1.12.4)" + name: 'Integration test c' script: - make TEST=c integration diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 000000000..f48cffb3a --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @src-d/data-processing diff --git a/MAINTAINERS b/MAINTAINERS index dc6817245..8d8de2618 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -1 +1,2 @@ -Antonio Navarro Perez (@ajnavarro) +Miguel Molina (@erizocosmico) +Juanjo Álvarez Martinez (@juanjux) diff --git a/README.md b/README.md index 55271b663..1ab7e9ee5 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +_**Notice: This repository is no longer actively maintained, and no further updates will be done, nor issues/PRs will be answered or attended. An alternative actively maintained can be found at https://github.com/dolthub/go-mysql-server repository.**_ + # go-mysql-server [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) Build Status diff --git a/SUPPORTED.md b/SUPPORTED.md index a316dffe1..87841ac72 100644 --- a/SUPPORTED.md +++ b/SUPPORTED.md @@ -27,7 +27,6 @@ - ALIAS (AS) - CAST/CONVERT - CREATE TABLE -- DESCRIBE/DESC/EXPLAIN [table name] - DESCRIBE/DESC/EXPLAIN FORMAT=TREE [query] - DISTINCT - FILTER (WHERE) @@ -80,9 +79,6 @@ - div - % -## Subqueries -- supported only as tables, not as expressions. - ## Functions - ARRAY_LENGTH - CEIL @@ -133,3 +129,6 @@ - WEEKDAY - YEAR - YEARWEEK + +## Subqueries +Supported both as a table and as expressions but they can't access the parent query scope. diff --git a/_integration/javascript/package-lock.json b/_integration/javascript/package-lock.json index 41c153513..b37fc5abe 100644 --- a/_integration/javascript/package-lock.json +++ b/_integration/javascript/package-lock.json @@ -658,11 +658,6 @@ "integrity": "sha1-ibTRmasr7kneFk6gK4nORi1xt2c=", "dev": true }, - "bignumber.js": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/bignumber.js/-/bignumber.js-4.1.0.tgz", - "integrity": "sha512-eJzYkFYy9L4JzXsbymsFn3p54D+llV27oTQ+ziJG7WFRheJcNZilgVXMG0LoZtlQSKBsJdWtLFqOD0u+U0jZKA==" - }, "binary-extensions": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.0.0.tgz", @@ -2266,13 +2261,21 @@ "dev": true }, "mysql": { - "version": "github:mysqljs/mysql#cf5d1e396a343ffa0fba23b3791d2dd5da20f30e", - "from": "github:mysqljs/mysql", + "version": "2.17.1", + "resolved": "https://registry.npmjs.org/mysql/-/mysql-2.17.1.tgz", + "integrity": "sha512-7vMqHQ673SAk5C8fOzTG2LpPcf3bNt0oL3sFpxPEEFp1mdlDcrLK0On7z8ZYKaaHrHwNcQ/MTUz7/oobZ2OyyA==", "requires": { - "bignumber.js": "9.0.0", + "bignumber.js": "7.2.1", "readable-stream": "2.3.6", "safe-buffer": "5.1.2", "sqlstring": "2.3.1" + }, + "dependencies": { + "bignumber.js": { + "version": "7.2.1", + "resolved": "https://registry.npmjs.org/bignumber.js/-/bignumber.js-7.2.1.tgz", + "integrity": "sha512-S4XzBk5sMB+Rcb/LNcpzXr57VRTxgAvaAEDAl1AwRx27j00hT84O6OkteE7u8UB3NuaaygCRrEpqox4uDOrbdQ==" + } } }, "normalize-package-data": { @@ -2674,9 +2677,9 @@ } }, "process-nextick-args": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.0.tgz", - "integrity": "sha512-MtEC1TqN0EU5nephaJ4rAtThHtC86dNN9qCuEhtshvpVBkAW5ZO7BASN9REnF9eoXGcRub+pFuKEpOHE+HbEMw==" + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.1.tgz", + "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==" }, "pseudomap": { "version": "1.0.2", @@ -2803,7 +2806,7 @@ }, "readable-stream": { "version": "2.3.6", - "resolved": "http://registry.npmjs.org/readable-stream/-/readable-stream-2.3.6.tgz", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.6.tgz", "integrity": "sha512-tQtKA9WIAhBF3+VLAseyMqZeBjW0AHJoxOtYqSUZNJxauErmLbVm2FW1y+J/YA9dUrAC39ITejlZWhVIwawkKw==", "requires": { "core-util-is": "~1.0.0", diff --git a/_integration/javascript/package.json b/_integration/javascript/package.json index eff42abd3..da6670c88 100644 --- a/_integration/javascript/package.json +++ b/_integration/javascript/package.json @@ -7,7 +7,7 @@ "test": "./node_modules/.bin/ava" }, "dependencies": { - "mysql": "github:mysqljs/mysql" + "mysql": "2.17.1" }, "devDependencies": { "ava": "2.3.0" diff --git a/_integration/python-sqlalchemy/requirements.txt b/_integration/python-sqlalchemy/requirements.txt index a0e94bb24..4a2392f06 100644 --- a/_integration/python-sqlalchemy/requirements.txt +++ b/_integration/python-sqlalchemy/requirements.txt @@ -1,2 +1,3 @@ pandas -sqlalchemy \ No newline at end of file +sqlalchemy +mysqlclient diff --git a/_integration/python-sqlalchemy/test.py b/_integration/python-sqlalchemy/test.py index 6c8767d7e..e1baa052d 100644 --- a/_integration/python-sqlalchemy/test.py +++ b/_integration/python-sqlalchemy/test.py @@ -6,14 +6,13 @@ class TestMySQL(unittest.TestCase): def test_connect(self): - engine = sqlalchemy.create_engine('mysql+pymysql://root:@127.0.0.1:3306/mydb') + engine = sqlalchemy.create_engine('mysql+mysqldb://root:@127.0.0.1:3306/mydb') with engine.connect() as conn: expected = { "name": {0: 'John Doe', 1: 'John Doe', 2: 'Jane Doe', 3: 'Evil Bob'}, "email": {0: 'john@doe.com', 1: 'johnalt@doe.com', 2: 'jane@doe.com', 3: 'evilbob@gmail.com'}, - "phone_numbers": {0: '["555-555-555"]', 1: '[]', 2: '[]', 3: '["555-666-555","666-666-666"]'}, + "phone_numbers": {0: ['555-555-555'], 1: [], 2: [], 3: ['555-666-555', '666-666-666']}, } - repo_df = pd.read_sql_table("mytable", con=conn) d = repo_df.to_dict() del d["created_at"] diff --git a/engine.go b/engine.go index ae11ef701..f09c167c5 100644 --- a/engine.go +++ b/engine.go @@ -1,4 +1,4 @@ -package sqle // import "github.com/src-d/go-mysql-server" +package sqle import ( "time" @@ -120,7 +120,7 @@ func (e *Engine) Query( case *plan.CreateIndex: typ = sql.CreateIndexProcess perm = auth.ReadPerm | auth.WritePerm - case *plan.InsertInto, *plan.DeleteFrom, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables: + case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables: perm = auth.ReadPerm | auth.WritePerm } @@ -153,6 +153,18 @@ func (e *Engine) Query( return analyzed.Schema(), iter, nil } +// Async returns true if the query is async. If there are any errors with the +// query it returns false +func (e *Engine) Async(ctx *sql.Context, query string) bool { + parsed, err := parse.Parse(ctx, query) + if err != nil { + return false + } + + asyncNode, ok := parsed.(sql.AsyncNode) + return ok && asyncNode.IsAsync() +} + // AddDatabase adds the given database to the catalog. func (e *Engine) AddDatabase(db sql.Database) { e.Catalog.AddDatabase(db) diff --git a/engine_pilosa_test.go b/engine_pilosa_test.go index 00380a1b9..a6da4c5a9 100644 --- a/engine_pilosa_test.go +++ b/engine_pilosa_test.go @@ -207,4 +207,3 @@ func TestCreateIndex(t *testing.T) { require.NoError(os.RemoveAll(tmpDir)) }() } - diff --git a/engine_test.go b/engine_test.go index dce9066d6..b022a3e6b 100644 --- a/engine_test.go +++ b/engine_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "github.com/opentracing/opentracing-go" + sqle "github.com/src-d/go-mysql-server" "github.com/src-d/go-mysql-server/auth" "github.com/src-d/go-mysql-server/memory" @@ -582,7 +584,7 @@ var queries = []struct { { `SELECT COALESCE(NULL, NULL, NULL, COALESCE(NULL, 1234567890))`, []sql.Row{ - {int64(1234567890)}, + {int32(1234567890)}, }, }, { @@ -866,8 +868,8 @@ var queries = []struct { WHERE TABLE_SCHEMA='mydb' AND TABLE_NAME='mytable' `, []sql.Row{ - {"s", "TEXT"}, - {"i", "BIGINT"}, + {"s", "text"}, + {"i", "bigint"}, }, }, { @@ -981,7 +983,7 @@ var queries = []struct { }, { `SELECT -1`, - []sql.Row{{int64(-1)}}, + []sql.Row{{int8(-1)}}, }, { ` @@ -1043,13 +1045,13 @@ var queries = []struct { { `SELECT nullif(123, 321)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { `SELECT ifnull(123, NULL)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { @@ -1061,19 +1063,19 @@ var queries = []struct { { `SELECT ifnull(NULL, 123)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { `SELECT ifnull(123, 123)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { `SELECT ifnull(123, 321)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { @@ -1085,7 +1087,7 @@ var queries = []struct { { `SELECT round(15, 1)`, []sql.Row{ - {int64(15)}, + {int8(15)}, }, }, { @@ -1452,7 +1454,7 @@ var queries = []struct { }, { `SELECT 1 FROM mytable GROUP BY i HAVING i > 1`, - []sql.Row{{int64(1)}, {int64(1)}}, + []sql.Row{{int8(1)}, {int8(1)}}, }, { `SELECT avg(i) FROM mytable GROUP BY i HAVING avg(i) > 1`, @@ -1548,6 +1550,42 @@ var queries = []struct { {int64(5), "there is some text in here"}, }, }, + { + `SELECT i FROM mytable WHERE i = (SELECT 1)`, + []sql.Row{{int64(1)}}, + }, + { + `SELECT i FROM mytable WHERE i IN (SELECT i FROM mytable)`, + []sql.Row{ + {int64(1)}, + {int64(2)}, + {int64(3)}, + }, + }, + { + `SELECT i FROM mytable WHERE i NOT IN (SELECT i FROM mytable ORDER BY i ASC LIMIT 2)`, + []sql.Row{ + {int64(3)}, + }, + }, + { + `SELECT (SELECT i FROM mytable ORDER BY i ASC LIMIT 1) AS x`, + []sql.Row{{int64(1)}}, + }, + { + `SELECT DISTINCT n FROM bigtable ORDER BY t`, + []sql.Row{ + {int64(1)}, + {int64(9)}, + {int64(7)}, + {int64(3)}, + {int64(2)}, + {int64(8)}, + {int64(6)}, + {int64(5)}, + {int64(4)}, + }, + }, } func TestQueries(t *testing.T) { @@ -1865,8 +1903,8 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(math.MaxInt8), int64(math.MaxInt16), int64(math.MaxInt32), int64(math.MaxInt64), - int64(math.MaxUint8), int64(math.MaxUint16), int64(math.MaxUint32), uint64(math.MaxUint64), + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), float64(math.MaxFloat32), float64(math.MaxFloat64), timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), "random text", true, `{"key":"value"}`, "blobdata", @@ -1883,8 +1921,8 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(math.MaxInt8), int64(math.MaxInt16), int64(math.MaxInt32), int64(math.MaxInt64), - int64(math.MaxUint8), int64(math.MaxUint16), int64(math.MaxUint32), uint64(math.MaxUint64), + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), float64(math.MaxFloat32), float64(math.MaxFloat64), timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), "random text", true, `{"key":"value"}`, "blobdata", @@ -1901,8 +1939,8 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1), - int64(0), int64(0), int64(0), int64(0), + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), + uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), "", false, ``, "", @@ -1919,8 +1957,8 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1), - int64(0), int64(0), int64(0), int64(0), + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), + uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), "", false, ``, "", @@ -2065,8 +2103,8 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(math.MaxInt8), int64(math.MaxInt16), int64(math.MaxInt32), int64(math.MaxInt64), - int64(math.MaxUint8), int64(math.MaxUint16), int64(math.MaxUint32), uint64(math.MaxUint64), + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), float64(math.MaxFloat32), float64(math.MaxFloat64), timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), "random text", true, `{"key":"value"}`, "blobdata", @@ -2083,8 +2121,8 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(math.MaxInt8), int64(math.MaxInt16), int64(math.MaxInt32), int64(math.MaxInt64), - int64(math.MaxUint8), int64(math.MaxUint16), int64(math.MaxUint32), uint64(math.MaxUint64), + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), float64(math.MaxFloat32), float64(math.MaxFloat64), timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), "random text", true, `{"key":"value"}`, "blobdata", @@ -2101,8 +2139,8 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1), - int64(0), int64(0), int64(0), int64(0), + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), + uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), "", false, ``, "", @@ -2119,8 +2157,8 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1), - int64(0), int64(0), int64(0), int64(0), + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), + uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), "", false, ``, "", @@ -2209,6 +2247,142 @@ func TestReplaceIntoErrors(t *testing.T) { } } +func TestUpdate(t *testing.T) { + var updates = []struct { + updateQuery string + expectedUpdate []sql.Row + selectQuery string + expectedSelect []sql.Row + }{ + { + "UPDATE mytable SET s = 'updated';", + []sql.Row{{int64(3), int64(3)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}}, + }, + { + "UPDATE mytable SET s = 'updated' WHERE i > 9999;", + []sql.Row{{int64(0), int64(0)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}}, + }, + { + "UPDATE mytable SET s = 'updated' WHERE i = 1;", + []sql.Row{{int64(1), int64(1)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "second row"}, {int64(3), "third row"}}, + }, + { + "UPDATE mytable SET s = 'updated' WHERE i <> 9999;", + []sql.Row{{int64(3), int64(3)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}}, + }, + { + "UPDATE floattable SET f32 = f32 + f32, f64 = f32 * f64 WHERE i = 2;", + []sql.Row{{int64(1), int64(1)}}, + "SELECT * FROM floattable WHERE i = 2;", + []sql.Row{{int64(2), float32(3.0), float64(4.5)}}, + }, + { + "UPDATE floattable SET f32 = 5, f32 = 4 WHERE i = 1;", + []sql.Row{{int64(1), int64(1)}}, + "SELECT f32 FROM floattable WHERE i = 1;", + []sql.Row{{float32(4.0)}}, + }, + { + "UPDATE mytable SET s = 'first row' WHERE i = 1;", + []sql.Row{{int64(1), int64(0)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}}, + }, + { + "UPDATE niltable SET b = NULL WHERE f IS NULL;", + []sql.Row{{int64(2), int64(1)}}, + "SELECT * FROM niltable WHERE f IS NULL;", + []sql.Row{{int64(4), nil, nil}, {nil, nil, nil}}, + }, + { + "UPDATE mytable SET s = 'updated' ORDER BY i ASC LIMIT 2;", + []sql.Row{{int64(2), int64(2)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "third row"}}, + }, + { + "UPDATE mytable SET s = 'updated' ORDER BY i DESC LIMIT 2;", + []sql.Row{{int64(2), int64(2)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "updated"}, {int64(3), "updated"}}, + }, + { + "UPDATE mytable SET s = 'updated' ORDER BY i LIMIT 1 OFFSET 1;", + []sql.Row{{int64(1), int64(1)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "updated"}, {int64(3), "third row"}}, + }, + { + "UPDATE mytable SET s = 'updated';", + []sql.Row{{int64(3), int64(3)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}}, + }, + } + + for _, update := range updates { + e := newEngine(t) + ctx := newCtx() + testQueryWithContext(ctx, t, e, update.updateQuery, update.expectedUpdate) + testQueryWithContext(ctx, t, e, update.selectQuery, update.expectedSelect) + } +} + +func TestUpdateErrors(t *testing.T) { + var expectedFailures = []struct { + name string + query string + }{ + { + "invalid table", + "UPDATE doesnotexist SET i = 0;", + }, + { + "invalid column set", + "UPDATE mytable SET z = 0;", + }, + { + "invalid column set value", + "UPDATE mytable SET i = z;", + }, + { + "invalid column where", + "UPDATE mytable SET s = 'hi' WHERE z = 1;", + }, + { + "invalid column order by", + "UPDATE mytable SET s = 'hi' ORDER BY z;", + }, + { + "negative limit", + "UPDATE mytable SET s = 'hi' LIMIT -1;", + }, + { + "negative offset", + "UPDATE mytable SET s = 'hi' LIMIT 1 OFFSET -1;", + }, + { + "set null on non-nullable", + "UPDATE mytable SET s = NULL;", + }, + } + + for _, expectedFailure := range expectedFailures { + t.Run(expectedFailure.name, func(t *testing.T) { + _, _, err := newEngine(t).Query(newCtx(), expectedFailure.query) + require.Error(t, err) + }) + } +} + const testNumPartitions = 5 func TestAmbiguousColumnResolution(t *testing.T) { @@ -2270,7 +2444,7 @@ func TestAmbiguousColumnResolution(t *testing.T) { require.Equal(expected, rs) } -func TestDDL(t *testing.T) { +func TestCreateTable(t *testing.T) { require := require.New(t) e := newEngine(t) @@ -2301,6 +2475,83 @@ func TestDDL(t *testing.T) { } require.Equal(s, testTable.Schema()) + + testQuery(t, e, + "CREATE TABLE t2 (a INTEGER NOT NULL PRIMARY KEY, "+ + "b VARCHAR(10) NOT NULL)", + []sql.Row(nil), + ) + + db, err = e.Catalog.Database("mydb") + require.NoError(err) + + testTable, ok = db.Tables()["t2"] + require.True(ok) + + s = sql.Schema{ + {Name: "a", Type: sql.Int32, Nullable: false, PrimaryKey: true, Source: "t2"}, + {Name: "b", Type: sql.Text, Nullable: false, Source: "t2"}, + } + + require.Equal(s, testTable.Schema()) + + testQuery(t, e, + "CREATE TABLE t3(a INTEGER NOT NULL,"+ + "b TEXT NOT NULL,"+ + "c bool, primary key (a,b))", + []sql.Row(nil), + ) + + db, err = e.Catalog.Database("mydb") + require.NoError(err) + + testTable, ok = db.Tables()["t3"] + require.True(ok) + + s = sql.Schema{ + {Name: "a", Type: sql.Int32, Nullable: false, PrimaryKey: true, Source: "t3"}, + {Name: "b", Type: sql.Text, Nullable: false, PrimaryKey: true, Source: "t3"}, + {Name: "c", Type: sql.Uint8, Nullable: true, Source: "t3"}, + } + + require.Equal(s, testTable.Schema()) +} + +func TestDropTable(t *testing.T) { + require := require.New(t) + + e := newEngine(t) + db, err := e.Catalog.Database("mydb") + require.NoError(err) + + _, ok := db.Tables()["mytable"] + require.True(ok) + + testQuery(t, e, + "DROP TABLE IF EXISTS mytable, not_exist", + []sql.Row(nil), + ) + + _, ok = db.Tables()["mytable"] + require.False(ok) + + _, ok = db.Tables()["othertable"] + require.True(ok) + _, ok = db.Tables()["tabletest"] + require.True(ok) + + testQuery(t, e, + "DROP TABLE IF EXISTS othertable, tabletest", + []sql.Row(nil), + ) + + _, ok = db.Tables()["othertable"] + require.False(ok) + _, ok = db.Tables()["tabletest"] + require.False(ok) + + _, _, err = e.Query(newCtx(), "DROP TABLE not_exist") + require.Error(err) } func TestNaturalJoin(t *testing.T) { @@ -2634,12 +2885,12 @@ func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine { insertRows( t, floatTable, - sql.NewRow(1, float32(1.0), float64(1.0)), - sql.NewRow(2, float32(1.5), float64(1.5)), - sql.NewRow(3, float32(2.0), float64(2.0)), - sql.NewRow(4, float32(2.5), float64(2.5)), - sql.NewRow(-1, float32(-1.0), float64(-1.0)), - sql.NewRow(-2, float32(-1.5), float64(-1.5)), + sql.NewRow(int64(1), float32(1.0), float64(1.0)), + sql.NewRow(int64(2), float32(1.5), float64(1.5)), + sql.NewRow(int64(3), float32(2.0), float64(2.0)), + sql.NewRow(int64(4), float32(2.5), float64(2.5)), + sql.NewRow(int64(-1), float32(-1.0), float64(-1.0)), + sql.NewRow(int64(-2), float32(-1.5), float64(-1.5)), ) nilTable := memory.NewPartitionedTable("niltable", sql.Schema{ @@ -2925,7 +3176,7 @@ func TestSessionVariables(t *testing.T) { rows, err := sql.RowIterToRows(iter) require.NoError(err) - require.Equal([]sql.Row{{int64(1), ",STRICT_TRANS_TABLES"}}, rows) + require.Equal([]sql.Row{{int8(1), ",STRICT_TRANS_TABLES"}}, rows) } func TestSessionVariablesONOFF(t *testing.T) { @@ -3171,6 +3422,32 @@ func TestDeleteFromErrors(t *testing.T) { } } +type mockSpan struct { + opentracing.Span + finished bool +} + +func (m *mockSpan) Finish() { + m.finished = true +} + +func TestRootSpanFinish(t *testing.T) { + e := newEngine(t) + fakeSpan := &mockSpan{Span: opentracing.NoopTracer{}.StartSpan("")} + ctx := sql.NewContext( + context.Background(), + sql.WithRootSpan(fakeSpan), + ) + + _, iter, err := e.Query(ctx, "SELECT 1") + require.NoError(t, err) + + _, err = sql.RowIterToRows(iter) + require.NoError(t, err) + + require.True(t, fakeSpan.finished) +} + var generatorQueries = []struct { query string expected []sql.Row diff --git a/internal/sockstate/netstat_linux.go b/internal/sockstate/netstat_linux.go index 435c7deaa..a7eb6ff62 100644 --- a/internal/sockstate/netstat_linux.go +++ b/internal/sockstate/netstat_linux.go @@ -22,8 +22,8 @@ import ( const ( pathTCP4Tab = "/proc/net/tcp" pathTCP6Tab = "/proc/net/tcp6" - ipv4StrLen = 8 - ipv6StrLen = 32 + ipv4StrLen = 8 + ipv6StrLen = 32 ) type procFd struct { diff --git a/log.go b/log.go index 4880acebf..346b545f7 100644 --- a/log.go +++ b/log.go @@ -1,4 +1,4 @@ -package sqle // import "github.com/src-d/go-mysql-server" +package sqle import ( "github.com/golang/glog" diff --git a/memory/database.go b/memory/database.go index b5b87a741..133c4f5de 100644 --- a/memory/database.go +++ b/memory/database.go @@ -33,8 +33,8 @@ func (d *Database) AddTable(name string, t sql.Table) { d.tables[name] = t } -// Create creates a table with the given name and schema -func (d *Database) Create(name string, schema sql.Schema) error { +// CreateTable creates a table with the given name and schema +func (d *Database) CreateTable(ctx *sql.Context, name string, schema sql.Schema) error { _, ok := d.tables[name] if ok { return sql.ErrTableAlreadyExists.New(name) @@ -43,3 +43,15 @@ func (d *Database) Create(name string, schema sql.Schema) error { d.tables[name] = NewTable(name, schema) return nil } + +// DropTable drops the table with the given name +func (d *Database) DropTable(ctx *sql.Context, name string) error { + _, ok := d.tables[name] + if !ok { + return sql.ErrTableNotFound.New(name) + } + + delete(d.tables, name) + return nil +} + diff --git a/memory/database_test.go b/memory/database_test.go index e41c9e289..c6a2f9182 100644 --- a/memory/database_test.go +++ b/memory/database_test.go @@ -19,8 +19,7 @@ func TestDatabase_AddTable(t *testing.T) { tables := db.Tables() require.Equal(0, len(tables)) - var altDb sql.Alterable = db - err := altDb.Create("test_table", nil) + err := db.CreateTable(sql.NewEmptyContext(), "test_table", nil) require.NoError(err) tables = db.Tables() @@ -29,6 +28,6 @@ func TestDatabase_AddTable(t *testing.T) { require.True(ok) require.NotNil(tt) - err = altDb.Create("test_table", nil) + err = db.CreateTable(sql.NewEmptyContext(), "test_table", nil) require.Error(err) } diff --git a/memory/table.go b/memory/table.go index a73ecedb8..088956316 100644 --- a/memory/table.go +++ b/memory/table.go @@ -290,6 +290,37 @@ func (t *Table) Delete(ctx *sql.Context, row sql.Row) error { return nil } +func (t *Table) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error { + if err := checkRow(t.schema, oldRow); err != nil { + return err + } + if err := checkRow(t.schema, newRow); err != nil { + return err + } + + matches := false + for partitionIndex, partition := range t.partitions { + for partitionRowIndex, partitionRow := range partition { + matches = true + for rIndex, val := range oldRow { + if val != partitionRow[rIndex] { + matches = false + break + } + } + if matches { + t.partitions[partitionIndex][partitionRowIndex] = newRow + break + } + } + if matches { + break + } + } + + return nil +} + func checkRow(schema sql.Schema, row sql.Row) error { if len(row) != len(schema) { return sql.ErrUnexpectedRowLength.New(len(schema), len(row)) diff --git a/server/context.go b/server/context.go index ca2d7ce7b..6ee2ee508 100644 --- a/server/context.go +++ b/server/context.go @@ -92,15 +92,13 @@ func (s *SessionManager) NewContextWithQuery( s.mu.Unlock() context := sql.NewContext( - opentracing.ContextWithSpan( - context.Background(), - s.tracer.StartSpan("query"), - ), + context.Background(), sql.WithSession(sess), sql.WithTracer(s.tracer), sql.WithPid(s.nextPid()), sql.WithQuery(query), sql.WithMemoryManager(s.memory), + sql.WithRootSpan(s.tracer.StartSpan("query")), ) return context diff --git a/server/handler.go b/server/handler.go index c76597eee..ed71a5f18 100644 --- a/server/handler.go +++ b/server/handler.go @@ -113,9 +113,13 @@ func (h *Handler) ComQuery( callback func(*sqltypes.Result) error, ) (err error) { ctx := h.sm.NewContextWithQuery(c, query) - newCtx, cancel := context.WithCancel(ctx) - defer cancel() - ctx = ctx.WithContext(newCtx) + + if !h.e.Async(ctx, query) { + newCtx, cancel := context.WithCancel(ctx) + ctx = ctx.WithContext(newCtx) + + defer cancel() + } handled, err := h.handleKill(c, query) if err != nil { diff --git a/server/server.go b/server/server.go index b2bc8924c..15fdb45c1 100644 --- a/server/server.go +++ b/server/server.go @@ -1,4 +1,4 @@ -package server // import "github.com/src-d/go-mysql-server/server" +package server import ( "time" diff --git a/sql/analyzer/analyzer.go b/sql/analyzer/analyzer.go index 1066d6138..f7bc456d5 100644 --- a/sql/analyzer/analyzer.go +++ b/sql/analyzer/analyzer.go @@ -1,4 +1,4 @@ -package analyzer // import "github.com/src-d/go-mysql-server/sql/analyzer" +package analyzer import ( "os" diff --git a/sql/analyzer/assign_indexes.go b/sql/analyzer/assign_indexes.go index 4462cb915..b6bcb1327 100644 --- a/sql/analyzer/assign_indexes.go +++ b/sql/analyzer/assign_indexes.go @@ -759,8 +759,20 @@ func containsColumns(e sql.Expression) bool { return result } +func containsSubquery(e sql.Expression) bool { + var result bool + expression.Inspect(e, func(e sql.Expression) bool { + if _, ok := e.(*expression.Subquery); ok { + result = true + return false + } + return true + }) + return result +} + func isEvaluable(e sql.Expression) bool { - return !containsColumns(e) + return !containsColumns(e) && !containsSubquery(e) } func canMergeIndexes(a, b sql.IndexLookup) bool { diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index 0037d9b0d..88283cf9f 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -34,17 +34,19 @@ func optimizeDistinct(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, e a.Log("optimize distinct, node of type: %T", node) if n, ok := node.(*plan.Distinct); ok { - var isSorted bool + var sortField *expression.GetField plan.Inspect(n, func(node sql.Node) bool { a.Log("checking for optimization in node of type: %T", node) - if _, ok := node.(*plan.Sort); ok { - isSorted = true + if sort, ok := node.(*plan.Sort); ok && sortField == nil { + if col, ok := sort.SortFields[0].Column.(*expression.GetField); ok { + sortField = col + } return false } return true }) - if isSorted { + if sortField != nil && n.Schema().Contains(sortField.Name(), sortField.Table()) { a.Log("distinct optimized for ordered output") return plan.NewOrderedDistinct(n.Child), nil } diff --git a/sql/analyzer/optimization_rules_test.go b/sql/analyzer/optimization_rules_test.go index 285117251..60dee3b0f 100644 --- a/sql/analyzer/optimization_rules_test.go +++ b/sql/analyzer/optimization_rules_test.go @@ -186,24 +186,54 @@ func TestEraseProjection(t *testing.T) { } func TestOptimizeDistinct(t *testing.T) { - require := require.New(t) - - t1 := memory.NewTable("foo", nil) - t2 := memory.NewTable("foo", nil) + t1 := memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo"}, + {Name: "b", Source: "foo"}, + }) - notSorted := plan.NewDistinct(plan.NewResolvedTable(t1)) - sorted := plan.NewDistinct(plan.NewSort(nil, plan.NewResolvedTable(t2))) + testCases := []struct { + name string + child sql.Node + optimized bool + }{ + { + "without sort", + plan.NewResolvedTable(t1), + false, + }, + { + "sort but column not projected", + plan.NewSort( + []plan.SortField{ + {Column: gf(0, "foo", "c")}, + }, + plan.NewResolvedTable(t1), + ), + false, + }, + { + "sort and column projected", + plan.NewSort( + []plan.SortField{ + {Column: gf(0, "foo", "a")}, + }, + plan.NewResolvedTable(t1), + ), + true, + }, + } rule := getRule("optimize_distinct") - analyzedNotSorted, err := rule.Apply(sql.NewEmptyContext(), nil, notSorted) - require.NoError(err) - - analyzedSorted, err := rule.Apply(sql.NewEmptyContext(), nil, sorted) - require.NoError(err) + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + node, err := rule.Apply(sql.NewEmptyContext(), nil, plan.NewDistinct(tt.child)) + require.NoError(t, err) - require.Equal(notSorted, analyzedNotSorted) - require.Equal(plan.NewOrderedDistinct(sorted.Child), analyzedSorted) + _, ok := node.(*plan.OrderedDistinct) + require.Equal(t, tt.optimized, ok) + }) + } } func TestMoveJoinConditionsToFilter(t *testing.T) { diff --git a/sql/analyzer/process.go b/sql/analyzer/process.go index 926d52c83..aa0d9d0c7 100644 --- a/sql/analyzer/process.go +++ b/sql/analyzer/process.go @@ -40,20 +40,29 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { } total = count } - processList.AddProgressItem(ctx.Pid(), name, total) + processList.AddTableProgress(ctx.Pid(), name, total) seen[name] = struct{}{} - notify := func() { - processList.UpdateProgress(ctx.Pid(), name, 1) + onPartitionDone := func(partitionName string) { + processList.UpdateTableProgress(ctx.Pid(), name, 1) + processList.RemovePartitionProgress(ctx.Pid(), name, partitionName) + } + + onPartitionStart := func(partitionName string) { + processList.AddPartitionProgress(ctx.Pid(), name, partitionName, -1) + } + + onRowNext := func(partitionName string) { + processList.UpdatePartitionProgress(ctx.Pid(), name, partitionName, 1) } var t sql.Table switch table := n.Table.(type) { case sql.IndexableTable: - t = plan.NewProcessIndexableTable(table, notify) + t = plan.NewProcessIndexableTable(table, onPartitionDone, onPartitionStart, onRowNext) default: - t = plan.NewProcessTable(table, notify) + t = plan.NewProcessTable(table, onPartitionDone, onPartitionStart, onRowNext) } return plan.NewResolvedTable(t), nil @@ -85,5 +94,10 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { return nil, err } - return plan.NewQueryProcess(node, func() { processList.Done(ctx.Pid()) }), nil + return plan.NewQueryProcess(node, func() { + processList.Done(ctx.Pid()) + if span := ctx.RootSpan(); span != nil { + span.Finish() + } + }), nil } diff --git a/sql/analyzer/process_test.go b/sql/analyzer/process_test.go index 32a719500..91271af4c 100644 --- a/sql/analyzer/process_test.go +++ b/sql/analyzer/process_test.go @@ -35,10 +35,16 @@ func TestTrackProcess(t *testing.T) { require.Len(processes, 1) require.Equal("SELECT foo", processes[0].Query) require.Equal(sql.QueryProcess, processes[0].Type) - require.Equal(map[string]sql.Progress{ - "foo": sql.Progress{Total: 2}, - "bar": sql.Progress{Total: 4}, - }, processes[0].Progress) + require.Equal( + map[string]sql.TableProgress{ + "foo": sql.TableProgress{ + Progress: sql.Progress{Name: "foo", Done: 0, Total: 2}, + PartitionsProgress: map[string]sql.PartitionProgress{}}, + "bar": sql.TableProgress{ + Progress: sql.Progress{Name: "bar", Done: 0, Total: 4}, + PartitionsProgress: map[string]sql.PartitionProgress{}}, + }, + processes[0].Progress) proc, ok := result.(*plan.QueryProcess) require.True(ok) diff --git a/sql/analyzer/pushdown.go b/sql/analyzer/pushdown.go index 34312b1ac..07049c413 100644 --- a/sql/analyzer/pushdown.go +++ b/sql/analyzer/pushdown.go @@ -20,7 +20,7 @@ func pushdown(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { // don't do pushdown on certain queries switch n.(type) { - case *plan.InsertInto, *plan.DeleteFrom, *plan.CreateIndex: + case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.CreateIndex: return n, nil } diff --git a/sql/analyzer/resolve_subqueries.go b/sql/analyzer/resolve_subqueries.go index 5015253c1..20b97df3f 100644 --- a/sql/analyzer/resolve_subqueries.go +++ b/sql/analyzer/resolve_subqueries.go @@ -2,6 +2,7 @@ package analyzer import ( "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" ) @@ -10,7 +11,7 @@ func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err defer span.Finish() a.Log("resolving subqueries") - return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + n, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.SubqueryAlias: a.Log("found subquery %q with child of type %T", n.Name(), n.Child) @@ -24,4 +25,25 @@ func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err return n, nil } }) + if err != nil { + return nil, err + } + + return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) { + s, ok := e.(*expression.Subquery) + if !ok || s.Resolved() { + return e, nil + } + + q, err := a.Analyze(ctx, s.Query) + if err != nil { + return nil, err + } + + if qp, ok := q.(*plan.QueryProcess); ok { + q = qp.Child + } + + return s.WithQuery(q), nil + }) } diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index e7448d9e2..7d35ee6be 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -20,6 +20,7 @@ const ( validateCaseResultTypesRule = "validate_case_result_types" validateIntervalUsageRule = "validate_interval_usage" validateExplodeUsageRule = "validate_explode_usage" + validateSubqueryColumnsRule = "validate_subquery_columns" ) var ( @@ -57,6 +58,12 @@ var ( ErrExplodeInvalidUse = errors.NewKind( "using EXPLODE is not supported outside a Project node", ) + + // ErrSubqueryColumns is returned when an expression subquery returns + // more than a single column. + ErrSubqueryColumns = errors.NewKind( + "subquery expressions can only return a single column", + ) ) // DefaultValidationRules to apply while analyzing nodes. @@ -70,6 +77,7 @@ var DefaultValidationRules = []Rule{ {validateCaseResultTypesRule, validateCaseResultTypes}, {validateIntervalUsageRule, validateIntervalUsage}, {validateExplodeUsageRule, validateExplodeUsage}, + {validateSubqueryColumnsRule, validateSubqueryColumns}, } func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { @@ -322,6 +330,25 @@ func validateExplodeUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, return n, nil } +func validateSubqueryColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + valid := true + plan.InspectExpressions(n, func(e sql.Expression) bool { + s, ok := e.(*expression.Subquery) + if ok && len(s.Query.Schema()) != 1 { + valid = false + return false + } + + return true + }) + + if !valid { + return nil, ErrSubqueryColumns.New() + } + + return n, nil +} + func stringContains(strs []string, target string) bool { for _, s := range strs { if s == target { diff --git a/sql/analyzer/validation_rules_test.go b/sql/analyzer/validation_rules_test.go index 1af5ad53a..543fbc688 100644 --- a/sql/analyzer/validation_rules_test.go +++ b/sql/analyzer/validation_rules_test.go @@ -674,6 +674,37 @@ func TestValidateExplodeUsage(t *testing.T) { } } +func TestValidateSubqueryColumns(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + node := plan.NewProject([]sql.Expression{ + expression.NewSubquery(plan.NewProject( + []sql.Expression{ + lit(1), + lit(2), + }, + dummyNode{true}, + )), + }, dummyNode{true}) + + _, err := validateSubqueryColumns(ctx, nil, node) + require.Error(err) + require.True(ErrSubqueryColumns.Is(err)) + + node = plan.NewProject([]sql.Expression{ + expression.NewSubquery(plan.NewProject( + []sql.Expression{ + lit(1), + }, + dummyNode{true}, + )), + }, dummyNode{true}) + + _, err = validateSubqueryColumns(ctx, nil, node) + require.NoError(err) +} + type dummyNode struct{ resolved bool } func (n dummyNode) String() string { return "dummynode" } diff --git a/sql/core.go b/sql/core.go index a45a1d69f..16ef6fa8d 100644 --- a/sql/core.go +++ b/sql/core.go @@ -1,4 +1,4 @@ -package sql // import "github.com/src-d/go-mysql-server/sql" +package sql import ( "fmt" @@ -120,6 +120,12 @@ type OpaqueNode interface { Opaque() bool } +// AsyncNode is a node that can be executed asynchronously. +type AsyncNode interface { + // IsAsync reports whether the node is async or not. + IsAsync() bool +} + // Expressioner is a node that contains expressions. type Expressioner interface { // Expressions returns the list of expressions contained by the node. @@ -217,6 +223,12 @@ type Replacer interface { Inserter } +// Updater allows rows to be updated. +type Updater interface { + // Update the given row. Provides both the old and new rows. + Update(ctx *Context, old Row, new Row) error +} + // Database represents the database. type Database interface { Nameable @@ -224,9 +236,14 @@ type Database interface { Tables() map[string]Table } -// Alterable should be implemented by databases that can handle DDL statements -type Alterable interface { - Create(name string, schema Schema) error +// TableCreator should be implemented by databases that can create new tables. +type TableCreator interface { + CreateTable(ctx *Context, name string, schema Schema) error +} + +// TableDropper should be implemented by databases that can drop tables. +type TableDropper interface { + DropTable(ctx *Context, name string) error } // Lockable should be implemented by tables that can be locked and unlocked. diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 8491f9cef..87601579b 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -466,7 +466,6 @@ func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - // TODO: support subqueries switch right := in.Right().(type) { case Tuple: for _, el := range right { @@ -496,6 +495,34 @@ func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } + return false, nil + case *Subquery: + if leftElems > 1 { + return nil, ErrInvalidOperandColumns.New(leftElems, 1) + } + + typ := right.Type() + values, err := right.EvalMultiple(ctx) + if err != nil { + return nil, err + } + + for _, val := range values { + val, err = typ.Convert(val) + if err != nil { + return nil, err + } + + cmp, err := typ.Compare(left, val) + if err != nil { + return nil, err + } + + if cmp == 0 { + return true, nil + } + } + return false, nil default: return nil, ErrUnsupportedInOperand.New(right) @@ -547,7 +574,6 @@ func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - // TODO: support subqueries switch right := in.Right().(type) { case Tuple: for _, el := range right { @@ -577,6 +603,34 @@ func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } + return true, nil + case *Subquery: + if leftElems > 1 { + return nil, ErrInvalidOperandColumns.New(leftElems, 1) + } + + typ := right.Type() + values, err := right.EvalMultiple(ctx) + if err != nil { + return nil, err + } + + for _, val := range values { + val, err = typ.Convert(val) + if err != nil { + return nil, err + } + + cmp, err := typ.Compare(left, val) + if err != nil { + return nil, err + } + + if cmp == 0 { + return false, nil + } + } + return true, nil default: return nil, ErrUnsupportedInOperand.New(right) diff --git a/sql/expression/comparison_test.go b/sql/expression/comparison_test.go index c69d1fdc5..802c3676a 100644 --- a/sql/expression/comparison_test.go +++ b/sql/expression/comparison_test.go @@ -1,10 +1,13 @@ -package expression +package expression_test import ( "testing" "github.com/src-d/go-mysql-server/internal/regex" + "github.com/src-d/go-mysql-server/memory" "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" errors "gopkg.in/src-d/go-errors.v1" "github.com/stretchr/testify/require" @@ -95,11 +98,11 @@ var likeComparisonCases = map[sql.Type]map[int][][]interface{}{ func TestEquals(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := NewGetField(0, resultType, "col1", true) + get0 := expression.NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := NewGetField(1, resultType, "col2", true) + get1 := expression.NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := NewEquals(get0, get1) + eq := expression.NewEquals(get0, get1) require.NotNil(eq) require.Equal(sql.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -122,11 +125,11 @@ func TestEquals(t *testing.T) { func TestLessThan(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := NewGetField(0, resultType, "col1", true) + get0 := expression.NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := NewGetField(1, resultType, "col2", true) + get1 := expression.NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := NewLessThan(get0, get1) + eq := expression.NewLessThan(get0, get1) require.NotNil(eq) require.Equal(sql.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -149,11 +152,11 @@ func TestLessThan(t *testing.T) { func TestGreaterThan(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := NewGetField(0, resultType, "col1", true) + get0 := expression.NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := NewGetField(1, resultType, "col2", true) + get1 := expression.NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := NewGreaterThan(get0, get1) + eq := expression.NewGreaterThan(get0, get1) require.NotNil(eq) require.Equal(sql.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -185,13 +188,13 @@ func testRegexpCases(t *testing.T) { require := require.New(t) for resultType, cmpCase := range likeComparisonCases { - get0 := NewGetField(0, resultType, "col1", true) + get0 := expression.NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := NewGetField(1, resultType, "col2", true) + get1 := expression.NewGetField(1, resultType, "col2", true) require.NotNil(get1) for cmpResult, cases := range cmpCase { for _, pair := range cases { - eq := NewRegexp(get0, get1) + eq := expression.NewRegexp(get0, get1) require.NotNil(eq) require.Equal(sql.Boolean, eq.Type()) @@ -214,9 +217,9 @@ func TestInvalidRegexp(t *testing.T) { t.Helper() require := require.New(t) - col1 := NewGetField(0, sql.Text, "col1", true) - invalid := NewLiteral("*col1", sql.Text) - r := NewRegexp(col1, invalid) + col1 := expression.NewGetField(0, sql.Text, "col1", true) + invalid := expression.NewLiteral("*col1", sql.Text) + r := expression.NewRegexp(col1, invalid) row := sql.NewRow("col1") _, err := r.Eval(sql.NewEmptyContext(), row) @@ -234,10 +237,10 @@ func TestIn(t *testing.T) { }{ { "left is nil", - NewLiteral(nil, sql.Null), - NewTuple( - NewLiteral(int64(1), sql.Int64), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(nil, sql.Null), + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), ), nil, nil, @@ -245,32 +248,32 @@ func TestIn(t *testing.T) { }, { "left and right don't have the same cols", - NewLiteral(1, sql.Int64), - NewTuple( - NewTuple( - NewLiteral(int64(1), sql.Int64), - NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(1, sql.Int64), + expression.NewTuple( + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), ), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), ), nil, nil, - ErrInvalidOperandColumns, + expression.ErrInvalidOperandColumns, }, { "right is an unsupported operand", - NewLiteral(1, sql.Int64), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(1, sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), nil, nil, - ErrUnsupportedInOperand, + expression.ErrUnsupportedInOperand, }, { "left is in right", - NewGetField(0, sql.Int64, "foo", false), - NewTuple( - NewGetField(0, sql.Int64, "foo", false), - NewLiteral(int64(2), sql.Int64), + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewTuple( + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewLiteral(int64(2), sql.Int64), ), sql.NewRow(int64(1)), true, @@ -278,10 +281,10 @@ func TestIn(t *testing.T) { }, { "left is not in right", - NewGetField(0, sql.Int64, "foo", false), - NewTuple( - NewGetField(1, sql.Int64, "bar", false), - NewLiteral(int64(2), sql.Int64), + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewTuple( + expression.NewGetField(1, sql.Int64, "bar", false), + expression.NewLiteral(int64(2), sql.Int64), ), sql.NewRow(int64(1), int64(3)), false, @@ -293,7 +296,96 @@ func TestIn(t *testing.T) { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - result, err := NewIn(tt.left, tt.right).Eval(sql.NewEmptyContext(), tt.row) + result, err := expression.NewIn(tt.left, tt.right). + Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.result, result) + } + }) + } +} + +func TestInSubquery(t *testing.T) { + ctx := sql.NewEmptyContext() + table := memory.NewTable("foo", sql.Schema{ + {Name: "t", Source: "foo", Type: sql.Text}, + }) + + require.NoError(t, table.Insert(ctx, sql.Row{"one"})) + require.NoError(t, table.Insert(ctx, sql.Row{"two"})) + require.NoError(t, table.Insert(ctx, sql.Row{"three"})) + + project := func(expr sql.Expression) sql.Node { + return plan.NewProject([]sql.Expression{ + expr, + }, plan.NewResolvedTable(table)) + } + + testCases := []struct { + name string + left sql.Expression + right sql.Node + row sql.Row + result interface{} + err *errors.Kind + }{ + { + "left is nil", + expression.NewLiteral(nil, sql.Null), + project( + expression.NewLiteral(int64(1), sql.Int64), + ), + nil, + nil, + nil, + }, + { + "left and right don't have the same cols", + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), + ), + project( + expression.NewLiteral(int64(2), sql.Int64), + ), + nil, + nil, + expression.ErrInvalidOperandColumns, + }, + { + "left is in right", + expression.NewGetField(0, sql.Text, "foo", false), + project( + expression.NewGetField(0, sql.Text, "foo", false), + ), + sql.NewRow("two"), + true, + nil, + }, + { + "left is not in right", + expression.NewGetField(0, sql.Text, "foo", false), + project( + expression.NewGetField(0, sql.Text, "foo", false), + ), + sql.NewRow("four"), + false, + nil, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + result, err := expression.NewIn( + tt.left, + expression.NewSubquery(tt.right), + ).Eval(sql.NewEmptyContext(), tt.row) if tt.err != nil { require.Error(err) require.True(tt.err.Is(err)) @@ -316,10 +408,10 @@ func TestNotIn(t *testing.T) { }{ { "left is nil", - NewLiteral(nil, sql.Null), - NewTuple( - NewLiteral(int64(1), sql.Int64), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(nil, sql.Null), + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), ), nil, nil, @@ -327,32 +419,32 @@ func TestNotIn(t *testing.T) { }, { "left and right don't have the same cols", - NewLiteral(1, sql.Int64), - NewTuple( - NewTuple( - NewLiteral(int64(1), sql.Int64), - NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(1, sql.Int64), + expression.NewTuple( + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), ), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), ), nil, nil, - ErrInvalidOperandColumns, + expression.ErrInvalidOperandColumns, }, { "right is an unsupported operand", - NewLiteral(1, sql.Int64), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(1, sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), nil, nil, - ErrUnsupportedInOperand, + expression.ErrUnsupportedInOperand, }, { "left is in right", - NewGetField(0, sql.Int64, "foo", false), - NewTuple( - NewGetField(0, sql.Int64, "foo", false), - NewLiteral(int64(2), sql.Int64), + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewTuple( + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewLiteral(int64(2), sql.Int64), ), sql.NewRow(int64(1)), false, @@ -360,10 +452,10 @@ func TestNotIn(t *testing.T) { }, { "left is not in right", - NewGetField(0, sql.Int64, "foo", false), - NewTuple( - NewGetField(1, sql.Int64, "bar", false), - NewLiteral(int64(2), sql.Int64), + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewTuple( + expression.NewGetField(1, sql.Int64, "bar", false), + expression.NewLiteral(int64(2), sql.Int64), ), sql.NewRow(int64(1), int64(3)), true, @@ -375,7 +467,8 @@ func TestNotIn(t *testing.T) { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - result, err := NewNotIn(tt.left, tt.right).Eval(sql.NewEmptyContext(), tt.row) + result, err := expression.NewNotIn(tt.left, tt.right). + Eval(sql.NewEmptyContext(), tt.row) if tt.err != nil { require.Error(err) require.True(tt.err.Is(err)) @@ -386,3 +479,98 @@ func TestNotIn(t *testing.T) { }) } } + +func TestNotInSubquery(t *testing.T) { + ctx := sql.NewEmptyContext() + table := memory.NewTable("foo", sql.Schema{ + {Name: "t", Source: "foo", Type: sql.Text}, + }) + + require.NoError(t, table.Insert(ctx, sql.Row{"one"})) + require.NoError(t, table.Insert(ctx, sql.Row{"two"})) + require.NoError(t, table.Insert(ctx, sql.Row{"three"})) + + project := func(expr sql.Expression) sql.Node { + return plan.NewProject([]sql.Expression{ + expr, + }, plan.NewResolvedTable(table)) + } + + testCases := []struct { + name string + left sql.Expression + right sql.Node + row sql.Row + result interface{} + err *errors.Kind + }{ + { + "left is nil", + expression.NewLiteral(nil, sql.Null), + project( + expression.NewLiteral(int64(1), sql.Int64), + ), + nil, + nil, + nil, + }, + { + "left and right don't have the same cols", + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), + ), + project( + expression.NewLiteral(int64(2), sql.Int64), + ), + nil, + nil, + expression.ErrInvalidOperandColumns, + }, + { + "left is in right", + expression.NewGetField(0, sql.Text, "foo", false), + project( + expression.NewGetField(0, sql.Text, "foo", false), + ), + sql.NewRow("two"), + false, + nil, + }, + { + "left is not in right", + expression.NewGetField(0, sql.Text, "foo", false), + project( + expression.NewGetField(0, sql.Text, "foo", false), + ), + sql.NewRow("four"), + true, + nil, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + result, err := expression.NewNotIn( + tt.left, + expression.NewSubquery(tt.right), + ).Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.result, result) + } + }) + } +} + +func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { + t.Helper() + v, err := e.Eval(sql.NewEmptyContext(), row) + require.NoError(t, err) + return v +} diff --git a/sql/expression/doc.go b/sql/expression/doc.go deleted file mode 100644 index 820316fe6..000000000 --- a/sql/expression/doc.go +++ /dev/null @@ -1 +0,0 @@ -package expression // import "github.com/src-d/go-mysql-server/sql/expression" diff --git a/sql/expression/function/aggregation/avg.go b/sql/expression/function/aggregation/avg.go index 5e6b2ff49..eae1c8c0e 100644 --- a/sql/expression/function/aggregation/avg.go +++ b/sql/expression/function/aggregation/avg.go @@ -1,4 +1,4 @@ -package aggregation // import "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" +package aggregation import ( "fmt" diff --git a/sql/expression/function/arraylength.go b/sql/expression/function/arraylength.go index 00d10cfd2..61a902cd9 100644 --- a/sql/expression/function/arraylength.go +++ b/sql/expression/function/arraylength.go @@ -1,4 +1,4 @@ -package function // import "github.com/src-d/go-mysql-server/sql/expression/function" +package function import ( "fmt" diff --git a/sql/expression/function/ceil_round_floor.go b/sql/expression/function/ceil_round_floor.go index 4c8c1d757..c056f82da 100644 --- a/sql/expression/function/ceil_round_floor.go +++ b/sql/expression/function/ceil_round_floor.go @@ -203,6 +203,18 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { dVal = float64(dNum) case int32: dVal = float64(dNum) + case int16: + dVal = float64(dNum) + case int8: + dVal = float64(dNum) + case uint64: + dVal = float64(dNum) + case uint32: + dVal = float64(dNum) + case uint16: + dVal = float64(dNum) + case uint8: + dVal = float64(dNum) case int: dVal = float64(dNum) default: @@ -233,6 +245,18 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return int64(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil case int32: return int32(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case int16: + return int16(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case int8: + return int8(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint64: + return uint64(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint32: + return uint32(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint16: + return uint16(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint8: + return uint8(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil case int: return int(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil default: diff --git a/sql/expression/function/ceil_round_floor_test.go b/sql/expression/function/ceil_round_floor_test.go index 2ad014c42..4af2456ef 100644 --- a/sql/expression/function/ceil_round_floor_test.go +++ b/sql/expression/function/ceil_round_floor_test.go @@ -156,6 +156,50 @@ func TestRound(t *testing.T) { {"int32 with float d", sql.Int32, sql.Float64, sql.NewRow(int32(5), float32(2.123)), int32(5), nil}, {"int32 with float negative d", sql.Int32, sql.Float64, sql.NewRow(int32(52), float32(-1)), int32(50), nil}, {"int32 with blob d", sql.Int32, sql.Blob, sql.NewRow(int32(5), []byte{1, 2, 3}), int32(5), nil}, + {"int16 is nil", sql.Int16, sql.Int16, sql.NewRow(nil, nil), nil, nil}, + {"int16 without d", sql.Int16, sql.Int16, sql.NewRow(int16(5), nil), int16(5), nil}, + {"int16 with d", sql.Int16, sql.Int16, sql.NewRow(int16(5), 2), int16(5), nil}, + {"int16 with negative d", sql.Int16, sql.Int16, sql.NewRow(int16(52), -1), int16(50), nil}, + {"int16 with float d", sql.Int16, sql.Float64, sql.NewRow(int16(5), float32(2.123)), int16(5), nil}, + {"int16 with float negative d", sql.Int16, sql.Float64, sql.NewRow(int16(52), float32(-1)), int16(50), nil}, + {"int16 with blob d", sql.Int16, sql.Blob, sql.NewRow(int16(5), []byte{1, 2, 3}), int16(5), nil}, + {"int8 is nil", sql.Int8, sql.Int8, sql.NewRow(nil, nil), nil, nil}, + {"int8 without d", sql.Int8, sql.Int8, sql.NewRow(int8(5), nil), int8(5), nil}, + {"int8 with d", sql.Int8, sql.Int8, sql.NewRow(int8(5), 2), int8(5), nil}, + {"int8 with negative d", sql.Int8, sql.Int8, sql.NewRow(int8(52), -1), int8(50), nil}, + {"int8 with float d", sql.Int8, sql.Float64, sql.NewRow(int8(5), float32(2.123)), int8(5), nil}, + {"int8 with float negative d", sql.Int8, sql.Float64, sql.NewRow(int8(52), float32(-1)), int8(50), nil}, + {"int8 with blob d", sql.Int8, sql.Blob, sql.NewRow(int8(5), []byte{1, 2, 3}), int8(5), nil}, + {"uint64 is nil", sql.Uint64, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"uint64 without d", sql.Uint64, sql.Int32, sql.NewRow(uint64(5), nil), uint64(5), nil}, + {"uint64 with d", sql.Uint64, sql.Int32, sql.NewRow(uint64(5), 2), uint64(5), nil}, + {"uint64 with negative d", sql.Uint64, sql.Int32, sql.NewRow(uint64(52), -1), uint64(50), nil}, + {"uint64 with float d", sql.Uint64, sql.Float64, sql.NewRow(uint64(5), float32(2.123)), uint64(5), nil}, + {"uint64 with float negative d", sql.Uint64, sql.Float64, sql.NewRow(uint64(52), float32(-1)), uint64(50), nil}, + {"uint32 with blob d", sql.Uint32, sql.Blob, sql.NewRow(uint32(5), []byte{1, 2, 3}), uint32(5), nil}, + {"uint32 is nil", sql.Uint32, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"uint32 without d", sql.Uint32, sql.Int32, sql.NewRow(uint32(5), nil), uint32(5), nil}, + {"uint32 with d", sql.Uint32, sql.Int32, sql.NewRow(uint32(5), 2), uint32(5), nil}, + {"uint32 with negative d", sql.Uint32, sql.Int32, sql.NewRow(uint32(52), -1), uint32(50), nil}, + {"uint32 with float d", sql.Uint32, sql.Float64, sql.NewRow(uint32(5), float32(2.123)), uint32(5), nil}, + {"uint32 with float negative d", sql.Uint32, sql.Float64, sql.NewRow(uint32(52), float32(-1)), uint32(50), nil}, + {"uint32 with blob d", sql.Uint32, sql.Blob, sql.NewRow(uint32(5), []byte{1, 2, 3}), uint32(5), nil}, + {"uint16 with blob d", sql.Uint16, sql.Blob, sql.NewRow(uint16(5), []byte{1, 2, 3}), uint16(5), nil}, + {"uint16 is nil", sql.Uint16, sql.Int16, sql.NewRow(nil, nil), nil, nil}, + {"uint16 without d", sql.Uint16, sql.Int16, sql.NewRow(uint16(5), nil), uint16(5), nil}, + {"uint16 with d", sql.Uint16, sql.Int16, sql.NewRow(uint16(5), 2), uint16(5), nil}, + {"uint16 with negative d", sql.Uint16, sql.Int16, sql.NewRow(uint16(52), -1), uint16(50), nil}, + {"uint16 with float d", sql.Uint16, sql.Float64, sql.NewRow(uint16(5), float32(2.123)), uint16(5), nil}, + {"uint16 with float negative d", sql.Uint16, sql.Float64, sql.NewRow(uint16(52), float32(-1)), uint16(50), nil}, + {"uint16 with blob d", sql.Uint16, sql.Blob, sql.NewRow(uint16(5), []byte{1, 2, 3}), uint16(5), nil}, + {"uint8 with blob d", sql.Uint8, sql.Blob, sql.NewRow(uint8(5), []byte{1, 2, 3}), uint8(5), nil}, + {"uint8 is nil", sql.Uint8, sql.Int8, sql.NewRow(nil, nil), nil, nil}, + {"uint8 without d", sql.Uint8, sql.Int8, sql.NewRow(uint8(5), nil), uint8(5), nil}, + {"uint8 with d", sql.Uint8, sql.Int8, sql.NewRow(uint8(5), 2), uint8(5), nil}, + {"uint8 with negative d", sql.Uint8, sql.Int8, sql.NewRow(uint8(52), -1), uint8(50), nil}, + {"uint8 with float d", sql.Uint8, sql.Float64, sql.NewRow(uint8(5), float32(2.123)), uint8(5), nil}, + {"uint8 with float negative d", sql.Uint8, sql.Float64, sql.NewRow(uint8(52), float32(-1)), uint8(50), nil}, + {"uint8 with blob d", sql.Uint8, sql.Blob, sql.NewRow(uint8(5), []byte{1, 2, 3}), uint8(5), nil}, {"blob is nil", sql.Blob, sql.Int32, sql.NewRow(nil, nil), nil, nil}, {"blob is ok", sql.Blob, sql.Int32, sql.NewRow([]byte{1, 2, 3}, nil), int32(0), nil}, {"text int without d", sql.Text, sql.Int32, sql.NewRow("5", nil), int32(5), nil}, diff --git a/sql/expression/function/time.go b/sql/expression/function/time.go index 2385aec0d..c0dcaf3d4 100644 --- a/sql/expression/function/time.go +++ b/sql/expression/function/time.go @@ -363,8 +363,10 @@ func (d *YearWeek) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } if val != nil { - if mode, ok = val.(int64); ok { - mode %= 8 // mode in [0, 7] + if i64, err := sql.Int64.Convert(val); err == nil { + if mode, ok = i64.(int64); ok { + mode %= 8 // mode in [0, 7] + } } } yyyy, week := calcWeek(yyyy, mm, dd, weekMode(mode)|weekBehaviourYear) diff --git a/sql/expression/set.go b/sql/expression/set.go new file mode 100644 index 000000000..d18bde374 --- /dev/null +++ b/sql/expression/set.go @@ -0,0 +1,61 @@ +package expression + +import ( + "fmt" + "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" +) + +var errCannotSetField = errors.NewKind("Expected GetField expression on left but got %T") + +// SetField updates the value of a field from a row. +type SetField struct { + BinaryExpression +} + +// NewSetField creates a new SetField expression. +func NewSetField(colName, expr sql.Expression) sql.Expression { + return &SetField{BinaryExpression{Left: colName, Right: expr}} +} + +func (s *SetField) String() string { + return fmt.Sprintf("SETFIELD %s = %s", s.Left, s.Right) +} + +// Type implements the Expression interface. +func (s *SetField) Type() sql.Type { + return s.Left.Type() +} + +// Eval implements the Expression interface. +// Returns a copy of the given row with an updated value. +func (s *SetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + getField, ok := s.Left.(*GetField) + if !ok { + return nil, errCannotSetField.New(s.Left) + } + if getField.fieldIndex < 0 || getField.fieldIndex >= len(row) { + return nil, ErrIndexOutOfBounds.New(getField.fieldIndex, len(row)) + } + val, err := s.Right.Eval(ctx, row) + if err != nil { + return nil, err + } + if val != nil { + val, err = getField.fieldType.Convert(val) + if err != nil { + return nil, err + } + } + updatedRow := row.Copy() + updatedRow[getField.fieldIndex] = val + return updatedRow, nil +} + +// WithChildren implements the Expression interface. +func (s *SetField) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 2) + } + return NewSetField(children[0], children[1]), nil +} \ No newline at end of file diff --git a/sql/expression/subquery.go b/sql/expression/subquery.go new file mode 100644 index 000000000..faae15aa1 --- /dev/null +++ b/sql/expression/subquery.go @@ -0,0 +1,125 @@ +package expression + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" +) + +var errExpectedSingleRow = errors.NewKind("the subquery returned more than 1 row") + +// Subquery that is executed as an expression. +type Subquery struct { + Query sql.Node + value interface{} +} + +// NewSubquery returns a new subquery node. +func NewSubquery(node sql.Node) *Subquery { + return &Subquery{node, nil} +} + +// Eval implements the Expression interface. +func (s *Subquery) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) { + if s.value != nil { + if elems, ok := s.value.([]interface{}); ok { + if len(elems) > 1 { + return nil, errExpectedSingleRow.New() + } + return elems[0], nil + } + return s.value, nil + } + + iter, err := s.Query.RowIter(ctx) + if err != nil { + return nil, err + } + + rows, err := sql.RowIterToRows(iter) + if err != nil { + return nil, err + } + + if len(rows) == 0 { + s.value = nil + return nil, nil + } + + if len(rows) > 1 { + return nil, errExpectedSingleRow.New() + } + + s.value = rows[0][0] + return s.value, nil +} + +// EvalMultiple returns all rows returned by a subquery. +func (s *Subquery) EvalMultiple(ctx *sql.Context) ([]interface{}, error) { + if s.value != nil { + return s.value.([]interface{}), nil + } + + iter, err := s.Query.RowIter(ctx) + if err != nil { + return nil, err + } + + rows, err := sql.RowIterToRows(iter) + if err != nil { + return nil, err + } + + if len(rows) == 0 { + s.value = []interface{}{} + return nil, nil + } + + var result = make([]interface{}, len(rows)) + for i, row := range rows { + result[i] = row[0] + } + s.value = result + + return result, nil +} + +// IsNullable implements the Expression interface. +func (s *Subquery) IsNullable() bool { + return s.Query.Schema()[0].Nullable +} + +func (s *Subquery) String() string { + return fmt.Sprintf("(%s)", s.Query) +} + +// Resolved implements the Expression interface. +func (s *Subquery) Resolved() bool { + return s.Query.Resolved() +} + +// Type implements the Expression interface. +func (s *Subquery) Type() sql.Type { + return s.Query.Schema()[0].Type +} + +// WithChildren implements the Expression interface. +func (s *Subquery) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } + return s, nil +} + +// Children implements the Expression interface. +func (s *Subquery) Children() []sql.Expression { + return nil +} + +// WithQuery returns the subquery with the query node changed. +func (s *Subquery) WithQuery(node sql.Node) *Subquery { + ns := *s + ns.Query = node + return &ns +} diff --git a/sql/expression/subquery_test.go b/sql/expression/subquery_test.go new file mode 100644 index 000000000..0d3f353be --- /dev/null +++ b/sql/expression/subquery_test.go @@ -0,0 +1,69 @@ +package expression_test + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestSubquery(t *testing.T) { + require := require.New(t) + table := memory.NewTable("", nil) + require.NoError(table.Insert(sql.NewEmptyContext(), nil)) + + subquery := expression.NewSubquery(plan.NewProject( + []sql.Expression{ + expression.NewLiteral("one", sql.Text), + }, + plan.NewResolvedTable(table), + )) + + value, err := subquery.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal(value, "one") +} + +func TestSubqueryTooManyRows(t *testing.T) { + require := require.New(t) + table := memory.NewTable("", nil) + require.NoError(table.Insert(sql.NewEmptyContext(), nil)) + require.NoError(table.Insert(sql.NewEmptyContext(), nil)) + + subquery := expression.NewSubquery(plan.NewProject( + []sql.Expression{ + expression.NewLiteral("one", sql.Text), + }, + plan.NewResolvedTable(table), + )) + + _, err := subquery.Eval(sql.NewEmptyContext(), nil) + require.Error(err) +} + +func TestSubqueryMultipleRows(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + table := memory.NewTable("foo", sql.Schema{ + {Name: "t", Source: "foo", Type: sql.Text}, + }) + + require.NoError(table.Insert(ctx, sql.Row{"one"})) + require.NoError(table.Insert(ctx, sql.Row{"two"})) + require.NoError(table.Insert(ctx, sql.Row{"three"})) + + subquery := expression.NewSubquery(plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Text, "t", false), + }, + plan.NewResolvedTable(table), + )) + + values, err := subquery.EvalMultiple(ctx) + require.NoError(err) + require.Equal(values, []interface{}{"one", "two", "three"}) +} diff --git a/sql/expression/transform.go b/sql/expression/transform.go index ccf3d4276..05195c7fd 100644 --- a/sql/expression/transform.go +++ b/sql/expression/transform.go @@ -1,6 +1,8 @@ package expression -import "github.com/src-d/go-mysql-server/sql" +import ( + "github.com/src-d/go-mysql-server/sql" +) // TransformUp applies a transformation function to the given expression from the // bottom up. diff --git a/sql/index/config.go b/sql/index/config.go index dcf846216..7d5b7c9fd 100644 --- a/sql/index/config.go +++ b/sql/index/config.go @@ -24,7 +24,6 @@ func NewConfig( driverID string, driverConfig map[string]string, ) *Config { - cfg := &Config{ DB: db, Table: table, diff --git a/sql/index/pilosa/driver.go b/sql/index/pilosa/driver.go index a9b30cf04..4f9146b21 100644 --- a/sql/index/pilosa/driver.go +++ b/sql/index/pilosa/driver.go @@ -110,6 +110,8 @@ func (*Driver) ID() string { return DriverID } +var errWriteConfigFile = errors.NewKind("unable to write indexes configuration file") + // Create a new index. func (d *Driver) Create( db, table, id string, @@ -133,7 +135,7 @@ func (d *Driver) Create( cfg := index.NewConfig(db, table, id, exprs, d.ID(), config) err = index.WriteConfigFile(d.configFilePath(db, table, id), cfg) if err != nil { - return nil, err + return nil, errWriteConfigFile.Wrap(err) } idx, err := d.newPilosaIndex(db, table) @@ -146,12 +148,14 @@ func (d *Driver) Create( processingFile, []byte{processingFileOnCreate}, ); err != nil { - return nil, err + return nil, errWriteConfigFile.Wrap(err) } return newPilosaIndex(idx, cfg), nil } +var errReadIndexes = errors.NewKind("error loading all indexes for table %s of database %s: %s") + // LoadAll loads all indexes for given db and table func (d *Driver) LoadAll(db, table string) ([]sql.Index, error) { var ( @@ -165,7 +169,7 @@ func (d *Driver) LoadAll(db, table string) ([]sql.Index, error) { if os.IsNotExist(err) { return indexes, nil } - return nil, err + return nil, errReadIndexes.New(table, db, err) } for _, info := range dirs { if info.IsDir() && !strings.HasPrefix(info.Name(), ".") { @@ -187,6 +191,11 @@ func (d *Driver) LoadAll(db, table string) ([]sql.Index, error) { return indexes, nil } +var ( + errLoadingIndexConfig = errors.NewKind("unable to load index configuration") + errReadIndexConfig = errors.NewKind("unable to read index configuration") +) + func (d *Driver) loadIndex(db, table, id string) (*pilosaIndex, error) { idx, err := d.newPilosaIndex(db, table) if err != nil { @@ -206,7 +215,7 @@ func (d *Driver) loadIndex(db, table, id string) (*pilosaIndex, error) { processing := d.processingFilePath(db, table, id) ok, err := index.ExistsProcessingFile(processing) if err != nil { - return nil, err + return nil, errLoadingIndexConfig.Wrap(err) } if ok { log := logrus.WithFields(logrus.Fields{ @@ -226,7 +235,7 @@ func (d *Driver) loadIndex(db, table, id string) (*pilosaIndex, error) { cfg, err := index.ReadConfigFile(config) if err != nil { - return nil, err + return nil, errReadIndexConfig.Wrap(err) } cfgDriver := cfg.Driver(DriverID) if cfgDriver == nil { @@ -382,13 +391,13 @@ func (d *Driver) Save( []byte{processingFileOnSave}, ) if err != nil { - return err + return errWriteConfigFile.Wrap(err) } cfgPath := d.configFilePath(i.Database(), i.Table(), i.ID()) cfg, err := index.ReadConfigFile(cfgPath) if err != nil { - return err + return errReadIndexConfig.Wrap(err) } driverCfg := cfg.Driver(DriverID) @@ -462,7 +471,7 @@ func (d *Driver) Save( return errors[0] } if err = index.WriteConfigFile(cfgPath, cfg); err != nil { - return err + return errWriteConfigFile.Wrap(err) } observeIndex(time.Since(start), timePilosa, timeMapping, rows) diff --git a/sql/index/pilosa/lookup.go b/sql/index/pilosa/lookup.go index e50aeeabd..29173921b 100644 --- a/sql/index/pilosa/lookup.go +++ b/sql/index/pilosa/lookup.go @@ -621,6 +621,14 @@ func decodeGob(k []byte, value interface{}) (interface{}, error) { var v string err := decoder.Decode(&v) return v, err + case int8: + var v int8 + err := decoder.Decode(&v) + return v, err + case int16: + var v int16 + err := decoder.Decode(&v) + return v, err case int32: var v int32 err := decoder.Decode(&v) @@ -629,6 +637,14 @@ func decodeGob(k []byte, value interface{}) (interface{}, error) { var v int64 err := decoder.Decode(&v) return v, err + case uint8: + var v uint8 + err := decoder.Decode(&v) + return v, err + case uint16: + var v uint16 + err := decoder.Decode(&v) + return v, err case uint32: var v uint32 err := decoder.Decode(&v) @@ -688,6 +704,36 @@ func compare(a, b interface{}) (int, error) { } return strings.Compare(a, v), nil + case int8: + v, ok := b.(int8) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case int16: + v, ok := b.(int16) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil case int32: v, ok := b.(int32) if !ok { @@ -717,6 +763,36 @@ func compare(a, b interface{}) (int, error) { return -1, nil } + return 1, nil + case uint8: + v, ok := b.(uint8) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case uint16: + v, ok := b.(uint16) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + return 1, nil case uint32: v, ok := b.(uint32) diff --git a/sql/index/pilosa/lookup_test.go b/sql/index/pilosa/lookup_test.go index e93a3e587..b93da3ca2 100644 --- a/sql/index/pilosa/lookup_test.go +++ b/sql/index/pilosa/lookup_test.go @@ -98,8 +98,12 @@ func TestCompare(t *testing.T) { func TestDecodeGob(t *testing.T) { testCases := []interface{}{ "foo", + int8(1), + int16(1), int32(1), int64(1), + uint8(1), + uint16(1), uint32(1), uint64(1), float64(1), diff --git a/sql/information_schema.go b/sql/information_schema.go index 0a304786a..48147c7d5 100644 --- a/sql/information_schema.go +++ b/sql/information_schema.go @@ -1,9 +1,10 @@ -package sql // import "github.com/src-d/go-mysql-server/sql" +package sql import ( "bytes" "fmt" "io" + "strings" ) const ( @@ -214,27 +215,27 @@ func columnsRowIter(cat *Catalog) RowIter { collName = "utf8_bin" } rows = append(rows, Row{ - "def", // table_catalog - db.Name(), // table_schema - t.Name(), // table_name - c.Name, // column_name - uint64(i), // ordinal_position - c.Default, // column_default - nullable, // is_nullable - MySQLTypeName(c.Type), // data_type - nil, // character_maximum_length - nil, // character_octet_length - nil, // numeric_precision - nil, // numeric_scale - nil, // datetime_precision - charName, // character_set_name - collName, // collation_name - MySQLTypeName(c.Type), // column_type - "", // column_key - "", // extra - "select", // privileges - "", // column_comment - "", // generation_expression + "def", // table_catalog + db.Name(), // table_schema + t.Name(), // table_name + c.Name, // column_name + uint64(i), // ordinal_position + c.Default, // column_default + nullable, // is_nullable + strings.ToLower(MySQLTypeName(c.Type)), // data_type + nil, // character_maximum_length + nil, // character_octet_length + nil, // numeric_precision + nil, // numeric_scale + nil, // datetime_precision + charName, // character_set_name + collName, // collation_name + strings.ToLower(MySQLTypeName(c.Type)), // column_type + "", // column_key + "", // extra + "select", // privileges + "", // column_comment + "", // generation_expression }) } } diff --git a/sql/parse/indexes.go b/sql/parse/indexes.go index 4a0ac4e0e..a435f7888 100644 --- a/sql/parse/indexes.go +++ b/sql/parse/indexes.go @@ -39,7 +39,7 @@ func parseShowIndex(s string) (sql.Node, error) { ), nil } -func parseCreateIndex(s string) (sql.Node, error) { +func parseCreateIndex(ctx *sql.Context, s string) (sql.Node, error) { r := bufio.NewReader(strings.NewReader(s)) var name, table, driver string @@ -78,7 +78,7 @@ func parseCreateIndex(s string) (sql.Node, error) { var indexExprs = make([]sql.Expression, len(exprs)) for i, e := range exprs { var err error - indexExprs[i], err = parseExpr(e) + indexExprs[i], err = parseExpr(ctx, e) if err != nil { return nil, err } diff --git a/sql/parse/indexes_test.go b/sql/parse/indexes_test.go index 4072b7a5d..a9d7df5a4 100644 --- a/sql/parse/indexes_test.go +++ b/sql/parse/indexes_test.go @@ -159,7 +159,7 @@ func TestParseCreateIndex(t *testing.T) { t.Run(tt.query, func(t *testing.T) { require := require.New(t) - result, err := parseCreateIndex(strings.ToLower(tt.query)) + result, err := parseCreateIndex(sql.NewEmptyContext(), strings.ToLower(tt.query)) if tt.err != nil { require.Error(err) require.True(tt.err.Is(err)) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 72a125ecd..f4dfe6f0a 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -1,4 +1,4 @@ -package parse // import "github.com/src-d/go-mysql-server/sql/parse" +package parse import ( "bufio" @@ -26,9 +26,6 @@ var ( // ErrUnsupportedFeature is thrown when a feature is not already supported ErrUnsupportedFeature = errors.NewKind("unsupported feature: %s") - // ErrUnsupportedSubqueryExpression is thrown because subqueries are not supported, yet. - ErrUnsupportedSubqueryExpression = errors.NewKind("unsupported subquery expression") - // ErrInvalidSQLValType is returned when a SQLVal type is not valid. ErrInvalidSQLValType = errors.NewKind("invalid SQLVal of type: %d") @@ -50,6 +47,17 @@ var ( unlockTablesRegex = regexp.MustCompile(`^unlock\s+tables$`) lockTablesRegex = regexp.MustCompile(`^lock\s+tables\s`) setRegex = regexp.MustCompile(`^set\s+`) + createViewRegex = regexp.MustCompile(`^create\s+view\s+`) +) + +// These constants aren't exported from vitess for some reason. This could be removed if we changed this. +const ( + colKeyNone sqlparser.ColumnKeyOption = iota + colKeyPrimary + colKeySpatialKey + colKeyUnique + colKeyUniqueKey + colKey ) // Parse parses the given SQL sentence and returns the corresponding node. @@ -73,7 +81,7 @@ func Parse(ctx *sql.Context, query string) (sql.Node, error) { case describeTablesRegex.MatchString(lowerQuery): return parseDescribeTables(lowerQuery) case createIndexRegex.MatchString(lowerQuery): - return parseCreateIndex(s) + return parseCreateIndex(ctx, s) case dropIndexRegex.MatchString(lowerQuery): return parseDropIndex(s) case showIndexRegex.MatchString(lowerQuery): @@ -85,7 +93,7 @@ func Parse(ctx *sql.Context, query string) (sql.Node, error) { case showWarningsRegex.MatchString(lowerQuery): return parseShowWarnings(ctx, s) case showCollationRegex.MatchString(lowerQuery): - return parseShowCollation(s) + return parseShowCollation(ctx, s) case describeRegex.MatchString(lowerQuery): return parseDescribeQuery(ctx, s) case fullProcessListRegex.MatchString(lowerQuery): @@ -96,6 +104,9 @@ func Parse(ctx *sql.Context, query string) (sql.Node, error) { return parseLockTables(ctx, s) case setRegex.MatchString(lowerQuery): s = fixSetQuery(s) + case createViewRegex.MatchString(lowerQuery): + // CREATE VIEW parses as a CREATE DDL statement with an empty table spec + return nil, ErrUnsupportedFeature.New("CREATE VIEW") } stmt, err := sqlparser.Parse(s) @@ -135,13 +146,24 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node default: return nil, ErrUnsupportedSyntax.New(n) case *sqlparser.Show: - return convertShow(n, query) + // When a query is empty it means it comes from a subquery, as we don't + // have the query itself in a subquery. Hence, a SHOW could not be + // parsed. + if query == "" { + return nil, ErrUnsupportedFeature.New("SHOW in subquery") + } + return convertShow(ctx, n, query) case *sqlparser.Select: return convertSelect(ctx, n) case *sqlparser.Insert: return convertInsert(ctx, n) case *sqlparser.DDL: - return convertDDL(n) + // unlike other statements, DDL statements have loose parsing by default + ddl, err := sqlparser.ParseStrictDDL(query) + if err != nil { + return nil, err + } + return convertDDL(ddl.(*sqlparser.DDL)) case *sqlparser.Set: return convertSet(ctx, n) case *sqlparser.Use: @@ -150,6 +172,8 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node return plan.NewRollback(), nil case *sqlparser.Delete: return convertDelete(ctx, n) + case *sqlparser.Update: + return convertUpdate(ctx, n) } } @@ -165,7 +189,7 @@ func convertSet(ctx *sql.Context, n *sqlparser.Set) (sql.Node, error) { var variables = make([]plan.SetVariable, len(n.Exprs)) for i, e := range n.Exprs { - expr, err := exprToExpression(e.Expr) + expr, err := exprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -215,7 +239,7 @@ func convertSet(ctx *sql.Context, n *sqlparser.Set) (sql.Node, error) { return plan.NewSet(variables...), nil } -func convertShow(s *sqlparser.Show, query string) (sql.Node, error) { +func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, error) { switch s.Type { case sqlparser.KeywordString(sqlparser.TABLES): var dbName string @@ -228,7 +252,7 @@ func convertShow(s *sqlparser.Show, query string) (sql.Node, error) { if s.ShowTablesOpt.Filter != nil { if s.ShowTablesOpt.Filter.Filter != nil { var err error - filter, err = exprToExpression(s.ShowTablesOpt.Filter.Filter) + filter, err = exprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) if err != nil { return nil, err } @@ -270,7 +294,7 @@ func convertShow(s *sqlparser.Show, query string) (sql.Node, error) { } if s.ShowTablesOpt.Filter.Filter != nil { - filter, err := exprToExpression(s.ShowTablesOpt.Filter.Filter) + filter, err := exprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) if err != nil { return nil, err } @@ -281,7 +305,7 @@ func convertShow(s *sqlparser.Show, query string) (sql.Node, error) { return node, nil case sqlparser.KeywordString(sqlparser.TABLE): - return parseShowTableStatus(query) + return parseShowTableStatus(ctx, query) default: unsupportedShow := fmt.Sprintf("SHOW %s", s.Type) return nil, ErrUnsupportedFeature.New(unsupportedShow) @@ -295,19 +319,19 @@ func convertSelect(ctx *sql.Context, s *sqlparser.Select) (sql.Node, error) { } if s.Where != nil { - node, err = whereToFilter(s.Where, node) + node, err = whereToFilter(ctx, s.Where, node) if err != nil { return nil, err } } - node, err = selectToProjectOrGroupBy(s.SelectExprs, s.GroupBy, node) + node, err = selectToProjectOrGroupBy(ctx, s.SelectExprs, s.GroupBy, node) if err != nil { return nil, err } if s.Having != nil { - node, err = havingToHaving(s.Having, node) + node, err = havingToHaving(ctx, s.Having, node) if err != nil { return nil, err } @@ -318,7 +342,7 @@ func convertSelect(ctx *sql.Context, s *sqlparser.Select) (sql.Node, error) { } if len(s.OrderBy) != 0 { - node, err = orderByToSort(s.OrderBy, node) + node, err = orderByToSort(ctx, s.OrderBy, node) if err != nil { return nil, err } @@ -349,13 +373,23 @@ func convertDDL(c *sqlparser.DDL) (sql.Node, error) { switch c.Action { case sqlparser.CreateStr: return convertCreateTable(c) + case sqlparser.DropStr: + return convertDropTable(c) default: return nil, ErrUnsupportedSyntax.New(c) } } +func convertDropTable(c *sqlparser.DDL) (sql.Node, error) { + tableNames := make([]string, len(c.FromTables)) + for i, t := range c.FromTables { + tableNames[i] = t.Name.String() + } + return plan.NewDropTable(sql.UnresolvedDatabase(""), c.IfExists, tableNames...), nil +} + func convertCreateTable(c *sqlparser.DDL) (sql.Node, error) { - schema, err := columnDefinitionToSchema(c.TableSpec.Columns) + schema, err := tableSpecToSchema(c.TableSpec) if err != nil { return nil, err } @@ -395,14 +429,14 @@ func convertDelete(ctx *sql.Context, d *sqlparser.Delete) (sql.Node, error) { } if d.Where != nil { - node, err = whereToFilter(d.Where, node) + node, err = whereToFilter(ctx, d.Where, node) if err != nil { return nil, err } } if len(d.OrderBy) != 0 { - node, err = orderByToSort(d.OrderBy, node) + node, err = orderByToSort(ctx, d.OrderBy, node) if err != nil { return nil, err } @@ -426,27 +460,100 @@ func convertDelete(ctx *sql.Context, d *sqlparser.Delete) (sql.Node, error) { return plan.NewDeleteFrom(node), nil } -func columnDefinitionToSchema(colDef []*sqlparser.ColumnDefinition) (sql.Schema, error) { +func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) { + node, err := tableExprsToTable(ctx, d.TableExprs) + if err != nil { + return nil, err + } + + updateExprs, err := updateExprsToExpressions(ctx, d.Exprs) + if err != nil { + return nil, err + } + + if d.Where != nil { + node, err = whereToFilter(ctx, d.Where, node) + if err != nil { + return nil, err + } + } + + if len(d.OrderBy) != 0 { + node, err = orderByToSort(ctx, d.OrderBy, node) + if err != nil { + return nil, err + } + } + + // Limit must wrap offset, and not vice-versa, so that skipped rows don't count toward the returned row count. + if d.Limit != nil && d.Limit.Offset != nil { + node, err = offsetToOffset(ctx, d.Limit.Offset, node) + if err != nil { + return nil, err + } + + } + + if d.Limit != nil { + node, err = limitToLimit(ctx, d.Limit.Rowcount, node) + if err != nil { + return nil, err + } + } + + return plan.NewUpdate(node, updateExprs), nil +} + +func tableSpecToSchema(tableSpec *sqlparser.TableSpec) (sql.Schema, error) { var schema sql.Schema - for _, cd := range colDef { - typ := cd.Type - internalTyp, err := sql.MysqlTypeToType(typ.SQLType()) + for _, cd := range tableSpec.Columns { + column, err := getColumn(cd, tableSpec.Indexes) if err != nil { return nil, err } - schema = append(schema, &sql.Column{ - Nullable: !bool(typ.NotNull), - Type: internalTyp, - Name: cd.Name.String(), - // TODO - Default: nil, - }) + schema = append(schema, column) } return schema, nil } +// getColumn returns the sql.Column for the column definition given, as part of a create table statement. +func getColumn(cd *sqlparser.ColumnDefinition, indexes []*sqlparser.IndexDefinition) (*sql.Column, error) { + typ := cd.Type + internalTyp, err := sql.MysqlTypeToType(typ.SQLType()) + if err != nil { + return nil, err + } + + // Primary key info can either be specified in the column's type info (for in-line declarations), or in a slice of + // indexes attached to the table def. We have to check both places to find if a column is part of the primary key + isPkey := cd.Type.KeyOpt == colKeyPrimary + + if !isPkey { + OuterLoop: + for _, index := range indexes { + if index.Info.Primary { + for _, indexCol := range index.Columns { + if indexCol.Column.Equal(cd.Name) { + isPkey = true + break OuterLoop + } + } + } + } + } + + return &sql.Column{ + Nullable: !bool(typ.NotNull), + Type: internalTyp, + Name: cd.Name.String(), + PrimaryKey: isPkey, + // TODO + Default: nil, + }, nil +} + func columnsToStrings(cols sqlparser.Columns) []string { res := make([]string, len(cols)) for i, c := range cols { @@ -463,19 +570,19 @@ func insertRowsToNode(ctx *sql.Context, ir sqlparser.InsertRows) (sql.Node, erro case *sqlparser.Union: return nil, ErrUnsupportedFeature.New("UNION") case sqlparser.Values: - return valuesToValues(v) + return valuesToValues(ctx, v) default: return nil, ErrUnsupportedSyntax.New(ir) } } -func valuesToValues(v sqlparser.Values) (sql.Node, error) { +func valuesToValues(ctx *sql.Context, v sqlparser.Values) (sql.Node, error) { exprTuples := make([][]sql.Expression, len(v)) for i, vt := range v { exprs := make([]sql.Expression, len(vt)) exprTuples[i] = exprs for j, e := range vt { - expr, err := exprToExpression(e) + expr, err := exprToExpression(ctx, e) if err != nil { return nil, err } @@ -573,7 +680,7 @@ func tableExprToTable( return nil, ErrUnsupportedSyntax.New("missed ON clause for JOIN statement") } - cond, err := exprToExpression(t.Condition.On) + cond, err := exprToExpression(ctx, t.Condition.On) if err != nil { return nil, err } @@ -591,8 +698,8 @@ func tableExprToTable( } } -func whereToFilter(w *sqlparser.Where, child sql.Node) (*plan.Filter, error) { - c, err := exprToExpression(w.Expr) +func whereToFilter(ctx *sql.Context, w *sqlparser.Where, child sql.Node) (*plan.Filter, error) { + c, err := exprToExpression(ctx, w.Expr) if err != nil { return nil, err } @@ -600,10 +707,10 @@ func whereToFilter(w *sqlparser.Where, child sql.Node) (*plan.Filter, error) { return plan.NewFilter(c, child), nil } -func orderByToSort(ob sqlparser.OrderBy, child sql.Node) (*plan.Sort, error) { +func orderByToSort(ctx *sql.Context, ob sqlparser.OrderBy, child sql.Node) (*plan.Sort, error) { var sortFields []plan.SortField for _, o := range ob { - e, err := exprToExpression(o.Expr) + e, err := exprToExpression(ctx, o.Expr) if err != nil { return nil, err } @@ -642,8 +749,8 @@ func limitToLimit( return plan.NewLimit(rowCount, child), nil } -func havingToHaving(having *sqlparser.Where, node sql.Node) (sql.Node, error) { - cond, err := exprToExpression(having.Expr) +func havingToHaving(ctx *sql.Context, having *sqlparser.Where, node sql.Node) (sql.Node, error) { + cond, err := exprToExpression(ctx, having.Expr) if err != nil { return nil, err } @@ -670,24 +777,30 @@ func offsetToOffset( // getInt64Literal returns an int64 *expression.Literal for the value given, or an unsupported error with the string // given if the expression doesn't represent an integer literal. -func getInt64Literal(expr sqlparser.Expr, errStr string) (*expression.Literal, error) { - e, err := exprToExpression(expr) +func getInt64Literal(ctx *sql.Context, expr sqlparser.Expr, errStr string) (*expression.Literal, error) { + e, err := exprToExpression(ctx, expr) if err != nil { return nil, err } nl, ok := e.(*expression.Literal) - if !ok || nl.Type() != sql.Int64 { + if !ok || !sql.IsInteger(nl.Type()) { return nil, ErrUnsupportedFeature.New(errStr) } else { - return nl, nil + i64, err := sql.Int64.Convert(nl.Value()) + if err != nil { + return nil, ErrUnsupportedFeature.New(errStr) + } + return expression.NewLiteral(i64, sql.Int64), nil } + + return nl, nil } // getInt64Value returns the int64 literal value in the expression given, or an error with the errStr given if it // cannot. func getInt64Value(ctx *sql.Context, expr sqlparser.Expr, errStr string) (int64, error) { - ie, err := getInt64Literal(expr, errStr) + ie, err := getInt64Literal(ctx, expr, errStr) if err != nil { return 0, err } @@ -715,8 +828,13 @@ func isAggregate(e sql.Expression) bool { return isAgg } -func selectToProjectOrGroupBy(se sqlparser.SelectExprs, g sqlparser.GroupBy, child sql.Node) (sql.Node, error) { - selectExprs, err := selectExprsToExpressions(se) +func selectToProjectOrGroupBy( + ctx *sql.Context, + se sqlparser.SelectExprs, + g sqlparser.GroupBy, + child sql.Node, +) (sql.Node, error) { + selectExprs, err := selectExprsToExpressions(ctx, se) if err != nil { return nil, err } @@ -732,7 +850,7 @@ func selectToProjectOrGroupBy(se sqlparser.SelectExprs, g sqlparser.GroupBy, chi } if isAgg { - groupingExprs, err := groupByToExpressions(g) + groupingExprs, err := groupByToExpressions(ctx, g) if err != nil { return nil, err } @@ -741,12 +859,14 @@ func selectToProjectOrGroupBy(se sqlparser.SelectExprs, g sqlparser.GroupBy, chi for i, ge := range groupingExprs { // if GROUP BY index if l, ok := ge.(*expression.Literal); ok && sql.IsNumber(l.Type()) { - if idx, ok := l.Value().(int64); ok && idx > 0 && idx <= agglen { - aggexpr := selectExprs[idx-1] - if alias, ok := aggexpr.(*expression.Alias); ok { - aggexpr = expression.NewUnresolvedColumn(alias.Name()) + if i64, err := sql.Int64.Convert(l.Value()); err == nil { + if idx, ok := i64.(int64); ok && idx > 0 && idx <= agglen { + aggexpr := selectExprs[idx-1] + if alias, ok := aggexpr.(*expression.Alias); ok { + aggexpr = expression.NewUnresolvedColumn(alias.Name()) + } + groupingExprs[i] = aggexpr } - groupingExprs[i] = aggexpr } } } @@ -757,10 +877,10 @@ func selectToProjectOrGroupBy(se sqlparser.SelectExprs, g sqlparser.GroupBy, chi return plan.NewProject(selectExprs, child), nil } -func selectExprsToExpressions(se sqlparser.SelectExprs) ([]sql.Expression, error) { +func selectExprsToExpressions(ctx *sql.Context, se sqlparser.SelectExprs) ([]sql.Expression, error) { var exprs []sql.Expression for _, e := range se { - pe, err := selectExprToExpression(e) + pe, err := selectExprToExpression(ctx, e) if err != nil { return nil, err } @@ -771,7 +891,7 @@ func selectExprsToExpressions(se sqlparser.SelectExprs) ([]sql.Expression, error return exprs, nil } -func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { +func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error) { switch v := e.(type) { default: return nil, ErrUnsupportedSyntax.New(e) @@ -783,14 +903,14 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { err error ) if v.Name != nil { - name, err = exprToExpression(v.Name) + name, err = exprToExpression(ctx, v.Name) } else { - name, err = exprToExpression(v.StrVal) + name, err = exprToExpression(ctx, v.StrVal) } if err != nil { return nil, err } - from, err := exprToExpression(v.From) + from, err := exprToExpression(ctx, v.From) if err != nil { return nil, err } @@ -798,17 +918,17 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { if v.To == nil { return function.NewSubstring(name, from) } - to, err := exprToExpression(v.To) + to, err := exprToExpression(ctx, v.To) if err != nil { return nil, err } return function.NewSubstring(name, from, to) case *sqlparser.ComparisonExpr: - return comparisonExprToExpression(v) + return comparisonExprToExpression(ctx, v) case *sqlparser.IsExpr: - return isExprToExpression(v) + return isExprToExpression(ctx, v) case *sqlparser.NotExpr: - c, err := exprToExpression(v.Expr) + c, err := exprToExpression(ctx, v.Expr) if err != nil { return nil, err } @@ -829,7 +949,7 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { } return expression.NewUnresolvedColumn(v.Name.String()), nil case *sqlparser.FuncExpr: - exprs, err := selectExprsToExpressions(v.Exprs) + exprs, err := selectExprsToExpressions(ctx, v.Exprs) if err != nil { return nil, err } @@ -849,50 +969,50 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { return expression.NewUnresolvedFunction(v.Name.Lowered(), isAggregateFunc(v), exprs...), nil case *sqlparser.ParenExpr: - return exprToExpression(v.Expr) + return exprToExpression(ctx, v.Expr) case *sqlparser.AndExpr: - lhs, err := exprToExpression(v.Left) + lhs, err := exprToExpression(ctx, v.Left) if err != nil { return nil, err } - rhs, err := exprToExpression(v.Right) + rhs, err := exprToExpression(ctx, v.Right) if err != nil { return nil, err } return expression.NewAnd(lhs, rhs), nil case *sqlparser.OrExpr: - lhs, err := exprToExpression(v.Left) + lhs, err := exprToExpression(ctx, v.Left) if err != nil { return nil, err } - rhs, err := exprToExpression(v.Right) + rhs, err := exprToExpression(ctx, v.Right) if err != nil { return nil, err } return expression.NewOr(lhs, rhs), nil case *sqlparser.ConvertExpr: - expr, err := exprToExpression(v.Expr) + expr, err := exprToExpression(ctx, v.Expr) if err != nil { return nil, err } return expression.NewConvert(expr, v.Type.Type), nil case *sqlparser.RangeCond: - val, err := exprToExpression(v.Left) + val, err := exprToExpression(ctx, v.Left) if err != nil { return nil, err } - lower, err := exprToExpression(v.From) + lower, err := exprToExpression(ctx, v.From) if err != nil { return nil, err } - upper, err := exprToExpression(v.To) + upper, err := exprToExpression(ctx, v.To) if err != nil { return nil, err } @@ -908,7 +1028,7 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { case sqlparser.ValTuple: var exprs = make([]sql.Expression, len(v)) for i, e := range v { - expr, err := exprToExpression(e) + expr, err := exprToExpression(ctx, e) if err != nil { return nil, err } @@ -917,15 +1037,19 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { return expression.NewTuple(exprs...), nil case *sqlparser.BinaryExpr: - return binaryExprToExpression(v) + return binaryExprToExpression(ctx, v) case *sqlparser.UnaryExpr: - return unaryExprToExpression(v) + return unaryExprToExpression(ctx, v) case *sqlparser.Subquery: - return nil, ErrUnsupportedSubqueryExpression.New() + node, err := convert(ctx, v.Select, "") + if err != nil { + return nil, err + } + return expression.NewSubquery(node), nil case *sqlparser.CaseExpr: - return caseExprToExpression(v) + return caseExprToExpression(ctx, v) case *sqlparser.IntervalExpr: - return intervalExprToExpression(v) + return intervalExprToExpression(ctx, v) } } @@ -938,22 +1062,46 @@ func isAggregateFunc(v *sqlparser.FuncExpr) bool { return v.IsAggregate() } +// Convert an integer, represented by the specified string in the specified +// base, to its smallest representation possible, out of: +// int8, uint8, int16, uint16, int32, uint32, int64 and uint64 +func convertInt(value string, base int) (sql.Expression, error) { + if i8, err := strconv.ParseInt(value, base, 8); err == nil { + return expression.NewLiteral(int8(i8), sql.Int8), nil + } + if ui8, err := strconv.ParseUint(value, base, 8); err == nil { + return expression.NewLiteral(uint8(ui8), sql.Uint8), nil + } + if i16, err := strconv.ParseInt(value, base, 16); err == nil { + return expression.NewLiteral(int16(i16), sql.Int16), nil + } + if ui16, err := strconv.ParseUint(value, base, 16); err == nil { + return expression.NewLiteral(uint16(ui16), sql.Uint16), nil + } + if i32, err := strconv.ParseInt(value, base, 32); err == nil { + return expression.NewLiteral(int32(i32), sql.Int32), nil + } + if ui32, err := strconv.ParseUint(value, base, 32); err == nil { + return expression.NewLiteral(uint32(ui32), sql.Uint32), nil + } + if i64, err := strconv.ParseInt(value, base, 64); err == nil { + return expression.NewLiteral(int64(i64), sql.Int64), nil + } + + ui64, err := strconv.ParseUint(value, base, 64) + if err != nil { + return nil, err + } + + return expression.NewLiteral(uint64(ui64), sql.Uint64), nil +} + func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { switch v.Type { case sqlparser.StrVal: return expression.NewLiteral(string(v.Val), sql.Text), nil case sqlparser.IntVal: - //TODO: Use smallest integer representation and widen later. - val, err := strconv.ParseInt(string(v.Val), 10, 64) - if err != nil { - // Might be a uint64 value that is greater than int64 max - val, checkErr := strconv.ParseUint(string(v.Val), 10, 64) - if checkErr != nil { - return nil, err - } - return expression.NewLiteral(val, sql.Uint64), nil - } - return expression.NewLiteral(val, sql.Int64), nil + return convertInt(string(v.Val), 10) case sqlparser.FloatVal: val, err := strconv.ParseFloat(string(v.Val), 64) if err != nil { @@ -968,11 +1116,7 @@ func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { v = strings.Trim(v[1:], "'") } - val, err := strconv.ParseInt(v, 16, 64) - if err != nil { - return nil, err - } - return expression.NewLiteral(val, sql.Int64), nil + return convertInt(v, 16) case sqlparser.HexVal: val, err := v.HexDecode() if err != nil { @@ -988,8 +1132,8 @@ func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { return nil, ErrInvalidSQLValType.New(v.Type) } -func isExprToExpression(c *sqlparser.IsExpr) (sql.Expression, error) { - e, err := exprToExpression(c.Expr) +func isExprToExpression(ctx *sql.Context, c *sqlparser.IsExpr) (sql.Expression, error) { + e, err := exprToExpression(ctx, c.Expr) if err != nil { return nil, err } @@ -1012,13 +1156,13 @@ func isExprToExpression(c *sqlparser.IsExpr) (sql.Expression, error) { } } -func comparisonExprToExpression(c *sqlparser.ComparisonExpr) (sql.Expression, error) { - left, err := exprToExpression(c.Left) +func comparisonExprToExpression(ctx *sql.Context, c *sqlparser.ComparisonExpr) (sql.Expression, error) { + left, err := exprToExpression(ctx, c.Left) if err != nil { return nil, err } - right, err := exprToExpression(c.Right) + right, err := exprToExpression(ctx, c.Right) if err != nil { return nil, err } @@ -1055,10 +1199,10 @@ func comparisonExprToExpression(c *sqlparser.ComparisonExpr) (sql.Expression, er } } -func groupByToExpressions(g sqlparser.GroupBy) ([]sql.Expression, error) { +func groupByToExpressions(ctx *sql.Context, g sqlparser.GroupBy) ([]sql.Expression, error) { es := make([]sql.Expression, len(g)) for i, ve := range g { - e, err := exprToExpression(ve) + e, err := exprToExpression(ctx, ve) if err != nil { return nil, err } @@ -1069,7 +1213,7 @@ func groupByToExpressions(g sqlparser.GroupBy) ([]sql.Expression, error) { return es, nil } -func selectExprToExpression(se sqlparser.SelectExpr) (sql.Expression, error) { +func selectExprToExpression(ctx *sql.Context, se sqlparser.SelectExpr) (sql.Expression, error) { switch e := se.(type) { default: return nil, ErrUnsupportedSyntax.New(e) @@ -1079,7 +1223,7 @@ func selectExprToExpression(se sqlparser.SelectExpr) (sql.Expression, error) { } return expression.NewQualifiedStar(e.TableName.Name.String()), nil case *sqlparser.AliasedExpr: - expr, err := exprToExpression(e.Expr) + expr, err := exprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -1093,10 +1237,10 @@ func selectExprToExpression(se sqlparser.SelectExpr) (sql.Expression, error) { } } -func unaryExprToExpression(e *sqlparser.UnaryExpr) (sql.Expression, error) { +func unaryExprToExpression(ctx *sql.Context, e *sqlparser.UnaryExpr) (sql.Expression, error) { switch e.Operator { case sqlparser.MinusStr: - expr, err := exprToExpression(e.Expr) + expr, err := exprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -1108,7 +1252,7 @@ func unaryExprToExpression(e *sqlparser.UnaryExpr) (sql.Expression, error) { } } -func binaryExprToExpression(be *sqlparser.BinaryExpr) (sql.Expression, error) { +func binaryExprToExpression(ctx *sql.Context, be *sqlparser.BinaryExpr) (sql.Expression, error) { switch be.Operator { case sqlparser.PlusStr, @@ -1123,12 +1267,12 @@ func binaryExprToExpression(be *sqlparser.BinaryExpr) (sql.Expression, error) { sqlparser.IntDivStr, sqlparser.ModStr: - l, err := exprToExpression(be.Left) + l, err := exprToExpression(ctx, be.Left) if err != nil { return nil, err } - r, err := exprToExpression(be.Right) + r, err := exprToExpression(ctx, be.Right) if err != nil { return nil, err } @@ -1150,12 +1294,12 @@ func binaryExprToExpression(be *sqlparser.BinaryExpr) (sql.Expression, error) { } } -func caseExprToExpression(e *sqlparser.CaseExpr) (sql.Expression, error) { +func caseExprToExpression(ctx *sql.Context, e *sqlparser.CaseExpr) (sql.Expression, error) { var expr sql.Expression var err error if e.Expr != nil { - expr, err = exprToExpression(e.Expr) + expr, err = exprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -1164,13 +1308,13 @@ func caseExprToExpression(e *sqlparser.CaseExpr) (sql.Expression, error) { var branches []expression.CaseBranch for _, w := range e.Whens { var cond sql.Expression - cond, err = exprToExpression(w.Cond) + cond, err = exprToExpression(ctx, w.Cond) if err != nil { return nil, err } var val sql.Expression - val, err = exprToExpression(w.Val) + val, err = exprToExpression(ctx, w.Val) if err != nil { return nil, err } @@ -1183,7 +1327,7 @@ func caseExprToExpression(e *sqlparser.CaseExpr) (sql.Expression, error) { var elseExpr sql.Expression if e.Else != nil { - elseExpr, err = exprToExpression(e.Else) + elseExpr, err = exprToExpression(ctx, e.Else) if err != nil { return nil, err } @@ -1192,8 +1336,8 @@ func caseExprToExpression(e *sqlparser.CaseExpr) (sql.Expression, error) { return expression.NewCase(expr, branches, elseExpr), nil } -func intervalExprToExpression(e *sqlparser.IntervalExpr) (sql.Expression, error) { - expr, err := exprToExpression(e.Expr) +func intervalExprToExpression(ctx *sql.Context, e *sqlparser.IntervalExpr) (sql.Expression, error) { + expr, err := exprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -1201,6 +1345,22 @@ func intervalExprToExpression(e *sqlparser.IntervalExpr) (sql.Expression, error) return expression.NewInterval(expr, e.Unit), nil } +func updateExprsToExpressions(ctx *sql.Context, e sqlparser.UpdateExprs) ([]sql.Expression, error) { + res := make([]sql.Expression, len(e)) + for i, updateExpr := range e { + colName, err := exprToExpression(ctx, updateExpr.Name) + if err != nil { + return nil, err + } + innerExpr, err := exprToExpression(ctx, updateExpr.Expr) + if err != nil { + return nil, err + } + res[i] = expression.NewSetField(colName, innerExpr) + } + return res, nil +} + func removeComments(s string) string { r := bufio.NewReader(strings.NewReader(s)) var result []rune @@ -1300,7 +1460,7 @@ func readString(r *bufio.Reader, single bool) []rune { return result } -func parseShowTableStatus(query string) (sql.Node, error) { +func parseShowTableStatus(ctx *sql.Context, query string) (sql.Node, error) { buf := bufio.NewReader(strings.NewReader(query)) err := parseFuncs{ expect("show"), @@ -1342,7 +1502,7 @@ func parseShowTableStatus(query string) (sql.Node, error) { return nil, err } - expr, err := parseExpr(string(bs)) + expr, err := parseExpr(ctx, string(bs)) if err != nil { return nil, err } @@ -1366,7 +1526,7 @@ func parseShowTableStatus(query string) (sql.Node, error) { } } -func parseShowCollation(query string) (sql.Node, error) { +func parseShowCollation(ctx *sql.Context, query string) (sql.Node, error) { buf := bufio.NewReader(strings.NewReader(query)) err := parseFuncs{ expect("show"), @@ -1399,7 +1559,7 @@ func parseShowCollation(query string) (sql.Node, error) { return nil, err } - expr, err := parseExpr(string(bs)) + expr, err := parseExpr(ctx, string(bs)) if err != nil { return nil, err } diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 6e74dd243..b66dca943 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -1,6 +1,7 @@ package parse import ( + "math" "testing" "github.com/src-d/go-mysql-server/sql/expression" @@ -50,6 +51,60 @@ var fixtures = map[string]sql.Node{ Nullable: true, }}, ), + `CREATE TABLE t1(a INTEGER NOT NULL PRIMARY KEY, b TEXT)`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: false, + }}, + ), + `CREATE TABLE t1(a INTEGER, b TEXT, PRIMARY KEY (a))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: true, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: false, + }}, + ), + `CREATE TABLE t1(a INTEGER, b TEXT, PRIMARY KEY (a, b))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: true, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: true, + }}, + ), + `DROP TABLE foo;`: plan.NewDropTable( + sql.UnresolvedDatabase(""), false, "foo", + ), + `DROP TABLE IF EXISTS foo;`: plan.NewDropTable( + sql.UnresolvedDatabase(""), true, "foo", + ), + `DROP TABLE IF EXISTS foo, bar, baz;`: plan.NewDropTable( + sql.UnresolvedDatabase(""), true, "foo", "bar", "baz", + ), `DESCRIBE TABLE foo;`: plan.NewDescribe( plan.NewUnresolvedTable("foo", ""), ), @@ -181,7 +236,7 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewEquals( expression.NewUnresolvedColumn("qux"), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), ), plan.NewUnresolvedTable("foo", ""), ), @@ -258,7 +313,7 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("t1", ""), plan.NewValues([][]sql.Expression{{ expression.NewLiteral("a", sql.Text), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), }}), false, []string{"col1", "col2"}, @@ -267,7 +322,7 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("t1", ""), plan.NewValues([][]sql.Expression{{ expression.NewLiteral("a", sql.Text), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), }}), true, []string{"col1", "col2"}, @@ -360,7 +415,7 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewEquals( expression.NewUnresolvedColumn("a"), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), ), plan.NewUnresolvedTable("foo", ""), ), @@ -432,9 +487,9 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewNot( expression.NewBetween( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(2), sql.Int64), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(2), sql.Int8), + expression.NewLiteral(int8(5), sql.Int8), ), ), plan.NewUnresolvedTable("foo", ""), @@ -444,16 +499,16 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewStar()}, plan.NewFilter( expression.NewBetween( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(2), sql.Int64), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(2), sql.Int8), + expression.NewLiteral(int8(5), sql.Int8), ), plan.NewUnresolvedTable("foo", ""), ), ), `SELECT 0x01AF`: plan.NewProject( []sql.Expression{ - expression.NewLiteral(int64(431), sql.Int64), + expression.NewLiteral(int16(431), sql.Int16), }, plan.NewUnresolvedTable("dual", ""), ), @@ -470,12 +525,12 @@ var fixtures = map[string]sql.Node{ "somefunc", false, expression.NewTuple( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(2), sql.Int8), ), expression.NewTuple( - expression.NewLiteral(int64(3), sql.Int64), - expression.NewLiteral(int64(4), sql.Int64), + expression.NewLiteral(int8(3), sql.Int8), + expression.NewLiteral(int8(4), sql.Int8), ), ), plan.NewUnresolvedTable("b", ""), @@ -486,7 +541,7 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewEquals( expression.NewLiteral(":foo_id", sql.Text), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), plan.NewUnresolvedTable("foo", ""), ), @@ -510,13 +565,13 @@ var fixtures = map[string]sql.Node{ ), `SELECT CAST(-3 AS UNSIGNED) FROM foo`: plan.NewProject( []sql.Expression{ - expression.NewConvert(expression.NewLiteral(int64(-3), sql.Int64), expression.ConvertToUnsigned), + expression.NewConvert(expression.NewLiteral(int8(-3), sql.Int8), expression.ConvertToUnsigned), }, plan.NewUnresolvedTable("foo", ""), ), `SELECT 2 = 2 FROM foo`: plan.NewProject( []sql.Expression{ - expression.NewEquals(expression.NewLiteral(int64(2), sql.Int64), expression.NewLiteral(int64(2), sql.Int64)), + expression.NewEquals(expression.NewLiteral(int8(2), sql.Int8), expression.NewLiteral(int8(2), sql.Int8)), }, plan.NewUnresolvedTable("foo", ""), ), @@ -560,10 +615,10 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewStar()}, plan.NewFilter( expression.NewIn( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), expression.NewTuple( expression.NewLiteral("1", sql.Text), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), ), plan.NewUnresolvedTable("foo", ""), @@ -573,10 +628,10 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewStar()}, plan.NewFilter( expression.NewNotIn( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), expression.NewTuple( expression.NewLiteral("1", sql.Text), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), ), plan.NewUnresolvedTable("foo", ""), @@ -585,12 +640,12 @@ var fixtures = map[string]sql.Node{ `SELECT a, b FROM t ORDER BY 2, 1`: plan.NewSort( []plan.SortField{ { - Column: expression.NewLiteral(int64(2), sql.Int64), + Column: expression.NewLiteral(int8(2), sql.Int8), Order: plan.Ascending, NullOrdering: plan.NullsFirst, }, { - Column: expression.NewLiteral(int64(1), sql.Int64), + Column: expression.NewLiteral(int8(1), sql.Int8), Order: plan.Ascending, NullOrdering: plan.NullsFirst, }, @@ -605,22 +660,22 @@ var fixtures = map[string]sql.Node{ ), `SELECT 1 + 1;`: plan.NewProject( []sql.Expression{ - expression.NewPlus(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(1), sql.Int64)), + expression.NewPlus(expression.NewLiteral(int8(1), sql.Int8), expression.NewLiteral(int8(1), sql.Int8)), }, plan.NewUnresolvedTable("dual", ""), ), `SELECT 1 * (2 + 1);`: plan.NewProject( []sql.Expression{ - expression.NewMult(expression.NewLiteral(int64(1), sql.Int64), - expression.NewPlus(expression.NewLiteral(int64(2), sql.Int64), expression.NewLiteral(int64(1), sql.Int64))), + expression.NewMult(expression.NewLiteral(int8(1), sql.Int8), + expression.NewPlus(expression.NewLiteral(int8(2), sql.Int8), expression.NewLiteral(int8(1), sql.Int8))), }, plan.NewUnresolvedTable("dual", ""), ), `SELECT (0 - 1) * (1 | 1);`: plan.NewProject( []sql.Expression{ expression.NewMult( - expression.NewMinus(expression.NewLiteral(int64(0), sql.Int64), expression.NewLiteral(int64(1), sql.Int64)), - expression.NewBitOr(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(1), sql.Int64)), + expression.NewMinus(expression.NewLiteral(int8(0), sql.Int8), expression.NewLiteral(int8(1), sql.Int8)), + expression.NewBitOr(expression.NewLiteral(int8(1), sql.Int8), expression.NewLiteral(int8(1), sql.Int8)), ), }, plan.NewUnresolvedTable("dual", ""), @@ -628,8 +683,8 @@ var fixtures = map[string]sql.Node{ `SELECT (1 << 3) % (2 div 1);`: plan.NewProject( []sql.Expression{ expression.NewMod( - expression.NewShiftLeft(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(3), sql.Int64)), - expression.NewIntDiv(expression.NewLiteral(int64(2), sql.Int64), expression.NewLiteral(int64(1), sql.Int64))), + expression.NewShiftLeft(expression.NewLiteral(int8(1), sql.Int8), expression.NewLiteral(int8(3), sql.Int8)), + expression.NewIntDiv(expression.NewLiteral(int8(2), sql.Int8), expression.NewLiteral(int8(1), sql.Int8))), }, plan.NewUnresolvedTable("dual", ""), ), @@ -645,7 +700,7 @@ var fixtures = map[string]sql.Node{ `SELECT '1.0' + 2;`: plan.NewProject( []sql.Expression{ expression.NewPlus( - expression.NewLiteral("1.0", sql.Text), expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral("1.0", sql.Text), expression.NewLiteral(int8(2), sql.Int8), ), }, plan.NewUnresolvedTable("dual", ""), @@ -714,7 +769,7 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedFunction( "max", true, expression.NewUnresolvedColumn("i"), ), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), "/", ), }, @@ -742,7 +797,7 @@ var fixtures = map[string]sql.Node{ `SET autocommit=1, foo="bar"`: plan.NewSet( plan.SetVariable{ Name: "autocommit", - Value: expression.NewLiteral(int64(1), sql.Int64), + Value: expression.NewLiteral(int8(1), sql.Int8), }, plan.SetVariable{ Name: "foo", @@ -752,7 +807,7 @@ var fixtures = map[string]sql.Node{ `SET @@session.autocommit=1, foo="bar"`: plan.NewSet( plan.SetVariable{ Name: "@@session.autocommit", - Value: expression.NewLiteral(int64(1), sql.Int64), + Value: expression.NewLiteral(int8(1), sql.Int8), }, plan.SetVariable{ Name: "foo", @@ -818,11 +873,11 @@ var fixtures = map[string]sql.Node{ `SET SESSION NET_READ_TIMEOUT= 700, SESSION NET_WRITE_TIMEOUT= 700`: plan.NewSet( plan.SetVariable{ Name: "@@session.net_read_timeout", - Value: expression.NewLiteral(int64(700), sql.Int64), + Value: expression.NewLiteral(int16(700), sql.Int16), }, plan.SetVariable{ Name: "@@session.net_write_timeout", - Value: expression.NewLiteral(int64(700), sql.Int64), + Value: expression.NewLiteral(int16(700), sql.Int16), }, ), `SET gtid_mode=DEFAULT`: plan.NewSet( @@ -975,11 +1030,11 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("foo"), []expression.CaseBranch{ { - Cond: expression.NewLiteral(int64(1), sql.Int64), + Cond: expression.NewLiteral(int8(1), sql.Int8), Value: expression.NewLiteral("foo", sql.Text), }, { - Cond: expression.NewLiteral(int64(2), sql.Int64), + Cond: expression.NewLiteral(int8(2), sql.Int8), Value: expression.NewLiteral("bar", sql.Text), }, }, @@ -992,11 +1047,11 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("foo"), []expression.CaseBranch{ { - Cond: expression.NewLiteral(int64(1), sql.Int64), + Cond: expression.NewLiteral(int8(1), sql.Int8), Value: expression.NewLiteral("foo", sql.Text), }, { - Cond: expression.NewLiteral(int64(2), sql.Int64), + Cond: expression.NewLiteral(int8(2), sql.Int8), Value: expression.NewLiteral("bar", sql.Text), }, }, @@ -1011,14 +1066,14 @@ var fixtures = map[string]sql.Node{ { Cond: expression.NewEquals( expression.NewUnresolvedColumn("foo"), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), ), Value: expression.NewLiteral("foo", sql.Text), }, { Cond: expression.NewEquals( expression.NewUnresolvedColumn("foo"), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), Value: expression.NewLiteral("bar", sql.Text), }, @@ -1055,7 +1110,7 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewArithmetic( expression.NewLiteral("2018-05-01", sql.Text), expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), "+", @@ -1066,7 +1121,7 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewArithmetic( expression.NewLiteral("2018-05-01", sql.Text), expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), "-", @@ -1076,7 +1131,7 @@ var fixtures = map[string]sql.Node{ `SELECT INTERVAL 1 DAY + '2018-05-01'`: plan.NewProject( []sql.Expression{expression.NewArithmetic( expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), expression.NewLiteral("2018-05-01", sql.Text), @@ -1089,13 +1144,13 @@ var fixtures = map[string]sql.Node{ expression.NewArithmetic( expression.NewLiteral("2018-05-01", sql.Text), expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), "+", ), expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), "+", @@ -1105,7 +1160,7 @@ var fixtures = map[string]sql.Node{ `SELECT COUNT(*) FROM foo GROUP BY a HAVING COUNT(*) > 5`: plan.NewHaving( expression.NewGreaterThan( expression.NewUnresolvedFunction("count", true, expression.NewStar()), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(5), sql.Int8), ), plan.NewGroupBy( []sql.Expression{expression.NewUnresolvedFunction("count", true, expression.NewStar())}, @@ -1117,7 +1172,7 @@ var fixtures = map[string]sql.Node{ plan.NewHaving( expression.NewGreaterThan( expression.NewUnresolvedFunction("count", true, expression.NewStar()), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(5), sql.Int8), ), plan.NewGroupBy( []sql.Expression{expression.NewUnresolvedFunction("count", true, expression.NewStar())}, @@ -1132,8 +1187,8 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewUnresolvedTable("bar", ""), expression.NewEquals( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), ), ), ), @@ -1143,8 +1198,8 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewUnresolvedTable("bar", ""), expression.NewEquals( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), ), ), ), @@ -1154,8 +1209,8 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewUnresolvedTable("bar", ""), expression.NewEquals( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), ), ), ), @@ -1165,8 +1220,8 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewUnresolvedTable("bar", ""), expression.NewEquals( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), ), ), ), @@ -1191,6 +1246,23 @@ var fixtures = map[string]sql.Node{ []sql.Expression{}, plan.NewUnresolvedTable("foo", ""), ), + `SELECT -128, 127, 255, -32768, 32767, 65535, -2147483648, 2147483647, 4294967295, -9223372036854775808, 9223372036854775807, 18446744073709551615`: plan.NewProject( + []sql.Expression{ + expression.NewLiteral(int8(math.MinInt8), sql.Int8), + expression.NewLiteral(int8(math.MaxInt8), sql.Int8), + expression.NewLiteral(uint8(math.MaxUint8), sql.Uint8), + expression.NewLiteral(int16(math.MinInt16), sql.Int16), + expression.NewLiteral(int16(math.MaxInt16), sql.Int16), + expression.NewLiteral(uint16(math.MaxUint16), sql.Uint16), + expression.NewLiteral(int32(math.MinInt32), sql.Int32), + expression.NewLiteral(int32(math.MaxInt32), sql.Int32), + expression.NewLiteral(uint32(math.MaxUint32), sql.Uint32), + expression.NewLiteral(int64(math.MinInt64), sql.Int64), + expression.NewLiteral(int64(math.MaxInt64), sql.Int64), + expression.NewLiteral(uint64(math.MaxUint64), sql.Uint64), + }, + plan.NewUnresolvedTable("dual", ""), + ), } func TestParse(t *testing.T) { @@ -1208,12 +1280,11 @@ func TestParse(t *testing.T) { } var fixturesErrors = map[string]*errors.Kind{ - `SHOW METHEMONEY`: ErrUnsupportedFeature, - `LOCK TABLES foo AS READ`: errUnexpectedSyntax, - `LOCK TABLES foo LOW_PRIORITY READ`: errUnexpectedSyntax, - `SELECT * FROM mytable WHERE i IN (SELECT i FROM foo)`: ErrUnsupportedSubqueryExpression, - `SELECT * FROM mytable LIMIT -100`: ErrUnsupportedSyntax, - `SELECT * FROM mytable LIMIT 100 OFFSET -1`: ErrUnsupportedSyntax, + `SHOW METHEMONEY`: ErrUnsupportedFeature, + `LOCK TABLES foo AS READ`: errUnexpectedSyntax, + `LOCK TABLES foo LOW_PRIORITY READ`: errUnexpectedSyntax, + `SELECT * FROM mytable LIMIT -100`: ErrUnsupportedSyntax, + `SELECT * FROM mytable LIMIT 100 OFFSET -1`: ErrUnsupportedSyntax, `SELECT * FROM files JOIN commit_files JOIN refs @@ -1225,6 +1296,7 @@ var fixturesErrors = map[string]*errors.Kind{ `SELECT INTERVAL 1 DAY + INTERVAL 1 DAY`: ErrUnsupportedSyntax, `SELECT '2018-05-01' + (INTERVAL 1 DAY + INTERVAL 1 DAY)`: ErrUnsupportedSyntax, `SELECT AVG(DISTINCT foo) FROM b`: ErrUnsupportedSyntax, + `CREATE VIEW view1 AS SELECT x FROM t1 WHERE x>0`: ErrUnsupportedFeature, } func TestParseErrors(t *testing.T) { diff --git a/sql/parse/util.go b/sql/parse/util.go index a99e7a76c..bfb358b1d 100644 --- a/sql/parse/util.go +++ b/sql/parse/util.go @@ -251,7 +251,7 @@ func readRemaining(val *string) parseFunc { } } -func parseExpr(str string) (sql.Expression, error) { +func parseExpr(ctx *sql.Context, str string) (sql.Expression, error) { stmt, err := sqlparser.Parse("SELECT " + str) if err != nil { return nil, err @@ -271,7 +271,7 @@ func parseExpr(str string) (sql.Expression, error) { return nil, errInvalidIndexExpression.New(str) } - return exprToExpression(selectExpr.Expr) + return exprToExpression(ctx, selectExpr.Expr) } func readQuotableIdent(ident *string) parseFunc { diff --git a/sql/plan/common.go b/sql/plan/common.go index beec46177..28e0dd935 100644 --- a/sql/plan/common.go +++ b/sql/plan/common.go @@ -1,4 +1,4 @@ -package plan // import "github.com/src-d/go-mysql-server/sql/plan" +package plan import "github.com/src-d/go-mysql-server/sql" diff --git a/sql/plan/create_index.go b/sql/plan/create_index.go index f45de2f7e..00455a127 100644 --- a/sql/plan/create_index.go +++ b/sql/plan/create_index.go @@ -282,6 +282,11 @@ func (c *CreateIndex) WithChildren(children ...sql.Node) (sql.Node, error) { return &nc, nil } +// IsAsync implements the AsyncNode interface. +func (c *CreateIndex) IsAsync() bool { + return c.Async +} + // getColumnsAndPrepareExpressions extracts the unique columns required by all // those expressions and fixes the indexes of the GetFields in the expressions // to match a row with only the returned columns in that same order. diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index 3eb97331e..d9acf8a09 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -1,12 +1,14 @@ package plan import ( + "fmt" "github.com/src-d/go-mysql-server/sql" "gopkg.in/src-d/go-errors.v1" ) // ErrCreateTable is thrown when the database doesn't support table creation -var ErrCreateTable = errors.NewKind("tables cannot be created on database %s") +var ErrCreateTableNotSupported = errors.NewKind("tables cannot be created on database %s") +var ErrDropTableNotSupported = errors.NewKind("tables cannot be dropped on database %s") // CreateTable is a node describing the creation of some table. type CreateTable struct { @@ -50,12 +52,12 @@ func (c *CreateTable) Resolved() bool { // RowIter implements the Node interface. func (c *CreateTable) RowIter(s *sql.Context) (sql.RowIter, error) { - d, ok := c.db.(sql.Alterable) - if !ok { - return nil, ErrCreateTable.New(c.db.Name()) + creatable, ok := c.db.(sql.TableCreator) + if ok { + return sql.RowsToRowIter(), creatable.CreateTable(s, c.name, c.schema) } - return sql.RowsToRowIter(), d.Create(c.name, c.schema) + return nil, ErrCreateTableNotSupported.New(c.db.Name()) } // Schema implements the Node interface. @@ -75,3 +77,86 @@ func (c *CreateTable) WithChildren(children ...sql.Node) (sql.Node, error) { func (c *CreateTable) String() string { return "CreateTable" } + +// DropTable is a node describing dropping one or more tables +type DropTable struct { + db sql.Database + names []string + ifExists bool +} + +// NewDropTable creates a new DropTable node +func NewDropTable(db sql.Database, ifExists bool, tableNames ...string) *DropTable { + return &DropTable{ + db: db, + names: tableNames, + ifExists: ifExists, + } +} + +var _ sql.Databaser = (*DropTable)(nil) + +// Database implements the sql.Databaser interface. +func (d *DropTable) Database() sql.Database { + return d.db +} + +// WithDatabase implements the sql.Databaser interface. +func (d *DropTable) WithDatabase(db sql.Database) (sql.Node, error) { + nc := *d + nc.db = db + return &nc, nil +} + +// Resolved implements the Resolvable interface. +func (d *DropTable) Resolved() bool { + _, ok := d.db.(sql.UnresolvedDatabase) + return !ok +} + +// RowIter implements the Node interface. +func (d *DropTable) RowIter(s *sql.Context) (sql.RowIter, error) { + droppable, ok := d.db.(sql.TableDropper) + if !ok { + return nil, ErrDropTableNotSupported.New(d.db.Name()) + } + + var err error + for _, tableName := range d.names { + _, ok := d.db.Tables()[tableName] + if !ok { + if d.ifExists { + continue + } + return nil, sql.ErrTableNotFound.New(tableName) + } + err = droppable.DropTable(s, tableName) + if err != nil { + break + } + } + + return sql.RowsToRowIter(), err +} + +// Schema implements the Node interface. +func (d *DropTable) Schema() sql.Schema { return nil } + +// Children implements the Node interface. +func (d *DropTable) Children() []sql.Node { return nil } + +// WithChildren implements the Node interface. +func (d *DropTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 0) + } + return d, nil +} + +func (d *DropTable) String() string { + ifExists := "" + if d.ifExists { + ifExists = "if exists " + } + return fmt.Sprintf("Drop table %s%s", ifExists, d.names) +} diff --git a/sql/plan/ddl_test.go b/sql/plan/ddl_test.go index a644857e9..a226a631c 100644 --- a/sql/plan/ddl_test.go +++ b/sql/plan/ddl_test.go @@ -22,15 +22,7 @@ func TestCreateTable(t *testing.T) { {Name: "c2", Type: sql.Int32}, } - c := NewCreateTable(db, "testTable", s) - - rows, err := c.RowIter(sql.NewEmptyContext()) - - require.NoError(err) - - r, err := rows.Next() - require.Equal(err, io.EOF) - require.Nil(r) + createTable(t, db, "testTable", s) tables = db.Tables() @@ -43,3 +35,59 @@ func TestCreateTable(t *testing.T) { require.Equal("testTable", s.Source) } } + +func TestDropTable(t *testing.T) { + require := require.New(t) + + db := memory.NewDatabase("test") + + s := sql.Schema{ + {Name: "c1", Type: sql.Text}, + {Name: "c2", Type: sql.Int32}, + } + + createTable(t, db, "testTable1", s) + createTable(t, db, "testTable2", s) + createTable(t, db, "testTable3", s) + + d := NewDropTable(db, false, "testTable1", "testTable2") + rows, err := d.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + r, err := rows.Next() + require.Equal(err, io.EOF) + require.Nil(r) + + _, ok := db.Tables()["testTable1"] + require.False(ok) + _, ok = db.Tables()["testTable2"] + require.False(ok) + _, ok = db.Tables()["testTable3"] + require.True(ok) + + d = NewDropTable(db, false, "testTable1") + _, err = d.RowIter(sql.NewEmptyContext()) + require.Error(err) + + d = NewDropTable(db, true, "testTable1") + _, err = d.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + d = NewDropTable(db, true, "testTable1", "testTable2", "testTable3") + _, err = d.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + _, ok = db.Tables()["testTable3"] + require.False(ok) +} + +func createTable(t *testing.T, db sql.Database, name string, schema sql.Schema) { + c := NewCreateTable(db, name, schema) + + rows, err := c.RowIter(sql.NewEmptyContext()) + require.NoError(t, err) + + r, err := rows.Next() + require.Equal(t, err, io.EOF) + require.Nil(t, r) +} \ No newline at end of file diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 807191f4b..06e638d26 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -11,15 +11,12 @@ import ( // ErrInsertIntoNotSupported is thrown when a table doesn't support inserts var ErrInsertIntoNotSupported = errors.NewKind("table doesn't support INSERT INTO") var ErrReplaceIntoNotSupported = errors.NewKind("table doesn't support REPLACE INTO") -var ErrInsertIntoMismatchValueCount = - errors.NewKind("number of values does not match number of columns provided") +var ErrInsertIntoMismatchValueCount = errors.NewKind("number of values does not match number of columns provided") var ErrInsertIntoUnsupportedValues = errors.NewKind("%T is unsupported for inserts") var ErrInsertIntoDuplicateColumn = errors.NewKind("duplicate column name %v") var ErrInsertIntoNonexistentColumn = errors.NewKind("invalid column name %v") -var ErrInsertIntoNonNullableDefaultNullColumn = - errors.NewKind("column name '%v' is non-nullable but attempted to set default value of null") -var ErrInsertIntoNonNullableProvidedNull = - errors.NewKind("column name '%v' is non-nullable but attempted to set a value of null") +var ErrInsertIntoNonNullableDefaultNullColumn = errors.NewKind("column name '%v' is non-nullable but attempted to set default value of null") +var ErrInsertIntoNonNullableProvidedNull = errors.NewKind("column name '%v' is non-nullable but attempted to set a value of null") // InsertInto is a node describing the insertion into some table. type InsertInto struct { @@ -148,6 +145,20 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) { return i, err } + // Convert integer values in row to specified type in schema + for colIdx, oldValue := range row { + dstColType := projExprs[colIdx].Type() + + if sql.IsInteger(dstColType) && oldValue != nil { + newValue, err := dstColType.Convert(oldValue) + if err != nil { + return i, err + } + + row[colIdx] = newValue + } + } + if replaceable != nil { if err = replaceable.Delete(ctx, row); err != nil { if err != sql.ErrDeleteRowNotFound { diff --git a/sql/plan/process.go b/sql/plan/process.go index 32a8452d6..2e0bec439 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -37,7 +37,7 @@ func (p *QueryProcess) RowIter(ctx *sql.Context) (sql.RowIter, error) { return nil, err } - return &trackedRowIter{iter, p.Notify}, nil + return &trackedRowIter{iter: iter, onDone: p.Notify}, nil } func (p *QueryProcess) String() string { return p.Child.String() } @@ -48,12 +48,14 @@ func (p *QueryProcess) String() string { return p.Child.String() } // partition is processed. type ProcessIndexableTable struct { sql.IndexableTable - Notify NotifyFunc + OnPartitionDone NamedNotifyFunc + OnPartitionStart NamedNotifyFunc + OnRowNext NamedNotifyFunc } // NewProcessIndexableTable returns a new ProcessIndexableTable. -func NewProcessIndexableTable(t sql.IndexableTable, notify NotifyFunc) *ProcessIndexableTable { - return &ProcessIndexableTable{t, notify} +func NewProcessIndexableTable(t sql.IndexableTable, onPartitionDone, onPartitionStart, OnRowNext NamedNotifyFunc) *ProcessIndexableTable { + return &ProcessIndexableTable{t, onPartitionDone, onPartitionStart, OnRowNext} } // Underlying implements sql.TableWrapper interface. @@ -71,7 +73,7 @@ func (t *ProcessIndexableTable) IndexKeyValues( return nil, err } - return &trackedPartitionIndexKeyValueIter{iter, t.Notify}, nil + return &trackedPartitionIndexKeyValueIter{iter, t.OnPartitionDone, t.OnPartitionStart, t.OnRowNext}, nil } // PartitionRows implements the sql.Table interface. @@ -81,22 +83,46 @@ func (t *ProcessIndexableTable) PartitionRows(ctx *sql.Context, p sql.Partition) return nil, err } - return &trackedRowIter{iter, t.Notify}, nil + partitionName := partitionName(p) + if t.OnPartitionStart != nil { + t.OnPartitionStart(partitionName) + } + + var onDone NotifyFunc + if t.OnPartitionDone != nil { + onDone = func() { + t.OnPartitionDone(partitionName) + } + } + + var onNext NotifyFunc + if t.OnRowNext != nil { + onNext = func() { + t.OnRowNext(partitionName) + } + } + + return &trackedRowIter{iter: iter, onNext: onNext, onDone: onDone}, nil } var _ sql.IndexableTable = (*ProcessIndexableTable)(nil) +// NamedNotifyFunc is a function to notify about some event with a string argument. +type NamedNotifyFunc func(name string) + // ProcessTable is a wrapper for sql.Tables inside a query process. It // notifies the process manager about the status of a query when a partition // is processed. type ProcessTable struct { sql.Table - Notify NotifyFunc + OnPartitionDone NamedNotifyFunc + OnPartitionStart NamedNotifyFunc + OnRowNext NamedNotifyFunc } // NewProcessTable returns a new ProcessTable. -func NewProcessTable(t sql.Table, notify NotifyFunc) *ProcessTable { - return &ProcessTable{t, notify} +func NewProcessTable(t sql.Table, onPartitionDone, onPartitionStart, OnRowNext NamedNotifyFunc) *ProcessTable { + return &ProcessTable{t, onPartitionDone, onPartitionStart, OnRowNext} } // Underlying implements sql.TableWrapper interface. @@ -111,18 +137,38 @@ func (t *ProcessTable) PartitionRows(ctx *sql.Context, p sql.Partition) (sql.Row return nil, err } - return &trackedRowIter{iter, t.Notify}, nil + partitionName := partitionName(p) + if t.OnPartitionStart != nil { + t.OnPartitionStart(partitionName) + } + + var onDone NotifyFunc + if t.OnPartitionDone != nil { + onDone = func() { + t.OnPartitionDone(partitionName) + } + } + + var onNext NotifyFunc + if t.OnRowNext != nil { + onNext = func() { + t.OnRowNext(partitionName) + } + } + + return &trackedRowIter{iter: iter, onNext: onNext, onDone: onDone}, nil } type trackedRowIter struct { iter sql.RowIter - notify NotifyFunc + onDone NotifyFunc + onNext NotifyFunc } func (i *trackedRowIter) done() { - if i.notify != nil { - i.notify() - i.notify = nil + if i.onDone != nil { + i.onDone() + i.onDone = nil } } @@ -134,6 +180,11 @@ func (i *trackedRowIter) Next() (sql.Row, error) { } return nil, err } + + if i.onNext != nil { + i.onNext() + } + return row, nil } @@ -144,7 +195,9 @@ func (i *trackedRowIter) Close() error { type trackedPartitionIndexKeyValueIter struct { sql.PartitionIndexKeyValueIter - notify NotifyFunc + OnPartitionDone NamedNotifyFunc + OnPartitionStart NamedNotifyFunc + OnRowNext NamedNotifyFunc } func (i *trackedPartitionIndexKeyValueIter) Next() (sql.Partition, sql.IndexKeyValueIter, error) { @@ -153,18 +206,38 @@ func (i *trackedPartitionIndexKeyValueIter) Next() (sql.Partition, sql.IndexKeyV return nil, nil, err } - return p, &trackedIndexKeyValueIter{iter, i.notify}, nil + partitionName := partitionName(p) + if i.OnPartitionStart != nil { + i.OnPartitionStart(partitionName) + } + + var onDone NotifyFunc + if i.OnPartitionDone != nil { + onDone = func() { + i.OnPartitionDone(partitionName) + } + } + + var onNext NotifyFunc + if i.OnRowNext != nil { + onNext = func() { + i.OnRowNext(partitionName) + } + } + + return p, &trackedIndexKeyValueIter{iter, onDone, onNext}, nil } type trackedIndexKeyValueIter struct { iter sql.IndexKeyValueIter - notify NotifyFunc + onDone NotifyFunc + onNext NotifyFunc } func (i *trackedIndexKeyValueIter) done() { - if i.notify != nil { - i.notify() - i.notify = nil + if i.onDone != nil { + i.onDone() + i.onDone = nil } } @@ -185,5 +258,16 @@ func (i *trackedIndexKeyValueIter) Next() ([]interface{}, []byte, error) { return nil, nil, err } + if i.onNext != nil { + i.onNext() + } + return v, k, nil } + +func partitionName(p sql.Partition) string { + if n, ok := p.(sql.Nameable); ok { + return n.Name() + } + return string(p.Key()) +} diff --git a/sql/plan/process_test.go b/sql/plan/process_test.go index 2d4e65763..de819edb0 100644 --- a/sql/plan/process_test.go +++ b/sql/plan/process_test.go @@ -61,7 +61,9 @@ func TestProcessTable(t *testing.T) { table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(3))) table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(4))) - var notifications int + var partitionDoneNotifications int + var partitionStartNotifications int + var rowNextNotifications int node := NewProject( []sql.Expression{ @@ -70,8 +72,14 @@ func TestProcessTable(t *testing.T) { NewResolvedTable( NewProcessTable( table, - func() { - notifications++ + func(partitionName string) { + partitionDoneNotifications++ + }, + func(partitionName string) { + partitionStartNotifications++ + }, + func(partitionName string) { + rowNextNotifications++ }, ), ), @@ -91,7 +99,9 @@ func TestProcessTable(t *testing.T) { } require.ElementsMatch(expected, rows) - require.Equal(2, notifications) + require.Equal(2, partitionDoneNotifications) + require.Equal(2, partitionStartNotifications) + require.Equal(4, rowNextNotifications) } func TestProcessIndexableTable(t *testing.T) { @@ -106,12 +116,20 @@ func TestProcessIndexableTable(t *testing.T) { table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(3))) table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(4))) - var notifications int + var partitionDoneNotifications int + var partitionStartNotifications int + var rowNextNotifications int pt := NewProcessIndexableTable( table, - func() { - notifications++ + func(partitionName string) { + partitionDoneNotifications++ + }, + func(partitionName string) { + partitionStartNotifications++ + }, + func(partitionName string) { + rowNextNotifications++ }, ) @@ -144,5 +162,7 @@ func TestProcessIndexableTable(t *testing.T) { } require.ElementsMatch(expectedValues, values) - require.Equal(2, notifications) + require.Equal(2, partitionDoneNotifications) + require.Equal(2, partitionStartNotifications) + require.Equal(4, rowNextNotifications) } diff --git a/sql/plan/processlist.go b/sql/plan/processlist.go index bc9f4d18c..a447fb79d 100644 --- a/sql/plan/processlist.go +++ b/sql/plan/processlist.go @@ -1,7 +1,6 @@ package plan import ( - "fmt" "sort" "strings" @@ -77,21 +76,36 @@ func (p *ShowProcessList) RowIter(ctx *sql.Context) (sql.RowIter, error) { for i, proc := range processes { var status []string - for name, progress := range proc.Progress { - status = append(status, fmt.Sprintf("%s(%s)", name, progress)) + var names []string + for name := range proc.Progress { + names = append(names, name) + } + sort.Strings(names) + + for _, name := range names { + progress := proc.Progress[name] + + printer := sql.NewTreePrinter() + _ = printer.WriteNode("\n" + progress.String()) + children := []string{} + for _, partitionProgress := range progress.PartitionsProgress { + children = append(children, partitionProgress.String()) + } + sort.Strings(children) + _ = printer.WriteChildren(children...) + + status = append(status, printer.String()) } if len(status) == 0 { status = []string{"running"} } - sort.Strings(status) - rows[i] = process{ id: int64(proc.Connection), user: proc.User, time: int64(proc.Seconds()), - state: strings.Join(status, ", "), + state: strings.Join(status, ""), command: proc.Type.String(), host: ctx.Session.Client().Address, info: proc.Query, diff --git a/sql/plan/processlist_test.go b/sql/plan/processlist_test.go index 401661ba2..12ee98b9a 100644 --- a/sql/plan/processlist_test.go +++ b/sql/plan/processlist_test.go @@ -21,19 +21,21 @@ func TestShowProcessList(t *testing.T) { ctx, err := p.AddProcess(ctx, sql.QueryProcess, "SELECT foo") require.NoError(err) - p.AddProgressItem(ctx.Pid(), "a", 5) - p.AddProgressItem(ctx.Pid(), "b", 6) + p.AddTableProgress(ctx.Pid(), "a", 5) + p.AddTableProgress(ctx.Pid(), "b", 6) ctx = sql.NewContext(context.Background(), sql.WithPid(2), sql.WithSession(sess)) ctx, err = p.AddProcess(ctx, sql.CreateIndexProcess, "SELECT bar") require.NoError(err) - p.AddProgressItem(ctx.Pid(), "foo", 2) + p.AddTableProgress(ctx.Pid(), "foo", 2) - p.UpdateProgress(1, "a", 3) - p.UpdateProgress(1, "a", 1) - p.UpdateProgress(1, "b", 2) - p.UpdateProgress(2, "foo", 1) + p.UpdateTableProgress(1, "a", 3) + p.UpdateTableProgress(1, "a", 1) + p.UpdatePartitionProgress(1, "a", "a-1", 7) + p.UpdatePartitionProgress(1, "a", "a-2", 9) + p.UpdateTableProgress(1, "b", 2) + p.UpdateTableProgress(2, "foo", 1) n.ProcessList = p n.Database = "foo" @@ -44,8 +46,15 @@ func TestShowProcessList(t *testing.T) { require.NoError(err) expected := []sql.Row{ - {int64(1), "foo", addr, "foo", "query", int64(0), "a(4/5), b(2/6)", "SELECT foo"}, - {int64(1), "foo", addr, "foo", "create_index", int64(0), "foo(1/2)", "SELECT bar"}, + {int64(1), "foo", addr, "foo", "query", int64(0), + ` +a (4/5 partitions) + ├─ a-1 (7/? rows) + └─ a-2 (9/? rows) + +b (2/6 partitions) +`, "SELECT foo"}, + {int64(1), "foo", addr, "foo", "create_index", int64(0), "\nfoo (1/2 partitions)\n", "SELECT bar"}, } require.ElementsMatch(expected, rows) diff --git a/sql/plan/show_create_table_test.go b/sql/plan/show_create_table_test.go index ff597682c..0da6418bf 100644 --- a/sql/plan/show_create_table_test.go +++ b/sql/plan/show_create_table_test.go @@ -19,6 +19,8 @@ func TestShowCreateTable(t *testing.T) { &sql.Column{Name: "baz", Type: sql.Text, Default: "", Nullable: false}, &sql.Column{Name: "zab", Type: sql.Int32, Default: int32(0), Nullable: true}, &sql.Column{Name: "bza", Type: sql.Uint64, Default: uint64(0), Nullable: true}, + &sql.Column{Name: "foo", Type: sql.VarChar(123), Default: "", Nullable: true}, + &sql.Column{Name: "pok", Type: sql.Char(123), Default: "", Nullable: true}, }) db.AddTable(table.Name(), table) @@ -39,7 +41,10 @@ func TestShowCreateTable(t *testing.T) { table.Name(), "CREATE TABLE `test-table` (\n `baz` text NOT NULL,\n"+ " `zab` integer DEFAULT 0,\n"+ - " `bza` bigint unsigned DEFAULT 0\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", + " `bza` bigint unsigned DEFAULT 0,\n"+ + " `foo` varchar(123),\n"+ + " `pok` char(123)\n"+ + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", ) require.Equal(expected, row) diff --git a/sql/plan/update.go b/sql/plan/update.go new file mode 100644 index 000000000..a29fb7392 --- /dev/null +++ b/sql/plan/update.go @@ -0,0 +1,189 @@ +package plan + +import ( + "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" + "io" +) + +var ErrUpdateNotSupported = errors.NewKind("table doesn't support UPDATE") +var ErrUpdateUnexpectedSetResult = errors.NewKind("attempted to set field but expression returned %T") + +// Update is a node for updating rows on tables. +type Update struct { + sql.Node + UpdateExprs []sql.Expression +} + +// NewUpdate creates an Update node. +func NewUpdate(n sql.Node, updateExprs []sql.Expression) *Update { + return &Update{n, updateExprs} +} + +// Expressions implements the Expressioner interface. +func (p *Update) Expressions() []sql.Expression { + return p.UpdateExprs +} + +// Schema implements the Node interface. +func (p *Update) Schema() sql.Schema { + return sql.Schema{ + { + Name: "matched", + Type: sql.Int64, + Default: int64(0), + Nullable: false, + }, + { + Name: "updated", + Type: sql.Int64, + Default: int64(0), + Nullable: false, + }, + } +} + +// Resolved implements the Resolvable interface. +func (p *Update) Resolved() bool { + if !p.Node.Resolved() { + return false + } + for _, updateExpr := range p.UpdateExprs { + if !updateExpr.Resolved() { + return false + } + } + return true +} + +func (p *Update) Children() []sql.Node { + return []sql.Node{p.Node} +} + +func getUpdatable(node sql.Node) (sql.Updater, error) { + switch node := node.(type) { + case sql.Updater: + return node, nil + case *ResolvedTable: + return getUpdatableTable(node.Table) + } + for _, child := range node.Children() { + updater, _ := getUpdatable(child) + if updater != nil { + return updater, nil + } + } + return nil, ErrUpdateNotSupported.New() +} + +func getUpdatableTable(t sql.Table) (sql.Updater, error) { + switch t := t.(type) { + case sql.Updater: + return t, nil + case sql.TableWrapper: + return getUpdatableTable(t.Underlying()) + default: + return nil, ErrUpdateNotSupported.New() + } +} + +// Execute inserts the rows in the database. +func (p *Update) Execute(ctx *sql.Context) (int, int, error) { + updatable, err := getUpdatable(p.Node) + if err != nil { + return 0, 0, err + } + schema := p.Node.Schema() + + iter, err := p.Node.RowIter(ctx) + if err != nil { + return 0, 0, err + } + + rowsMatched := 0 + rowsUpdated := 0 + for { + oldRow, err := iter.Next() + if err == io.EOF { + break + } + if err != nil { + _ = iter.Close() + return rowsMatched, rowsUpdated, err + } + rowsMatched++ + + newRow, err := p.applyUpdates(ctx, oldRow) + if err != nil { + _ = iter.Close() + return rowsMatched, rowsUpdated, err + } + if equals, err := oldRow.Equals(newRow, schema); err == nil { + if !equals { + err = updatable.Update(ctx, oldRow, newRow) + if err != nil { + _ = iter.Close() + return rowsMatched, rowsUpdated, err + } + rowsUpdated++ + } + } else { + _ = iter.Close() + return rowsMatched, rowsUpdated, err + } + } + + return rowsMatched, rowsUpdated, nil +} + +// RowIter implements the Node interface. +func (p *Update) RowIter(ctx *sql.Context) (sql.RowIter, error) { + matched, updated, err := p.Execute(ctx) + if err != nil { + return nil, err + } + + return sql.RowsToRowIter(sql.NewRow(int64(matched), int64(updated))), nil +} + +// WithChildren implements the Node interface. +func (p *Update) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + return NewUpdate(children[0], p.UpdateExprs), nil +} + +// WithExpressions implements the Expressioner interface. +func (p *Update) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) { + if len(newExprs) != len(p.UpdateExprs) { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(p.UpdateExprs), 1) + } + return NewUpdate(p.Node, newExprs), nil +} + +func (p Update) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("Update") + _ = pr.WriteChildren(p.Node.String()) + for _, updateExpr := range p.UpdateExprs { + _ = pr.WriteChildren(updateExpr.String()) + } + return pr.String() +} + +func (p *Update) applyUpdates(ctx *sql.Context, row sql.Row) (sql.Row, error) { + var ok bool + prev := row + for _, updateExpr := range p.UpdateExprs { + val, err := updateExpr.Eval(ctx, prev) + if err != nil { + return nil, err + } + prev, ok = val.(sql.Row) + if !ok { + return nil, ErrUpdateUnexpectedSetResult.New(val) + } + } + return prev, nil +} diff --git a/sql/processlist.go b/sql/processlist.go index 3a17f50d0..5bc1fecf6 100644 --- a/sql/processlist.go +++ b/sql/processlist.go @@ -12,17 +12,46 @@ import ( // Progress between done items and total items. type Progress struct { + Name string Done int64 Total int64 } -func (p Progress) String() string { +func (p Progress) totalString() string { var total = "?" if p.Total > 0 { total = fmt.Sprint(p.Total) } + return total +} - return fmt.Sprintf("%d/%s", p.Done, total) +// TableProgress keeps track of a table progress, and for each of its partitions +type TableProgress struct { + Progress + PartitionsProgress map[string]PartitionProgress +} + +func NewTableProgress(name string, total int64) TableProgress { + return TableProgress{ + Progress: Progress{ + Name: name, + Total: total, + }, + PartitionsProgress: make(map[string]PartitionProgress), + } +} + +func (p TableProgress) String() string { + return fmt.Sprintf("%s (%d/%s partitions)", p.Name, p.Done, p.totalString()) +} + +// PartitionProgress keeps track of a partition progress +type PartitionProgress struct { + Progress +} + +func (p PartitionProgress) String() string { + return fmt.Sprintf("%s (%d/%s rows)", p.Name, p.Done, p.totalString()) } // ProcessType is the type of process. @@ -53,7 +82,7 @@ type Process struct { User string Type ProcessType Query string - Progress map[string]Progress + Progress map[string]TableProgress StartedAt time.Time Kill context.CancelFunc } @@ -108,7 +137,7 @@ func (pl *ProcessList) AddProcess( Connection: ctx.ID(), Type: typ, Query: query, - Progress: make(map[string]Progress), + Progress: make(map[string]TableProgress), User: ctx.Session.Client().User, StartedAt: time.Now(), Kill: cancel, @@ -117,9 +146,9 @@ func (pl *ProcessList) AddProcess( return ctx, nil } -// UpdateProgress updates the progress of the item with the given name for the +// UpdateTableProgress updates the progress of the table with the given name for the // process with the given pid. -func (pl *ProcessList) UpdateProgress(pid uint64, name string, delta int64) { +func (pl *ProcessList) UpdateTableProgress(pid uint64, name string, delta int64) { pl.mu.Lock() defer pl.mu.Unlock() @@ -130,16 +159,41 @@ func (pl *ProcessList) UpdateProgress(pid uint64, name string, delta int64) { progress, ok := p.Progress[name] if !ok { - progress = Progress{Total: -1} + progress = NewTableProgress(name, -1) } progress.Done += delta p.Progress[name] = progress } -// AddProgressItem adds a new item to track progress from to the process with +// UpdatePartitionProgress updates the progress of the table partition with the +// given name for the process with the given pid. +func (pl *ProcessList) UpdatePartitionProgress(pid uint64, tableName, partitionName string, delta int64) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + tablePg, ok := p.Progress[tableName] + if !ok { + return + } + + partitionPg, ok := tablePg.PartitionsProgress[partitionName] + if !ok { + partitionPg = PartitionProgress{Progress: Progress{Name: partitionName, Total: -1}} + } + + partitionPg.Done += delta + tablePg.PartitionsProgress[partitionName] = partitionPg +} + +// AddTableProgress adds a new item to track progress from to the process with // the given pid. If the pid does not exist, it will do nothing. -func (pl *ProcessList) AddProgressItem(pid uint64, name string, total int64) { +func (pl *ProcessList) AddTableProgress(pid uint64, name string, total int64) { pl.mu.Lock() defer pl.mu.Unlock() @@ -152,8 +206,66 @@ func (pl *ProcessList) AddProgressItem(pid uint64, name string, total int64) { pg.Total = total p.Progress[name] = pg } else { - p.Progress[name] = Progress{Total: total} + p.Progress[name] = NewTableProgress(name, total) + } +} + +// AddPartitionProgress adds a new item to track progress from to the process with +// the given pid. If the pid or the table does not exist, it will do nothing. +func (pl *ProcessList) AddPartitionProgress(pid uint64, tableName, partitionName string, total int64) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + tablePg, ok := p.Progress[tableName] + if !ok { + return + } + + if pg, ok := tablePg.PartitionsProgress[partitionName]; ok { + pg.Total = total + tablePg.PartitionsProgress[partitionName] = pg + } else { + tablePg.PartitionsProgress[partitionName] = + PartitionProgress{Progress: Progress{Name: partitionName, Total: total}} + } +} + +// RemoveTableProgress removes an existing item tracking progress from the +// process with the given pid, if it exists. +func (pl *ProcessList) RemoveTableProgress(pid uint64, name string) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return } + + delete(p.Progress, name) +} + +// RemovePartitionProgress removes an existing item tracking progress from the +// process with the given pid, if it exists. +func (pl *ProcessList) RemovePartitionProgress(pid uint64, tableName, partitionName string) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + tablePg, ok := p.Progress[tableName] + if !ok { + return + } + + delete(tablePg.PartitionsProgress, partitionName) } // Kill terminates all queries for a given connection id. @@ -206,7 +318,7 @@ func (pl *ProcessList) Processes() []Process { for _, proc := range pl.procs { p := *proc - var progress = make(map[string]Progress, len(p.Progress)) + var progress = make(map[string]TableProgress, len(p.Progress)) for n, p := range p.Progress { progress[n] = p } diff --git a/sql/processlist_test.go b/sql/processlist_test.go index a6fb3a44f..198f40b12 100644 --- a/sql/processlist_test.go +++ b/sql/processlist_test.go @@ -20,16 +20,16 @@ func TestProcessList(t *testing.T) { require.Equal(uint64(1), ctx.Pid()) require.Len(p.procs, 1) - p.AddProgressItem(ctx.Pid(), "a", 5) - p.AddProgressItem(ctx.Pid(), "b", 6) + p.AddTableProgress(ctx.Pid(), "a", 5) + p.AddTableProgress(ctx.Pid(), "b", 6) expectedProcess := &Process{ Pid: 1, Connection: 1, Type: QueryProcess, - Progress: map[string]Progress{ - "a": Progress{0, 5}, - "b": Progress{0, 6}, + Progress: map[string]TableProgress{ + "a": {Progress{Name: "a", Done: 0, Total: 5}, map[string]PartitionProgress{}}, + "b": {Progress{Name: "b", Done: 0, Total: 6}, map[string]PartitionProgress{}}, }, User: "foo", Query: "SELECT foo", @@ -39,19 +39,36 @@ func TestProcessList(t *testing.T) { p.procs[ctx.Pid()].Kill = nil require.Equal(expectedProcess, p.procs[ctx.Pid()]) + p.AddPartitionProgress(ctx.Pid(), "b", "b-1", -1) + p.AddPartitionProgress(ctx.Pid(), "b", "b-2", -1) + p.AddPartitionProgress(ctx.Pid(), "b", "b-3", -1) + + p.UpdatePartitionProgress(ctx.Pid(), "b", "b-2", 1) + + p.RemovePartitionProgress(ctx.Pid(), "b", "b-3") + + expectedProgress := map[string]TableProgress{ + "a": {Progress{Name: "a", Total: 5}, map[string]PartitionProgress{}}, + "b": {Progress{Name: "b", Total: 6}, map[string]PartitionProgress{ + "b-1": {Progress{Name: "b-1", Done: 0, Total: -1}}, + "b-2": {Progress{Name: "b-2", Done: 1, Total: -1}}, + }}, + } + require.Equal(expectedProgress, p.procs[ctx.Pid()].Progress) + ctx = NewContext(context.Background(), WithPid(2), WithSession(sess)) ctx, err = p.AddProcess(ctx, CreateIndexProcess, "SELECT bar") require.NoError(err) - p.AddProgressItem(ctx.Pid(), "foo", 2) + p.AddTableProgress(ctx.Pid(), "foo", 2) require.Equal(uint64(2), ctx.Pid()) require.Len(p.procs, 2) - p.UpdateProgress(1, "a", 3) - p.UpdateProgress(1, "a", 1) - p.UpdateProgress(1, "b", 2) - p.UpdateProgress(2, "foo", 1) + p.UpdateTableProgress(1, "a", 3) + p.UpdateTableProgress(1, "a", 1) + p.UpdateTableProgress(1, "b", 2) + p.UpdateTableProgress(2, "foo", 1) require.Equal(int64(4), p.procs[1].Progress["a"].Done) require.Equal(int64(2), p.procs[1].Progress["b"].Done) diff --git a/sql/session.go b/sql/session.go index 8123ed51e..251467ba0 100644 --- a/sql/session.go +++ b/sql/session.go @@ -211,10 +211,11 @@ func NewBaseSession() Session { type Context struct { context.Context Session - Memory *MemoryManager - pid uint64 - query string - tracer opentracing.Tracer + Memory *MemoryManager + pid uint64 + query string + tracer opentracing.Tracer + rootSpan opentracing.Span } // ContextOption is a function to configure the context. @@ -255,6 +256,13 @@ func WithMemoryManager(m *MemoryManager) ContextOption { } } +// WithRootSpan sets the root span of the context. +func WithRootSpan(s opentracing.Span) ContextOption { + return func(ctx *Context) { + ctx.rootSpan = s + } +} + // NewContext creates a new query context. Options can be passed to configure // the context. If some aspect of the context is not configure, the default // value will be used. @@ -264,7 +272,7 @@ func NewContext( ctx context.Context, opts ...ContextOption, ) *Context { - c := &Context{ctx, NewBaseSession(), nil, 0, "", opentracing.NoopTracer{}} + c := &Context{ctx, NewBaseSession(), nil, 0, "", opentracing.NoopTracer{}, nil} for _, opt := range opts { opt(c) } @@ -298,12 +306,17 @@ func (c *Context) Span( span := c.tracer.StartSpan(opName, opts...) ctx := opentracing.ContextWithSpan(c.Context, span) - return span, &Context{ctx, c.Session, c.Memory, c.Pid(), c.Query(), c.tracer} + return span, &Context{ctx, c.Session, c.Memory, c.Pid(), c.Query(), c.tracer, c.rootSpan} } // WithContext returns a new context with the given underlying context. func (c *Context) WithContext(ctx context.Context) *Context { - return &Context{ctx, c.Session, c.Memory, c.Pid(), c.Query(), c.tracer} + return &Context{ctx, c.Session, c.Memory, c.Pid(), c.Query(), c.tracer, c.rootSpan} +} + +// RootSpan returns the root span, if any. +func (c *Context) RootSpan() opentracing.Span { + return c.rootSpan } // Error adds an error as warning to the session. diff --git a/sql/type.go b/sql/type.go index caadd9254..a93e7d83e 100644 --- a/sql/type.go +++ b/sql/type.go @@ -124,6 +124,8 @@ type Column struct { Nullable bool // Source is the name of the table this column came from. Source string + // PrimaryKey is true if the column is part of the primary key for its table. + PrimaryKey bool } // Check ensures the value is correct for this column. @@ -184,6 +186,10 @@ var ( Int16 = numberT{t: sqltypes.Int16} // Uint16 is an unsigned integer of 16 bits Uint16 = numberT{t: sqltypes.Uint16} + // Int24 is an integer of 24 bits. + Int24 = numberT{t: sqltypes.Int24} + // Uint24 is an unsigned integer of 24 bits. + Uint24 = numberT{t: sqltypes.Uint24} // Int32 is an integer of 32 bits. Int32 = numberT{t: sqltypes.Int32} // Uint32 is an unsigned integer of 32 bits. @@ -246,12 +252,16 @@ func MysqlTypeToType(sql query.Type) (Type, error) { return Int16, nil case sqltypes.Uint16: return Uint16, nil + case sqltypes.Int24: + return Int24, nil + case sqltypes.Uint24: + return Uint24, nil case sqltypes.Int32: return Int32, nil - case sqltypes.Int64: - return Int64, nil case sqltypes.Uint32: return Uint32, nil + case sqltypes.Int64: + return Int64, nil case sqltypes.Uint64: return Uint64, nil case sqltypes.Float32: @@ -335,10 +345,18 @@ func (t numberT) SQL(v interface{}) (sqltypes.Value, error) { } switch t.t { + case sqltypes.Int8: + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil + case sqltypes.Int16: + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil case sqltypes.Int32: return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil case sqltypes.Int64: return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil + case sqltypes.Uint8: + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil + case sqltypes.Uint16: + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil case sqltypes.Uint32: return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil case sqltypes.Uint64: @@ -735,7 +753,6 @@ func (t charT) Compare(a interface{}, b interface{}) (int, error) { return strings.Compare(a.(string), b.(string)), nil } - type varCharT struct { length int } @@ -1155,12 +1172,12 @@ func IsNumber(t Type) bool { // IsSigned checks if t is a signed type. func IsSigned(t Type) bool { - return t == Int32 || t == Int64 + return t == Int8 || t == Int16 || t == Int32 || t == Int64 } // IsUnsigned checks if t is an unsigned type. func IsUnsigned(t Type) bool { - return t == Uint64 || t == Uint32 + return t == Uint8 || t == Uint16 || t == Uint32 || t == Uint64 } // IsInteger checks if t is a (U)Int32/64 type. @@ -1249,9 +1266,9 @@ func MySQLTypeName(t Type) string { case sqltypes.Date: return "DATE" case sqltypes.Char: - return "CHAR" + return fmt.Sprintf("CHAR(%v)", t.(charT).Capacity()) case sqltypes.VarChar: - return "VARCHAR" + return fmt.Sprintf("VARCHAR(%v)", t.(varCharT).Capacity()) case sqltypes.Text: return "TEXT" case sqltypes.Bit: diff --git a/sql/type_test.go b/sql/type_test.go index a73bf08be..0dab60260 100644 --- a/sql/type_test.go +++ b/sql/type_test.go @@ -46,16 +46,118 @@ func TestBoolean(t *testing.T) { eq(t, Boolean, false, false) } +// Test conversion of all types of numbers to the specified signed integer type +// in typ, where minusOne, zero and one are the expected values with the +// same type as typ +func testSignedInt(t *testing.T, typ Type, minusOne, zero, one interface{}) { + t.Helper() + + convert(t, typ, -1, minusOne) + convert(t, typ, int8(-1), minusOne) + convert(t, typ, int16(-1), minusOne) + convert(t, typ, int32(-1), minusOne) + convert(t, typ, int64(-1), minusOne) + convert(t, typ, 0, zero) + convert(t, typ, int8(0), zero) + convert(t, typ, int16(0), zero) + convert(t, typ, int32(0), zero) + convert(t, typ, int64(0), zero) + convert(t, typ, uint8(0), zero) + convert(t, typ, uint16(0), zero) + convert(t, typ, uint32(0), zero) + convert(t, typ, uint64(0), zero) + convert(t, typ, 1, one) + convert(t, typ, int8(1), one) + convert(t, typ, int16(1), one) + convert(t, typ, int32(1), one) + convert(t, typ, int64(1), one) + convert(t, typ, uint8(1), one) + convert(t, typ, uint16(1), one) + convert(t, typ, uint32(1), one) + convert(t, typ, uint64(1), one) + convert(t, typ, "-1", minusOne) + convert(t, typ, "0", zero) + convert(t, typ, "1", one) + convertErr(t, typ, "") + + lt(t, Int8, minusOne, one) + eq(t, Int8, zero, zero) + eq(t, Int8, minusOne, minusOne) + eq(t, Int8, one, one) + gt(t, Int8, one, minusOne) +} + +// Test conversion of all types of numbers to the specified unsigned integer +// type in typ, where zero and one are the expected values with the same type +// as typ. The expected errors when converting from negative numbers are also +// tested +func testUnsignedInt(t *testing.T, typ Type, zero, one interface{}) { + t.Helper() + + convertErr(t, typ, -1) + convertErr(t, typ, int8(-1)) + convertErr(t, typ, int16(-1)) + convertErr(t, typ, int32(-1)) + convertErr(t, typ, int64(-1)) + convert(t, typ, 0, zero) + convert(t, typ, int8(0), zero) + convert(t, typ, int16(0), zero) + convert(t, typ, int32(0), zero) + convert(t, typ, int64(0), zero) + convert(t, typ, uint8(0), zero) + convert(t, typ, uint16(0), zero) + convert(t, typ, uint32(0), zero) + convert(t, typ, uint64(0), zero) + convert(t, typ, 1, one) + convert(t, typ, int8(1), one) + convert(t, typ, int16(1), one) + convert(t, typ, int32(1), one) + convert(t, typ, int64(1), one) + convert(t, typ, uint8(1), one) + convert(t, typ, uint16(1), one) + convert(t, typ, uint32(1), one) + convert(t, typ, uint64(1), one) + convertErr(t, typ, "-1") + convert(t, typ, "0", zero) + convert(t, typ, "1", one) + convertErr(t, typ, "") + + lt(t, Int8, zero, one) + eq(t, Int8, zero, zero) + eq(t, Int8, one, one) + gt(t, Int8, one, zero) +} + +func TestInt8(t *testing.T) { + testSignedInt(t, Int8, int8(-1), int8(0), int8(1)) +} + +func TestInt16(t *testing.T) { + testSignedInt(t, Int16, int16(-1), int16(0), int16(1)) +} + func TestInt32(t *testing.T) { - convert(t, Int32, int32(1), int32(1)) - convert(t, Int32, 1, int32(1)) - convert(t, Int32, int64(1), int32(1)) - convert(t, Int32, "5", int32(5)) - convertErr(t, Int32, "") + testSignedInt(t, Int32, int32(-1), int32(0), int32(1)) +} - lt(t, Int32, int32(1), int32(2)) - eq(t, Int32, int32(1), int32(1)) - gt(t, Int32, int32(3), int32(2)) +func TestInt64(t *testing.T) { + testSignedInt(t, Int64, int64(-1), int64(0), int64(1)) +} + +func TestUint8(t *testing.T) { + testUnsignedInt(t, Uint8, uint8(0), uint8(1)) +} + +func TestUint16(t *testing.T) { + testUnsignedInt(t, Uint16, uint16(0), uint16(1)) +} + +func TestUint32(t *testing.T) { + testUnsignedInt(t, Uint32, uint32(0), uint32(1)) +} + +func TestUint64(t *testing.T) { + testUnsignedInt(t, Uint64, uint64(0), uint64(1)) } func TestNumberComparison(t *testing.T) { @@ -90,6 +192,8 @@ func TestNumberComparison(t *testing.T) { {Uint8, uint8(42)}, {Int16, int16(42)}, {Uint16, uint16(42)}, + {Int24, int32(42)}, + {Uint24, uint32(42)}, {Int32, int32(42)}, {Uint32, uint32(42)}, {Int64, int64(42)}, @@ -140,18 +244,6 @@ func TestNumberComparison(t *testing.T) { } } -func TestInt64(t *testing.T) { - convert(t, Int64, int32(1), int64(1)) - convert(t, Int64, 1, int64(1)) - convert(t, Int64, int64(1), int64(1)) - convertErr(t, Int64, "") - convert(t, Int64, "5", int64(5)) - - lt(t, Int64, int64(1), int64(2)) - eq(t, Int64, int64(1), int64(1)) - gt(t, Int64, int64(3), int64(2)) -} - func TestFloat64(t *testing.T) { require := require.New(t) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy