Skip to content

Commit 3ed4316

Browse files
authored
feat: Implement async job execution for torchtune training (meta-llama#1437)
# What does this PR do? Now a separate thread is started to execute training jobs. Training requests now return job ID before the job completes. (Which fixes API timeouts for any jobs that take longer than a minute.) Note: the scheduler code is meant to be spun out in the future into a common provider service that can be reused for different APIs and providers. It is also expected to back the /jobs API proposed here: meta-llama#1238 Hence its somewhat generalized form which is expected to simplify its adoption elsewhere in the future. Note: this patch doesn't attempt to implement missing APIs (e.g. cancel or job removal). This work will belong to follow-up PRs. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] Added unit tests for the scheduler module. For the API coverage, did manual testing and was able to run a training cycle on GPU. The initial call returned job ID before the training completed, as (now) expected. Artifacts are returned as expected. ``` JobArtifactsResponse(checkpoints=[{'identifier': 'meta-llama/Llama-3.2-3B-Instruct-sft-0', 'created_at': '2025-03-07T22:45:19.892714', 'epoch': 0, 'post_training_job_id': 'test-job2ee77104-2fd3-4a4e-84cf-f83f8b8f1f50', 'path': '/home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0', 'training_metrics': None}], job_uuid='test-job2ee77104-2fd3-4a4e-84cf-f83f8b8f1f50') ``` The integration test is currently disabled for the provider. I will look into how it can be enabled in a different PR / issue context. [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
1 parent 7641a5c commit 3ed4316

File tree

3 files changed

+472
-39
lines changed

3 files changed

+472
-39
lines changed

llama_stack/providers/inline/post_training/torchtune/post_training.py

Lines changed: 87 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6-
from datetime import datetime, timezone
6+
from enum import Enum
77
from typing import Any, Dict, Optional
88

99
from llama_stack.apis.datasetio import DatasetIO
1010
from llama_stack.apis.datasets import Datasets
1111
from llama_stack.apis.post_training import (
1212
AlgorithmConfig,
13+
Checkpoint,
1314
DPOAlignmentConfig,
1415
JobStatus,
1516
ListPostTrainingJobsResponse,
@@ -25,9 +26,19 @@
2526
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
2627
LoraFinetuningSingleDevice,
2728
)
29+
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
30+
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
2831
from llama_stack.schema_utils import webmethod
2932

3033

34+
class TrainingArtifactType(Enum):
35+
CHECKPOINT = "checkpoint"
36+
RESOURCES_STATS = "resources_stats"
37+
38+
39+
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
40+
41+
3142
class TorchtunePostTrainingImpl:
3243
def __init__(
3344
self,
@@ -38,13 +49,27 @@ def __init__(
3849
self.config = config
3950
self.datasetio_api = datasetio_api
4051
self.datasets_api = datasets
52+
self._scheduler = Scheduler()
53+
54+
async def shutdown(self) -> None:
55+
await self._scheduler.shutdown()
56+
57+
@staticmethod
58+
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
59+
return JobArtifact(
60+
type=TrainingArtifactType.CHECKPOINT.value,
61+
name=checkpoint.identifier,
62+
uri=checkpoint.path,
63+
metadata=dict(checkpoint),
64+
)
4165

42-
# TODO: assume sync job, will need jobs API for async scheduling
43-
self.jobs = {}
44-
self.checkpoints_dict = {}
45-
46-
async def shutdown(self):
47-
pass
66+
@staticmethod
67+
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
68+
return JobArtifact(
69+
type=TrainingArtifactType.RESOURCES_STATS.value,
70+
name=TrainingArtifactType.RESOURCES_STATS.value,
71+
metadata=resources_stats,
72+
)
4873

4974
async def supervised_fine_tune(
5075
self,
@@ -56,20 +81,11 @@ async def supervised_fine_tune(
5681
checkpoint_dir: Optional[str],
5782
algorithm_config: Optional[AlgorithmConfig],
5883
) -> PostTrainingJob:
59-
if job_uuid in self.jobs:
60-
raise ValueError(f"Job {job_uuid} already exists")
61-
62-
post_training_job = PostTrainingJob(job_uuid=job_uuid)
84+
if isinstance(algorithm_config, LoraFinetuningConfig):
6385

64-
job_status_response = PostTrainingJobStatusResponse(
65-
job_uuid=job_uuid,
66-
status=JobStatus.scheduled,
67-
scheduled_at=datetime.now(timezone.utc),
68-
)
69-
self.jobs[job_uuid] = job_status_response
86+
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
87+
on_log_message_cb("Starting Lora finetuning")
7088

71-
if isinstance(algorithm_config, LoraFinetuningConfig):
72-
try:
7389
recipe = LoraFinetuningSingleDevice(
7490
self.config,
7591
job_uuid,
@@ -82,26 +98,22 @@ async def supervised_fine_tune(
8298
self.datasetio_api,
8399
self.datasets_api,
84100
)
85-
86-
job_status_response.status = JobStatus.in_progress
87-
job_status_response.started_at = datetime.now(timezone.utc)
88-
89101
await recipe.setup()
102+
90103
resources_allocated, checkpoints = await recipe.train()
91104

92-
self.checkpoints_dict[job_uuid] = checkpoints
93-
job_status_response.resources_allocated = resources_allocated
94-
job_status_response.checkpoints = checkpoints
95-
job_status_response.status = JobStatus.completed
96-
job_status_response.completed_at = datetime.now(timezone.utc)
105+
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
106+
for checkpoint in checkpoints:
107+
artifact = self._checkpoint_to_artifact(checkpoint)
108+
on_artifact_collected_cb(artifact)
97109

98-
except Exception:
99-
job_status_response.status = JobStatus.failed
100-
raise
110+
on_status_change_cb(SchedulerJobStatus.completed)
111+
on_log_message_cb("Lora finetuning completed")
101112
else:
102113
raise NotImplementedError()
103114

104-
return post_training_job
115+
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
116+
return PostTrainingJob(job_uuid=job_uuid)
105117

106118
async def preference_optimize(
107119
self,
@@ -114,19 +126,55 @@ async def preference_optimize(
114126
) -> PostTrainingJob: ...
115127

116128
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
117-
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
129+
return ListPostTrainingJobsResponse(
130+
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
131+
)
132+
133+
@staticmethod
134+
def _get_artifacts_metadata_by_type(job, artifact_type):
135+
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
136+
137+
@classmethod
138+
def _get_checkpoints(cls, job):
139+
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
140+
141+
@classmethod
142+
def _get_resources_allocated(cls, job):
143+
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
144+
return data[0] if data else None
118145

119146
@webmethod(route="/post-training/job/status")
120147
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
121-
return self.jobs.get(job_uuid, None)
148+
job = self._scheduler.get_job(job_uuid)
149+
150+
match job.status:
151+
# TODO: Add support for other statuses to API
152+
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
153+
status = JobStatus.scheduled
154+
case SchedulerJobStatus.running:
155+
status = JobStatus.in_progress
156+
case SchedulerJobStatus.completed:
157+
status = JobStatus.completed
158+
case SchedulerJobStatus.failed:
159+
status = JobStatus.failed
160+
case _:
161+
raise NotImplementedError()
162+
163+
return PostTrainingJobStatusResponse(
164+
job_uuid=job_uuid,
165+
status=status,
166+
scheduled_at=job.scheduled_at,
167+
started_at=job.started_at,
168+
completed_at=job.completed_at,
169+
checkpoints=self._get_checkpoints(job),
170+
resources_allocated=self._get_resources_allocated(job),
171+
)
122172

123173
@webmethod(route="/post-training/job/cancel")
124174
async def cancel_training_job(self, job_uuid: str) -> None:
125-
raise NotImplementedError("Job cancel is not implemented yet")
175+
self._scheduler.cancel(job_uuid)
126176

127177
@webmethod(route="/post-training/job/artifacts")
128178
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
129-
if job_uuid in self.checkpoints_dict:
130-
checkpoints = self.checkpoints_dict.get(job_uuid, [])
131-
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
132-
return None
179+
job = self._scheduler.get_job(job_uuid)
180+
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy