Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 70 additions & 51 deletions cli/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ import (
"github.com/coder/serpent"
)

var (
// noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify
// when the local address is not specified in port-forward flags.
noAddr netip.Addr
ipv6Loopback = netip.MustParseAddr("::1")
ipv4Loopback = netip.MustParseAddr("127.0.0.1")
)

func (r *RootCmd) portForward() *serpent.Command {
var (
tcpForwards []string // <port>:<port>
Expand Down Expand Up @@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command {
// Start all listeners.
var (
wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs))
listeners = make([]net.Listener, 0, len(specs)*2)
closeAllListeners = func() {
logger.Debug(ctx, "closing all listeners")
for _, l := range listeners {
Expand All @@ -135,13 +143,25 @@ func (r *RootCmd) portForward() *serpent.Command {
)
defer closeAllListeners()

for i, spec := range specs {
for _, spec := range specs {
if spec.listenHost == noAddr {
// first, opportunistically try to listen on IPv6
spec6 := spec
spec6.listenHost = ipv6Loopback
l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger)
if err6 != nil {
logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6))
} else {
listeners = append(listeners, l6)
}
spec.listenHost = ipv4Loopback
}
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger)
if err != nil {
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
return err
}
listeners[i] = l
listeners = append(listeners, l)
}

stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID)
Expand Down Expand Up @@ -206,12 +226,19 @@ func listenAndPortForward(
spec portForwardSpec,
logger slog.Logger,
) (net.Listener, error) {
logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress))
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
logger = logger.With(
slog.F("network", spec.network),
slog.F("listen_host", spec.listenHost),
slog.F("listen_port", spec.listenPort),
)
listenAddress := netip.AddrPortFrom(spec.listenHost, spec.listenPort)
dialAddress := fmt.Sprintf("127.0.0.1:%d", spec.dialPort)
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n",
spec.network, listenAddress, spec.network, dialAddress)

l, err := inv.Net.Listen(spec.listenNetwork, spec.listenAddress)
l, err := inv.Net.Listen(spec.network, listenAddress.String())
if err != nil {
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
return nil, xerrors.Errorf("listen '%s://%s': %w", spec.network, listenAddress.String(), err)
}
logger.Debug(ctx, "listening")

Expand All @@ -226,24 +253,31 @@ func listenAndPortForward(
logger.Debug(ctx, "listener closed")
return
}
_, _ = fmt.Fprintf(inv.Stderr, "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err)
_, _ = fmt.Fprintf(inv.Stderr,
"Error accepting connection from '%s://%s': %v\n",
spec.network, listenAddress.String(), err)
_, _ = fmt.Fprintln(inv.Stderr, "Killing listener")
return
}
logger.Debug(ctx, "accepted connection", slog.F("remote_addr", netConn.RemoteAddr()))
logger.Debug(ctx, "accepted connection",
slog.F("remote_addr", netConn.RemoteAddr()))

go func(netConn net.Conn) {
defer netConn.Close()
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress)
if err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
_, _ = fmt.Fprintf(inv.Stderr,
"Failed to dial '%s://%s' in workspace: %s\n",
spec.network, dialAddress, err)
return
}
defer remoteConn.Close()
logger.Debug(ctx, "dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
logger.Debug(ctx,
"dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))

agentssh.Bicopy(ctx, netConn, remoteConn)
logger.Debug(ctx, "connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
logger.Debug(ctx,
"connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
}(netConn)
}
}(spec)
Expand All @@ -252,58 +286,48 @@ func listenAndPortForward(
}

type portForwardSpec struct {
listenNetwork string // tcp, udp
listenAddress string // <ip>:<port> or path

dialNetwork string // tcp, udp
dialAddress string // <ip>:<port> or path
network string // tcp, udp
listenHost netip.Addr
listenPort, dialPort uint16
}

func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
specs := []portForwardSpec{}

for _, specEntry := range tcpSpecs {
for _, spec := range strings.Split(specEntry, ",") {
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
if err != nil {
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
}

for _, port := range ports {
specs = append(specs, portForwardSpec{
listenNetwork: "tcp",
listenAddress: port.local.String(),
dialNetwork: "tcp",
dialAddress: port.remote.String(),
})
for _, pfSpec := range pfSpecs {
pfSpec.network = "tcp"
specs = append(specs, pfSpec)
}
}
}

for _, specEntry := range udpSpecs {
for _, spec := range strings.Split(specEntry, ",") {
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
if err != nil {
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
}

for _, port := range ports {
specs = append(specs, portForwardSpec{
listenNetwork: "udp",
listenAddress: port.local.String(),
dialNetwork: "udp",
dialAddress: port.remote.String(),
})
for _, pfSpec := range pfSpecs {
pfSpec.network = "udp"
specs = append(specs, pfSpec)
}
}
}

// Check for duplicate entries.
locals := map[string]struct{}{}
for _, spec := range specs {
localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress)
localStr := fmt.Sprintf("%s:%s:%d", spec.network, spec.listenHost, spec.listenPort)
if _, ok := locals[localStr]; ok {
return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress)
return nil, xerrors.Errorf("local %s host:%s port:%d is specified twice", spec.network, spec.listenHost, spec.listenPort)
}
locals[localStr] = struct{}{}
}
Expand All @@ -323,10 +347,6 @@ func parsePort(in string) (uint16, error) {
return uint16(port), nil
}

type parsedSrcDestPort struct {
local, remote netip.AddrPort
}

// specRegexp matches port specs. It handles all the following formats:
//
// 8000
Expand All @@ -347,21 +367,19 @@ type parsedSrcDestPort struct {
// 9: end or remote port range
var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`)

func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
var (
err error
localAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
remoteAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
)
func parseSrcDestPorts(in string) ([]portForwardSpec, error) {
groups := specRegexp.FindStringSubmatch(in)
if len(groups) == 0 {
return nil, xerrors.Errorf("invalid port specification %q", in)
}

var localAddr netip.Addr
if groups[2] != "" {
localAddr, err = netip.ParseAddr(strings.Trim(groups[2], "[]"))
parsedAddr, err := netip.ParseAddr(strings.Trim(groups[2], "[]"))
if err != nil {
return nil, xerrors.Errorf("invalid IP address %q", groups[2])
}
localAddr = parsedAddr
}

local, err := parsePortRange(groups[3], groups[5])
Expand All @@ -378,11 +396,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
if len(local) != len(remote) {
return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote))
}
var out []parsedSrcDestPort
var out []portForwardSpec
for i := range local {
out = append(out, parsedSrcDestPort{
local: netip.AddrPortFrom(localAddr, local[i]),
remote: netip.AddrPortFrom(remoteAddr, remote[i]),
out = append(out, portForwardSpec{
listenHost: localAddr,
listenPort: local[i],
dialPort: remote[i],
})
}
return out, nil
Expand Down
32 changes: 16 additions & 16 deletions cli/portforward_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ func Test_parsePortForwards(t *testing.T) {
},
},
want: []portForwardSpec{
{"tcp", "127.0.0.1:8000", "tcp", "127.0.0.1:8000"},
{"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"},
{"tcp", "127.0.0.1:9000", "tcp", "127.0.0.1:9000"},
{"tcp", "127.0.0.1:9001", "tcp", "127.0.0.1:9001"},
{"tcp", "127.0.0.1:9002", "tcp", "127.0.0.1:9002"},
{"tcp", "127.0.0.1:9003", "tcp", "127.0.0.1:9005"},
{"tcp", "127.0.0.1:9004", "tcp", "127.0.0.1:9006"},
{"tcp", "127.0.0.1:10000", "tcp", "127.0.0.1:10000"},
{"tcp", "127.0.0.1:4444", "tcp", "127.0.0.1:4444"},
{"tcp", noAddr, 8000, 8000},
{"tcp", noAddr, 8080, 8081},
{"tcp", noAddr, 9000, 9000},
{"tcp", noAddr, 9001, 9001},
{"tcp", noAddr, 9002, 9002},
{"tcp", noAddr, 9003, 9005},
{"tcp", noAddr, 9004, 9006},
{"tcp", noAddr, 10000, 10000},
{"tcp", noAddr, 4444, 4444},
},
},
{
Expand All @@ -46,7 +46,7 @@ func Test_parsePortForwards(t *testing.T) {
tcpSpecs: []string{"127.0.0.1:8080:8081"},
},
want: []portForwardSpec{
{"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"},
{"tcp", ipv4Loopback, 8080, 8081},
},
},
{
Expand All @@ -55,7 +55,7 @@ func Test_parsePortForwards(t *testing.T) {
tcpSpecs: []string{"[::1]:8080:8081"},
},
want: []portForwardSpec{
{"tcp", "[::1]:8080", "tcp", "127.0.0.1:8081"},
{"tcp", ipv6Loopback, 8080, 8081},
},
},
{
Expand All @@ -64,9 +64,9 @@ func Test_parsePortForwards(t *testing.T) {
udpSpecs: []string{"8000,8080-8081"},
},
want: []portForwardSpec{
{"udp", "127.0.0.1:8000", "udp", "127.0.0.1:8000"},
{"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8080"},
{"udp", "127.0.0.1:8081", "udp", "127.0.0.1:8081"},
{"udp", noAddr, 8000, 8000},
{"udp", noAddr, 8080, 8080},
{"udp", noAddr, 8081, 8081},
},
},
{
Expand All @@ -75,7 +75,7 @@ func Test_parsePortForwards(t *testing.T) {
udpSpecs: []string{"127.0.0.1:8080:8081"},
},
want: []portForwardSpec{
{"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8081"},
{"udp", ipv4Loopback, 8080, 8081},
},
},
{
Expand All @@ -84,7 +84,7 @@ func Test_parsePortForwards(t *testing.T) {
udpSpecs: []string{"[::1]:8080:8081"},
},
want: []portForwardSpec{
{"udp", "[::1]:8080", "udp", "127.0.0.1:8081"},
{"udp", ipv6Loopback, 8080, 8081},
},
},
{
Expand Down
Loading
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy