diff --git a/3.test_cases/pytorch/grpo/README.md b/3.test_cases/pytorch/grpo/README.md new file mode 100644 index 000000000..9ad640773 --- /dev/null +++ b/3.test_cases/pytorch/grpo/README.md @@ -0,0 +1,73 @@ +# Multi-node Large model GRPO training using Hugging Face TRL + +## Overview + +This is a test case for multi-node large model GRPO training using Hugging Face TRL. + +## Prerequisites + +### Docker Image + +We define all the dependencies in `grpo.Dockerfile` and build the image with the following command: + +```bash +docker build -f grpo.Dockerfile -t grpo:latest . +``` + +### Enroot + +To run our container on Slurm we convert the container into a Squash file using Enroot: + +```bash +enroot import -o ./grpo.sqsh dockerd://grpo:latest +``` + +## Launching GRPO training + +We launch the GRPO training with the following command: + +```bash +sbatch train.sbatch Qwen/Qwen2.5-72B-Instruct +``` + +The logs can be inspected using tail command: + +GRPO Training logs: +```bash +tail -f -n +0 grpo_XXX.out +``` +sample output: +``` + 1%| | 17/2264 [01:22<2:55:16, 4.68s/it] +0: {'loss': 0.0785, 'grad_norm': 0.8229517735973697, 'learning_rate': 9.916077738515903e-06, 'num_tokens': 1498339.0, 'completions/mean_length': 134.934765625, 'completions/min_length': 35.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 0.08203125, 'completions/mean_terminated_length': 124.83461303710938, 'completions/min_terminated_length': 35.0, 'completions/max_terminated_length': 253.8, 'rewards/format_reward/mean': 0.90703125, 'rewards/format_reward/std': 0.27258416190743445, 'rewards/accuracy_reward/mean': 0.224609375, 'rewards/accuracy_reward/std': 0.4104481041431427, 'reward': 1.131640625, 'reward_std': 0.34059175848960876, 'kl': 0.2958984375, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.01} +``` + +vLLM logs: +```bash +tail -f -n +0 vllm_XXX.out +``` +sample output: +``` +0: INFO: 10.4.37.27:41696 - "POST /upda_named_param/ HTTP/1.1" 200 OK +0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK +0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK +0: INFO 05-14 23:13:00 [block_pool.py:264] Successfully reset prefix cache +0: INFO: 10.4.37.27:41696 - "POST /reset_prefix_cache/ HTTP/1.1" 200 OK +Processed prompts: 100%|██████████| 256/256 [00:01<00:00, 176.40it/s, est. speed input: 32916.33 toks/s, output: 13802.34 toks/s] +0: INFO: 10.4.37.27:41696 - "POST /generate/ HTTP/1.1" 200 OK +0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK +0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK +0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK +``` + +## Inference + +```bash +srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=infer.err python /grpo/infer.py --model /grpo/.../Qwen/Qwen2.5-14B-Instruct-GRPO/checkpoint-700/ +``` + +## Evaluation + +```bash +srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=eval.err python /grpo/eval.py --model /grpo/.../Qwen/Qwen2.5-14B-Instruct-GRPO/checkpoint-700/ +``` \ No newline at end of file diff --git a/3.test_cases/pytorch/grpo/deepspeed_zero3.yaml b/3.test_cases/pytorch/grpo/deepspeed_zero3.yaml new file mode 100644 index 000000000..b5a1201f8 --- /dev/null +++ b/3.test_cases/pytorch/grpo/deepspeed_zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/3.test_cases/pytorch/grpo/eval.py b/3.test_cases/pytorch/grpo/eval.py new file mode 100644 index 000000000..fb9d5e8de --- /dev/null +++ b/3.test_cases/pytorch/grpo/eval.py @@ -0,0 +1,100 @@ +import argparse +import torch +from datasets import load_dataset +from vllm import LLM, SamplingParams +from transformers import AutoConfig, AutoTokenizer +from tqdm import tqdm +from math_verify import parse, verify +import re +import sys + + +def get_tensor_parallel_size(model: str) -> int: + config = AutoConfig.from_pretrained(model) + num_key_value_heads = getattr( + config, "num_key_value_heads", getattr(config, "num_attention_heads", 1) + ) + vocab_size = getattr(config, "vocab_size", 1) + gpus_count = torch.cuda.device_count() if torch.cuda.is_available() else 1 + tensor_parallel_size = 1 + for tp in reversed(range(1, gpus_count + 1)): + if num_key_value_heads % tp == 0 and vocab_size % tp == 0: + tensor_parallel_size = tp + break + return tensor_parallel_size + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2.5-0.5B-Instruct", + help="The model to use", + ) + args = parser.parse_args() + + dataset_id = "PrimeIntellect/verifiable-math-problems" + dataset = load_dataset(dataset_id, split="train") + + dataset = dataset.train_test_split(test_size=0.01, seed=42) + test_dataset = dataset["test"] + + SYSTEM_PROMPT = ( + "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " + "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " + "process and answer are enclosed within and tags, respectively, i.e., " + "reasoning process hereanswer here" + ) + + ending = "Return your final response as 'Final Answer: \\boxed{}', where is the number or mathematical expression of the solution." + + def make_conversation(example): + return { + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": example["prompt"][: -len(ending)].strip()}, + ], + } + + test_dataset = test_dataset.map(make_conversation) + + tensor_parallel_size = get_tensor_parallel_size(args.model) + print(f"{tensor_parallel_size=}") + + llm = LLM(model=args.model, tensor_parallel_size=tensor_parallel_size) + + tokenizer = AutoTokenizer.from_pretrained(args.model) + + prompts_and_solutions = [ + ( + tokenizer.apply_chat_template(sample["prompt"], tokenize=False), + sample["gold_standard_solution"], + ) + for sample in tqdm( + test_dataset, desc="Loading prompts and solutions", file=sys.stdout + ) + ] + prompts = [prompt for prompt, _ in prompts_and_solutions] + solutions = [solution for _, solution in prompts_and_solutions] + + outputs = llm.generate( + prompts, sampling_params=SamplingParams(max_tokens=1000, temperature=0.0) + ) + + generated_texts = [output.outputs[0].text for output in outputs] + results = [ + verify(parse(generated_text), parse(solution)) + for generated_text, solution in tqdm( + zip(generated_texts, solutions), + total=len(generated_texts), + desc="Verifying answers", + file=sys.stdout, + ) + ] + score = sum(results) / len(results) + print(f"Percentage of correct answers: {score:.2%}") + + +if __name__ == "__main__": + main() diff --git a/3.test_cases/pytorch/grpo/grpo.Dockerfile b/3.test_cases/pytorch/grpo/grpo.Dockerfile new file mode 100644 index 000000000..421fe17d6 --- /dev/null +++ b/3.test_cases/pytorch/grpo/grpo.Dockerfile @@ -0,0 +1,24 @@ +FROM public.ecr.aws/hpc-cloud/nccl-tests:latest + +# Install Miniconda to not depend on the base image python +RUN mkdir -p /opt/miniconda3 \ + && curl -L https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o /tmp/Miniconda3-latest-Linux-x86_64.sh \ + && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -f -p /opt/miniconda3 \ + && rm /tmp/Miniconda3-latest-Linux-x86_64.sh \ + && /opt/miniconda3/bin/conda init bash + +ENV PATH="/opt/miniconda3/bin:${PATH}" + +# Install Rust which is required by TRL's dependency 'outlines' +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + +ENV PATH="/root/.cargo/bin:${PATH}" + +# Install Python dependencies before installing TRL with VLLM backend +RUN pip install torch==2.6.0 transformers datasets accelerate peft deepspeed wandb math_verify flashinfer-python + +# # Install FlashInfer +# RUN pip install flashinfer-python -i https://flashinfer.ai/whl/cu126/torch2.6/ DOES NOT WORK + +# Install TRL with VLLM backend +RUN PKG_CONFIG_PATH=/opt/miniconda3/lib/pkgconfig pip install trl[vllm] diff --git a/3.test_cases/pytorch/grpo/inference.py b/3.test_cases/pytorch/grpo/inference.py new file mode 100644 index 000000000..483141e60 --- /dev/null +++ b/3.test_cases/pytorch/grpo/inference.py @@ -0,0 +1,89 @@ +import argparse +import torch +from datasets import load_dataset +from vllm import LLM, SamplingParams +from transformers import AutoConfig, AutoTokenizer +from math_verify import parse, verify + +def get_tensor_parallel_size(model: str) -> int: + config = AutoConfig.from_pretrained(model) + num_key_value_heads = getattr( + config, "num_key_value_heads", getattr(config, "num_attention_heads", 1) + ) + vocab_size = getattr(config, "vocab_size", 1) + gpus_count = torch.cuda.device_count() if torch.cuda.is_available() else 1 + tensor_parallel_size = 1 + for tp in reversed(range(1, gpus_count + 1)): + if num_key_value_heads % tp == 0 and vocab_size % tp == 0: + tensor_parallel_size = tp + break + return tensor_parallel_size + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2.5-0.5B-Instruct", + help="The model to use", + ) + args = parser.parse_args() + + dataset_id = "PrimeIntellect/verifiable-math-problems" + dataset = load_dataset(dataset_id, split="train") + + dataset = dataset.train_test_split(test_size=0.1, seed=42) + test_dataset = dataset["test"] + + SYSTEM_PROMPT = ( + "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " + "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " + "process and answer are enclosed within and tags, respectively, i.e., " + "reasoning process hereanswer here" + ) + + ending = "Return your final response as 'Final Answer: \\boxed{}', where is the number or mathematical expression of the solution." + + def make_conversation(example): + return { + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": example["prompt"][: -len(ending)].strip()}, + ], + } + + test_dataset = test_dataset.map(make_conversation) + + tensor_parallel_size = get_tensor_parallel_size(args.model) + print(f"{tensor_parallel_size=}") + + llm = LLM(model=args.model, tensor_parallel_size=tensor_parallel_size) + + tokenizer = AutoTokenizer.from_pretrained(args.model) + + i = 0 + for example in test_dataset: + prompt = example["prompt"] + prompt = tokenizer.apply_chat_template(prompt, tokenize=False) + response = llm.generate(prompt, sampling_params=SamplingParams(max_tokens=1000, temperature=0.0)) + generated_texts = response[0].outputs[0].text + parsed_generated_texts = parse(generated_texts) + verification_info = example['verification_info'] + parsed_gold_standard_solution = parse(example['gold_standard_solution']) + result = verify(parsed_gold_standard_solution, parsed_generated_texts) + + print(f"{prompt=}") + print(f"{generated_texts=}") + print(f"{parsed_generated_texts=}") + print(f"{verification_info=}") + print(f"{parsed_gold_standard_solution=}") + print(f"{result=}") + print("-" * 100) + i += 1 + if i > 100: + break + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/3.test_cases/pytorch/grpo/train.py b/3.test_cases/pytorch/grpo/train.py new file mode 100644 index 000000000..b16661bf0 --- /dev/null +++ b/3.test_cases/pytorch/grpo/train.py @@ -0,0 +1,203 @@ +import argparse +import os +import re +from datasets import load_dataset +from trl import GRPOConfig, GRPOTrainer +from math_verify import parse, verify +from datetime import datetime +import accelerate + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--vllm_server_host", type=str, default="", help="The server IP" + ) + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2.5-0.5B-Instruct", + help="The model to use", + ) + args = parser.parse_args() + + dataset_id = "PrimeIntellect/verifiable-math-problems" + dataset = load_dataset(dataset_id, split="train") + + dataset = dataset.train_test_split(test_size=0.1, seed=42) + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + SYSTEM_PROMPT = ( + "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " + "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " + "process and answer are enclosed within and tags, respectively, i.e., " + "reasoning process hereanswer here" + ) + + ending = "Return your final response as 'Final Answer: \\boxed{}', where is the number or mathematical expression of the solution." + + def make_conversation(example): + return { + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": example["prompt"][: -len(ending)].strip()}, + ], + } + + train_dataset = train_dataset.map(make_conversation) + + # def format_reward(completions, **kwargs): + # pattern = r"^(.*?)\s*(.*?)$" + # rewards = [] + # for completion in completions: + # match = re.search(pattern, completion[0]["content"], re.DOTALL) + # if match: + # # think_content = match.group(1).strip() + # # answer_content = match.group(2).strip() + # # if len(think_content) > len(answer_content) > 0: + # # rewards.append(1.0) + # # else: + # # rewards.append(0.5) + # rewards.append(1.0) + # else: + # rewards.append(0.0) + # return rewards + + def simple_format_reward(completions, **kwargs): + completion_contents = [completion[0]["content"] for completion in completions] + rewards = [] + for content in completion_contents: + reward = 0.0 + if "" in content: + reward += 0.25 + if "" in content: + reward += 0.25 + if "" in content: + reward += 0.25 + if "" in content: + reward += 0.25 + rewards.append(reward) + return rewards + + def format_reward(completions, **kwargs): + """Reward function that checks if the reasoning process is enclosed within and tags, while the final answer is enclosed within and tags.""" + pattern = r"^(.*?)\s*(.*?)$" + completion_contents = [completion[0]["content"] for completion in completions] + matches = [ + re.match(pattern, content, re.DOTALL) for content in completion_contents + ] + return [1.0 if match else 0.0 for match in matches] + + def accuracy_reward(completions, **kwargs): + """Reward function that checks if the completion is the same as the ground truth.""" + solutions = kwargs["gold_standard_solution"] + completion_contents = [completion[0]["content"] for completion in completions] + rewards = [] + for content, solution in zip(completion_contents, solutions): + gold_parsed = parse(solution) + if len(gold_parsed) != 0: + answer_parsed = parse(content) + if verify(gold_parsed, answer_parsed): + rewards.append(1.0) + else: + rewards.append(0.0) + else: + rewards.append(None) + return rewards + + def len_reward(completions, **kwargs): + """Compute length-based rewards to discourage overthinking and promote token efficiency. + + Taken from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599 + + Args: + completions: List of model completions + solution: List of ground truth solutions + + Returns: + List of rewards where: + - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len) + - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len)) + """ + solution = kwargs["gold_standard_solution"] + contents = [completion[0]["content"] for completion in completions] + + # First check correctness of answers + correctness = [] + for content, sol in zip(contents, solution): + gold_parsed = parse(sol) + if len(gold_parsed) == 0: + # Skip unparseable examples + correctness.append(True) # Treat as correct to avoid penalizing + print("Failed to parse gold solution: ", sol) + continue + + answer_parsed = parse(content) + correctness.append(verify(answer_parsed, gold_parsed)) + + # Calculate lengths + lengths = [len(content) for content in contents] + min_len = min(lengths) + max_len = max(lengths) + + # If all responses have the same length, return zero rewards + if max_len == min_len: + return [0.0] * len(completions) + + rewards = [] + for length, is_correct in zip(lengths, correctness): + lambda_val = 0.5 - (length - min_len) / (max_len - min_len) + + if is_correct: + reward = lambda_val + else: + reward = min(0, lambda_val) + + rewards.append(float(reward)) + + return rewards + + parent_dir = os.path.dirname(__file__) + date_time_dir = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + output_dir = os.path.join(parent_dir, date_time_dir, args.model + "-GRPO") + output_dir_list = [output_dir] + accelerate.utils.broadcast_object_list(output_dir_list) + output_dir = output_dir_list[0] + assert output_dir is not None + + training_args = GRPOConfig( + output_dir=output_dir, + learning_rate=1e-5, + remove_unused_columns=False, + num_train_epochs=1, + per_device_train_batch_size=96, # 96 for 14B, 32 for 72B + bf16=True, + gradient_checkpointing=True, + logging_steps=10, + report_to="wandb", + use_vllm=True, + vllm_server_host=args.vllm_server_host, + vllm_server_timeout=600, + save_strategy="steps", + save_steps=100, + torch_empty_cache_steps=10, + # Parameters related to evaluation + # eval_strategy="steps", + # eval_steps=1000, + # eval_on_start=True, + ) + + trainer = GRPOTrainer( + model=args.model, + reward_funcs=[simple_format_reward, format_reward, accuracy_reward, len_reward], + args=training_args, + train_dataset=train_dataset, + eval_dataset=test_dataset, + ) + + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/3.test_cases/pytorch/grpo/train.sbatch b/3.test_cases/pytorch/grpo/train.sbatch new file mode 100644 index 000000000..596044bf1 --- /dev/null +++ b/3.test_cases/pytorch/grpo/train.sbatch @@ -0,0 +1,78 @@ +#!/bin/bash +#SBATCH --job-name=grpo +#SBATCH --nodes=9 +#SBATCH --ntasks-per-node 1 + +## Set libfabric flags to use EFA +export FI_PROVIDER=efa +export FI_EFA_USE_DEVICE_RDMA=1 # use for p4d +export FI_EFA_FORK_SAFE=1 + +## Set this flag for debugging EFA +#export FI_LOG_LEVEL=warn + +## NCCL Environment variables +# export NCCL_DEBUG=INFO + +### Increase the send queue depth and can turn NCCL communications into non-blocking. +### https://www.usenix.org/system/files/atc23-choi.pdf +export NCCL_BUFFSIZE=8388608 +### Improve performance by increasing buffer size for Send/Recv, Gather, Scatter and Alltoall communications +### https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html +export NCCL_P2P_NET_CHUNKSIZE=524288 + +### Improve performance for AllReduce by selecting specific protocol and algorithm for specific +### message size and number of ranks. +### More information https://github.com/aws/aws-ofi-nccl/wiki/Algorithm-and-Protocol-Tuner-for-AWS. +export NCCL_TUNER_PLUGIN=/opt/aws-ofi-nccl/install/lib/libnccl-ofi-tuner.so + +NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) +TRAIN_NODES_NUM=$((SLURM_NNODES - 1)) +TRAIN_NODES="${NODELIST[@]:0:$TRAIN_NODES_NUM}" +VLLM_NODE="${NODELIST[$TRAIN_NODES_NUM]}" +head_node_ip=${NODELIST[0]} +GPUS_PER_NODE=8 + +LAUNCHER="accelerate launch \ + --config_file /grpo/deepspeed_zero3.yaml \ + --num_processes $((TRAIN_NODES_NUM * GPUS_PER_NODE)) \ + --num_machines ${TRAIN_NODES_NUM} \ + --rdzv_backend c10d \ + --main_process_ip $head_node_ip \ + --main_process_port 29500 \ + --machine_rank \$SLURM_NODEID " + +MODEL="${1:-'Qwen/Qwen2.5-0.5B-Instruct'}" + +CMD="/grpo/train.py --model $MODEL --vllm_server_host $VLLM_NODE" + +# Fetch model config and get number of heads for tensor parallel size +CONFIG_URL="https://huggingface.co/$MODEL/raw/main/config.json" +CONFIG_JSON=$(curl -s $CONFIG_URL) +NUM_HEADS=$(echo "$CONFIG_JSON" | python3 -c "import sys, json; config = json.load(sys.stdin); print(config.get('num_key_value_heads', config.get('num_attention_heads', 1)))") +VOCAB_SIZE=$(echo "$CONFIG_JSON" | python3 -c "import sys, json; config = json.load(sys.stdin); print(config.get('vocab_size', 1))") + +# Find largest tensor parallel size that divides NUM_HEADS and is <= GPUS_PER_NODE +TENSOR_PARALLEL=1 +for ((i=GPUS_PER_NODE; i>=1; i--)); do + if [ $((NUM_HEADS % i)) -eq 0 ] && [ $((VOCAB_SIZE % i)) -eq 0 ]; then + TENSOR_PARALLEL=$i + break + fi +done + +echo "Using tensor parallel size: $TENSOR_PARALLEL" + +srun -l --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh \ + --output=grpo_%j.out --error=grpo_%j.err \ + --container-mounts=.:/grpo,$HF_HOME:$HF_HOME \ + --nodes=$TRAIN_NODES_NUM --ntasks=$TRAIN_NODES_NUM --nodelist="${TRAIN_NODES}" \ + bash -c "$LAUNCHER $CMD" & + +srun -l --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh \ + --output=vllm_%j.out --error=vllm_%j.out \ + --container-mounts=.:/grpo,$HF_HOME:$HF_HOME \ + --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \ + trl vllm-serve --model $MODEL --tensor_parallel_size $TENSOR_PARALLEL & + +wait \ No newline at end of file