Skip to content

Commit

Permalink
Revert "merge upstream changes" (pytorch#570)
Browse files Browse the repository at this point in the history
Reverts pytorch#569

sorry accidental commit things to the wrong fork
  • Loading branch information
tianyu-l authored Sep 8, 2024
1 parent a09cde3 commit 1923ce4
Show file tree
Hide file tree
Showing 11 changed files with 21 additions and 392 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,3 @@ torchtitan/datasets/**/*.model
*.log
error.json
_remote_module_non_scriptable.py

# torch compile debug related
torch_compile_debug/*
232 changes: 0 additions & 232 deletions benchmark.py

This file was deleted.

24 changes: 0 additions & 24 deletions run_benchmark_train.sh

This file was deleted.

2 changes: 0 additions & 2 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

# TORCH_TRACE="./outputs/trace" \
TORCH_NCCL_AVOID_RECORD_STREAMS=1 \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
8 changes: 0 additions & 8 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,6 @@ def __init__(self):
action="store_true",
help="Whether to apply loss parallel when sequence parallel is enabled",
)

# experimental configs
self.parser.add_argument(
"--experimental.torch_spmd",
default=False,
action="store_true",
help="Whether to use the experimental torch_spmd style parallelism",
)
self.parser.add_argument(
"--experimental.enable_async_tensor_parallel",
default=False,
Expand Down
60 changes: 4 additions & 56 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,55 +29,7 @@
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims


# NOTE(lty): experimental for the PT-D 24 research internship project
def torch_spmd_parallelize(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
torch._inductor.config.simplefsdp.enable_reorder = True
torch._inductor.config.simplefsdp.enable_bucket = True

if parallel_dims.tp_enabled:
apply_tp(
model,
world_mesh["tp"],
loss_parallel=parallel_dims.loss_parallel_enabled,
enable_float8=job_config.float8.enable_float8_linear,
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
)

ac_config = job_config.activation_checkpoint
if ac_config.mode != "none":
apply_ac(model, ac_config)
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")

if parallel_dims.dp_enabled:
from torch_spmd.data_parallel import data_parallel, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh

model = data_parallel(
model,
dp_mesh,
mode="fully_shard",
ac_mode=ac_config.mode,
mp_policy=mp_policy,
)
logger.info("Applied Simple FSDP to the model")

if job_config.training.compile:
model = torch.compile(model, fullgraph=True)
logger.info("Compiling with torch.compile")

return model
from torchtitan.parallelisms.utils import check_strided_sharding_enabled


def parallelize_llama(
Expand All @@ -93,9 +45,6 @@ def parallelize_llama(
NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
# NOTE(lty): experimental for the PT-D 24 research internship project
if job_config.experimental.torch_spmd:
return torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config)

if parallel_dims.tp_enabled:
if (
Expand Down Expand Up @@ -351,12 +300,11 @@ def apply_fsdp(
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

# TODO(lty): the check below requires the latest PyTorch nightly; remove for now
# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
# that users won't use a nightly build which is older than 20240809 by then.
# if tp_enabled:
# # check if strided sharding is enabled, which is necessary for 2D/3D DCP
# check_strided_sharding_enabled()
if tp_enabled:
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
check_strided_sharding_enabled()

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
Expand Down
Loading

0 comments on commit 1923ce4

Please sign in to comment.