From 6d2985e75bb1a661b5522e972a27a14155af76d7 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Mon, 2 Jun 2025 14:56:31 +1000 Subject: [PATCH 1/5] feat: add vpn start progress --- .../.idea/projectSettingsUpdater.xml | 1 + App/Models/RpcModel.cs | 19 ++ .../PublishProfiles/win-arm64.pubxml | 12 - App/Properties/PublishProfiles/win-x64.pubxml | 12 - App/Properties/PublishProfiles/win-x86.pubxml | 12 - App/Services/RpcController.cs | 31 ++- App/ViewModels/TrayWindowViewModel.cs | 38 ++- .../Pages/TrayWindowLoginRequiredPage.xaml | 2 +- App/Views/Pages/TrayWindowMainPage.xaml | 11 +- Tests.Vpn.Service/DownloaderTest.cs | 50 +++- Vpn.Proto/vpn.proto | 16 +- Vpn.Service/Downloader.cs | 220 +++++++++++++++--- Vpn.Service/Manager.cs | 67 +++++- Vpn.Service/ManagerRpc.cs | 2 +- Vpn.Service/Program.cs | 11 +- Vpn.Service/TunnelSupervisor.cs | 6 +- Vpn/Speaker.cs | 2 +- 17 files changed, 421 insertions(+), 91 deletions(-) delete mode 100644 App/Properties/PublishProfiles/win-arm64.pubxml delete mode 100644 App/Properties/PublishProfiles/win-x64.pubxml delete mode 100644 App/Properties/PublishProfiles/win-x86.pubxml diff --git a/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml b/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml index 64af657..ef20cb0 100644 --- a/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml +++ b/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml @@ -2,6 +2,7 @@ \ No newline at end of file diff --git a/App/Models/RpcModel.cs b/App/Models/RpcModel.cs index 034f405..33b4647 100644 --- a/App/Models/RpcModel.cs +++ b/App/Models/RpcModel.cs @@ -19,12 +19,30 @@ public enum VpnLifecycle Stopping, } +public class VpnStartupProgress +{ + public double Progress { get; set; } = 0.0; // 0.0 to 1.0 + public string Message { get; set; } = string.Empty; + + public VpnStartupProgress Clone() + { + return new VpnStartupProgress + { + Progress = Progress, + Message = Message, + }; + } +} + public class RpcModel { public RpcLifecycle RpcLifecycle { get; set; } = RpcLifecycle.Disconnected; public VpnLifecycle VpnLifecycle { get; set; } = VpnLifecycle.Unknown; + // Nullable because it is only set when the VpnLifecycle is Starting + public VpnStartupProgress? VpnStartupProgress { get; set; } + public IReadOnlyList Workspaces { get; set; } = []; public IReadOnlyList Agents { get; set; } = []; @@ -35,6 +53,7 @@ public RpcModel Clone() { RpcLifecycle = RpcLifecycle, VpnLifecycle = VpnLifecycle, + VpnStartupProgress = VpnStartupProgress?.Clone(), Workspaces = Workspaces, Agents = Agents, }; diff --git a/App/Properties/PublishProfiles/win-arm64.pubxml b/App/Properties/PublishProfiles/win-arm64.pubxml deleted file mode 100644 index ac9753e..0000000 --- a/App/Properties/PublishProfiles/win-arm64.pubxml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - FileSystem - ARM64 - win-arm64 - bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\ - - diff --git a/App/Properties/PublishProfiles/win-x64.pubxml b/App/Properties/PublishProfiles/win-x64.pubxml deleted file mode 100644 index 942523b..0000000 --- a/App/Properties/PublishProfiles/win-x64.pubxml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - FileSystem - x64 - win-x64 - bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\ - - diff --git a/App/Properties/PublishProfiles/win-x86.pubxml b/App/Properties/PublishProfiles/win-x86.pubxml deleted file mode 100644 index e763481..0000000 --- a/App/Properties/PublishProfiles/win-x86.pubxml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - FileSystem - x86 - win-x86 - bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\ - - diff --git a/App/Services/RpcController.cs b/App/Services/RpcController.cs index 7beff66..b42c058 100644 --- a/App/Services/RpcController.cs +++ b/App/Services/RpcController.cs @@ -161,7 +161,12 @@ public async Task StartVpn(CancellationToken ct = default) throw new RpcOperationException( $"Cannot start VPN without valid credentials, current state: {credentials.State}"); - MutateState(state => { state.VpnLifecycle = VpnLifecycle.Starting; }); + MutateState(state => + { + state.VpnLifecycle = VpnLifecycle.Starting; + // Explicitly clear the startup progress. + state.VpnStartupProgress = null; + }); ServiceMessage reply; try @@ -251,6 +256,9 @@ private void MutateState(Action mutator) using (_stateLock.Lock()) { mutator(_state); + // Unset the startup progress if the VpnLifecycle is not Starting + if (_state.VpnLifecycle != VpnLifecycle.Starting) + _state.VpnStartupProgress = null; newState = _state.Clone(); } @@ -283,15 +291,32 @@ private void ApplyStatusUpdate(Status status) }); } + private void ApplyStartProgressUpdate(StartProgress message) + { + MutateState(state => + { + // MutateState will undo these changes if it doesn't believe we're + // in the "Starting" state. + state.VpnStartupProgress = new VpnStartupProgress + { + Progress = message.Progress, + Message = message.Message, + }; + }); + } + private void SpeakerOnReceive(ReplyableRpcMessage message) { switch (message.Message.MsgCase) { + case ServiceMessage.MsgOneofCase.Start: + case ServiceMessage.MsgOneofCase.Stop: case ServiceMessage.MsgOneofCase.Status: ApplyStatusUpdate(message.Message.Status); break; - case ServiceMessage.MsgOneofCase.Start: - case ServiceMessage.MsgOneofCase.Stop: + case ServiceMessage.MsgOneofCase.StartProgress: + ApplyStartProgressUpdate(message.Message.StartProgress); + break; case ServiceMessage.MsgOneofCase.None: default: // TODO: log unexpected message diff --git a/App/ViewModels/TrayWindowViewModel.cs b/App/ViewModels/TrayWindowViewModel.cs index d8b3182..cd3a641 100644 --- a/App/ViewModels/TrayWindowViewModel.cs +++ b/App/ViewModels/TrayWindowViewModel.cs @@ -29,7 +29,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost { private const int MaxAgents = 5; private const string DefaultDashboardUrl = "https://coder.com"; - private const string DefaultHostnameSuffix = ".coder"; + private const string DefaultStartProgressMessage = "Starting Coder Connect..."; private readonly IServiceProvider _services; private readonly IRpcController _rpcController; @@ -53,6 +53,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost [ObservableProperty] [NotifyPropertyChangedFor(nameof(ShowEnableSection))] + [NotifyPropertyChangedFor(nameof(ShowVpnStartProgressSection))] [NotifyPropertyChangedFor(nameof(ShowWorkspacesHeader))] [NotifyPropertyChangedFor(nameof(ShowNoAgentsSection))] [NotifyPropertyChangedFor(nameof(ShowAgentsSection))] @@ -63,6 +64,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost [ObservableProperty] [NotifyPropertyChangedFor(nameof(ShowEnableSection))] + [NotifyPropertyChangedFor(nameof(ShowVpnStartProgressSection))] [NotifyPropertyChangedFor(nameof(ShowWorkspacesHeader))] [NotifyPropertyChangedFor(nameof(ShowNoAgentsSection))] [NotifyPropertyChangedFor(nameof(ShowAgentsSection))] @@ -70,7 +72,25 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost [NotifyPropertyChangedFor(nameof(ShowFailedSection))] public partial string? VpnFailedMessage { get; set; } = null; - public bool ShowEnableSection => VpnFailedMessage is null && VpnLifecycle is not VpnLifecycle.Started; + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(VpnStartProgressIsIndeterminate))] + [NotifyPropertyChangedFor(nameof(VpnStartProgressValueOrDefault))] + public partial int? VpnStartProgressValue { get; set; } = null; + + public int VpnStartProgressValueOrDefault => VpnStartProgressValue ?? 0; + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(VpnStartProgressMessageOrDefault))] + public partial string? VpnStartProgressMessage { get; set; } = null; + + public string VpnStartProgressMessageOrDefault => + string.IsNullOrEmpty(VpnStartProgressMessage) ? DefaultStartProgressMessage : VpnStartProgressMessage; + + public bool VpnStartProgressIsIndeterminate => VpnStartProgressValueOrDefault == 0; + + public bool ShowEnableSection => VpnFailedMessage is null && VpnLifecycle is not VpnLifecycle.Starting and not VpnLifecycle.Started; + + public bool ShowVpnStartProgressSection => VpnFailedMessage is null && VpnLifecycle is VpnLifecycle.Starting; public bool ShowWorkspacesHeader => VpnFailedMessage is null && VpnLifecycle is VpnLifecycle.Started; @@ -170,6 +190,20 @@ private void UpdateFromRpcModel(RpcModel rpcModel) VpnLifecycle = rpcModel.VpnLifecycle; VpnSwitchActive = rpcModel.VpnLifecycle is VpnLifecycle.Starting or VpnLifecycle.Started; + // VpnStartupProgress is only set when the VPN is starting. + if (rpcModel.VpnLifecycle is VpnLifecycle.Starting && rpcModel.VpnStartupProgress != null) + { + // Convert 0.00-1.00 to 0-100. + var progress = (int)(rpcModel.VpnStartupProgress.Progress * 100); + VpnStartProgressValue = Math.Clamp(progress, 0, 100); + VpnStartProgressMessage = string.IsNullOrEmpty(rpcModel.VpnStartupProgress.Message) ? null : rpcModel.VpnStartupProgress.Message; + } + else + { + VpnStartProgressValue = null; + VpnStartProgressMessage = null; + } + // Add every known agent. HashSet workspacesWithAgents = []; List agents = []; diff --git a/App/Views/Pages/TrayWindowLoginRequiredPage.xaml b/App/Views/Pages/TrayWindowLoginRequiredPage.xaml index c1d69aa..171e292 100644 --- a/App/Views/Pages/TrayWindowLoginRequiredPage.xaml +++ b/App/Views/Pages/TrayWindowLoginRequiredPage.xaml @@ -36,7 +36,7 @@ diff --git a/App/Views/Pages/TrayWindowMainPage.xaml b/App/Views/Pages/TrayWindowMainPage.xaml index 283867d..f488454 100644 --- a/App/Views/Pages/TrayWindowMainPage.xaml +++ b/App/Views/Pages/TrayWindowMainPage.xaml @@ -43,6 +43,8 @@ + + + HorizontalContentAlignment="Left"> diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs index 986ce46..a47ffbc 100644 --- a/Tests.Vpn.Service/DownloaderTest.cs +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -2,6 +2,7 @@ using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Text; +using System.Threading.Channels; using Coder.Desktop.Vpn.Service; using Microsoft.Extensions.Logging.Abstractions; @@ -278,7 +279,7 @@ public async Task Download(CancellationToken ct) NullDownloadValidator.Instance, ct); await dlTask.Task; Assert.That(dlTask.TotalBytes, Is.EqualTo(4)); - Assert.That(dlTask.BytesRead, Is.EqualTo(4)); + Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); Assert.That(dlTask.Progress, Is.EqualTo(1)); Assert.That(dlTask.IsCompleted, Is.True); Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); @@ -301,17 +302,56 @@ public async Task DownloadSameDest(CancellationToken ct) var dlTask0 = await startTask0; await dlTask0.Task; Assert.That(dlTask0.TotalBytes, Is.EqualTo(5)); - Assert.That(dlTask0.BytesRead, Is.EqualTo(5)); + Assert.That(dlTask0.BytesWritten, Is.EqualTo(5)); Assert.That(dlTask0.Progress, Is.EqualTo(1)); Assert.That(dlTask0.IsCompleted, Is.True); var dlTask1 = await startTask1; await dlTask1.Task; Assert.That(dlTask1.TotalBytes, Is.EqualTo(5)); - Assert.That(dlTask1.BytesRead, Is.EqualTo(5)); + Assert.That(dlTask1.BytesWritten, Is.EqualTo(5)); Assert.That(dlTask1.Progress, Is.EqualTo(1)); Assert.That(dlTask1.IsCompleted, Is.True); } + [Test(Description = "Download with X-Original-Content-Length")] + [CancelAfter(30_000)] + public async Task DownloadWithXOriginalContentLength(CancellationToken ct) + { + using var httpServer = new TestHttpServer(async ctx => + { + ctx.Response.StatusCode = 200; + ctx.Response.Headers.Add("X-Original-Content-Length", "6"); // wrong but should be used until complete + ctx.Response.ContentType = "text/plain"; + ctx.Response.ContentLength64 = 4; // This should be ignored. + await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct); + }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + var manager = new Downloader(NullLogger.Instance); + var req = new HttpRequestMessage(HttpMethod.Get, url); + var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct); + + var progressChannel = Channel.CreateUnbounded(); + dlTask.ProgressChanged += (_, args) => + Assert.That(progressChannel.Writer.TryWrite(args), Is.True); + + await dlTask.Task; + Assert.That(dlTask.TotalBytes, Is.EqualTo(4)); // should equal BytesWritten after completion + Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); + progressChannel.Writer.Complete(); + + var list = progressChannel.Reader.ReadAllAsync(ct).ToBlockingEnumerable(ct).ToList(); + Assert.That(list.Count, Is.GreaterThanOrEqualTo(2)); // there may be an item in the middle + // The first item should be the initial progress with 0 bytes written. + Assert.That(list[0].BytesWritten, Is.EqualTo(0)); + Assert.That(list[0].TotalBytes, Is.EqualTo(6)); // from X-Original-Content-Length + Assert.That(list[0].Progress, Is.EqualTo(0.0d)); + // The last item should be final progress with the actual total bytes. + Assert.That(list[^1].BytesWritten, Is.EqualTo(4)); + Assert.That(list[^1].TotalBytes, Is.EqualTo(4)); // from the actual bytes written + Assert.That(list[^1].Progress, Is.EqualTo(1.0d)); + } + [Test(Description = "Download with custom headers")] [CancelAfter(30_000)] public async Task WithHeaders(CancellationToken ct) @@ -347,7 +387,7 @@ public async Task DownloadExisting(CancellationToken ct) var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, NullDownloadValidator.Instance, ct); await dlTask.Task; - Assert.That(dlTask.BytesRead, Is.Zero); + Assert.That(dlTask.BytesWritten, Is.Zero); Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); Assert.That(File.GetLastWriteTime(destPath), Is.LessThan(DateTime.Now - TimeSpan.FromDays(1))); } @@ -368,7 +408,7 @@ public async Task DownloadExistingDifferentContent(CancellationToken ct) var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, NullDownloadValidator.Instance, ct); await dlTask.Task; - Assert.That(dlTask.BytesRead, Is.EqualTo(4)); + Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); Assert.That(File.GetLastWriteTime(destPath), Is.GreaterThan(DateTime.Now - TimeSpan.FromDays(1))); } diff --git a/Vpn.Proto/vpn.proto b/Vpn.Proto/vpn.proto index 2561a4b..fa2f003 100644 --- a/Vpn.Proto/vpn.proto +++ b/Vpn.Proto/vpn.proto @@ -60,7 +60,8 @@ message ServiceMessage { oneof msg { StartResponse start = 2; StopResponse stop = 3; - Status status = 4; // either in reply to a StatusRequest or broadcasted + Status status = 4; // either in reply to a StatusRequest or broadcasted + StartProgress start_progress = 5; // broadcasted during startup } } @@ -218,6 +219,19 @@ message StartResponse { string error_message = 2; } +// StartProgress is sent from the manager to the client to indicate the +// download/startup progress of the tunnel. This will be sent during the +// processing of a StartRequest before the StartResponse is sent. +// +// Note: this is currently a broadcasted message to all clients due to the +// inability to easily send messages to a specific client in the Speaker +// implementation. If clients are not expecting these messages, they +// should ignore them. +message StartProgress { + double progress = 1; // 0.0 to 1.0 + string message = 2; // human-readable status message, must be set +} + // StopRequest is a request from the manager to stop the tunnel. The tunnel replies with a // StopResponse. message StopRequest {} diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index 6a3108b..856a637 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -338,32 +338,81 @@ internal static async Task TaskOrCancellation(Task task, CancellationToken cance } } +public class DownloadProgressEvent +{ + // TODO: speed calculation would be nice + public ulong BytesWritten { get; init; } + public ulong? TotalBytes { get; init; } // null if unknown + public double? Progress { get; init; } // 0.0 - 1.0, null if unknown + + public override string ToString() + { + var s = FriendlyBytes(BytesWritten); + if (TotalBytes != null) + s += $" of {FriendlyBytes(TotalBytes.Value)}"; + else + s += " of unknown"; + if (Progress != null) + s += $" ({Progress:0%})"; + return s; + } + + private static readonly string[] ByteSuffixes = ["B", "KB", "MB", "GB", "TB", "PB", "EB"]; + + // Unfortunately this is copied from FriendlyByteConverter in App. Ideally + // it should go into some shared utilities project, but it's overkill to do + // that for a single tiny function until we have more shared code. + private static string FriendlyBytes(ulong bytes) + { + if (bytes == 0) + return $"0 {ByteSuffixes[0]}"; + + var place = Convert.ToInt32(Math.Floor(Math.Log(bytes, 1024))); + var num = Math.Round(bytes / Math.Pow(1024, place), 1); + return $"{num} {ByteSuffixes[place]}"; + } +} + /// -/// Downloads an Url to a file on disk. The download will be written to a temporary file first, then moved to the final +/// Downloads a Url to a file on disk. The download will be written to a temporary file first, then moved to the final /// destination. The SHA1 of any existing file will be calculated and used as an ETag to avoid downloading the file if /// it hasn't changed. /// public class DownloadTask { private const int BufferSize = 4096; + private const int ProgressUpdateDelayMs = 50; + private const string XOriginalContentLengthHeader = "X-Original-Content-Length"; // overrides Content-Length if available - private static readonly HttpClient HttpClient = new(); + private static readonly HttpClient HttpClient = new(new HttpClientHandler + { + AutomaticDecompression = DecompressionMethods.All, + }); private readonly string _destinationDirectory; private readonly ILogger _logger; private readonly RaiiSemaphoreSlim _semaphore = new(1, 1); private readonly IDownloadValidator _validator; - public readonly string DestinationPath; + private readonly string _destinationPath; + private readonly string _tempDestinationPath; + + // ProgressChanged events are always delayed by up to 50ms to avoid + // flooding. + // + // This will be called: + // - once after the request succeeds but before the read/write routine + // begins + // - occasionally while the file is being downloaded (at least 50ms apart) + // - once when the download is complete + public EventHandler? ProgressChanged; public readonly HttpRequestMessage Request; - public readonly string TempDestinationPath; - public ulong? TotalBytes { get; private set; } - public ulong BytesRead { get; private set; } public Task Task { get; private set; } = null!; // Set in EnsureStartedAsync - - public double? Progress => TotalBytes == null ? null : (double)BytesRead / TotalBytes.Value; + public ulong BytesWritten { get; private set; } + public ulong? TotalBytes { get; private set; } + public double? Progress => TotalBytes == null ? null : (double)BytesWritten / TotalBytes.Value; public bool IsCompleted => Task.IsCompleted; internal DownloadTask(ILogger logger, HttpRequestMessage req, string destinationPath, IDownloadValidator validator) @@ -374,17 +423,17 @@ internal DownloadTask(ILogger logger, HttpRequestMessage req, string destination if (string.IsNullOrWhiteSpace(destinationPath)) throw new ArgumentException("Destination path must not be empty", nameof(destinationPath)); - DestinationPath = Path.GetFullPath(destinationPath); - if (Path.EndsInDirectorySeparator(DestinationPath)) - throw new ArgumentException($"Destination path '{DestinationPath}' must not end in a directory separator", + _destinationPath = Path.GetFullPath(destinationPath); + if (Path.EndsInDirectorySeparator(_destinationPath)) + throw new ArgumentException($"Destination path '{_destinationPath}' must not end in a directory separator", nameof(destinationPath)); - _destinationDirectory = Path.GetDirectoryName(DestinationPath) + _destinationDirectory = Path.GetDirectoryName(_destinationPath) ?? throw new ArgumentException( - $"Destination path '{DestinationPath}' must have a parent directory", + $"Destination path '{_destinationPath}' must have a parent directory", nameof(destinationPath)); - TempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(DestinationPath) + + _tempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(_destinationPath) + ".download-" + Path.GetRandomFileName()); } @@ -406,9 +455,9 @@ private async Task Start(CancellationToken ct = default) // If the destination path exists, generate a Coder SHA1 ETag and send // it in the If-None-Match header to the server. - if (File.Exists(DestinationPath)) + if (File.Exists(_destinationPath)) { - await using var stream = File.OpenRead(DestinationPath); + await using var stream = File.OpenRead(_destinationPath); var etag = Convert.ToHexString(await SHA1.HashDataAsync(stream, ct)).ToLower(); Request.Headers.Add("If-None-Match", "\"" + etag + "\""); } @@ -419,11 +468,11 @@ private async Task Start(CancellationToken ct = default) _logger.LogInformation("File has not been modified, skipping download"); try { - await _validator.ValidateAsync(DestinationPath, ct); + await _validator.ValidateAsync(_destinationPath, ct); } catch (Exception e) { - _logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", DestinationPath); + _logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", _destinationPath); throw new Exception("Existing file failed validation after 304 Not Modified", e); } @@ -448,6 +497,26 @@ private async Task Start(CancellationToken ct = default) if (res.Content.Headers.ContentLength >= 0) TotalBytes = (ulong)res.Content.Headers.ContentLength; + // X-Original-Content-Length overrules Content-Length if set. + if (res.Headers.TryGetValues(XOriginalContentLengthHeader, out var headerValues)) + { + // If there are multiple we only look at the first one. + var headerValue = headerValues.ToList().FirstOrDefault(); + if (!string.IsNullOrEmpty(headerValue) && ulong.TryParse(headerValue, out var originalContentLength)) + TotalBytes = originalContentLength; + else + _logger.LogWarning( + "Failed to parse {XOriginalContentLengthHeader} header value '{HeaderValue}'", + XOriginalContentLengthHeader, headerValue); + } + + SendProgressUpdate(new DownloadProgressEvent + { + BytesWritten = 0, + TotalBytes = TotalBytes, + Progress = 0.0, + }); + await Download(res, ct); } @@ -459,11 +528,11 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct) FileStream tempFile; try { - tempFile = File.Create(TempDestinationPath, BufferSize, FileOptions.SequentialScan); + tempFile = File.Create(_tempDestinationPath, BufferSize, FileOptions.SequentialScan); } catch (Exception e) { - _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", TempDestinationPath); + _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", _tempDestinationPath); throw; } @@ -476,13 +545,31 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct) { await tempFile.WriteAsync(buffer.AsMemory(0, n), ct); sha1?.TransformBlock(buffer, 0, n, null, 0); - BytesRead += (ulong)n; + BytesWritten += (ulong)n; + await QueueProgressUpdate(new DownloadProgressEvent + { + BytesWritten = BytesWritten, + TotalBytes = TotalBytes, + Progress = Progress, + }, ct); } } - if (TotalBytes != null && BytesRead != TotalBytes) + // Clear any pending progress updates to ensure they won't be sent + // after the final update. + await ClearQueuedProgressUpdate(ct); + // Then write the final status update. + TotalBytes = BytesWritten; + SendProgressUpdate(new DownloadProgressEvent + { + BytesWritten = BytesWritten, + TotalBytes = BytesWritten, + Progress = 1.0, + }); + + if (TotalBytes != null && BytesWritten != TotalBytes) throw new IOException( - $"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesRead}"); + $"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesWritten}"); // Verify the ETag if it was sent by the server. if (res.Headers.Contains("ETag") && sha1 != null) @@ -497,26 +584,99 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct) try { - await _validator.ValidateAsync(TempDestinationPath, ct); + await _validator.ValidateAsync(_tempDestinationPath, ct); } catch (Exception e) { _logger.LogWarning(e, "Downloaded file '{TempDestinationPath}' failed custom validation", - TempDestinationPath); + _tempDestinationPath); throw new HttpRequestException("Downloaded file failed validation", e); } - File.Move(TempDestinationPath, DestinationPath, true); + File.Move(_tempDestinationPath, _destinationPath, true); } - finally + catch { #if DEBUG _logger.LogWarning("Not deleting temporary file '{TempDestinationPath}' in debug mode", - TempDestinationPath); + _tempDestinationPath); #else - if (File.Exists(TempDestinationPath)) - File.Delete(TempDestinationPath); + try + { + if (File.Exists(TempDestinationPath)) + File.Delete(TempDestinationPath); + } + catch (Exception e) + { + _logger.LogError(e, "Failed to delete temporary file '{TempDestinationPath}'", _tempDestinationPath); + } #endif + throw; } } + + // _progressEventLock protects _progressUpdateTask and _pendingProgressEvent. + private readonly RaiiSemaphoreSlim _progressEventLock = new(1, 1); + private readonly CancellationTokenSource _progressUpdateCts = new(); + private Task? _progressUpdateTask; + private DownloadProgressEvent? _pendingProgressEvent; + + // Can be called multiple times, but must not be called or in progress while + // SendQueuedProgressUpdateNow is called. + private async Task QueueProgressUpdate(DownloadProgressEvent e, CancellationToken ct) + { + using var _1 = await _progressEventLock.LockAsync(ct); + _pendingProgressEvent = e; + + if (_progressUpdateCts.IsCancellationRequested) + throw new InvalidOperationException("Progress update task was cancelled, cannot queue new progress update"); + + // Start a task with a 50ms delay unless one is already running. + var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _progressUpdateCts.Token); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + _progressUpdateTask ??= Task.Delay(ProgressUpdateDelayMs, cts.Token) + .ContinueWith(t => + { + cts.Cancel(); + using var _2 = _progressEventLock.Lock(); + _progressUpdateTask = null; + if (t.IsFaulted || t.IsCanceled) return; + + var ev = _pendingProgressEvent; + if (ev != null) SendProgressUpdate(ev); + }, cts.Token); + } + + // Must only be called after all QueueProgressUpdate calls have completed. + private async Task ClearQueuedProgressUpdate(CancellationToken ct) + { + Task? t; + using (var _ = _progressEventLock.LockAsync(ct)) + { + await _progressUpdateCts.CancelAsync(); + t = _progressUpdateTask; + } + + // We can't continue to hold the lock here because the continuation + // grabs a lock. We don't need to worry about a new task spawning after + // this because the token is cancelled. + if (t == null) return; + try + { + await t.WaitAsync(ct); + } + catch (TaskCanceledException) + { + // Ignore + } + } + + private void SendProgressUpdate(DownloadProgressEvent e) + { + var handler = ProgressChanged; + if (handler == null) + return; + // Start a new task in the background to invoke the event. + _ = Task.Run(() => handler.Invoke(this, e)); + } } diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs index fc014c0..cf2bb8a 100644 --- a/Vpn.Service/Manager.cs +++ b/Vpn.Service/Manager.cs @@ -26,6 +26,10 @@ public interface IManager : IDisposable /// public class Manager : IManager { + // We scale the download progress to 0.00-0.90, and use 0.90-1.00 for the + // remainder of startup. + private const double DownloadProgressScale = 0.90; + private readonly ManagerConfig _config; private readonly IDownloader _downloader; private readonly ILogger _logger; @@ -131,6 +135,8 @@ private async ValueTask HandleClientMessageStart(ClientMessage me { try { + await BroadcastStartProgress(0.0, "Starting Coder Connect...", ct); + var serverVersion = await CheckServerVersionAndCredentials(message.Start.CoderUrl, message.Start.ApiToken, ct); if (_status == TunnelStatus.Started && _lastStartRequest != null && @@ -151,10 +157,14 @@ private async ValueTask HandleClientMessageStart(ClientMessage me _lastServerVersion = serverVersion; // TODO: each section of this operation needs a timeout + // Stop the tunnel if it's running so we don't have to worry about // permissions issues when replacing the binary. await _tunnelSupervisor.StopAsync(ct); + await DownloadTunnelBinaryAsync(message.Start.CoderUrl, serverVersion.SemVersion, ct); + + await BroadcastStartProgress(DownloadProgressScale, "Starting Coder Connect...", ct); await _tunnelSupervisor.StartAsync(_config.TunnelBinaryPath, HandleTunnelRpcMessage, HandleTunnelRpcError, ct); @@ -237,6 +247,9 @@ private void HandleTunnelRpcMessage(ReplyableRpcMessage CurrentStatus(CancellationToken ct = default) private async Task BroadcastStatus(TunnelStatus? newStatus = null, CancellationToken ct = default) { if (newStatus != null) _status = newStatus.Value; - await _managerRpc.BroadcastAsync(new ServiceMessage + await FallibleBroadcast(new ServiceMessage { Status = await CurrentStatus(ct), }, ct); } + private async Task FallibleBroadcast(ServiceMessage message, CancellationToken ct = default) + { + // Broadcast the messages out with a low timeout. If clients don't + // receive broadcasts in time, it's not a big deal. + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); + //cts.CancelAfter(TimeSpan.FromMilliseconds(100)); + try + { + await _managerRpc.BroadcastAsync(message, cts.Token); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Could not broadcast low priority message to all RPC clients: {Message}", message); + } + } + private void HandleTunnelRpcError(Exception e) { _logger.LogError(e, "Manager<->Tunnel RPC error"); @@ -427,10 +456,44 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected var downloadTask = await _downloader.StartDownloadAsync(req, _config.TunnelBinaryPath, validators, ct); - // TODO: monitor and report progress when we have a mechanism to do so + var progressLock = new RaiiSemaphoreSlim(1, 1); + var progressBroadcastCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + downloadTask.ProgressChanged += (sender, ev) => + { + using var _ = progressLock.Lock(); + if (progressBroadcastCts.IsCancellationRequested) return; + _logger.LogInformation("Download progress: {ev}", ev); + + // Scale the progress value to be between 0.00 and 0.90. + var progress = ev.Progress * DownloadProgressScale ?? 0.0; + var message = $"Downloading Coder Connect binary...\n{ev}"; + BroadcastStartProgress(progress, message, progressBroadcastCts.Token).Wait(progressBroadcastCts.Token); + }; // Awaiting this will check the checksum (via the ETag) if the file // exists, and will also validate the signature and version. await downloadTask.Task; + + // Prevent any lagging progress events from being sent. + // ReSharper disable once PossiblyMistakenUseOfCancellationToken + using (await progressLock.LockAsync(ct)) + await progressBroadcastCts.CancelAsync(); + + // We don't send a broadcast here as we immediately send one in the + // parent routine. + _logger.LogInformation("Completed downloading VPN binary"); + } + + private async Task BroadcastStartProgress(double progress, string message, CancellationToken ct = default) + { + _logger.LogInformation("Start progress: {Progress:0%} - {Message}", progress, message); + await FallibleBroadcast(new ServiceMessage + { + StartProgress = new StartProgress + { + Progress = progress, + Message = message, + }, + }, ct); } } diff --git a/Vpn.Service/ManagerRpc.cs b/Vpn.Service/ManagerRpc.cs index c23752f..d922caf 100644 --- a/Vpn.Service/ManagerRpc.cs +++ b/Vpn.Service/ManagerRpc.cs @@ -133,7 +133,7 @@ public async Task BroadcastAsync(ServiceMessage message, CancellationToken ct) try { var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); - cts.CancelAfter(5 * 1000); + cts.CancelAfter(TimeSpan.FromSeconds(2)); await client.Speaker.SendMessage(message, cts.Token); } catch (ObjectDisposedException) diff --git a/Vpn.Service/Program.cs b/Vpn.Service/Program.cs index fc61247..094875d 100644 --- a/Vpn.Service/Program.cs +++ b/Vpn.Service/Program.cs @@ -16,10 +16,12 @@ public static class Program #if !DEBUG private const string ServiceName = "Coder Desktop"; private const string ConfigSubKey = @"SOFTWARE\Coder Desktop\VpnService"; + private const string DefaultLogLevel = "Information"; #else // This value matches Create-Service.ps1. private const string ServiceName = "Coder Desktop (Debug)"; private const string ConfigSubKey = @"SOFTWARE\Coder Desktop\DebugVpnService"; + private const string DefaultLogLevel = "Debug"; #endif private const string ManagerConfigSection = "Manager"; @@ -81,6 +83,10 @@ private static async Task BuildAndRun(string[] args) builder.Services.AddSingleton(); // Services + builder.Services.AddHostedService(); + builder.Services.AddHostedService(); + + // Either run as a Windows service or a console application if (!Environment.UserInteractive) { MainLogger.Information("Running as a windows service"); @@ -91,9 +97,6 @@ private static async Task BuildAndRun(string[] args) MainLogger.Information("Running as a console application"); } - builder.Services.AddHostedService(); - builder.Services.AddHostedService(); - var host = builder.Build(); Log.Logger = (ILogger)host.Services.GetService(typeof(ILogger))!; MainLogger.Information("Application is starting"); @@ -108,7 +111,7 @@ private static void AddDefaultConfig(IConfigurationBuilder builder) ["Serilog:Using:0"] = "Serilog.Sinks.File", ["Serilog:Using:1"] = "Serilog.Sinks.Console", - ["Serilog:MinimumLevel"] = "Information", + ["Serilog:MinimumLevel"] = DefaultLogLevel, ["Serilog:Enrich:0"] = "FromLogContext", ["Serilog:WriteTo:0:Name"] = "File", diff --git a/Vpn.Service/TunnelSupervisor.cs b/Vpn.Service/TunnelSupervisor.cs index a323cac..6ff4f3b 100644 --- a/Vpn.Service/TunnelSupervisor.cs +++ b/Vpn.Service/TunnelSupervisor.cs @@ -100,17 +100,15 @@ public async Task StartAsync(string binPath, }; // TODO: maybe we should change the log format in the inner binary // to something without a timestamp - var outLogger = Log.ForContext("SourceContext", "coder-vpn.exe[OUT]"); - var errLogger = Log.ForContext("SourceContext", "coder-vpn.exe[ERR]"); _subprocess.OutputDataReceived += (_, args) => { if (!string.IsNullOrWhiteSpace(args.Data)) - outLogger.Debug("{Data}", args.Data); + _logger.LogDebug("stdout: {Data}", args.Data); }; _subprocess.ErrorDataReceived += (_, args) => { if (!string.IsNullOrWhiteSpace(args.Data)) - errLogger.Debug("{Data}", args.Data); + _logger.LogDebug("stderr: {Data}", args.Data); }; // Pass the other end of the pipes to the subprocess and dispose diff --git a/Vpn/Speaker.cs b/Vpn/Speaker.cs index d113a50..37ec554 100644 --- a/Vpn/Speaker.cs +++ b/Vpn/Speaker.cs @@ -123,7 +123,7 @@ public async Task StartAsync(CancellationToken ct = default) // Handshakes should always finish quickly, so enforce a 5s timeout. using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); cts.CancelAfter(TimeSpan.FromSeconds(5)); - await PerformHandshake(ct); + await PerformHandshake(cts.Token); // Start ReceiveLoop in the background. _receiveTask = ReceiveLoop(_cts.Token); From cd966bea1eb2a3b3cc07834e2384423aa597397a Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Mon, 2 Jun 2025 15:09:48 +1000 Subject: [PATCH 2/5] fixup! feat: add vpn start progress --- Vpn.Service/Downloader.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index 856a637..4e7e5b2 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -603,8 +603,8 @@ await QueueProgressUpdate(new DownloadProgressEvent #else try { - if (File.Exists(TempDestinationPath)) - File.Delete(TempDestinationPath); + if (File.Exists(_tempDestinationPath)) + File.Delete(_tempDestinationPath); } catch (Exception e) { From fb46593e59264f1ada25239d18d78c45eade04fd Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 6 Jun 2025 15:22:19 +1000 Subject: [PATCH 3/5] change to enums --- App/Models/RpcModel.cs | 135 +++++++++++++++++- .../PublishProfiles/win-arm64.pubxml | 12 ++ App/Properties/PublishProfiles/win-x64.pubxml | 12 ++ App/Properties/PublishProfiles/win-x86.pubxml | 12 ++ App/Services/RpcController.cs | 9 +- App/ViewModels/TrayWindowViewModel.cs | 5 +- Tests.Vpn.Service/DownloaderTest.cs | 4 +- Vpn.Proto/vpn.proto | 13 +- Vpn.Service/Downloader.cs | 18 ++- Vpn.Service/Manager.cs | 30 ++-- 10 files changed, 207 insertions(+), 43 deletions(-) create mode 100644 App/Properties/PublishProfiles/win-arm64.pubxml create mode 100644 App/Properties/PublishProfiles/win-x64.pubxml create mode 100644 App/Properties/PublishProfiles/win-x86.pubxml diff --git a/App/Models/RpcModel.cs b/App/Models/RpcModel.cs index 33b4647..426863b 100644 --- a/App/Models/RpcModel.cs +++ b/App/Models/RpcModel.cs @@ -1,4 +1,7 @@ +using System; using System.Collections.Generic; +using System.Diagnostics; +using Coder.Desktop.App.Converters; using Coder.Desktop.Vpn.Proto; namespace Coder.Desktop.App.Models; @@ -19,17 +22,141 @@ public enum VpnLifecycle Stopping, } +public enum VpnStartupStage +{ + Unknown, + Initializing, + Downloading, + Finalizing, +} + +public class VpnDownloadProgress +{ + public ulong BytesWritten { get; set; } = 0; + public ulong? BytesTotal { get; set; } = null; // null means unknown total size + + public double Progress + { + get + { + if (BytesTotal is > 0) + { + return (double)BytesWritten / BytesTotal.Value; + } + return 0.0; + } + } + + public override string ToString() + { + // TODO: it would be nice if the two suffixes could match + var s = FriendlyByteConverter.FriendlyBytes(BytesWritten); + if (BytesTotal != null) + s += $" of {FriendlyByteConverter.FriendlyBytes(BytesTotal.Value)}"; + else + s += " of unknown"; + if (BytesTotal != null) + s += $" ({Progress:0%})"; + return s; + } + + public VpnDownloadProgress Clone() + { + return new VpnDownloadProgress + { + BytesWritten = BytesWritten, + BytesTotal = BytesTotal, + }; + } + + public static VpnDownloadProgress FromProto(StartProgressDownloadProgress proto) + { + return new VpnDownloadProgress + { + BytesWritten = proto.BytesWritten, + BytesTotal = proto.HasBytesTotal ? proto.BytesTotal : null, + }; + } +} + public class VpnStartupProgress { - public double Progress { get; set; } = 0.0; // 0.0 to 1.0 - public string Message { get; set; } = string.Empty; + public const string DefaultStartProgressMessage = "Starting Coder Connect..."; + + // Scale the download progress to an overall progress value between these + // numbers. + private const double DownloadProgressMin = 0.05; + private const double DownloadProgressMax = 0.80; + + public VpnStartupStage Stage { get; set; } = VpnStartupStage.Unknown; + public VpnDownloadProgress? DownloadProgress { get; set; } = null; + + // 0.0 to 1.0 + public double Progress + { + get + { + switch (Stage) + { + case VpnStartupStage.Unknown: + case VpnStartupStage.Initializing: + return 0.0; + case VpnStartupStage.Downloading: + var progress = DownloadProgress?.Progress ?? 0.0; + return DownloadProgressMin + (DownloadProgressMax - DownloadProgressMin) * progress; + case VpnStartupStage.Finalizing: + return DownloadProgressMax; + default: + throw new ArgumentOutOfRangeException(); + } + } + } + + public override string ToString() + { + switch (Stage) + { + case VpnStartupStage.Unknown: + case VpnStartupStage.Initializing: + return DefaultStartProgressMessage; + case VpnStartupStage.Downloading: + var s = "Downloading Coder Connect binary..."; + if (DownloadProgress is not null) + { + s += "\n" + DownloadProgress; + } + + return s; + case VpnStartupStage.Finalizing: + return "Finalizing Coder Connect startup..."; + default: + throw new ArgumentOutOfRangeException(); + } + } public VpnStartupProgress Clone() { return new VpnStartupProgress { - Progress = Progress, - Message = Message, + Stage = Stage, + DownloadProgress = DownloadProgress?.Clone(), + }; + } + + public static VpnStartupProgress FromProto(StartProgress proto) + { + return new VpnStartupProgress + { + Stage = proto.Stage switch + { + StartProgressStage.Initializing => VpnStartupStage.Initializing, + StartProgressStage.Downloading => VpnStartupStage.Downloading, + StartProgressStage.Finalizing => VpnStartupStage.Finalizing, + _ => VpnStartupStage.Unknown, + }, + DownloadProgress = proto.Stage is StartProgressStage.Downloading ? + VpnDownloadProgress.FromProto(proto.DownloadProgress) : + null, }; } } diff --git a/App/Properties/PublishProfiles/win-arm64.pubxml b/App/Properties/PublishProfiles/win-arm64.pubxml new file mode 100644 index 0000000..ac9753e --- /dev/null +++ b/App/Properties/PublishProfiles/win-arm64.pubxml @@ -0,0 +1,12 @@ + + + + + FileSystem + ARM64 + win-arm64 + bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\ + + diff --git a/App/Properties/PublishProfiles/win-x64.pubxml b/App/Properties/PublishProfiles/win-x64.pubxml new file mode 100644 index 0000000..942523b --- /dev/null +++ b/App/Properties/PublishProfiles/win-x64.pubxml @@ -0,0 +1,12 @@ + + + + + FileSystem + x64 + win-x64 + bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\ + + diff --git a/App/Properties/PublishProfiles/win-x86.pubxml b/App/Properties/PublishProfiles/win-x86.pubxml new file mode 100644 index 0000000..e763481 --- /dev/null +++ b/App/Properties/PublishProfiles/win-x86.pubxml @@ -0,0 +1,12 @@ + + + + + FileSystem + x86 + win-x86 + bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\ + + diff --git a/App/Services/RpcController.cs b/App/Services/RpcController.cs index b42c058..3345050 100644 --- a/App/Services/RpcController.cs +++ b/App/Services/RpcController.cs @@ -164,8 +164,7 @@ public async Task StartVpn(CancellationToken ct = default) MutateState(state => { state.VpnLifecycle = VpnLifecycle.Starting; - // Explicitly clear the startup progress. - state.VpnStartupProgress = null; + state.VpnStartupProgress = new VpnStartupProgress(); }); ServiceMessage reply; @@ -297,11 +296,7 @@ private void ApplyStartProgressUpdate(StartProgress message) { // MutateState will undo these changes if it doesn't believe we're // in the "Starting" state. - state.VpnStartupProgress = new VpnStartupProgress - { - Progress = message.Progress, - Message = message.Message, - }; + state.VpnStartupProgress = VpnStartupProgress.FromProto(message); }); } diff --git a/App/ViewModels/TrayWindowViewModel.cs b/App/ViewModels/TrayWindowViewModel.cs index cd3a641..820ff12 100644 --- a/App/ViewModels/TrayWindowViewModel.cs +++ b/App/ViewModels/TrayWindowViewModel.cs @@ -29,7 +29,6 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost { private const int MaxAgents = 5; private const string DefaultDashboardUrl = "https://coder.com"; - private const string DefaultStartProgressMessage = "Starting Coder Connect..."; private readonly IServiceProvider _services; private readonly IRpcController _rpcController; @@ -84,7 +83,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost public partial string? VpnStartProgressMessage { get; set; } = null; public string VpnStartProgressMessageOrDefault => - string.IsNullOrEmpty(VpnStartProgressMessage) ? DefaultStartProgressMessage : VpnStartProgressMessage; + string.IsNullOrEmpty(VpnStartProgressMessage) ? VpnStartupProgress.DefaultStartProgressMessage : VpnStartProgressMessage; public bool VpnStartProgressIsIndeterminate => VpnStartProgressValueOrDefault == 0; @@ -196,7 +195,7 @@ private void UpdateFromRpcModel(RpcModel rpcModel) // Convert 0.00-1.00 to 0-100. var progress = (int)(rpcModel.VpnStartupProgress.Progress * 100); VpnStartProgressValue = Math.Clamp(progress, 0, 100); - VpnStartProgressMessage = string.IsNullOrEmpty(rpcModel.VpnStartupProgress.Message) ? null : rpcModel.VpnStartupProgress.Message; + VpnStartProgressMessage = rpcModel.VpnStartupProgress.ToString(); } else { diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs index a47ffbc..b33f510 100644 --- a/Tests.Vpn.Service/DownloaderTest.cs +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -344,11 +344,11 @@ public async Task DownloadWithXOriginalContentLength(CancellationToken ct) Assert.That(list.Count, Is.GreaterThanOrEqualTo(2)); // there may be an item in the middle // The first item should be the initial progress with 0 bytes written. Assert.That(list[0].BytesWritten, Is.EqualTo(0)); - Assert.That(list[0].TotalBytes, Is.EqualTo(6)); // from X-Original-Content-Length + Assert.That(list[0].BytesTotal, Is.EqualTo(6)); // from X-Original-Content-Length Assert.That(list[0].Progress, Is.EqualTo(0.0d)); // The last item should be final progress with the actual total bytes. Assert.That(list[^1].BytesWritten, Is.EqualTo(4)); - Assert.That(list[^1].TotalBytes, Is.EqualTo(4)); // from the actual bytes written + Assert.That(list[^1].BytesTotal, Is.EqualTo(4)); // from the actual bytes written Assert.That(list[^1].Progress, Is.EqualTo(1.0d)); } diff --git a/Vpn.Proto/vpn.proto b/Vpn.Proto/vpn.proto index fa2f003..bace7e0 100644 --- a/Vpn.Proto/vpn.proto +++ b/Vpn.Proto/vpn.proto @@ -227,9 +227,18 @@ message StartResponse { // inability to easily send messages to a specific client in the Speaker // implementation. If clients are not expecting these messages, they // should ignore them. +enum StartProgressStage { + Initializing = 0; + Downloading = 1; + Finalizing = 2; +} +message StartProgressDownloadProgress { + uint64 bytes_written = 1; + optional uint64 bytes_total = 2; // unknown in some situations +} message StartProgress { - double progress = 1; // 0.0 to 1.0 - string message = 2; // human-readable status message, must be set + StartProgressStage stage = 1; + optional StartProgressDownloadProgress download_progress = 2; // only set when stage == Downloading } // StopRequest is a request from the manager to stop the tunnel. The tunnel replies with a diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index 4e7e5b2..a665ec4 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -342,14 +342,15 @@ public class DownloadProgressEvent { // TODO: speed calculation would be nice public ulong BytesWritten { get; init; } - public ulong? TotalBytes { get; init; } // null if unknown - public double? Progress { get; init; } // 0.0 - 1.0, null if unknown + public ulong? BytesTotal { get; init; } // null if unknown + + public double? Progress => BytesTotal == null ? null : (double)BytesWritten / BytesTotal.Value; public override string ToString() { var s = FriendlyBytes(BytesWritten); - if (TotalBytes != null) - s += $" of {FriendlyBytes(TotalBytes.Value)}"; + if (BytesTotal != null) + s += $" of {FriendlyBytes(BytesTotal.Value)}"; else s += " of unknown"; if (Progress != null) @@ -513,8 +514,7 @@ private async Task Start(CancellationToken ct = default) SendProgressUpdate(new DownloadProgressEvent { BytesWritten = 0, - TotalBytes = TotalBytes, - Progress = 0.0, + BytesTotal = TotalBytes, }); await Download(res, ct); @@ -549,8 +549,7 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct) await QueueProgressUpdate(new DownloadProgressEvent { BytesWritten = BytesWritten, - TotalBytes = TotalBytes, - Progress = Progress, + BytesTotal = TotalBytes, }, ct); } } @@ -563,8 +562,7 @@ await QueueProgressUpdate(new DownloadProgressEvent SendProgressUpdate(new DownloadProgressEvent { BytesWritten = BytesWritten, - TotalBytes = BytesWritten, - Progress = 1.0, + BytesTotal = BytesWritten, }); if (TotalBytes != null && BytesWritten != TotalBytes) diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs index cf2bb8a..0324ebb 100644 --- a/Vpn.Service/Manager.cs +++ b/Vpn.Service/Manager.cs @@ -26,10 +26,6 @@ public interface IManager : IDisposable /// public class Manager : IManager { - // We scale the download progress to 0.00-0.90, and use 0.90-1.00 for the - // remainder of startup. - private const double DownloadProgressScale = 0.90; - private readonly ManagerConfig _config; private readonly IDownloader _downloader; private readonly ILogger _logger; @@ -135,7 +131,7 @@ private async ValueTask HandleClientMessageStart(ClientMessage me { try { - await BroadcastStartProgress(0.0, "Starting Coder Connect...", ct); + await BroadcastStartProgress(StartProgressStage.Initializing, cancellationToken: ct); var serverVersion = await CheckServerVersionAndCredentials(message.Start.CoderUrl, message.Start.ApiToken, ct); @@ -164,7 +160,7 @@ private async ValueTask HandleClientMessageStart(ClientMessage me await DownloadTunnelBinaryAsync(message.Start.CoderUrl, serverVersion.SemVersion, ct); - await BroadcastStartProgress(DownloadProgressScale, "Starting Coder Connect...", ct); + await BroadcastStartProgress(StartProgressStage.Finalizing, cancellationToken: ct); await _tunnelSupervisor.StartAsync(_config.TunnelBinaryPath, HandleTunnelRpcMessage, HandleTunnelRpcError, ct); @@ -464,10 +460,14 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected if (progressBroadcastCts.IsCancellationRequested) return; _logger.LogInformation("Download progress: {ev}", ev); - // Scale the progress value to be between 0.00 and 0.90. - var progress = ev.Progress * DownloadProgressScale ?? 0.0; - var message = $"Downloading Coder Connect binary...\n{ev}"; - BroadcastStartProgress(progress, message, progressBroadcastCts.Token).Wait(progressBroadcastCts.Token); + var progress = new StartProgressDownloadProgress + { + BytesWritten = ev.BytesWritten, + }; + if (ev.BytesTotal != null) + progress.BytesTotal = ev.BytesTotal.Value; + BroadcastStartProgress(StartProgressStage.Downloading, progress, progressBroadcastCts.Token) + .Wait(progressBroadcastCts.Token); }; // Awaiting this will check the checksum (via the ETag) if the file @@ -484,16 +484,16 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected _logger.LogInformation("Completed downloading VPN binary"); } - private async Task BroadcastStartProgress(double progress, string message, CancellationToken ct = default) + private async Task BroadcastStartProgress(StartProgressStage stage, StartProgressDownloadProgress? downloadProgress = null, CancellationToken cancellationToken = default) { - _logger.LogInformation("Start progress: {Progress:0%} - {Message}", progress, message); + _logger.LogInformation("Start progress: {stage}", stage); await FallibleBroadcast(new ServiceMessage { StartProgress = new StartProgress { - Progress = progress, - Message = message, + Stage = stage, + DownloadProgress = downloadProgress, }, - }, ct); + }, cancellationToken); } } From 473164dca3c7830f43473438d77d8c26aacaf476 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 6 Jun 2025 15:59:21 +1000 Subject: [PATCH 4/5] rework download progress --- Tests.Vpn.Service/DownloaderTest.cs | 50 ++++----- Vpn.Service/Downloader.cs | 151 ++-------------------------- Vpn.Service/Manager.cs | 51 ++++++---- 3 files changed, 68 insertions(+), 184 deletions(-) diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs index b33f510..bb9b39c 100644 --- a/Tests.Vpn.Service/DownloaderTest.cs +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -2,7 +2,6 @@ using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Text; -using System.Threading.Channels; using Coder.Desktop.Vpn.Service; using Microsoft.Extensions.Logging.Abstractions; @@ -278,7 +277,7 @@ public async Task Download(CancellationToken ct) var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, NullDownloadValidator.Instance, ct); await dlTask.Task; - Assert.That(dlTask.TotalBytes, Is.EqualTo(4)); + Assert.That(dlTask.BytesTotal, Is.EqualTo(4)); Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); Assert.That(dlTask.Progress, Is.EqualTo(1)); Assert.That(dlTask.IsCompleted, Is.True); @@ -301,13 +300,13 @@ public async Task DownloadSameDest(CancellationToken ct) NullDownloadValidator.Instance, ct); var dlTask0 = await startTask0; await dlTask0.Task; - Assert.That(dlTask0.TotalBytes, Is.EqualTo(5)); + Assert.That(dlTask0.BytesTotal, Is.EqualTo(5)); Assert.That(dlTask0.BytesWritten, Is.EqualTo(5)); Assert.That(dlTask0.Progress, Is.EqualTo(1)); Assert.That(dlTask0.IsCompleted, Is.True); var dlTask1 = await startTask1; await dlTask1.Task; - Assert.That(dlTask1.TotalBytes, Is.EqualTo(5)); + Assert.That(dlTask1.BytesTotal, Is.EqualTo(5)); Assert.That(dlTask1.BytesWritten, Is.EqualTo(5)); Assert.That(dlTask1.Progress, Is.EqualTo(1)); Assert.That(dlTask1.IsCompleted, Is.True); @@ -320,9 +319,9 @@ public async Task DownloadWithXOriginalContentLength(CancellationToken ct) using var httpServer = new TestHttpServer(async ctx => { ctx.Response.StatusCode = 200; - ctx.Response.Headers.Add("X-Original-Content-Length", "6"); // wrong but should be used until complete + ctx.Response.Headers.Add("X-Original-Content-Length", "4"); ctx.Response.ContentType = "text/plain"; - ctx.Response.ContentLength64 = 4; // This should be ignored. + // Don't set Content-Length. await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct); }); var url = new Uri(httpServer.BaseUrl + "/test"); @@ -331,25 +330,30 @@ public async Task DownloadWithXOriginalContentLength(CancellationToken ct) var req = new HttpRequestMessage(HttpMethod.Get, url); var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct); - var progressChannel = Channel.CreateUnbounded(); - dlTask.ProgressChanged += (_, args) => - Assert.That(progressChannel.Writer.TryWrite(args), Is.True); - await dlTask.Task; - Assert.That(dlTask.TotalBytes, Is.EqualTo(4)); // should equal BytesWritten after completion + Assert.That(dlTask.BytesTotal, Is.EqualTo(4)); Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); - progressChannel.Writer.Complete(); - - var list = progressChannel.Reader.ReadAllAsync(ct).ToBlockingEnumerable(ct).ToList(); - Assert.That(list.Count, Is.GreaterThanOrEqualTo(2)); // there may be an item in the middle - // The first item should be the initial progress with 0 bytes written. - Assert.That(list[0].BytesWritten, Is.EqualTo(0)); - Assert.That(list[0].BytesTotal, Is.EqualTo(6)); // from X-Original-Content-Length - Assert.That(list[0].Progress, Is.EqualTo(0.0d)); - // The last item should be final progress with the actual total bytes. - Assert.That(list[^1].BytesWritten, Is.EqualTo(4)); - Assert.That(list[^1].BytesTotal, Is.EqualTo(4)); // from the actual bytes written - Assert.That(list[^1].Progress, Is.EqualTo(1.0d)); + } + + [Test(Description = "Download with mismatched Content-Length")] + [CancelAfter(30_000)] + public async Task DownloadWithMismatchedContentLength(CancellationToken ct) + { + using var httpServer = new TestHttpServer(async ctx => + { + ctx.Response.StatusCode = 200; + ctx.Response.Headers.Add("X-Original-Content-Length", "5"); // incorrect + ctx.Response.ContentType = "text/plain"; + await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct); + }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + var manager = new Downloader(NullLogger.Instance); + var req = new HttpRequestMessage(HttpMethod.Get, url); + var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct); + + var ex = Assert.ThrowsAsync(() => dlTask.Task); + Assert.That(ex.Message, Is.EqualTo("Downloaded file size does not match expected response content length: Expected=5, BytesWritten=4")); } [Test(Description = "Download with custom headers")] diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index a665ec4..c4a916f 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -338,42 +338,6 @@ internal static async Task TaskOrCancellation(Task task, CancellationToken cance } } -public class DownloadProgressEvent -{ - // TODO: speed calculation would be nice - public ulong BytesWritten { get; init; } - public ulong? BytesTotal { get; init; } // null if unknown - - public double? Progress => BytesTotal == null ? null : (double)BytesWritten / BytesTotal.Value; - - public override string ToString() - { - var s = FriendlyBytes(BytesWritten); - if (BytesTotal != null) - s += $" of {FriendlyBytes(BytesTotal.Value)}"; - else - s += " of unknown"; - if (Progress != null) - s += $" ({Progress:0%})"; - return s; - } - - private static readonly string[] ByteSuffixes = ["B", "KB", "MB", "GB", "TB", "PB", "EB"]; - - // Unfortunately this is copied from FriendlyByteConverter in App. Ideally - // it should go into some shared utilities project, but it's overkill to do - // that for a single tiny function until we have more shared code. - private static string FriendlyBytes(ulong bytes) - { - if (bytes == 0) - return $"0 {ByteSuffixes[0]}"; - - var place = Convert.ToInt32(Math.Floor(Math.Log(bytes, 1024))); - var num = Math.Round(bytes / Math.Pow(1024, place), 1); - return $"{num} {ByteSuffixes[place]}"; - } -} - /// /// Downloads a Url to a file on disk. The download will be written to a temporary file first, then moved to the final /// destination. The SHA1 of any existing file will be calculated and used as an ETag to avoid downloading the file if @@ -381,8 +345,7 @@ private static string FriendlyBytes(ulong bytes) /// public class DownloadTask { - private const int BufferSize = 4096; - private const int ProgressUpdateDelayMs = 50; + private const int BufferSize = 64 * 1024; private const string XOriginalContentLengthHeader = "X-Original-Content-Length"; // overrides Content-Length if available private static readonly HttpClient HttpClient = new(new HttpClientHandler @@ -398,22 +361,13 @@ public class DownloadTask private readonly string _destinationPath; private readonly string _tempDestinationPath; - // ProgressChanged events are always delayed by up to 50ms to avoid - // flooding. - // - // This will be called: - // - once after the request succeeds but before the read/write routine - // begins - // - occasionally while the file is being downloaded (at least 50ms apart) - // - once when the download is complete - public EventHandler? ProgressChanged; - public readonly HttpRequestMessage Request; public Task Task { get; private set; } = null!; // Set in EnsureStartedAsync + public bool DownloadStarted { get; private set; } // Whether we've received headers yet and started the actual download public ulong BytesWritten { get; private set; } - public ulong? TotalBytes { get; private set; } - public double? Progress => TotalBytes == null ? null : (double)BytesWritten / TotalBytes.Value; + public ulong? BytesTotal { get; private set; } + public double? Progress => BytesTotal == null ? null : (double)BytesWritten / BytesTotal.Value; public bool IsCompleted => Task.IsCompleted; internal DownloadTask(ILogger logger, HttpRequestMessage req, string destinationPath, IDownloadValidator validator) @@ -496,7 +450,7 @@ private async Task Start(CancellationToken ct = default) } if (res.Content.Headers.ContentLength >= 0) - TotalBytes = (ulong)res.Content.Headers.ContentLength; + BytesTotal = (ulong)res.Content.Headers.ContentLength; // X-Original-Content-Length overrules Content-Length if set. if (res.Headers.TryGetValues(XOriginalContentLengthHeader, out var headerValues)) @@ -504,24 +458,19 @@ private async Task Start(CancellationToken ct = default) // If there are multiple we only look at the first one. var headerValue = headerValues.ToList().FirstOrDefault(); if (!string.IsNullOrEmpty(headerValue) && ulong.TryParse(headerValue, out var originalContentLength)) - TotalBytes = originalContentLength; + BytesTotal = originalContentLength; else _logger.LogWarning( "Failed to parse {XOriginalContentLengthHeader} header value '{HeaderValue}'", XOriginalContentLengthHeader, headerValue); } - SendProgressUpdate(new DownloadProgressEvent - { - BytesWritten = 0, - BytesTotal = TotalBytes, - }); - await Download(res, ct); } private async Task Download(HttpResponseMessage res, CancellationToken ct) { + DownloadStarted = true; try { var sha1 = res.Headers.Contains("ETag") ? SHA1.Create() : null; @@ -546,28 +495,13 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct) await tempFile.WriteAsync(buffer.AsMemory(0, n), ct); sha1?.TransformBlock(buffer, 0, n, null, 0); BytesWritten += (ulong)n; - await QueueProgressUpdate(new DownloadProgressEvent - { - BytesWritten = BytesWritten, - BytesTotal = TotalBytes, - }, ct); } } - // Clear any pending progress updates to ensure they won't be sent - // after the final update. - await ClearQueuedProgressUpdate(ct); - // Then write the final status update. - TotalBytes = BytesWritten; - SendProgressUpdate(new DownloadProgressEvent - { - BytesWritten = BytesWritten, - BytesTotal = BytesWritten, - }); - - if (TotalBytes != null && BytesWritten != TotalBytes) + BytesTotal ??= BytesWritten; + if (BytesWritten != BytesTotal) throw new IOException( - $"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesWritten}"); + $"Downloaded file size does not match expected response content length: Expected={BytesTotal}, BytesWritten={BytesWritten}"); // Verify the ETag if it was sent by the server. if (res.Headers.Contains("ETag") && sha1 != null) @@ -612,69 +546,4 @@ await QueueProgressUpdate(new DownloadProgressEvent throw; } } - - // _progressEventLock protects _progressUpdateTask and _pendingProgressEvent. - private readonly RaiiSemaphoreSlim _progressEventLock = new(1, 1); - private readonly CancellationTokenSource _progressUpdateCts = new(); - private Task? _progressUpdateTask; - private DownloadProgressEvent? _pendingProgressEvent; - - // Can be called multiple times, but must not be called or in progress while - // SendQueuedProgressUpdateNow is called. - private async Task QueueProgressUpdate(DownloadProgressEvent e, CancellationToken ct) - { - using var _1 = await _progressEventLock.LockAsync(ct); - _pendingProgressEvent = e; - - if (_progressUpdateCts.IsCancellationRequested) - throw new InvalidOperationException("Progress update task was cancelled, cannot queue new progress update"); - - // Start a task with a 50ms delay unless one is already running. - var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _progressUpdateCts.Token); - cts.CancelAfter(TimeSpan.FromSeconds(5)); - _progressUpdateTask ??= Task.Delay(ProgressUpdateDelayMs, cts.Token) - .ContinueWith(t => - { - cts.Cancel(); - using var _2 = _progressEventLock.Lock(); - _progressUpdateTask = null; - if (t.IsFaulted || t.IsCanceled) return; - - var ev = _pendingProgressEvent; - if (ev != null) SendProgressUpdate(ev); - }, cts.Token); - } - - // Must only be called after all QueueProgressUpdate calls have completed. - private async Task ClearQueuedProgressUpdate(CancellationToken ct) - { - Task? t; - using (var _ = _progressEventLock.LockAsync(ct)) - { - await _progressUpdateCts.CancelAsync(); - t = _progressUpdateTask; - } - - // We can't continue to hold the lock here because the continuation - // grabs a lock. We don't need to worry about a new task spawning after - // this because the token is cancelled. - if (t == null) return; - try - { - await t.WaitAsync(ct); - } - catch (TaskCanceledException) - { - // Ignore - } - } - - private void SendProgressUpdate(DownloadProgressEvent e) - { - var handler = ProgressChanged; - if (handler == null) - return; - // Start a new task in the background to invoke the event. - _ = Task.Run(() => handler.Invoke(this, e)); - } } diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs index 0324ebb..886bb70 100644 --- a/Vpn.Service/Manager.cs +++ b/Vpn.Service/Manager.cs @@ -450,34 +450,46 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected _logger.LogDebug("Skipping tunnel binary version validation"); } + // Note: all ETag, signature and version validation is performed by the + // DownloadTask. var downloadTask = await _downloader.StartDownloadAsync(req, _config.TunnelBinaryPath, validators, ct); - var progressLock = new RaiiSemaphoreSlim(1, 1); - var progressBroadcastCts = CancellationTokenSource.CreateLinkedTokenSource(ct); - downloadTask.ProgressChanged += (sender, ev) => + // Wait for the download to complete, sending progress updates every + // 50ms. + while (true) { - using var _ = progressLock.Lock(); - if (progressBroadcastCts.IsCancellationRequested) return; - _logger.LogInformation("Download progress: {ev}", ev); + // Wait for the download to complete, or for a short delay before + // we send a progress update. + var delayTask = Task.Delay(TimeSpan.FromMilliseconds(50), ct); + var winner = await Task.WhenAny([ + downloadTask.Task, + delayTask, + ]); + if (winner == downloadTask.Task) + break; + + // Task.WhenAny will not throw if the winner was cancelled, so + // check CT afterward and not beforehand. + ct.ThrowIfCancellationRequested(); + + if (!downloadTask.DownloadStarted) + // Don't send progress updates if we don't know what the + // progress is yet. + continue; var progress = new StartProgressDownloadProgress { - BytesWritten = ev.BytesWritten, + BytesWritten = downloadTask.BytesWritten, }; - if (ev.BytesTotal != null) - progress.BytesTotal = ev.BytesTotal.Value; - BroadcastStartProgress(StartProgressStage.Downloading, progress, progressBroadcastCts.Token) - .Wait(progressBroadcastCts.Token); - }; + if (downloadTask.BytesTotal != null) + progress.BytesTotal = downloadTask.BytesTotal.Value; - // Awaiting this will check the checksum (via the ETag) if the file - // exists, and will also validate the signature and version. - await downloadTask.Task; + await BroadcastStartProgress(StartProgressStage.Downloading, progress, ct); + } - // Prevent any lagging progress events from being sent. - // ReSharper disable once PossiblyMistakenUseOfCancellationToken - using (await progressLock.LockAsync(ct)) - await progressBroadcastCts.CancelAsync(); + // Await again to re-throw any exceptions that occurred during the + // download. + await downloadTask.Task; // We don't send a broadcast here as we immediately send one in the // parent routine. @@ -486,7 +498,6 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected private async Task BroadcastStartProgress(StartProgressStage stage, StartProgressDownloadProgress? downloadProgress = null, CancellationToken cancellationToken = default) { - _logger.LogInformation("Start progress: {stage}", stage); await FallibleBroadcast(new ServiceMessage { StartProgress = new StartProgress From 02bc40046adcfd296ca84d05338da278f4070602 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 6 Jun 2025 16:36:23 +1000 Subject: [PATCH 5/5] invariant in startup progress model --- App/Models/RpcModel.cs | 23 +++++++++++++++++++---- App/Services/RpcController.cs | 8 ++------ Vpn.Service/Manager.cs | 2 +- Vpn.Service/ManagerRpc.cs | 15 +++++++++++---- Vpn.Service/TunnelSupervisor.cs | 6 +++--- 5 files changed, 36 insertions(+), 18 deletions(-) diff --git a/App/Models/RpcModel.cs b/App/Models/RpcModel.cs index 426863b..08d2303 100644 --- a/App/Models/RpcModel.cs +++ b/App/Models/RpcModel.cs @@ -88,8 +88,8 @@ public class VpnStartupProgress private const double DownloadProgressMin = 0.05; private const double DownloadProgressMax = 0.80; - public VpnStartupStage Stage { get; set; } = VpnStartupStage.Unknown; - public VpnDownloadProgress? DownloadProgress { get; set; } = null; + public VpnStartupStage Stage { get; init; } = VpnStartupStage.Unknown; + public VpnDownloadProgress? DownloadProgress { get; init; } = null; // 0.0 to 1.0 public double Progress @@ -165,10 +165,25 @@ public class RpcModel { public RpcLifecycle RpcLifecycle { get; set; } = RpcLifecycle.Disconnected; - public VpnLifecycle VpnLifecycle { get; set; } = VpnLifecycle.Unknown; + public VpnLifecycle VpnLifecycle + { + get; + set + { + if (VpnLifecycle != value && value == VpnLifecycle.Starting) + // Reset the startup progress when the VPN lifecycle changes to + // Starting. + VpnStartupProgress = null; + field = value; + } + } // Nullable because it is only set when the VpnLifecycle is Starting - public VpnStartupProgress? VpnStartupProgress { get; set; } + public VpnStartupProgress? VpnStartupProgress + { + get => VpnLifecycle is VpnLifecycle.Starting ? field ?? new VpnStartupProgress() : null; + set; + } public IReadOnlyList Workspaces { get; set; } = []; diff --git a/App/Services/RpcController.cs b/App/Services/RpcController.cs index 3345050..168a1be 100644 --- a/App/Services/RpcController.cs +++ b/App/Services/RpcController.cs @@ -164,7 +164,6 @@ public async Task StartVpn(CancellationToken ct = default) MutateState(state => { state.VpnLifecycle = VpnLifecycle.Starting; - state.VpnStartupProgress = new VpnStartupProgress(); }); ServiceMessage reply; @@ -255,9 +254,6 @@ private void MutateState(Action mutator) using (_stateLock.Lock()) { mutator(_state); - // Unset the startup progress if the VpnLifecycle is not Starting - if (_state.VpnLifecycle != VpnLifecycle.Starting) - _state.VpnStartupProgress = null; newState = _state.Clone(); } @@ -294,8 +290,8 @@ private void ApplyStartProgressUpdate(StartProgress message) { MutateState(state => { - // MutateState will undo these changes if it doesn't believe we're - // in the "Starting" state. + // The model itself will ignore this value if we're not in the + // starting state. state.VpnStartupProgress = VpnStartupProgress.FromProto(message); }); } diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs index 886bb70..fdb62af 100644 --- a/Vpn.Service/Manager.cs +++ b/Vpn.Service/Manager.cs @@ -331,7 +331,7 @@ private async Task FallibleBroadcast(ServiceMessage message, CancellationToken c // Broadcast the messages out with a low timeout. If clients don't // receive broadcasts in time, it's not a big deal. using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); - //cts.CancelAfter(TimeSpan.FromMilliseconds(100)); + cts.CancelAfter(TimeSpan.FromMilliseconds(30)); try { await _managerRpc.BroadcastAsync(message, cts.Token); diff --git a/Vpn.Service/ManagerRpc.cs b/Vpn.Service/ManagerRpc.cs index d922caf..4920570 100644 --- a/Vpn.Service/ManagerRpc.cs +++ b/Vpn.Service/ManagerRpc.cs @@ -127,14 +127,20 @@ public async Task ExecuteAsync(CancellationToken stoppingToken) public async Task BroadcastAsync(ServiceMessage message, CancellationToken ct) { + // Sends messages to all clients simultaneously and waits for them all + // to send or fail/timeout. + // // Looping over a ConcurrentDictionary is exception-safe, but any items // added or removed during the loop may or may not be included. - foreach (var (clientId, client) in _activeClients) + await Task.WhenAll(_activeClients.Select(async item => + { try { - var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); + // Enforce upper bound in case a CT with a timeout wasn't + // supplied. + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); cts.CancelAfter(TimeSpan.FromSeconds(2)); - await client.Speaker.SendMessage(message, cts.Token); + await item.Value.Speaker.SendMessage(message, cts.Token); } catch (ObjectDisposedException) { @@ -142,11 +148,12 @@ public async Task BroadcastAsync(ServiceMessage message, CancellationToken ct) } catch (Exception e) { - _logger.LogWarning(e, "Failed to send message to client {ClientId}", clientId); + _logger.LogWarning(e, "Failed to send message to client {ClientId}", item.Key); // TODO: this should probably kill the client, but due to the // async nature of the client handling, calling Dispose // will not remove the client from the active clients list } + })); } private async Task HandleRpcClientAsync(ulong clientId, Speaker speaker, diff --git a/Vpn.Service/TunnelSupervisor.cs b/Vpn.Service/TunnelSupervisor.cs index 6ff4f3b..7dd6738 100644 --- a/Vpn.Service/TunnelSupervisor.cs +++ b/Vpn.Service/TunnelSupervisor.cs @@ -99,16 +99,16 @@ public async Task StartAsync(string binPath, }, }; // TODO: maybe we should change the log format in the inner binary - // to something without a timestamp + // to something without a timestamp _subprocess.OutputDataReceived += (_, args) => { if (!string.IsNullOrWhiteSpace(args.Data)) - _logger.LogDebug("stdout: {Data}", args.Data); + _logger.LogInformation("stdout: {Data}", args.Data); }; _subprocess.ErrorDataReceived += (_, args) => { if (!string.IsNullOrWhiteSpace(args.Data)) - _logger.LogDebug("stderr: {Data}", args.Data); + _logger.LogInformation("stderr: {Data}", args.Data); }; // Pass the other end of the pipes to the subprocess and dispose pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy