diff --git a/3.test_cases/pytorch/grpo/README.md b/3.test_cases/pytorch/grpo/README.md
new file mode 100644
index 000000000..9ad640773
--- /dev/null
+++ b/3.test_cases/pytorch/grpo/README.md
@@ -0,0 +1,73 @@
+# Multi-node Large model GRPO training using Hugging Face TRL
+
+## Overview
+
+This is a test case for multi-node large model GRPO training using Hugging Face TRL.
+
+## Prerequisites
+
+### Docker Image
+
+We define all the dependencies in `grpo.Dockerfile` and build the image with the following command:
+
+```bash
+docker build -f grpo.Dockerfile -t grpo:latest .
+```
+
+### Enroot
+
+To run our container on Slurm we convert the container into a Squash file using Enroot:
+
+```bash
+enroot import -o ./grpo.sqsh dockerd://grpo:latest
+```
+
+## Launching GRPO training
+
+We launch the GRPO training with the following command:
+
+```bash
+sbatch train.sbatch Qwen/Qwen2.5-72B-Instruct
+```
+
+The logs can be inspected using tail command:
+
+GRPO Training logs:
+```bash
+tail -f -n +0 grpo_XXX.out
+```
+sample output:
+```
+ 1%| | 17/2264 [01:22<2:55:16, 4.68s/it]
+0: {'loss': 0.0785, 'grad_norm': 0.8229517735973697, 'learning_rate': 9.916077738515903e-06, 'num_tokens': 1498339.0, 'completions/mean_length': 134.934765625, 'completions/min_length': 35.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 0.08203125, 'completions/mean_terminated_length': 124.83461303710938, 'completions/min_terminated_length': 35.0, 'completions/max_terminated_length': 253.8, 'rewards/format_reward/mean': 0.90703125, 'rewards/format_reward/std': 0.27258416190743445, 'rewards/accuracy_reward/mean': 0.224609375, 'rewards/accuracy_reward/std': 0.4104481041431427, 'reward': 1.131640625, 'reward_std': 0.34059175848960876, 'kl': 0.2958984375, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.01}
+```
+
+vLLM logs:
+```bash
+tail -f -n +0 vllm_XXX.out
+```
+sample output:
+```
+0: INFO: 10.4.37.27:41696 - "POST /upda_named_param/ HTTP/1.1" 200 OK
+0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
+0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
+0: INFO 05-14 23:13:00 [block_pool.py:264] Successfully reset prefix cache
+0: INFO: 10.4.37.27:41696 - "POST /reset_prefix_cache/ HTTP/1.1" 200 OK
+Processed prompts: 100%|██████████| 256/256 [00:01<00:00, 176.40it/s, est. speed input: 32916.33 toks/s, output: 13802.34 toks/s]
+0: INFO: 10.4.37.27:41696 - "POST /generate/ HTTP/1.1" 200 OK
+0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
+0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
+0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
+```
+
+## Inference
+
+```bash
+srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=infer.err python /grpo/infer.py --model /grpo/.../Qwen/Qwen2.5-14B-Instruct-GRPO/checkpoint-700/
+```
+
+## Evaluation
+
+```bash
+srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=eval.err python /grpo/eval.py --model /grpo/.../Qwen/Qwen2.5-14B-Instruct-GRPO/checkpoint-700/
+```
\ No newline at end of file
diff --git a/3.test_cases/pytorch/grpo/deepspeed_zero3.yaml b/3.test_cases/pytorch/grpo/deepspeed_zero3.yaml
new file mode 100644
index 000000000..b5a1201f8
--- /dev/null
+++ b/3.test_cases/pytorch/grpo/deepspeed_zero3.yaml
@@ -0,0 +1,22 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 3
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/3.test_cases/pytorch/grpo/eval.py b/3.test_cases/pytorch/grpo/eval.py
new file mode 100644
index 000000000..fb9d5e8de
--- /dev/null
+++ b/3.test_cases/pytorch/grpo/eval.py
@@ -0,0 +1,100 @@
+import argparse
+import torch
+from datasets import load_dataset
+from vllm import LLM, SamplingParams
+from transformers import AutoConfig, AutoTokenizer
+from tqdm import tqdm
+from math_verify import parse, verify
+import re
+import sys
+
+
+def get_tensor_parallel_size(model: str) -> int:
+ config = AutoConfig.from_pretrained(model)
+ num_key_value_heads = getattr(
+ config, "num_key_value_heads", getattr(config, "num_attention_heads", 1)
+ )
+ vocab_size = getattr(config, "vocab_size", 1)
+ gpus_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
+ tensor_parallel_size = 1
+ for tp in reversed(range(1, gpus_count + 1)):
+ if num_key_value_heads % tp == 0 and vocab_size % tp == 0:
+ tensor_parallel_size = tp
+ break
+ return tensor_parallel_size
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="Qwen/Qwen2.5-0.5B-Instruct",
+ help="The model to use",
+ )
+ args = parser.parse_args()
+
+ dataset_id = "PrimeIntellect/verifiable-math-problems"
+ dataset = load_dataset(dataset_id, split="train")
+
+ dataset = dataset.train_test_split(test_size=0.01, seed=42)
+ test_dataset = dataset["test"]
+
+ SYSTEM_PROMPT = (
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
+ "process and answer are enclosed within and tags, respectively, i.e., "
+ "reasoning process hereanswer here"
+ )
+
+ ending = "Return your final response as 'Final Answer: \\boxed{}', where is the number or mathematical expression of the solution."
+
+ def make_conversation(example):
+ return {
+ "prompt": [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": example["prompt"][: -len(ending)].strip()},
+ ],
+ }
+
+ test_dataset = test_dataset.map(make_conversation)
+
+ tensor_parallel_size = get_tensor_parallel_size(args.model)
+ print(f"{tensor_parallel_size=}")
+
+ llm = LLM(model=args.model, tensor_parallel_size=tensor_parallel_size)
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
+
+ prompts_and_solutions = [
+ (
+ tokenizer.apply_chat_template(sample["prompt"], tokenize=False),
+ sample["gold_standard_solution"],
+ )
+ for sample in tqdm(
+ test_dataset, desc="Loading prompts and solutions", file=sys.stdout
+ )
+ ]
+ prompts = [prompt for prompt, _ in prompts_and_solutions]
+ solutions = [solution for _, solution in prompts_and_solutions]
+
+ outputs = llm.generate(
+ prompts, sampling_params=SamplingParams(max_tokens=1000, temperature=0.0)
+ )
+
+ generated_texts = [output.outputs[0].text for output in outputs]
+ results = [
+ verify(parse(generated_text), parse(solution))
+ for generated_text, solution in tqdm(
+ zip(generated_texts, solutions),
+ total=len(generated_texts),
+ desc="Verifying answers",
+ file=sys.stdout,
+ )
+ ]
+ score = sum(results) / len(results)
+ print(f"Percentage of correct answers: {score:.2%}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/3.test_cases/pytorch/grpo/grpo.Dockerfile b/3.test_cases/pytorch/grpo/grpo.Dockerfile
new file mode 100644
index 000000000..421fe17d6
--- /dev/null
+++ b/3.test_cases/pytorch/grpo/grpo.Dockerfile
@@ -0,0 +1,24 @@
+FROM public.ecr.aws/hpc-cloud/nccl-tests:latest
+
+# Install Miniconda to not depend on the base image python
+RUN mkdir -p /opt/miniconda3 \
+ && curl -L https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o /tmp/Miniconda3-latest-Linux-x86_64.sh \
+ && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -f -p /opt/miniconda3 \
+ && rm /tmp/Miniconda3-latest-Linux-x86_64.sh \
+ && /opt/miniconda3/bin/conda init bash
+
+ENV PATH="/opt/miniconda3/bin:${PATH}"
+
+# Install Rust which is required by TRL's dependency 'outlines'
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
+
+ENV PATH="/root/.cargo/bin:${PATH}"
+
+# Install Python dependencies before installing TRL with VLLM backend
+RUN pip install torch==2.6.0 transformers datasets accelerate peft deepspeed wandb math_verify flashinfer-python
+
+# # Install FlashInfer
+# RUN pip install flashinfer-python -i https://flashinfer.ai/whl/cu126/torch2.6/ DOES NOT WORK
+
+# Install TRL with VLLM backend
+RUN PKG_CONFIG_PATH=/opt/miniconda3/lib/pkgconfig pip install trl[vllm]
diff --git a/3.test_cases/pytorch/grpo/inference.py b/3.test_cases/pytorch/grpo/inference.py
new file mode 100644
index 000000000..483141e60
--- /dev/null
+++ b/3.test_cases/pytorch/grpo/inference.py
@@ -0,0 +1,89 @@
+import argparse
+import torch
+from datasets import load_dataset
+from vllm import LLM, SamplingParams
+from transformers import AutoConfig, AutoTokenizer
+from math_verify import parse, verify
+
+def get_tensor_parallel_size(model: str) -> int:
+ config = AutoConfig.from_pretrained(model)
+ num_key_value_heads = getattr(
+ config, "num_key_value_heads", getattr(config, "num_attention_heads", 1)
+ )
+ vocab_size = getattr(config, "vocab_size", 1)
+ gpus_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
+ tensor_parallel_size = 1
+ for tp in reversed(range(1, gpus_count + 1)):
+ if num_key_value_heads % tp == 0 and vocab_size % tp == 0:
+ tensor_parallel_size = tp
+ break
+ return tensor_parallel_size
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="Qwen/Qwen2.5-0.5B-Instruct",
+ help="The model to use",
+ )
+ args = parser.parse_args()
+
+ dataset_id = "PrimeIntellect/verifiable-math-problems"
+ dataset = load_dataset(dataset_id, split="train")
+
+ dataset = dataset.train_test_split(test_size=0.1, seed=42)
+ test_dataset = dataset["test"]
+
+ SYSTEM_PROMPT = (
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
+ "process and answer are enclosed within and tags, respectively, i.e., "
+ "reasoning process hereanswer here"
+ )
+
+ ending = "Return your final response as 'Final Answer: \\boxed{}', where is the number or mathematical expression of the solution."
+
+ def make_conversation(example):
+ return {
+ "prompt": [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": example["prompt"][: -len(ending)].strip()},
+ ],
+ }
+
+ test_dataset = test_dataset.map(make_conversation)
+
+ tensor_parallel_size = get_tensor_parallel_size(args.model)
+ print(f"{tensor_parallel_size=}")
+
+ llm = LLM(model=args.model, tensor_parallel_size=tensor_parallel_size)
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
+
+ i = 0
+ for example in test_dataset:
+ prompt = example["prompt"]
+ prompt = tokenizer.apply_chat_template(prompt, tokenize=False)
+ response = llm.generate(prompt, sampling_params=SamplingParams(max_tokens=1000, temperature=0.0))
+ generated_texts = response[0].outputs[0].text
+ parsed_generated_texts = parse(generated_texts)
+ verification_info = example['verification_info']
+ parsed_gold_standard_solution = parse(example['gold_standard_solution'])
+ result = verify(parsed_gold_standard_solution, parsed_generated_texts)
+
+ print(f"{prompt=}")
+ print(f"{generated_texts=}")
+ print(f"{parsed_generated_texts=}")
+ print(f"{verification_info=}")
+ print(f"{parsed_gold_standard_solution=}")
+ print(f"{result=}")
+ print("-" * 100)
+ i += 1
+ if i > 100:
+ break
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/3.test_cases/pytorch/grpo/train.py b/3.test_cases/pytorch/grpo/train.py
new file mode 100644
index 000000000..b16661bf0
--- /dev/null
+++ b/3.test_cases/pytorch/grpo/train.py
@@ -0,0 +1,203 @@
+import argparse
+import os
+import re
+from datasets import load_dataset
+from trl import GRPOConfig, GRPOTrainer
+from math_verify import parse, verify
+from datetime import datetime
+import accelerate
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--vllm_server_host", type=str, default="", help="The server IP"
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="Qwen/Qwen2.5-0.5B-Instruct",
+ help="The model to use",
+ )
+ args = parser.parse_args()
+
+ dataset_id = "PrimeIntellect/verifiable-math-problems"
+ dataset = load_dataset(dataset_id, split="train")
+
+ dataset = dataset.train_test_split(test_size=0.1, seed=42)
+ train_dataset = dataset["train"]
+ test_dataset = dataset["test"]
+
+ SYSTEM_PROMPT = (
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
+ "process and answer are enclosed within and tags, respectively, i.e., "
+ "reasoning process hereanswer here"
+ )
+
+ ending = "Return your final response as 'Final Answer: \\boxed{}', where is the number or mathematical expression of the solution."
+
+ def make_conversation(example):
+ return {
+ "prompt": [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": example["prompt"][: -len(ending)].strip()},
+ ],
+ }
+
+ train_dataset = train_dataset.map(make_conversation)
+
+ # def format_reward(completions, **kwargs):
+ # pattern = r"^(.*?)\s*(.*?)$"
+ # rewards = []
+ # for completion in completions:
+ # match = re.search(pattern, completion[0]["content"], re.DOTALL)
+ # if match:
+ # # think_content = match.group(1).strip()
+ # # answer_content = match.group(2).strip()
+ # # if len(think_content) > len(answer_content) > 0:
+ # # rewards.append(1.0)
+ # # else:
+ # # rewards.append(0.5)
+ # rewards.append(1.0)
+ # else:
+ # rewards.append(0.0)
+ # return rewards
+
+ def simple_format_reward(completions, **kwargs):
+ completion_contents = [completion[0]["content"] for completion in completions]
+ rewards = []
+ for content in completion_contents:
+ reward = 0.0
+ if "" in content:
+ reward += 0.25
+ if "" in content:
+ reward += 0.25
+ if "" in content:
+ reward += 0.25
+ if "" in content:
+ reward += 0.25
+ rewards.append(reward)
+ return rewards
+
+ def format_reward(completions, **kwargs):
+ """Reward function that checks if the reasoning process is enclosed within and tags, while the final answer is enclosed within and tags."""
+ pattern = r"^(.*?)\s*(.*?)$"
+ completion_contents = [completion[0]["content"] for completion in completions]
+ matches = [
+ re.match(pattern, content, re.DOTALL) for content in completion_contents
+ ]
+ return [1.0 if match else 0.0 for match in matches]
+
+ def accuracy_reward(completions, **kwargs):
+ """Reward function that checks if the completion is the same as the ground truth."""
+ solutions = kwargs["gold_standard_solution"]
+ completion_contents = [completion[0]["content"] for completion in completions]
+ rewards = []
+ for content, solution in zip(completion_contents, solutions):
+ gold_parsed = parse(solution)
+ if len(gold_parsed) != 0:
+ answer_parsed = parse(content)
+ if verify(gold_parsed, answer_parsed):
+ rewards.append(1.0)
+ else:
+ rewards.append(0.0)
+ else:
+ rewards.append(None)
+ return rewards
+
+ def len_reward(completions, **kwargs):
+ """Compute length-based rewards to discourage overthinking and promote token efficiency.
+
+ Taken from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
+
+ Args:
+ completions: List of model completions
+ solution: List of ground truth solutions
+
+ Returns:
+ List of rewards where:
+ - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
+ - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
+ """
+ solution = kwargs["gold_standard_solution"]
+ contents = [completion[0]["content"] for completion in completions]
+
+ # First check correctness of answers
+ correctness = []
+ for content, sol in zip(contents, solution):
+ gold_parsed = parse(sol)
+ if len(gold_parsed) == 0:
+ # Skip unparseable examples
+ correctness.append(True) # Treat as correct to avoid penalizing
+ print("Failed to parse gold solution: ", sol)
+ continue
+
+ answer_parsed = parse(content)
+ correctness.append(verify(answer_parsed, gold_parsed))
+
+ # Calculate lengths
+ lengths = [len(content) for content in contents]
+ min_len = min(lengths)
+ max_len = max(lengths)
+
+ # If all responses have the same length, return zero rewards
+ if max_len == min_len:
+ return [0.0] * len(completions)
+
+ rewards = []
+ for length, is_correct in zip(lengths, correctness):
+ lambda_val = 0.5 - (length - min_len) / (max_len - min_len)
+
+ if is_correct:
+ reward = lambda_val
+ else:
+ reward = min(0, lambda_val)
+
+ rewards.append(float(reward))
+
+ return rewards
+
+ parent_dir = os.path.dirname(__file__)
+ date_time_dir = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+ output_dir = os.path.join(parent_dir, date_time_dir, args.model + "-GRPO")
+ output_dir_list = [output_dir]
+ accelerate.utils.broadcast_object_list(output_dir_list)
+ output_dir = output_dir_list[0]
+ assert output_dir is not None
+
+ training_args = GRPOConfig(
+ output_dir=output_dir,
+ learning_rate=1e-5,
+ remove_unused_columns=False,
+ num_train_epochs=1,
+ per_device_train_batch_size=96, # 96 for 14B, 32 for 72B
+ bf16=True,
+ gradient_checkpointing=True,
+ logging_steps=10,
+ report_to="wandb",
+ use_vllm=True,
+ vllm_server_host=args.vllm_server_host,
+ vllm_server_timeout=600,
+ save_strategy="steps",
+ save_steps=100,
+ torch_empty_cache_steps=10,
+ # Parameters related to evaluation
+ # eval_strategy="steps",
+ # eval_steps=1000,
+ # eval_on_start=True,
+ )
+
+ trainer = GRPOTrainer(
+ model=args.model,
+ reward_funcs=[simple_format_reward, format_reward, accuracy_reward, len_reward],
+ args=training_args,
+ train_dataset=train_dataset,
+ eval_dataset=test_dataset,
+ )
+
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/3.test_cases/pytorch/grpo/train.sbatch b/3.test_cases/pytorch/grpo/train.sbatch
new file mode 100644
index 000000000..596044bf1
--- /dev/null
+++ b/3.test_cases/pytorch/grpo/train.sbatch
@@ -0,0 +1,78 @@
+#!/bin/bash
+#SBATCH --job-name=grpo
+#SBATCH --nodes=9
+#SBATCH --ntasks-per-node 1
+
+## Set libfabric flags to use EFA
+export FI_PROVIDER=efa
+export FI_EFA_USE_DEVICE_RDMA=1 # use for p4d
+export FI_EFA_FORK_SAFE=1
+
+## Set this flag for debugging EFA
+#export FI_LOG_LEVEL=warn
+
+## NCCL Environment variables
+# export NCCL_DEBUG=INFO
+
+### Increase the send queue depth and can turn NCCL communications into non-blocking.
+### https://www.usenix.org/system/files/atc23-choi.pdf
+export NCCL_BUFFSIZE=8388608
+### Improve performance by increasing buffer size for Send/Recv, Gather, Scatter and Alltoall communications
+### https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html
+export NCCL_P2P_NET_CHUNKSIZE=524288
+
+### Improve performance for AllReduce by selecting specific protocol and algorithm for specific
+### message size and number of ranks.
+### More information https://github.com/aws/aws-ofi-nccl/wiki/Algorithm-and-Protocol-Tuner-for-AWS.
+export NCCL_TUNER_PLUGIN=/opt/aws-ofi-nccl/install/lib/libnccl-ofi-tuner.so
+
+NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
+TRAIN_NODES_NUM=$((SLURM_NNODES - 1))
+TRAIN_NODES="${NODELIST[@]:0:$TRAIN_NODES_NUM}"
+VLLM_NODE="${NODELIST[$TRAIN_NODES_NUM]}"
+head_node_ip=${NODELIST[0]}
+GPUS_PER_NODE=8
+
+LAUNCHER="accelerate launch \
+ --config_file /grpo/deepspeed_zero3.yaml \
+ --num_processes $((TRAIN_NODES_NUM * GPUS_PER_NODE)) \
+ --num_machines ${TRAIN_NODES_NUM} \
+ --rdzv_backend c10d \
+ --main_process_ip $head_node_ip \
+ --main_process_port 29500 \
+ --machine_rank \$SLURM_NODEID "
+
+MODEL="${1:-'Qwen/Qwen2.5-0.5B-Instruct'}"
+
+CMD="/grpo/train.py --model $MODEL --vllm_server_host $VLLM_NODE"
+
+# Fetch model config and get number of heads for tensor parallel size
+CONFIG_URL="https://huggingface.co/$MODEL/raw/main/config.json"
+CONFIG_JSON=$(curl -s $CONFIG_URL)
+NUM_HEADS=$(echo "$CONFIG_JSON" | python3 -c "import sys, json; config = json.load(sys.stdin); print(config.get('num_key_value_heads', config.get('num_attention_heads', 1)))")
+VOCAB_SIZE=$(echo "$CONFIG_JSON" | python3 -c "import sys, json; config = json.load(sys.stdin); print(config.get('vocab_size', 1))")
+
+# Find largest tensor parallel size that divides NUM_HEADS and is <= GPUS_PER_NODE
+TENSOR_PARALLEL=1
+for ((i=GPUS_PER_NODE; i>=1; i--)); do
+ if [ $((NUM_HEADS % i)) -eq 0 ] && [ $((VOCAB_SIZE % i)) -eq 0 ]; then
+ TENSOR_PARALLEL=$i
+ break
+ fi
+done
+
+echo "Using tensor parallel size: $TENSOR_PARALLEL"
+
+srun -l --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh \
+ --output=grpo_%j.out --error=grpo_%j.err \
+ --container-mounts=.:/grpo,$HF_HOME:$HF_HOME \
+ --nodes=$TRAIN_NODES_NUM --ntasks=$TRAIN_NODES_NUM --nodelist="${TRAIN_NODES}" \
+ bash -c "$LAUNCHER $CMD" &
+
+srun -l --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh \
+ --output=vllm_%j.out --error=vllm_%j.out \
+ --container-mounts=.:/grpo,$HF_HOME:$HF_HOME \
+ --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \
+ trl vllm-serve --model $MODEL --tensor_parallel_size $TENSOR_PARALLEL &
+
+wait
\ No newline at end of file