@@ -29,12 +29,7 @@ func pruneColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
29
29
30
30
findUsedColumns (columns , n )
31
31
32
- n , err := addSubqueryBarriers (n )
33
- if err != nil {
34
- return nil , err
35
- }
36
-
37
- n , err = pruneUnusedColumns (n , columns )
32
+ n , err := pruneUnusedColumns (n , columns )
38
33
if err != nil {
39
34
return nil , err
40
35
}
@@ -81,12 +76,7 @@ func pruneSubqueryColumns(
81
76
82
77
findUsedColumns (columns , n .Child )
83
78
84
- node , err := addSubqueryBarriers (n .Child )
85
- if err != nil {
86
- return nil , err
87
- }
88
-
89
- node , err = pruneUnusedColumns (node , columns )
79
+ node , err := pruneUnusedColumns (n .Child , columns )
90
80
if err != nil {
91
81
return nil , err
92
82
}
@@ -126,30 +116,19 @@ func findUsedColumns(columns usedColumns, n sql.Node) {
126
116
})
127
117
}
128
118
129
- func addSubqueryBarriers (n sql.Node ) (sql.Node , error ) {
130
- return n .TransformUp (func (n sql.Node ) (sql.Node , error ) {
131
- sq , ok := n .(* plan.SubqueryAlias )
132
- if ! ok {
133
- return n , nil
134
- }
135
-
136
- return & subqueryBarrier {sq }, nil
137
- })
138
- }
139
-
140
119
func pruneSubqueries (
141
120
ctx * sql.Context ,
142
121
a * Analyzer ,
143
122
n sql.Node ,
144
123
parentColumns usedColumns ,
145
124
) (sql.Node , error ) {
146
125
return n .TransformUp (func (n sql.Node ) (sql.Node , error ) {
147
- barrier , ok := n .(* subqueryBarrier )
126
+ subq , ok := n .(* plan. SubqueryAlias )
148
127
if ! ok {
149
128
return n , nil
150
129
}
151
130
152
- return pruneSubqueryColumns (ctx , a , barrier . SubqueryAlias , parentColumns )
131
+ return pruneSubqueryColumns (ctx , a , subq , parentColumns )
153
132
})
154
133
}
155
134
@@ -173,39 +152,53 @@ type tableColumnPair struct {
173
152
174
153
func fixRemainingFieldsIndexes (n sql.Node ) (sql.Node , error ) {
175
154
return n .TransformUp (func (n sql.Node ) (sql.Node , error ) {
176
- exp , ok := n .(sql.Expressioner )
177
- if ! ok {
178
- return n , nil
179
- }
180
-
181
- var schema sql.Schema
182
- for _ , c := range n .Children () {
183
- schema = append (schema , c .Schema ()... )
184
- }
155
+ switch n := n .(type ) {
156
+ case * plan.SubqueryAlias :
157
+ child , err := fixRemainingFieldsIndexes (n .Child )
158
+ if err != nil {
159
+ return nil , err
160
+ }
185
161
186
- if len (schema ) == 0 {
187
- return n , nil
188
- }
162
+ return plan .NewSubqueryAlias (n .Name (), child ), nil
163
+ default :
164
+ exp , ok := n .(sql.Expressioner )
165
+ if ! ok {
166
+ return n , nil
167
+ }
189
168
190
- indexes := make ( map [ tableColumnPair ] int )
191
- for i , col := range schema {
192
- indexes [ tableColumnPair { col . Source , col . Name }] = i
193
- }
169
+ var schema sql. Schema
170
+ for _ , c := range n . Children () {
171
+ schema = append ( schema , c . Schema () ... )
172
+ }
194
173
195
- return exp .TransformExpressions (func (e sql.Expression ) (sql.Expression , error ) {
196
- gf , ok := e .(* expression.GetField )
197
- if ! ok {
198
- return e , nil
174
+ if len (schema ) == 0 {
175
+ return n , nil
199
176
}
200
177
201
- idx , ok := indexes [tableColumnPair { gf . Table (), gf . Name ()}]
202
- if ! ok {
203
- return nil , fmt . Errorf ( "unable to find column %q of table %q" , gf .Name (), gf . Table ())
178
+ indexes := make ( map [tableColumnPair ] int )
179
+ for i , col := range schema {
180
+ indexes [ tableColumnPair { col . Source , col .Name }] = i
204
181
}
205
182
206
- ngf := * gf
207
- return ngf .WithIndex (idx ), nil
208
- })
183
+ return exp .TransformExpressions (func (e sql.Expression ) (sql.Expression , error ) {
184
+ gf , ok := e .(* expression.GetField )
185
+ if ! ok {
186
+ return e , nil
187
+ }
188
+
189
+ idx , ok := indexes [tableColumnPair {gf .Table (), gf .Name ()}]
190
+ if ! ok {
191
+ return nil , fmt .Errorf ("unable to find column %q of table %q" , gf .Name (), gf .Table ())
192
+ }
193
+
194
+ if idx == gf .Index () {
195
+ return gf , nil
196
+ }
197
+
198
+ ngf := * gf
199
+ return ngf .WithIndex (idx ), nil
200
+ })
201
+ }
209
202
})
210
203
}
211
204
@@ -290,11 +283,3 @@ func shouldPruneExpr(e sql.Expression, cols usedColumns) bool {
290
283
291
284
return true
292
285
}
293
-
294
- type subqueryBarrier struct {
295
- * plan.SubqueryAlias
296
- }
297
-
298
- func (b * subqueryBarrier ) TransformUp (f sql.TransformNodeFunc ) (sql.Node , error ) {
299
- return f (b )
300
- }
0 commit comments