Skip to content

Commit

Permalink
fix float8 after the HSDP PR (pytorch#575)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianyu-l committed Sep 12, 2024
1 parent f2a1551 commit 2a25f4d
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_enabled
and parallel_dims.dp_type == "fsdp"
parallel_dims.dp_shard_enabled
and float8_config.enable_fsdp_float8_all_gather
)
scaling_type_input = ScalingType(float8_config.scaling_type_input)
Expand Down

0 comments on commit 2a25f4d

Please sign in to comment.