Skip to content

Commit 0751a96

Browse files
authored
feat: make training config fields optional (meta-llama#1861)
# What does this PR do? Today, supervised_fine_tune itself and the `TrainingConfig` class have a bunch of required fields that a provider implementation might not need. for example, if a provider wants to handle hyperparameters in its configuration as well as any type of dataset retrieval, optimizer or LoRA config, a user will still need to pass in a virtually empty `DataConfig`, `OptimizerConfig` and `AlgorithmConfig` in some cases. Many of these fields are intended to work specifically with llama models and knobs intended for customizing inline. Adding remote post_training providers will require loosening these arguments, or forcing users to pass in empty objects to satisfy the pydantic models. Signed-off-by: Charlie Doern <cdoern@redhat.com>
1 parent 70a7e4d commit 0751a96

File tree

4 files changed

+29
-21
lines changed

4 files changed

+29
-21
lines changed

docs/_static/llama-stack-spec.html

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9778,13 +9778,16 @@
97789778
"type": "integer"
97799779
},
97809780
"max_steps_per_epoch": {
9781-
"type": "integer"
9781+
"type": "integer",
9782+
"default": 1
97829783
},
97839784
"gradient_accumulation_steps": {
9784-
"type": "integer"
9785+
"type": "integer",
9786+
"default": 1
97859787
},
97869788
"max_validation_steps": {
9787-
"type": "integer"
9789+
"type": "integer",
9790+
"default": 1
97889791
},
97899792
"data_config": {
97909793
"$ref": "#/components/schemas/DataConfig"
@@ -9804,10 +9807,7 @@
98049807
"required": [
98059808
"n_epochs",
98069809
"max_steps_per_epoch",
9807-
"gradient_accumulation_steps",
9808-
"max_validation_steps",
9809-
"data_config",
9810-
"optimizer_config"
9810+
"gradient_accumulation_steps"
98119811
],
98129812
"title": "TrainingConfig"
98139813
},
@@ -10983,8 +10983,7 @@
1098310983
"job_uuid",
1098410984
"training_config",
1098510985
"hyperparam_search_config",
10986-
"logger_config",
10987-
"model"
10986+
"logger_config"
1098810987
],
1098910988
"title": "SupervisedFineTuneRequest"
1099010989
},

docs/_static/llama-stack-spec.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6744,10 +6744,13 @@ components:
67446744
type: integer
67456745
max_steps_per_epoch:
67466746
type: integer
6747+
default: 1
67476748
gradient_accumulation_steps:
67486749
type: integer
6750+
default: 1
67496751
max_validation_steps:
67506752
type: integer
6753+
default: 1
67516754
data_config:
67526755
$ref: '#/components/schemas/DataConfig'
67536756
optimizer_config:
@@ -6762,9 +6765,6 @@ components:
67626765
- n_epochs
67636766
- max_steps_per_epoch
67646767
- gradient_accumulation_steps
6765-
- max_validation_steps
6766-
- data_config
6767-
- optimizer_config
67686768
title: TrainingConfig
67696769
PreferenceOptimizeRequest:
67706770
type: object
@@ -7498,7 +7498,6 @@ components:
74987498
- training_config
74997499
- hyperparam_search_config
75007500
- logger_config
7501-
- model
75027501
title: SupervisedFineTuneRequest
75037502
SyntheticDataGenerateRequest:
75047503
type: object

llama_stack/apis/post_training/post_training.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel):
6060
@json_schema_type
6161
class TrainingConfig(BaseModel):
6262
n_epochs: int
63-
max_steps_per_epoch: int
64-
gradient_accumulation_steps: int
65-
max_validation_steps: int
66-
data_config: DataConfig
67-
optimizer_config: OptimizerConfig
63+
max_steps_per_epoch: int = 1
64+
gradient_accumulation_steps: int = 1
65+
max_validation_steps: Optional[int] = 1
66+
data_config: Optional[DataConfig] = None
67+
optimizer_config: Optional[OptimizerConfig] = None
6868
efficiency_config: Optional[EfficiencyConfig] = None
6969
dtype: Optional[str] = "bf16"
7070

@@ -177,9 +177,9 @@ async def supervised_fine_tune(
177177
training_config: TrainingConfig,
178178
hyperparam_search_config: Dict[str, Any],
179179
logger_config: Dict[str, Any],
180-
model: str = Field(
181-
default="Llama3.2-3B-Instruct",
182-
description="Model descriptor from `llama model list`",
180+
model: Optional[str] = Field(
181+
default=None,
182+
description="Model descriptor for training if not in provider config`",
183183
),
184184
checkpoint_dir: Optional[str] = None,
185185
algorithm_config: Optional[AlgorithmConfig] = None,

llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from llama_stack.apis.datasets import Datasets
3939
from llama_stack.apis.post_training import (
4040
Checkpoint,
41+
DataConfig,
42+
EfficiencyConfig,
4143
LoraFinetuningConfig,
4244
OptimizerConfig,
4345
QATFinetuningConfig,
@@ -89,6 +91,10 @@ def __init__(
8991
datasetio_api: DatasetIO,
9092
datasets_api: Datasets,
9193
) -> None:
94+
assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized"
95+
96+
assert isinstance(training_config.efficiency_config, EfficiencyConfig), "EfficiencyConfig must be initialized"
97+
9298
self.job_uuid = job_uuid
9399
self.training_config = training_config
94100
if not isinstance(algorithm_config, LoraFinetuningConfig):
@@ -188,13 +194,16 @@ async def setup(self) -> None:
188194
self._tokenizer = await self._setup_tokenizer()
189195
log.info("Tokenizer is initialized.")
190196

197+
assert isinstance(self.training_config.optimizer_config, OptimizerConfig), "OptimizerConfig must be initialized"
191198
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
192199
log.info("Optimizer is initialized.")
193200

194201
self._loss_fn = CEWithChunkedOutputLoss()
195202
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
196203
log.info("Loss is initialized.")
197204

205+
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
206+
198207
self._training_sampler, self._training_dataloader = await self._setup_data(
199208
dataset_id=self.training_config.data_config.dataset_id,
200209
tokenizer=self._tokenizer,
@@ -452,6 +461,7 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
452461
"""
453462
The core training loop.
454463
"""
464+
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
455465
# Initialize tokens count and running loss (for grad accumulation)
456466
t0 = time.perf_counter()
457467
running_loss: float = 0.0

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy