Skip to content

Multi-node LLM post-training with GRPO #684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions 3.test_cases/pytorch/grpo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Multi-node Large model GRPO training using Hugging Face TRL

## Overview

This is a test case for multi-node large model GRPO training using Hugging Face TRL.

## Prerequisites

### Docker Image

We define all the dependencies in `grpo.Dockerfile` and build the image with the following command:

```bash
docker build -f grpo.Dockerfile -t grpo:latest .
```

### Enroot

To run our container on Slurm we convert the container into a Squash file using Enroot:

```bash
enroot import -o ./grpo.sqsh dockerd://grpo:latest
```

## Launching GRPO training

We launch the GRPO training with the following command:

```bash
sbatch train.sbatch Qwen/Qwen2.5-72B-Instruct
```

The logs can be inspected using tail command:

GRPO Training logs:
```bash
tail -f -n +0 grpo_XXX.out
```
sample output:
```
1%| | 17/2264 [01:22<2:55:16, 4.68s/it]
0: {'loss': 0.0785, 'grad_norm': 0.8229517735973697, 'learning_rate': 9.916077738515903e-06, 'num_tokens': 1498339.0, 'completions/mean_length': 134.934765625, 'completions/min_length': 35.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 0.08203125, 'completions/mean_terminated_length': 124.83461303710938, 'completions/min_terminated_length': 35.0, 'completions/max_terminated_length': 253.8, 'rewards/format_reward/mean': 0.90703125, 'rewards/format_reward/std': 0.27258416190743445, 'rewards/accuracy_reward/mean': 0.224609375, 'rewards/accuracy_reward/std': 0.4104481041431427, 'reward': 1.131640625, 'reward_std': 0.34059175848960876, 'kl': 0.2958984375, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.01}
```

vLLM logs:
```bash
tail -f -n +0 vllm_XXX.out
```
sample output:
```
0: INFO: 10.4.37.27:41696 - "POST /upda_named_param/ HTTP/1.1" 200 OK
0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
0: INFO 05-14 23:13:00 [block_pool.py:264] Successfully reset prefix cache
0: INFO: 10.4.37.27:41696 - "POST /reset_prefix_cache/ HTTP/1.1" 200 OK
Processed prompts: 100%|██████████| 256/256 [00:01<00:00, 176.40it/s, est. speed input: 32916.33 toks/s, output: 13802.34 toks/s]
0: INFO: 10.4.37.27:41696 - "POST /generate/ HTTP/1.1" 200 OK
0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
```

## Inference

```bash
srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=infer.err python /grpo/infer.py --model /grpo/.../Qwen/Qwen2.5-14B-Instruct-GRPO/checkpoint-700/
```

## Evaluation

```bash
srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=eval.err python /grpo/eval.py --model /grpo/.../Qwen/Qwen2.5-14B-Instruct-GRPO/checkpoint-700/
```
22 changes: 22 additions & 0 deletions 3.test_cases/pytorch/grpo/deepspeed_zero3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
100 changes: 100 additions & 0 deletions 3.test_cases/pytorch/grpo/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import argparse
import torch
from datasets import load_dataset
from vllm import LLM, SamplingParams
from transformers import AutoConfig, AutoTokenizer
from tqdm import tqdm
from math_verify import parse, verify
import re
import sys


def get_tensor_parallel_size(model: str) -> int:
config = AutoConfig.from_pretrained(model)
num_key_value_heads = getattr(
config, "num_key_value_heads", getattr(config, "num_attention_heads", 1)
)
vocab_size = getattr(config, "vocab_size", 1)
gpus_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
tensor_parallel_size = 1
for tp in reversed(range(1, gpus_count + 1)):
if num_key_value_heads % tp == 0 and vocab_size % tp == 0:
tensor_parallel_size = tp
break
return tensor_parallel_size


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2.5-0.5B-Instruct",
help="The model to use",
)
args = parser.parse_args()

dataset_id = "PrimeIntellect/verifiable-math-problems"
dataset = load_dataset(dataset_id, split="train")

dataset = dataset.train_test_split(test_size=0.01, seed=42)
test_dataset = dataset["test"]

SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think>reasoning process here</think><answer>answer here</answer>"
)

ending = "Return your final response as 'Final Answer: \\boxed{<answer>}', where <answer> is the number or mathematical expression of the solution."

def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["prompt"][: -len(ending)].strip()},
],
}

test_dataset = test_dataset.map(make_conversation)

tensor_parallel_size = get_tensor_parallel_size(args.model)
print(f"{tensor_parallel_size=}")

llm = LLM(model=args.model, tensor_parallel_size=tensor_parallel_size)

tokenizer = AutoTokenizer.from_pretrained(args.model)

prompts_and_solutions = [
(
tokenizer.apply_chat_template(sample["prompt"], tokenize=False),
sample["gold_standard_solution"],
)
for sample in tqdm(
test_dataset, desc="Loading prompts and solutions", file=sys.stdout
)
]
prompts = [prompt for prompt, _ in prompts_and_solutions]
solutions = [solution for _, solution in prompts_and_solutions]

outputs = llm.generate(
prompts, sampling_params=SamplingParams(max_tokens=1000, temperature=0.0)
)

generated_texts = [output.outputs[0].text for output in outputs]
results = [
verify(parse(generated_text), parse(solution))
for generated_text, solution in tqdm(
zip(generated_texts, solutions),
total=len(generated_texts),
desc="Verifying answers",
file=sys.stdout,
)
]
score = sum(results) / len(results)
print(f"Percentage of correct answers: {score:.2%}")


if __name__ == "__main__":
main()
24 changes: 24 additions & 0 deletions 3.test_cases/pytorch/grpo/grpo.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
FROM public.ecr.aws/hpc-cloud/nccl-tests:latest

# Install Miniconda to not depend on the base image python
RUN mkdir -p /opt/miniconda3 \
&& curl -L https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o /tmp/Miniconda3-latest-Linux-x86_64.sh \
&& bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -f -p /opt/miniconda3 \
&& rm /tmp/Miniconda3-latest-Linux-x86_64.sh \
&& /opt/miniconda3/bin/conda init bash

ENV PATH="/opt/miniconda3/bin:${PATH}"

# Install Rust which is required by TRL's dependency 'outlines'
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y

ENV PATH="/root/.cargo/bin:${PATH}"

# Install Python dependencies before installing TRL with VLLM backend
RUN pip install torch==2.6.0 transformers datasets accelerate peft deepspeed wandb math_verify flashinfer-python

# # Install FlashInfer
# RUN pip install flashinfer-python -i https://flashinfer.ai/whl/cu126/torch2.6/ DOES NOT WORK

# Install TRL with VLLM backend
RUN PKG_CONFIG_PATH=/opt/miniconda3/lib/pkgconfig pip install trl[vllm]
89 changes: 89 additions & 0 deletions 3.test_cases/pytorch/grpo/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import argparse
import torch
from datasets import load_dataset
from vllm import LLM, SamplingParams
from transformers import AutoConfig, AutoTokenizer
from math_verify import parse, verify

def get_tensor_parallel_size(model: str) -> int:
config = AutoConfig.from_pretrained(model)
num_key_value_heads = getattr(
config, "num_key_value_heads", getattr(config, "num_attention_heads", 1)
)
vocab_size = getattr(config, "vocab_size", 1)
gpus_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
tensor_parallel_size = 1
for tp in reversed(range(1, gpus_count + 1)):
if num_key_value_heads % tp == 0 and vocab_size % tp == 0:
tensor_parallel_size = tp
break
return tensor_parallel_size


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2.5-0.5B-Instruct",
help="The model to use",
)
args = parser.parse_args()

dataset_id = "PrimeIntellect/verifiable-math-problems"
dataset = load_dataset(dataset_id, split="train")

dataset = dataset.train_test_split(test_size=0.1, seed=42)
test_dataset = dataset["test"]

SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think>reasoning process here</think><answer>answer here</answer>"
)

ending = "Return your final response as 'Final Answer: \\boxed{<answer>}', where <answer> is the number or mathematical expression of the solution."

def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["prompt"][: -len(ending)].strip()},
],
}

test_dataset = test_dataset.map(make_conversation)

tensor_parallel_size = get_tensor_parallel_size(args.model)
print(f"{tensor_parallel_size=}")

llm = LLM(model=args.model, tensor_parallel_size=tensor_parallel_size)

tokenizer = AutoTokenizer.from_pretrained(args.model)

i = 0
for example in test_dataset:
prompt = example["prompt"]
prompt = tokenizer.apply_chat_template(prompt, tokenize=False)
response = llm.generate(prompt, sampling_params=SamplingParams(max_tokens=1000, temperature=0.0))
generated_texts = response[0].outputs[0].text
parsed_generated_texts = parse(generated_texts)
verification_info = example['verification_info']
parsed_gold_standard_solution = parse(example['gold_standard_solution'])
result = verify(parsed_gold_standard_solution, parsed_generated_texts)

print(f"{prompt=}")
print(f"{generated_texts=}")
print(f"{parsed_generated_texts=}")
print(f"{verification_info=}")
print(f"{parsed_gold_standard_solution=}")
print(f"{result=}")
print("-" * 100)
i += 1
if i > 100:
break


if __name__ == "__main__":
main()
Loading