diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index 5bee0d3f07065..3e09ed02a5fec 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -380,7 +380,7 @@ func (server *provisionerdServer) UpdateJob(ctx context.Context, request *proto. return nil, xerrors.Errorf("insert job logs: %w", err) } server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID)) - data, err := json.Marshal(logs) + data, err := json.Marshal(provisionerJobLogsMessage{Logs: logs}) if err != nil { return nil, xerrors.Errorf("marshal job log: %w", err) } @@ -549,6 +549,16 @@ func (server *provisionerdServer) FailJob(ctx context.Context, failJob *proto.Fa } case *proto.FailedJob_TemplateImport_: } + + data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true}) + if err != nil { + return nil, xerrors.Errorf("marshal job log: %w", err) + } + err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data) + if err != nil { + server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) + return nil, xerrors.Errorf("publish end of job logs: %w", err) + } return &proto.Empty{}, nil } @@ -711,6 +721,16 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr reflect.TypeOf(completed.Type).String()) } + data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true}) + if err != nil { + return nil, xerrors.Errorf("marshal job log: %w", err) + } + err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data) + if err != nil { + server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) + return nil, xerrors.Errorf("publish end of job logs: %w", err) + } + server.Logger.Debug(ctx, "CompleteJob done", slog.F("job_id", jobID)) return &proto.Empty{}, nil } diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 8b163412f0ff4..97aafc95909d1 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -28,6 +28,7 @@ import ( // The combination of these responses should provide all current logs // to the consumer, and future logs are streamed in the follow request. func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) { + logger := api.Logger.With(slog.F("job_id", job.ID)) follow := r.URL.Query().Has("follow") afterRaw := r.URL.Query().Get("after") beforeRaw := r.URL.Query().Get("before") @@ -38,6 +39,37 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job return } + // if we are following logs, start the subscription before we query the database, so that we don't miss any logs + // between the end of our query and the start of the subscription. We might get duplicates, so we'll keep track + // of processed IDs. + var bufferedLogs <-chan database.ProvisionerJobLog + if follow { + bl, closeFollow, err := api.followLogs(job.ID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: "Internal error watching provisioner logs.", + Detail: err.Error(), + }) + return + } + defer closeFollow() + bufferedLogs = bl + + // Next query the job itself to see if it is complete. If so, the historical query to the database will return + // the full set of logs. It's a little sad to have to query the job again, given that our caller definitely + // has, but we need to query it *after* we start following the pubsub to avoid a race condition where the job + // completes between the prior query and the start of following the pubsub. A more substantial refactor could + // avoid this, but not worth it for one fewer query at this point. + job, err = api.Database.GetProvisionerJobByID(r.Context(), job.ID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: "Internal error querying job.", + Detail: err.Error(), + }) + return + } + } + var after time.Time // Only fetch logs created after the time provided. if afterRaw != "" { @@ -78,26 +110,27 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job } } - if !follow { - logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{ - JobID: job.ID, - CreatedAfter: after, - CreatedBefore: before, + logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{ + JobID: job.ID, + CreatedAfter: after, + CreatedBefore: before, + }) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: "Internal error fetching provisioner logs.", + Detail: err.Error(), }) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: "Internal error fetching provisioner logs.", - Detail: err.Error(), - }) - return - } - if logs == nil { - logs = []database.ProvisionerJobLog{} - } + return + } + if logs == nil { + logs = []database.ProvisionerJobLog{} + } + if !follow { + logger.Debug(r.Context(), "Finished non-follow job logs") httpapi.Write(rw, http.StatusOK, convertProvisionerJobLogs(logs)) return } @@ -118,82 +151,43 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageText) defer wsNetConn.Close() // Also closes conn. - bufferedLogs := make(chan database.ProvisionerJobLog, 128) - closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(job.ID), func(ctx context.Context, message []byte) { - var logs []database.ProvisionerJobLog - err := json.Unmarshal(message, &logs) - if err != nil { - api.Logger.Warn(ctx, fmt.Sprintf("invalid provisioner job log on channel %q: %s", provisionerJobLogsChannel(job.ID), err.Error())) - return - } - - for _, log := range logs { - select { - case bufferedLogs <- log: - api.Logger.Debug(r.Context(), "subscribe buffered log", slog.F("job_id", job.ID), slog.F("stage", log.Stage)) - default: - // If this overflows users could miss logs streaming. This can happen - // if a database request takes a long amount of time, and we get a lot of logs. - api.Logger.Warn(ctx, "provisioner job log overflowing channel") - } - } - }) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: "Internal error watching provisioner logs.", - Detail: err.Error(), - }) - return - } - defer closeSubscribe() - - provisionerJobLogs, err := api.Database.GetProvisionerLogsByIDBetween(ctx, database.GetProvisionerLogsByIDBetweenParams{ - JobID: job.ID, - CreatedAfter: after, - CreatedBefore: before, - }) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: "Internal error fetching provisioner logs.", - Detail: err.Error(), - }) - return - } + logIdsDone := make(map[uuid.UUID]bool) // The Go stdlib JSON encoder appends a newline character after message write. encoder := json.NewEncoder(wsNetConn) - for _, provisionerJobLog := range provisionerJobLogs { + for _, provisionerJobLog := range logs { + logIdsDone[provisionerJobLog.ID] = true err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog)) if err != nil { return } } + if job.CompletedAt.Valid { + // job was complete before we queried the database for historical logs, meaning we got everything. No need + // to stream anything from the bufferedLogs. + return + } - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() for { select { - case <-r.Context().Done(): - api.Logger.Debug(context.Background(), "job logs context canceled", slog.F("job_id", job.ID)) + case <-ctx.Done(): + logger.Debug(context.Background(), "job logs context canceled") return - case log := <-bufferedLogs: - api.Logger.Debug(r.Context(), "subscribe encoding log", slog.F("job_id", job.ID), slog.F("stage", log.Stage)) - err = encoder.Encode(convertProvisionerJobLog(log)) - if err != nil { + case log, ok := <-bufferedLogs: + if !ok { + logger.Debug(context.Background(), "done with published logs") return } - case <-ticker.C: - job, err := api.Database.GetProvisionerJobByID(r.Context(), job.ID) - if err != nil { - api.Logger.Warn(r.Context(), "streaming job logs; checking if completed", slog.Error(err), slog.F("job_id", job.ID.String())) - continue - } - if job.CompletedAt.Valid { - api.Logger.Debug(context.Background(), "streaming job logs done; job done", slog.F("job_id", job.ID)) - return + if logIdsDone[log.ID] { + logger.Debug(r.Context(), "subscribe duplicated log", + slog.F("stage", log.Stage)) + } else { + logger.Debug(r.Context(), "subscribe encoding log", + slog.F("stage", log.Stage)) + err = encoder.Encode(convertProvisionerJobLog(log)) + if err != nil { + return + } } } } @@ -343,3 +337,43 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov func provisionerJobLogsChannel(jobID uuid.UUID) string { return fmt.Sprintf("provisioner-log-logs:%s", jobID) } + +// provisionerJobLogsMessage is the message type published on the provisionerJobLogsChannel() channel +type provisionerJobLogsMessage struct { + EndOfLogs bool `json:"end_of_logs,omitempty"` + Logs []database.ProvisionerJobLog `json:"logs,omitempty"` +} + +func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, func(), error) { + logger := api.Logger.With(slog.F("job_id", jobID)) + bufferedLogs := make(chan database.ProvisionerJobLog, 128) + closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(jobID), + func(ctx context.Context, message []byte) { + jlMsg := provisionerJobLogsMessage{} + err := json.Unmarshal(message, &jlMsg) + if err != nil { + logger.Warn(ctx, "invalid provisioner job log on channel", slog.Error(err)) + return + } + + for _, log := range jlMsg.Logs { + select { + case bufferedLogs <- log: + logger.Debug(ctx, "subscribe buffered log", slog.F("stage", log.Stage)) + default: + // If this overflows users could miss logs streaming. This can happen + // we get a lot of logs and consumer isn't keeping up. We don't want to block the pubsub, + // so just drop them. + logger.Warn(ctx, "provisioner job log overflowing channel") + } + } + if jlMsg.EndOfLogs { + logger.Debug(ctx, "got End of Logs") + close(bufferedLogs) + } + }) + if err != nil { + return nil, nil, err + } + return bufferedLogs, closeSubscribe, nil +} diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go new file mode 100644 index 0000000000000..4901f2f1ea9a4 --- /dev/null +++ b/coderd/provisionerjobs_internal_test.go @@ -0,0 +1,183 @@ +package coderd + +import ( + "context" + "crypto/sha256" + "encoding/json" + "net/http/httptest" + "net/url" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/codersdk" +) + +func TestProvisionerJobLogs_Unit(t *testing.T) { + t.Parallel() + + t.Run("QueryPubSubDupes", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + // mDB := mocks.NewStore(t) + fDB := databasefake.New() + fPubsub := &fakePubSub{t: t, cond: sync.NewCond(&sync.Mutex{})} + opts := Options{ + Logger: logger, + Database: fDB, + Pubsub: fPubsub, + } + api := New(&opts) + server := httptest.NewServer(api.Handler) + t.Cleanup(server.Close) + userID := uuid.New() + keyID, keySecret, err := generateAPIKeyIDSecret() + require.NoError(t, err) + hashed := sha256.Sum256([]byte(keySecret)) + + u, err := url.Parse(server.URL) + require.NoError(t, err) + client := codersdk.Client{ + HTTPClient: server.Client(), + SessionToken: keyID + "-" + keySecret, + URL: u, + } + + buildID := uuid.New() + workspaceID := uuid.New() + jobID := uuid.New() + + expectedLogs := []database.ProvisionerJobLog{ + {ID: uuid.New(), JobID: jobID, Stage: "Stage0"}, + {ID: uuid.New(), JobID: jobID, Stage: "Stage1"}, + {ID: uuid.New(), JobID: jobID, Stage: "Stage2"}, + {ID: uuid.New(), JobID: jobID, Stage: "Stage3"}, + } + + // wow there are a lot of DB rows we touch... + _, err = fDB.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: keyID, + HashedSecret: hashed[:], + UserID: userID, + ExpiresAt: time.Now().Add(5 * time.Hour), + }) + require.NoError(t, err) + _, err = fDB.InsertUser(ctx, database.InsertUserParams{ + ID: userID, + RBACRoles: []string{"admin"}, + }) + require.NoError(t, err) + _, err = fDB.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ + ID: buildID, + WorkspaceID: workspaceID, + JobID: jobID, + }) + require.NoError(t, err) + _, err = fDB.InsertWorkspace(ctx, database.InsertWorkspaceParams{ + ID: workspaceID, + }) + require.NoError(t, err) + _, err = fDB.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: jobID, + }) + require.NoError(t, err) + for _, l := range expectedLogs[:2] { + _, err := fDB.InsertProvisionerJobLogs(ctx, database.InsertProvisionerJobLogsParams{ + ID: []uuid.UUID{l.ID}, + JobID: jobID, + Stage: []string{l.Stage}, + }) + require.NoError(t, err) + } + + logs, err := client.WorkspaceBuildLogsAfter(ctx, buildID, time.Now()) + require.NoError(t, err) + + // when the endpoint calls subscribe, we get the listener here. + fPubsub.cond.L.Lock() + for fPubsub.listener == nil { + fPubsub.cond.Wait() + } + + // endpoint should now be listening + assert.False(t, fPubsub.canceled) + assert.False(t, fPubsub.closed) + + // send all the logs in two batches, duplicating what we already returned on the DB query. + msg := provisionerJobLogsMessage{} + msg.Logs = expectedLogs[:2] + data, err := json.Marshal(msg) + require.NoError(t, err) + fPubsub.listener(ctx, data) + msg.Logs = expectedLogs[2:] + data, err = json.Marshal(msg) + require.NoError(t, err) + fPubsub.listener(ctx, data) + + // send end of logs + msg.Logs = nil + msg.EndOfLogs = true + data, err = json.Marshal(msg) + require.NoError(t, err) + fPubsub.listener(ctx, data) + + var stages []string + for l := range logs { + logger.Info(ctx, "got log", + slog.F("id", l.ID), + slog.F("stage", l.Stage)) + stages = append(stages, l.Stage) + } + assert.Equal(t, []string{"Stage0", "Stage1", "Stage2", "Stage3"}, stages) + for !fPubsub.canceled { + fPubsub.cond.Wait() + } + assert.False(t, fPubsub.closed) + }) +} + +type fakePubSub struct { + t *testing.T + cond *sync.Cond + listener database.Listener + canceled bool + closed bool +} + +func (f *fakePubSub) Subscribe(_ string, listener database.Listener) (cancel func(), err error) { + f.cond.L.Lock() + defer f.cond.L.Unlock() + f.listener = listener + f.cond.Signal() + return f.cancel, nil +} + +func (f *fakePubSub) Publish(_ string, _ []byte) error { + f.t.Fail() + return nil +} + +func (f *fakePubSub) Close() error { + f.cond.L.Lock() + defer f.cond.L.Unlock() + f.closed = true + f.cond.Signal() + return nil +} + +func (f *fakePubSub) cancel() { + f.cond.L.Lock() + defer f.cond.L.Unlock() + f.canceled = true + f.cond.Signal() +} diff --git a/coderd/provisionerjobs_test.go b/coderd/provisionerjobs_test.go index 404cd53683aa3..9d35f482dadc6 100644 --- a/coderd/provisionerjobs_test.go +++ b/coderd/provisionerjobs_test.go @@ -45,7 +45,8 @@ func TestProvisionerJobLogs(t *testing.T) { logs, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before) require.NoError(t, err) for { - _, ok := <-logs + log, ok := <-logs + t.Logf("got log: [%s] %s %s", log.Level, log.Stage, log.Output) if !ok { return }
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: