Skip to content

Commit 6002abb

Browse files
committed
torchtune usecase
1 parent e51c64a commit 6002abb

File tree

5 files changed

+182
-0
lines changed

5 files changed

+182
-0
lines changed

3.test_cases/torchtune/.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
checkponts
2+
models
3+
miniconda3
4+
pt_torchtune
5+
torchtune
6+
Miniconda3-latest-Linux-x86_64.sh
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/usr/bin/env bash
2+
set -ex
3+
4+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
5+
# SPDX-License-Identifier: MIT-0
6+
7+
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
8+
chmod +x Miniconda3-latest-Linux-x86_64.sh
9+
./Miniconda3-latest-Linux-x86_64.sh -b -f -p ./miniconda3
10+
11+
source ./miniconda3/bin/activate
12+
13+
conda create -y -p ./pt_torchtune python=3.10
14+
15+
source activate ./pt_torchtune/
16+
17+
# Install AWS Pytorch, see https://aws-pytorch-doc.com/
18+
# conda install -y pytorch=2.2.0 torchvision torchaudio torchtriton=2.2.0 pytorch-cuda=12.1 transformers datasets --strict-channel-priority --override-channels -c https://aws-ml-conda.s3.us-west-2.amazonaws.com -c nvidia -c conda-forge
19+
conda install -y pytorch torchvision torchaudio pytorch-cuda=12.1 transformers datasets -c pytorch -c nvidia
20+
21+
git clone https://github.com/pytorch/torchtune.git
22+
pip install -e ./torchtune
23+
24+
# Create checkpoint dir
25+
mkdir checkpoints
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/bin/bash
2+
3+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
# SPDX-License-Identifier: MIT-0
5+
6+
# set -ex;
7+
8+
# Default value for HF_MODEL
9+
DEFAULT_HF_MODEL="meta-llama/Llama-2-7b"
10+
read -p "Please enter Hugging Face model ($DEFAULT_HF_MODEL): " HF_MODEL
11+
if [ -z "$HF_MODEL" ]; then
12+
HF_MODEL="$DEFAULT_HF_MODEL"
13+
fi
14+
15+
read -p "Please enter Hugging Face Access Tokens: " HF_TOKEN
16+
17+
mkdir -p models/${HF_MODEL}
18+
19+
tune download \
20+
${HF_MODEL} \
21+
--output-dir models/${HF_MODEL} \
22+
--hf-token ${HF_TOKEN}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#!/bin/bash
2+
3+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
# SPDX-License-Identifier: MIT-0
5+
6+
#SBATCH --nodes=1 # number of nodes to use
7+
#SBATCH --job-name=full_ft # name of your job
8+
#SBATCH --exclusive # job has exclusive use of the resource, no sharing
9+
10+
set -ex;
11+
12+
###########################
13+
###### User Variables #####
14+
###########################
15+
16+
GPUS_PER_NODE=4 # 4 for G5.12x, 8 for P4/P5
17+
18+
###########################
19+
## Environment Variables ##
20+
###########################
21+
22+
## Plenty of EFA level variables
23+
## Comment out for non-efa instances (G4d, P3)
24+
## For G5.12x, Comment out RDMA and Fork safe
25+
## For G4dn and other G5, comment out all
26+
# export FI_EFA_USE_DEVICE_RDMA=1 # use for p4d
27+
# export FI_EFA_FORK_SAFE=1
28+
export FI_LOG_LEVEL=1
29+
export FI_PROVIDER=efa
30+
export NCCL_DEBUG=INFO
31+
## Switching SYNC_MEMOPS to zero can boost throughput with FSDP
32+
## Disables CU_POINTER_ATTRIBUTE_SYNC_MEMOPS
33+
## Reduces memory synchronizations
34+
## https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__UNIFIED.html
35+
export FI_EFA_SET_CUDA_SYNC_MEMOPS=0
36+
37+
###########################
38+
####### Torch Dist #######
39+
###########################
40+
41+
declare -a TORCHRUN_ARGS=(
42+
--nproc_per_node=$GPUS_PER_NODE \
43+
--nnodes=$SLURM_JOB_NUM_NODES \
44+
--rdzv_id=$SLURM_JOB_ID \
45+
--rdzv_backend=c10d \
46+
--rdzv_endpoint=$(hostname) \
47+
)
48+
49+
export TORCHTUNE=./pt_torchtune/bin/tune
50+
export TRAIN_CONFIG=./llama2_7B_full.yaml
51+
52+
srun -l ${TORCHTUNE} run "${TORCHRUN_ARGS[@]}" full_finetune_distributed --config ${TRAIN_CONFIG}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Config for multi-device full finetuning in full_finetune_distributed.py
2+
# using a Llama2 7B model
3+
#
4+
# This config assumes that you've run the following command before launching
5+
# this run:
6+
# tune download meta-llama/Llama-2-7b \
7+
# --hf-token <HF_TOKEN> \
8+
# --output-dir /tmp/llama2
9+
#
10+
# To launch on 4 devices, run the following command from root:
11+
# tune run --nproc_per_node 4 full_finetune_distributed \
12+
# --config llama2/7B_full \
13+
#
14+
# You can add specific overrides through the command line. For example
15+
# to override the checkpointer directory while launching training
16+
# you can run:
17+
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
18+
# --config llama2/7B_full \
19+
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
20+
#
21+
# This config works best when the model is being fine-tuned on 2+ GPUs.
22+
# Single device full finetuning requires more memory optimizations. It's
23+
# best to use 7B_full_single_device.yaml for those cases
24+
25+
26+
# Tokenizer
27+
tokenizer:
28+
_component_: torchtune.models.llama2.llama2_tokenizer
29+
path: models/meta-llama/Llama-2-7b/tokenizer.model
30+
31+
# Dataset
32+
dataset:
33+
_component_: torchtune.datasets.alpaca_dataset
34+
train_on_input: True
35+
seed: null
36+
shuffle: True
37+
38+
# Model Arguments
39+
model:
40+
_component_: torchtune.models.llama2.llama2_7b
41+
42+
checkpointer:
43+
_component_: torchtune.utils.FullModelMetaCheckpointer
44+
checkpoint_dir: models/meta-llama/Llama-2-7b
45+
checkpoint_files: [consolidated.00.pth]
46+
recipe_checkpoint: null
47+
output_dir: models/meta-llama/Llama-2-7b
48+
model_type: LLAMA2
49+
resume_from_checkpoint: False
50+
51+
# Fine-tuning arguments
52+
batch_size: 2
53+
epochs: 3
54+
optimizer:
55+
_component_: torch.optim.AdamW
56+
lr: 2e-5
57+
loss:
58+
_component_: torch.nn.CrossEntropyLoss
59+
max_steps_per_epoch: null
60+
gradient_accumulation_steps: 1
61+
62+
63+
# Training env
64+
device: cuda
65+
66+
# Memory management
67+
enable_activation_checkpointing: True
68+
69+
# Reduced precision
70+
dtype: bf16
71+
72+
# Logging
73+
metric_logger:
74+
_component_: torchtune.utils.metric_logging.DiskLogger
75+
log_dir: ${output_dir}
76+
output_dir: /tmp/alpaca-llama2-finetune
77+
log_every_n_steps: null

0 commit comments

Comments
 (0)