Skip to content

Commit 9ff3342

Browse files
committed
feat: add --network-info-dir and --network-info-interval flags to coder ssh
This is the first in a series of PRs to enable "coder ssh" to replace "coder vscodessh". This change adds --network-info-dir and --network-info-interval flags to the ssh subcommand. These were formerly only available with the vscodessh subcommand. Subsequent PRs will add a --ssh-host-prefix flag to the ssh subcommand, and adjust the log file naming to contain the parent PID.
1 parent 6ca1e59 commit 9ff3342

File tree

5 files changed

+309
-163
lines changed

5 files changed

+309
-163
lines changed

cli/ssh.go

Lines changed: 209 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cli
33
import (
44
"bytes"
55
"context"
6+
"encoding/json"
67
"errors"
78
"fmt"
89
"io"
@@ -13,6 +14,7 @@ import (
1314
"os/exec"
1415
"path/filepath"
1516
"slices"
17+
"strconv"
1618
"strings"
1719
"sync"
1820
"time"
@@ -21,11 +23,14 @@ import (
2123
"github.com/gofrs/flock"
2224
"github.com/google/uuid"
2325
"github.com/mattn/go-isatty"
26+
"github.com/spf13/afero"
2427
gossh "golang.org/x/crypto/ssh"
2528
gosshagent "golang.org/x/crypto/ssh/agent"
2629
"golang.org/x/term"
2730
"golang.org/x/xerrors"
2831
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
32+
"tailscale.com/tailcfg"
33+
"tailscale.com/types/netlogtype"
2934

3035
"cdr.dev/slog"
3136
"cdr.dev/slog/sloggers/sloghuman"
@@ -55,19 +60,21 @@ var (
5560

5661
func (r *RootCmd) ssh() *serpent.Command {
5762
var (
58-
stdio bool
59-
forwardAgent bool
60-
forwardGPG bool
61-
identityAgent string
62-
wsPollInterval time.Duration
63-
waitEnum string
64-
noWait bool
65-
logDirPath string
66-
remoteForwards []string
67-
env []string
68-
usageApp string
69-
disableAutostart bool
70-
appearanceConfig codersdk.AppearanceConfig
63+
stdio bool
64+
forwardAgent bool
65+
forwardGPG bool
66+
identityAgent string
67+
wsPollInterval time.Duration
68+
waitEnum string
69+
noWait bool
70+
logDirPath string
71+
remoteForwards []string
72+
env []string
73+
usageApp string
74+
disableAutostart bool
75+
appearanceConfig codersdk.AppearanceConfig
76+
networkInfoDir string
77+
networkInfoInterval time.Duration
7178
)
7279
client := new(codersdk.Client)
7380
cmd := &serpent.Command{
@@ -284,13 +291,21 @@ func (r *RootCmd) ssh() *serpent.Command {
284291
return err
285292
}
286293

294+
var errCh <-chan error
295+
if networkInfoDir != "" {
296+
errCh, err = setStatsCallback(ctx, conn, logger, networkInfoDir, networkInfoInterval)
297+
if err != nil {
298+
return err
299+
}
300+
}
301+
287302
wg.Add(1)
288303
go func() {
289304
defer wg.Done()
290305
watchAndClose(ctx, func() error {
291306
stack.close(xerrors.New("watchAndClose"))
292307
return nil
293-
}, logger, client, workspace)
308+
}, logger, client, workspace, errCh)
294309
}()
295310
copier.copy(&wg)
296311
return nil
@@ -312,6 +327,14 @@ func (r *RootCmd) ssh() *serpent.Command {
312327
return err
313328
}
314329

330+
var errCh <-chan error
331+
if networkInfoDir != "" {
332+
errCh, err = setStatsCallback(ctx, conn, logger, networkInfoDir, networkInfoInterval)
333+
if err != nil {
334+
return err
335+
}
336+
}
337+
315338
wg.Add(1)
316339
go func() {
317340
defer wg.Done()
@@ -324,6 +347,7 @@ func (r *RootCmd) ssh() *serpent.Command {
324347
logger,
325348
client,
326349
workspace,
350+
errCh,
327351
)
328352
}()
329353

@@ -540,6 +564,17 @@ func (r *RootCmd) ssh() *serpent.Command {
540564
Value: serpent.StringOf(&usageApp),
541565
Hidden: true,
542566
},
567+
{
568+
Flag: "network-info-dir",
569+
Description: "Specifies a directory to write network information periodically.",
570+
Value: serpent.StringOf(&networkInfoDir),
571+
},
572+
{
573+
Flag: "network-info-interval",
574+
Description: "Specifies the interval to update network information.",
575+
Default: "5s",
576+
Value: serpent.DurationOf(&networkInfoInterval),
577+
},
543578
sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)),
544579
}
545580
return cmd
@@ -555,7 +590,7 @@ func (r *RootCmd) ssh() *serpent.Command {
555590
// will usually not propagate.
556591
//
557592
// See: https://github.com/coder/coder/issues/6180
558-
func watchAndClose(ctx context.Context, closer func() error, logger slog.Logger, client *codersdk.Client, workspace codersdk.Workspace) {
593+
func watchAndClose(ctx context.Context, closer func() error, logger slog.Logger, client *codersdk.Client, workspace codersdk.Workspace, errCh <-chan error) {
559594
// Ensure session is ended on both context cancellation
560595
// and workspace stop.
561596
defer func() {
@@ -606,6 +641,9 @@ startWatchLoop:
606641
logger.Info(ctx, "workspace stopped")
607642
return
608643
}
644+
case err := <-errCh:
645+
logger.Error(ctx, "failed to collect network stats", slog.Error(err))
646+
return
609647
}
610648
}
611649
}
@@ -1144,3 +1182,159 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName {
11441182

11451183
return codersdk.UsageAppNameSSH
11461184
}
1185+
1186+
func setStatsCallback(
1187+
ctx context.Context,
1188+
agentConn *workspacesdk.AgentConn,
1189+
logger slog.Logger,
1190+
networkInfoDir string,
1191+
networkInfoInterval time.Duration,
1192+
) (<-chan error, error) {
1193+
fs, ok := ctx.Value("fs").(afero.Fs)
1194+
if !ok {
1195+
fs = afero.NewOsFs()
1196+
}
1197+
if err := fs.MkdirAll(networkInfoDir, 0o700); err != nil {
1198+
return nil, xerrors.Errorf("mkdir: %w", err)
1199+
}
1200+
1201+
// The VS Code extension obtains the PID of the SSH process to
1202+
// read files to display logs and network info.
1203+
//
1204+
// We get the parent PID because it's assumed `ssh` is calling this
1205+
// command via the ProxyCommand SSH option.
1206+
pid := os.Getppid()
1207+
1208+
// The VS Code extension obtains the PID of the SSH process to
1209+
// read the file below which contains network information to display.
1210+
//
1211+
// We get the parent PID because it's assumed `ssh` is calling this
1212+
// command via the ProxyCommand SSH option.
1213+
networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", pid))
1214+
1215+
var (
1216+
firstErrTime time.Time
1217+
errCh = make(chan error, 1)
1218+
)
1219+
cb := func(start, end time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
1220+
sendErr := func(tolerate bool, err error) {
1221+
logger.Error(ctx, "collect network stats", slog.Error(err))
1222+
// Tolerate up to 1 minute of errors.
1223+
if tolerate {
1224+
if firstErrTime.IsZero() {
1225+
logger.Info(ctx, "tolerating network stats errors for up to 1 minute")
1226+
firstErrTime = time.Now()
1227+
}
1228+
if time.Since(firstErrTime) < time.Minute {
1229+
return
1230+
}
1231+
}
1232+
1233+
select {
1234+
case errCh <- err:
1235+
default:
1236+
}
1237+
}
1238+
1239+
stats, err := collectNetworkStats(ctx, agentConn, start, end, virtual)
1240+
if err != nil {
1241+
sendErr(true, err)
1242+
return
1243+
}
1244+
1245+
rawStats, err := json.Marshal(stats)
1246+
if err != nil {
1247+
sendErr(false, err)
1248+
return
1249+
}
1250+
err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600)
1251+
if err != nil {
1252+
sendErr(false, err)
1253+
return
1254+
}
1255+
1256+
firstErrTime = time.Time{}
1257+
}
1258+
1259+
now := time.Now()
1260+
cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{})
1261+
agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb)
1262+
return errCh, nil
1263+
}
1264+
1265+
type sshNetworkStats struct {
1266+
P2P bool `json:"p2p"`
1267+
Latency float64 `json:"latency"`
1268+
PreferredDERP string `json:"preferred_derp"`
1269+
DERPLatency map[string]float64 `json:"derp_latency"`
1270+
UploadBytesSec int64 `json:"upload_bytes_sec"`
1271+
DownloadBytesSec int64 `json:"download_bytes_sec"`
1272+
}
1273+
1274+
func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
1275+
latency, p2p, pingResult, err := agentConn.Ping(ctx)
1276+
if err != nil {
1277+
return nil, err
1278+
}
1279+
node := agentConn.Node()
1280+
derpMap := agentConn.DERPMap()
1281+
derpLatency := map[string]float64{}
1282+
1283+
// Convert DERP region IDs to friendly names for display in the UI.
1284+
for rawRegion, latency := range node.DERPLatency {
1285+
regionParts := strings.SplitN(rawRegion, "-", 2)
1286+
regionID, err := strconv.Atoi(regionParts[0])
1287+
if err != nil {
1288+
continue
1289+
}
1290+
region, found := derpMap.Regions[regionID]
1291+
if !found {
1292+
// It's possible that a workspace agent is using an old DERPMap
1293+
// and reports regions that do not exist. If that's the case,
1294+
// report the region as unknown!
1295+
region = &tailcfg.DERPRegion{
1296+
RegionID: regionID,
1297+
RegionName: fmt.Sprintf("Unnamed %d", regionID),
1298+
}
1299+
}
1300+
// Convert the microseconds to milliseconds.
1301+
derpLatency[region.RegionName] = latency * 1000
1302+
}
1303+
1304+
totalRx := uint64(0)
1305+
totalTx := uint64(0)
1306+
for _, stat := range counts {
1307+
totalRx += stat.RxBytes
1308+
totalTx += stat.TxBytes
1309+
}
1310+
// Tracking the time since last request is required because
1311+
// ExtractTrafficStats() resets its counters after each call.
1312+
dur := end.Sub(start)
1313+
uploadSecs := float64(totalTx) / dur.Seconds()
1314+
downloadSecs := float64(totalRx) / dur.Seconds()
1315+
1316+
// Sometimes the preferred DERP doesn't match the one we're actually
1317+
// connected with. Perhaps because the agent prefers a different DERP and
1318+
// we're using that server instead.
1319+
preferredDerpID := node.PreferredDERP
1320+
if pingResult.DERPRegionID != 0 {
1321+
preferredDerpID = pingResult.DERPRegionID
1322+
}
1323+
preferredDerp, ok := derpMap.Regions[preferredDerpID]
1324+
preferredDerpName := fmt.Sprintf("Unnamed %d", preferredDerpID)
1325+
if ok {
1326+
preferredDerpName = preferredDerp.RegionName
1327+
}
1328+
if _, ok := derpLatency[preferredDerpName]; !ok {
1329+
derpLatency[preferredDerpName] = 0
1330+
}
1331+
1332+
return &sshNetworkStats{
1333+
P2P: p2p,
1334+
Latency: float64(latency.Microseconds()) / 1000,
1335+
PreferredDERP: preferredDerpName,
1336+
DERPLatency: derpLatency,
1337+
UploadBytesSec: int64(uploadSecs),
1338+
DownloadBytesSec: int64(downloadSecs),
1339+
}, nil
1340+
}

cli/ssh_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"time"
2525

2626
"github.com/google/uuid"
27+
"github.com/spf13/afero"
2728
"github.com/stretchr/testify/assert"
2829
"github.com/stretchr/testify/require"
2930
"golang.org/x/crypto/ssh"
@@ -438,6 +439,78 @@ func TestSSH(t *testing.T) {
438439
<-cmdDone
439440
})
440441

442+
t.Run("NetworkInfo", func(t *testing.T) {
443+
t.Parallel()
444+
client, workspace, agentToken := setupWorkspaceForAgent(t)
445+
_, _ = tGoContext(t, func(ctx context.Context) {
446+
// Run this async so the SSH command has to wait for
447+
// the build and agent to connect!
448+
_ = agenttest.New(t, client.URL, agentToken)
449+
<-ctx.Done()
450+
})
451+
452+
clientOutput, clientInput := io.Pipe()
453+
serverOutput, serverInput := io.Pipe()
454+
defer func() {
455+
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
456+
_ = c.Close()
457+
}
458+
}()
459+
460+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
461+
defer cancel()
462+
463+
fs := afero.NewMemMapFs()
464+
//nolint:revive,staticcheck
465+
ctx = context.WithValue(ctx, "fs", fs)
466+
467+
inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name, "--network-info-dir", "/net", "--network-info-interval", "25ms")
468+
clitest.SetupConfig(t, client, root)
469+
inv.Stdin = clientOutput
470+
inv.Stdout = serverInput
471+
inv.Stderr = io.Discard
472+
473+
cmdDone := tGo(t, func() {
474+
err := inv.WithContext(ctx).Run()
475+
assert.NoError(t, err)
476+
})
477+
478+
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
479+
Reader: serverOutput,
480+
Writer: clientInput,
481+
}, "", &ssh.ClientConfig{
482+
// #nosec
483+
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
484+
})
485+
require.NoError(t, err)
486+
defer conn.Close()
487+
488+
sshClient := ssh.NewClient(conn, channels, requests)
489+
session, err := sshClient.NewSession()
490+
require.NoError(t, err)
491+
defer session.Close()
492+
493+
command := "sh -c exit"
494+
if runtime.GOOS == "windows" {
495+
command = "cmd.exe /c exit"
496+
}
497+
err = session.Run(command)
498+
require.NoError(t, err)
499+
err = sshClient.Close()
500+
require.NoError(t, err)
501+
_ = clientOutput.Close()
502+
503+
assert.Eventually(t, func() bool {
504+
entries, err := afero.ReadDir(fs, "/net")
505+
if err != nil {
506+
return false
507+
}
508+
return len(entries) > 0
509+
}, testutil.WaitLong, testutil.IntervalFast)
510+
511+
<-cmdDone
512+
})
513+
441514
t.Run("Stdio_StartStoppedWorkspace_CleanStdout", func(t *testing.T) {
442515
t.Parallel()
443516

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy