Skip to content

Commit 6d9f4c5

Browse files
committed
feat: opportunistically listen on IPv6 in port-forward
1 parent f6c3f0a commit 6d9f4c5

File tree

3 files changed

+169
-67
lines changed

3 files changed

+169
-67
lines changed

cli/portforward.go

Lines changed: 70 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ import (
2525
"github.com/coder/serpent"
2626
)
2727

28+
var (
29+
// noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify
30+
// when the local address is not specified in port-forward flags.
31+
noAddr netip.Addr
32+
ipv6Loopback = netip.MustParseAddr("::1")
33+
ipv4Loopback = netip.MustParseAddr("127.0.0.1")
34+
)
35+
2836
func (r *RootCmd) portForward() *serpent.Command {
2937
var (
3038
tcpForwards []string // <port>:<port>
@@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command {
122130
// Start all listeners.
123131
var (
124132
wg = new(sync.WaitGroup)
125-
listeners = make([]net.Listener, len(specs))
133+
listeners = make([]net.Listener, 0, len(specs)*2)
126134
closeAllListeners = func() {
127135
logger.Debug(ctx, "closing all listeners")
128136
for _, l := range listeners {
@@ -135,13 +143,25 @@ func (r *RootCmd) portForward() *serpent.Command {
135143
)
136144
defer closeAllListeners()
137145

138-
for i, spec := range specs {
146+
for _, spec := range specs {
147+
if spec.listenHost == noAddr {
148+
// first, opportunistically try to listen on IPv6
149+
spec6 := spec
150+
spec6.listenHost = ipv6Loopback
151+
l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger)
152+
if err6 != nil {
153+
logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6))
154+
} else {
155+
listeners = append(listeners, l6)
156+
}
157+
spec.listenHost = ipv4Loopback
158+
}
139159
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger)
140160
if err != nil {
141161
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
142162
return err
143163
}
144-
listeners[i] = l
164+
listeners = append(listeners, l)
145165
}
146166

147167
stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID)
@@ -206,12 +226,19 @@ func listenAndPortForward(
206226
spec portForwardSpec,
207227
logger slog.Logger,
208228
) (net.Listener, error) {
209-
logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress))
210-
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
229+
logger = logger.With(
230+
slog.F("network", spec.network),
231+
slog.F("listen_host", spec.listenHost),
232+
slog.F("listen_port", spec.listenPort),
233+
)
234+
listenAddress := netip.AddrPortFrom(spec.listenHost, spec.listenPort)
235+
dialAddress := fmt.Sprintf("127.0.0.1:%d", spec.dialPort)
236+
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n",
237+
spec.network, listenAddress, spec.network, dialAddress)
211238

212-
l, err := inv.Net.Listen(spec.listenNetwork, spec.listenAddress)
239+
l, err := inv.Net.Listen(spec.network, listenAddress.String())
213240
if err != nil {
214-
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
241+
return nil, xerrors.Errorf("listen '%s://%s': %w", spec.network, listenAddress.String(), err)
215242
}
216243
logger.Debug(ctx, "listening")
217244

@@ -226,24 +253,31 @@ func listenAndPortForward(
226253
logger.Debug(ctx, "listener closed")
227254
return
228255
}
229-
_, _ = fmt.Fprintf(inv.Stderr, "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err)
256+
_, _ = fmt.Fprintf(inv.Stderr,
257+
"Error accepting connection from '%s://%s': %v\n",
258+
spec.network, listenAddress.String(), err)
230259
_, _ = fmt.Fprintln(inv.Stderr, "Killing listener")
231260
return
232261
}
233-
logger.Debug(ctx, "accepted connection", slog.F("remote_addr", netConn.RemoteAddr()))
262+
logger.Debug(ctx, "accepted connection",
263+
slog.F("remote_addr", netConn.RemoteAddr()))
234264

235265
go func(netConn net.Conn) {
236266
defer netConn.Close()
237-
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
267+
remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress)
238268
if err != nil {
239-
_, _ = fmt.Fprintf(inv.Stderr, "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
269+
_, _ = fmt.Fprintf(inv.Stderr,
270+
"Failed to dial '%s://%s' in workspace: %s\n",
271+
spec.network, dialAddress, err)
240272
return
241273
}
242274
defer remoteConn.Close()
243-
logger.Debug(ctx, "dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
275+
logger.Debug(ctx,
276+
"dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
244277

245278
agentssh.Bicopy(ctx, netConn, remoteConn)
246-
logger.Debug(ctx, "connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
279+
logger.Debug(ctx,
280+
"connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
247281
}(netConn)
248282
}
249283
}(spec)
@@ -252,58 +286,48 @@ func listenAndPortForward(
252286
}
253287

254288
type portForwardSpec struct {
255-
listenNetwork string // tcp, udp
256-
listenAddress string // <ip>:<port> or path
257-
258-
dialNetwork string // tcp, udp
259-
dialAddress string // <ip>:<port> or path
289+
network string // tcp, udp
290+
listenHost netip.Addr
291+
listenPort, dialPort uint16
260292
}
261293

262294
func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
263295
specs := []portForwardSpec{}
264296

265297
for _, specEntry := range tcpSpecs {
266298
for _, spec := range strings.Split(specEntry, ",") {
267-
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
299+
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
268300
if err != nil {
269301
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
270302
}
271303

272-
for _, port := range ports {
273-
specs = append(specs, portForwardSpec{
274-
listenNetwork: "tcp",
275-
listenAddress: port.local.String(),
276-
dialNetwork: "tcp",
277-
dialAddress: port.remote.String(),
278-
})
304+
for _, pfSpec := range pfSpecs {
305+
pfSpec.network = "tcp"
306+
specs = append(specs, pfSpec)
279307
}
280308
}
281309
}
282310

283311
for _, specEntry := range udpSpecs {
284312
for _, spec := range strings.Split(specEntry, ",") {
285-
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
313+
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
286314
if err != nil {
287315
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
288316
}
289317

290-
for _, port := range ports {
291-
specs = append(specs, portForwardSpec{
292-
listenNetwork: "udp",
293-
listenAddress: port.local.String(),
294-
dialNetwork: "udp",
295-
dialAddress: port.remote.String(),
296-
})
318+
for _, pfSpec := range pfSpecs {
319+
pfSpec.network = "udp"
320+
specs = append(specs, pfSpec)
297321
}
298322
}
299323
}
300324

301325
// Check for duplicate entries.
302326
locals := map[string]struct{}{}
303327
for _, spec := range specs {
304-
localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress)
328+
localStr := fmt.Sprintf("%s:%s:%d", spec.network, spec.listenHost, spec.listenPort)
305329
if _, ok := locals[localStr]; ok {
306-
return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress)
330+
return nil, xerrors.Errorf("local %s host:%s port:%d is specified twice", spec.network, spec.listenHost, spec.listenPort)
307331
}
308332
locals[localStr] = struct{}{}
309333
}
@@ -323,10 +347,6 @@ func parsePort(in string) (uint16, error) {
323347
return uint16(port), nil
324348
}
325349

326-
type parsedSrcDestPort struct {
327-
local, remote netip.AddrPort
328-
}
329-
330350
// specRegexp matches port specs. It handles all the following formats:
331351
//
332352
// 8000
@@ -347,21 +367,19 @@ type parsedSrcDestPort struct {
347367
// 9: end or remote port range
348368
var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`)
349369

350-
func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
351-
var (
352-
err error
353-
localAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
354-
remoteAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
355-
)
370+
func parseSrcDestPorts(in string) ([]portForwardSpec, error) {
356371
groups := specRegexp.FindStringSubmatch(in)
357372
if len(groups) == 0 {
358373
return nil, xerrors.Errorf("invalid port specification %q", in)
359374
}
375+
376+
var localAddr netip.Addr
360377
if groups[2] != "" {
361-
localAddr, err = netip.ParseAddr(strings.Trim(groups[2], "[]"))
378+
parsedAddr, err := netip.ParseAddr(strings.Trim(groups[2], "[]"))
362379
if err != nil {
363380
return nil, xerrors.Errorf("invalid IP address %q", groups[2])
364381
}
382+
localAddr = parsedAddr
365383
}
366384

367385
local, err := parsePortRange(groups[3], groups[5])
@@ -378,11 +396,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
378396
if len(local) != len(remote) {
379397
return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote))
380398
}
381-
var out []parsedSrcDestPort
399+
var out []portForwardSpec
382400
for i := range local {
383-
out = append(out, parsedSrcDestPort{
384-
local: netip.AddrPortFrom(localAddr, local[i]),
385-
remote: netip.AddrPortFrom(remoteAddr, remote[i]),
401+
out = append(out, portForwardSpec{
402+
listenHost: localAddr,
403+
listenPort: local[i],
404+
dialPort: remote[i],
386405
})
387406
}
388407
return out, nil

cli/portforward_internal_test.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ func Test_parsePortForwards(t *testing.T) {
2929
},
3030
},
3131
want: []portForwardSpec{
32-
{"tcp", "127.0.0.1:8000", "tcp", "127.0.0.1:8000"},
33-
{"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"},
34-
{"tcp", "127.0.0.1:9000", "tcp", "127.0.0.1:9000"},
35-
{"tcp", "127.0.0.1:9001", "tcp", "127.0.0.1:9001"},
36-
{"tcp", "127.0.0.1:9002", "tcp", "127.0.0.1:9002"},
37-
{"tcp", "127.0.0.1:9003", "tcp", "127.0.0.1:9005"},
38-
{"tcp", "127.0.0.1:9004", "tcp", "127.0.0.1:9006"},
39-
{"tcp", "127.0.0.1:10000", "tcp", "127.0.0.1:10000"},
40-
{"tcp", "127.0.0.1:4444", "tcp", "127.0.0.1:4444"},
32+
{"tcp", noAddr, 8000, 8000},
33+
{"tcp", noAddr, 8080, 8081},
34+
{"tcp", noAddr, 9000, 9000},
35+
{"tcp", noAddr, 9001, 9001},
36+
{"tcp", noAddr, 9002, 9002},
37+
{"tcp", noAddr, 9003, 9005},
38+
{"tcp", noAddr, 9004, 9006},
39+
{"tcp", noAddr, 10000, 10000},
40+
{"tcp", noAddr, 4444, 4444},
4141
},
4242
},
4343
{
@@ -46,7 +46,7 @@ func Test_parsePortForwards(t *testing.T) {
4646
tcpSpecs: []string{"127.0.0.1:8080:8081"},
4747
},
4848
want: []portForwardSpec{
49-
{"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"},
49+
{"tcp", ipv4Loopback, 8080, 8081},
5050
},
5151
},
5252
{
@@ -55,7 +55,7 @@ func Test_parsePortForwards(t *testing.T) {
5555
tcpSpecs: []string{"[::1]:8080:8081"},
5656
},
5757
want: []portForwardSpec{
58-
{"tcp", "[::1]:8080", "tcp", "127.0.0.1:8081"},
58+
{"tcp", ipv6Loopback, 8080, 8081},
5959
},
6060
},
6161
{
@@ -64,9 +64,9 @@ func Test_parsePortForwards(t *testing.T) {
6464
udpSpecs: []string{"8000,8080-8081"},
6565
},
6666
want: []portForwardSpec{
67-
{"udp", "127.0.0.1:8000", "udp", "127.0.0.1:8000"},
68-
{"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8080"},
69-
{"udp", "127.0.0.1:8081", "udp", "127.0.0.1:8081"},
67+
{"udp", noAddr, 8000, 8000},
68+
{"udp", noAddr, 8080, 8080},
69+
{"udp", noAddr, 8081, 8081},
7070
},
7171
},
7272
{
@@ -75,7 +75,7 @@ func Test_parsePortForwards(t *testing.T) {
7575
udpSpecs: []string{"127.0.0.1:8080:8081"},
7676
},
7777
want: []portForwardSpec{
78-
{"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8081"},
78+
{"udp", ipv4Loopback, 8080, 8081},
7979
},
8080
},
8181
{
@@ -84,7 +84,7 @@ func Test_parsePortForwards(t *testing.T) {
8484
udpSpecs: []string{"[::1]:8080:8081"},
8585
},
8686
want: []portForwardSpec{
87-
{"udp", "[::1]:8080", "udp", "127.0.0.1:8081"},
87+
{"udp", ipv6Loopback, 8080, 8081},
8888
},
8989
},
9090
{

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy