Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

sql/analyzer: refactor resolve_natural_joins rule #750

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ require (
github.com/stretchr/testify v1.2.2
go.etcd.io/bbolt v1.3.2
golang.org/x/net v0.0.0-20190227022144-312bce6e941f // indirect
google.golang.org/genproto v0.0.0-20180831171423-11092d34479b // indirect
google.golang.org/grpc v1.19.0 // indirect
gopkg.in/src-d/go-errors.v1 v1.0.0
gopkg.in/yaml.v2 v2.2.2
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekf
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.0 h1:kbxbvI4Un1LUWKxufD+BiE6AEExYYgkQLQmLFqA1LFk=
github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0=
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c h1:964Od4U6p2jUkFxvCydnIczKteheJEzHRToSGK3Bnlw=
Expand Down Expand Up @@ -137,6 +139,7 @@ golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190227022144-312bce6e941f h1:tbtX/qtlxzhZjgQue/7u7ygFwDEckd+DmS5+t8FgeKE=
golang.org/x/net v0.0.0-20190227022144-312bce6e941f/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
Expand Down
321 changes: 96 additions & 225 deletions sql/analyzer/resolve_natural_joins.go
Original file line number Diff line number Diff line change
@@ -1,266 +1,137 @@
package analyzer

import (
"reflect"
"strings"

"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/plan"
)

type transformedJoin struct {
node sql.Node
condCols map[string]*transformedSource
}

type transformedSource struct {
correct string
wrong []string
}

func resolveNaturalJoins(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
span, _ := ctx.Span("resolve_natural_joins")
defer span.Finish()

if n.Resolved() {
return n, nil
}

var transformed []*transformedJoin
var aliasTables = map[string][]string{}
var colsToUnresolve = map[string]*transformedSource{}
a.Log("resolving natural joins, node of type %T", n)
node, err := n.TransformUp(func(n sql.Node) (sql.Node, error) {
a.Log("transforming node of type: %T", n)
var replacements = make(map[tableCol]tableCol)
var tableAliases = make(map[string]string)

if alias, ok := n.(*plan.TableAlias); ok {
table := alias.Child.(*plan.ResolvedTable).Name()
aliasTables[alias.Name()] = append(aliasTables[alias.Name()], table)
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.TableAlias:
alias := n.Name()
table := n.Child.(*plan.ResolvedTable).Name()
tableAliases[strings.ToLower(alias)] = table
return n, nil
}

if n.Resolved() {
case *plan.NaturalJoin:
return resolveNaturalJoin(n, replacements)
case sql.Expressioner:
return replaceExpressions(n, replacements, tableAliases)
default:
return n, nil
}
})
}

join, ok := n.(*plan.NaturalJoin)
if !ok {
return n, nil
}

// we need both leaves resolved before resolving this one
if !join.Left.Resolved() || !join.Right.Resolved() {
return n, nil
}

leftSchema, rightSchema := join.Left.Schema(), join.Right.Schema()

var conditions, common, left, right []sql.Expression
var seen = make(map[string]struct{})

for i, lcol := range leftSchema {
var found bool
leftCol := expression.NewGetFieldWithTable(
i,
lcol.Type,
lcol.Source,
lcol.Name,
lcol.Nullable,
)

for j, rcol := range rightSchema {
if lcol.Name == rcol.Name {
common = append(common, leftCol)

conditions = append(
conditions,
expression.NewEquals(
leftCol,
expression.NewGetFieldWithTable(
len(leftSchema)+j,
rcol.Type,
rcol.Source,
rcol.Name,
rcol.Nullable,
),
),
)

found = true
seen[lcol.Name] = struct{}{}
if source, ok := colsToUnresolve[lcol.Name]; ok {
source.correct = lcol.Source
source.wrong = append(source.wrong, rcol.Source)
} else {
colsToUnresolve[lcol.Name] = &transformedSource{
correct: lcol.Source,
wrong: []string{rcol.Source},
}
}

break
}
}
func resolveNaturalJoin(
n *plan.NaturalJoin,
replacements map[tableCol]tableCol,
) (sql.Node, error) {
// Both sides of the natural join need to be resolved in order to resolve
// the natural join itself.
if !n.Left.Resolved() || !n.Right.Resolved() {
return n, nil
}

if !found {
left = append(left, leftCol)
leftSchema := n.Left.Schema()
rightSchema := n.Right.Schema()

var conditions, common, left, right []sql.Expression
for i, lcol := range leftSchema {
leftCol := expression.NewGetFieldWithTable(
i,
lcol.Type,
lcol.Source,
lcol.Name,
lcol.Nullable,
)
if idx, rcol := findCol(rightSchema, lcol.Name); rcol != nil {
common = append(common, leftCol)
replacements[tableCol{strings.ToLower(rcol.Source), strings.ToLower(rcol.Name)}] = tableCol{
strings.ToLower(lcol.Source), strings.ToLower(lcol.Name),
}
}

if len(conditions) == 0 {
return plan.NewCrossJoin(join.Left, join.Right), nil
}

for i, col := range rightSchema {
if _, ok := seen[col.Name]; !ok {
right = append(
right,
conditions = append(
conditions,
expression.NewEquals(
leftCol,
expression.NewGetFieldWithTable(
len(leftSchema)+i,
col.Type,
col.Source,
col.Name,
col.Nullable,
len(leftSchema)+idx,
rcol.Type,
rcol.Source,
rcol.Name,
rcol.Nullable,
),
)
}
}

projections := append(append(common, left...), right...)

tj := &transformedJoin{
node: plan.NewProject(
projections,
plan.NewInnerJoin(
join.Left,
join.Right,
expression.JoinAnd(conditions...),
),
),
condCols: colsToUnresolve,
)
} else {
left = append(left, leftCol)
}

transformed = append(transformed, tj)

return tj.node, nil
})

if err != nil || len(transformed) == 0 {
return node, err
}

var transformedSeen bool
return node.TransformUp(func(node sql.Node) (sql.Node, error) {
if ok, _ := isTransformedNode(node, transformed); ok {
transformedSeen = true
return node, nil
}

if !transformedSeen {
return node, nil
}

expressioner, ok := node.(sql.Expressioner)
if !ok {
return node, nil
}

return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
var col, table string
switch e := e.(type) {
case *expression.GetField:
col, table = e.Name(), e.Table()
case *expression.UnresolvedColumn:
col, table = e.Name(), e.Table()
default:
return e, nil
}

sources, ok := colsToUnresolve[col]
if !ok {
return e, nil
}

if !mustUnresolve(aliasTables, table, sources.wrong) {
return e, nil
}

return expression.NewUnresolvedQualifiedColumn(
sources.correct,
col,
), nil
})
})
}

func isTransformedNode(node sql.Node, transformed []*transformedJoin) (is bool, colsToUnresolve map[string]*transformedSource) {
var project *plan.Project
var join *plan.InnerJoin
switch n := node.(type) {
case *plan.Project:
var ok bool
join, ok = n.Child.(*plan.InnerJoin)
if !ok {
return
}

project = n
case *plan.InnerJoin:
join = n

default:
return
if len(conditions) == 0 {
return plan.NewCrossJoin(n.Left, n.Right), nil
}

for _, t := range transformed {
tproject, ok := t.node.(*plan.Project)
if !ok {
return
}

tjoin, ok := tproject.Child.(*plan.InnerJoin)
if !ok {
return
}

if project != nil && !reflect.DeepEqual(project.Projections, tproject.Projections) {
continue
}

if reflect.DeepEqual(join.Cond, tjoin.Cond) {
is = true
colsToUnresolve = t.condCols
for i, col := range rightSchema {
source := strings.ToLower(col.Source)
name := strings.ToLower(col.Name)
if _, ok := replacements[tableCol{source, name}]; !ok {
right = append(
right,
expression.NewGetFieldWithTable(
len(leftSchema)+i,
col.Type,
col.Source,
col.Name,
col.Nullable,
),
)
}
}

return
return plan.NewProject(
append(append(common, left...), right...),
plan.NewInnerJoin(n.Left, n.Right, expression.JoinAnd(conditions...)),
), nil
}

func mustUnresolve(aliasTable map[string][]string, table string, wrongSources []string) bool {
return isIn(table, wrongSources) || isAliasFor(aliasTable, table, wrongSources)
}

func isIn(s string, l []string) bool {
for _, e := range l {
if s == e {
return true
func findCol(s sql.Schema, name string) (int, *sql.Column) {
for i, c := range s {
if strings.ToLower(c.Name) == strings.ToLower(name) {
return i, c
}
}

return false
return -1, nil
}

func isAliasFor(aliasTable map[string][]string, table string, wrongSources []string) bool {
tables, ok := aliasTable[table]
if !ok {
return false
}
func replaceExpressions(
n sql.Expressioner,
replacements map[tableCol]tableCol,
tableAliases map[string]string,
) (sql.Node, error) {
return n.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
switch e := e.(type) {
case *expression.GetField, *expression.UnresolvedColumn:
var tableName = e.(sql.Tableable).Table()
if t, ok := tableAliases[strings.ToLower(tableName)]; ok {
tableName = t
}

for _, t := range tables {
if isIn(t, wrongSources) {
return true
name := e.(sql.Nameable).Name()
if col, ok := replacements[tableCol{strings.ToLower(tableName), strings.ToLower(name)}]; ok {
return expression.NewUnresolvedQualifiedColumn(col.table, col.col), nil
}
}
}

return false
return e, nil
})
}
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy