Skip to content

Commit

Permalink
[BE][5/n] simply pp vs. non-pp set up
Browse files Browse the repository at this point in the history
ghstack-source-id: 003bfbfbcf1511ddbd18e15d031b39f597d8e7db
Pull Request resolved: pytorch#510
  • Loading branch information
tianyu-l committed Aug 8, 2024
1 parent fa8cdd4 commit d6e3f77
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 85 deletions.
33 changes: 12 additions & 21 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,33 +122,25 @@ def loss_fn(pred, labels):
f"Building {model_name} {job_config.model.flavor} with {model_config}"
)
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)
model = model_cls.from_model_args(model_config)

# a no-op hander if float8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(whole_model)
float8_handler.convert_to_float8_training(model)

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [whole_model]
model_parts = [
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
for m in model_parts
]

init_device = "cuda"
for model in model_parts:
model.to_empty(device=init_device)
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

model.to_empty(device="cuda")
if not active_fake_mode():
whole_model.init_weights()
model.init_weights()
model.train()

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers(model_parts, job_config)
optimizers = build_optimizers([model], job_config)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)

for model in model_parts:
model.train()
logger.info(f"Vocab size: {model_config.vocab_size}")
# Create a dummy batch instead of loading from a dataset
batch = (
Expand All @@ -165,24 +157,23 @@ def loss_fn(pred, labels):
device="cuda",
),
)
fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0])
fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0])
fsdp_memtracker.track_inputs(batch)

with fsdp_memtracker:
for iter_idx in range(2):
input_ids, labels = batch
# train step
with train_context():
pred = whole_model(input_ids)
pred = model(input_ids)
loss = loss_fn(pred, labels)
del pred
loss.backward()

# clip gradients
for model in model_parts:
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
# optimizer step
Expand Down
2 changes: 0 additions & 2 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
from torchtitan.parallelisms.pipeline_llama import pipeline_llama
from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule


__all__ = [
"build_pipeline_schedule",
"models_parallelize_fns",
"models_pipelining_fns",
"ParallelDims",
Expand Down
22 changes: 8 additions & 14 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def parallelize_llama(
and not job_config.training.compile
):
raise RuntimeError("Async TP requires --training.compile")
model = apply_tp(
apply_tp(
model,
world_mesh["tp"],
loss_parallel=parallel_dims.loss_parallel_enabled,
Expand All @@ -60,7 +60,7 @@ def parallelize_llama(
)

if job_config.activation_checkpoint.mode != "none":
model = apply_ac(model, job_config.activation_checkpoint)
apply_ac(model, job_config.activation_checkpoint)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if job_config.training.compile:
Expand All @@ -69,14 +69,14 @@ def parallelize_llama(
"fused_rmsnorm is not compatible with torch.compile yet. "
"Please use rmsnorm or layernorm."
)
model = apply_compile(model)
apply_compile(model)

if parallel_dims.dp_enabled:
if parallel_dims.dp_type == "fsdp":
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names

model = apply_fsdp(
apply_fsdp(
model,
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
Expand All @@ -88,15 +88,13 @@ def parallelize_llama(
else:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
model = apply_ddp(
apply_ddp(
model,
world_mesh,
enable_compile=job_config.training.compile,
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
)

return model


def apply_tp(
model: nn.Module,
Expand All @@ -110,7 +108,7 @@ def apply_tp(
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Parallelize the final linear output layer
model = parallelize_module(
parallelize_module(
model,
tp_mesh,
{
Expand Down Expand Up @@ -192,7 +190,6 @@ def apply_tp(
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
"Tensor Parallelism to the model"
)
return model


# for selective op activation checkpointing
Expand Down Expand Up @@ -273,7 +270,6 @@ def apply_ac(model: nn.Module, ac_config):
model.layers.register_module(layer_id, transformer_block)

logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
return model


def apply_compile(model: nn.Module):
Expand All @@ -286,7 +282,6 @@ def apply_compile(model: nn.Module):
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")
return model


def apply_fsdp(
Expand Down Expand Up @@ -329,8 +324,8 @@ def apply_fsdp(
module._load_state_dict_pre_hooks.clear()
assert len(module._state_dict_pre_hooks) <= 1
module._state_dict_pre_hooks.clear()

logger.info("Applied FSDP to the model")
return model


def apply_ddp(
Expand All @@ -347,7 +342,6 @@ def apply_ddp(
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"

model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)

logger.info("Applied DDP to the model")
return model
18 changes: 13 additions & 5 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# This file applies the PT-D pipeline parallelism to the Llama model.

import copy
from typing import Union
from typing import Callable, Union

import torch
import torch.nn as nn
Expand All @@ -18,7 +18,10 @@
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank
from torchtitan.parallelisms.pipelining_utils import (
build_pipeline_schedule,
stage_ids_this_rank,
)


DeviceType = Union[int, str, torch.device]
Expand All @@ -31,6 +34,7 @@ def pipeline_llama(
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
loss_fn: Callable[..., torch.Tensor],
):
split_mode = job_config.experimental.pipeline_parallel_split_mode
valid_split_modes = ("manual", "tracer")
Expand All @@ -39,14 +43,18 @@ def pipeline_llama(
f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}"
)
if split_mode == "manual":
return pipeline_llama_manual(
stages, models = pipeline_llama_manual(
model, pp_mesh, parallel_dims, job_config, device, model_config
)
elif split_mode == "tracer":
return pipeline_llama_tracer(
stages, models = pipeline_llama_tracer(
model, pp_mesh, parallel_dims, job_config, device, model_config
)

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

return pp_schedule, models


def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"):
"""Get meta tensors with the right input shapes used for tracing"""
Expand Down Expand Up @@ -218,4 +226,4 @@ def pipeline_llama_tracer(
group=pp_mesh.get_group(),
)
)
return (stages, models)
return stages, models
2 changes: 1 addition & 1 deletion torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchtitan.logging import logger


def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn):
def build_pipeline_schedule(job_config, stages, loss_fn):
looped_schedule = False

if job_config.experimental.pipeline_parallel_schedule == "1f1b":
Expand Down
Loading

0 comments on commit d6e3f77

Please sign in to comment.