Skip to content

GRPO on Sagemaker using TRL Tutorial#2596

Draft
dwarez wants to merge 4 commits into
mainfrom
grpo-llm-trl-tutorial
Draft

GRPO on Sagemaker using TRL Tutorial#2596
dwarez wants to merge 4 commits into
mainfrom
grpo-llm-trl-tutorial

Conversation

@dwarez

@dwarez dwarez commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

The purpose of this notebook is to guide users in using Sagemaker to run a GRPO training with verifiable rewards using TRL.

Note:

  • I selected SmolLM3 as the model to fine-tune because more recent, strong model of approximately the same size tend to ace the dataset used (Salesforce/xlam-function-calling-60k) therefore the advantages computed during GRPO would be mostly null
  • The image used was purposely build for this and for TRL v1.7 (which should be released today/next few days) therefore the image uri it's temporary, we need to replace it with the actual image uri that we will publish
  • I had many issues with running this in Sagemaker jobs (while on EC2 worked perfectly fine), this is the reason behind the launcher.sh workaround you can find in the notebook. I guess it's somewhat connected to the fact that I didn't manage to make it work using vLLM as generation backend (one of the issues GRPO Trainer crashes with Could not infer dtype of NoneType when vLLM returns a NaN token logprob trl#6166)

cc @alvarobartt


Note

Low Risk
Documentation-only addition (notebook); no changes to library or runtime code paths.

Overview
Adds a new SageMaker SDK v3 tutorial notebook that walks through verifiable-reward GRPO on HuggingFaceTB/SmolLM3-3B for single-turn tool calling on Salesforce/xlam-function-calling-60k.

The notebook materializes the full job bundle via %%writefile: rewards.py (exact-match + format rewards on <tool_call> JSON), train.py (GRPOTrainer, DAPO loss, collapse early-stop, JSONL metrics), launch.sh (torchrun per GPU as the SageMaker entry), and DeepSpeed ZeRO-2 config. It covers dataset prep (chat prompts + hidden answers), S3 upload, ModelTrainer launch on ml.g6.12xlarge with a temporary TRL GRPO ECR image, and post-job reward-curve plotting from S3/CloudWatch/notebook output.

The documented run is intentionally short (smoke / wiring validation); DEBUG_NO_UPDATE supports a no-gradient smoke test.

Reviewed by Cursor Bugbot for commit cc6dcec. Bugbot is set up for automated code reviews on this repo. Configure here.

Signed-off-by: DWarez <dario.salvati@huggingface.co>
@dwarez dwarez marked this pull request as draft June 25, 2026 16:27

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Want higher recall? High effort reviews run extra passes and find more bugs. A team admin can switch effort levels in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit cc6dcec. Configure here.

"elif \"trainer\" in globals() and hasattr(trainer, \"latest_training_job\"):\n",
" job = trainer.latest_training_job.name\n",
"elif \"base_job_name\" in globals():\n",
" job = base_job_name\n",

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong SageMaker job name

Medium Severity

When resolving the training job, the notebook sets job to base_job_name if trainer.latest_training_job is missing. SageMaker assigns a unique suffix to the actual TrainingJobName, so describe_training_job with only base_job_name fails even right after launch.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit cc6dcec. Configure here.

" if e.response.get(\"Error\", {}).get(\"Code\") != \"AccessDeniedException\":\n",
" raise\n",
" print(\"CloudWatch Logs access denied; parsing saved notebook output instead.\")\n",
" records = notebook_metric_records()\n",

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong notebook fallback path

Low Severity

The CloudWatch fallback calls notebook_metric_records() with default path notebook.ipynb, but this tutorial file is sagemaker-notebook.ipynb. When S3 and CloudWatch both fail, the cell raises FileNotFoundError instead of parsing saved outputs.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit cc6dcec. Configure here.

"MODEL_ID = \"HuggingFaceTB/SmolLM3-3B\" # SmolLM3, HF's 3B instruct model\n",
"INSTANCE_TYPE = \"ml.g6.12xlarge\" # 4xL4 24GB\n",
"\n",
"TRAINING_IMAGE = \"754289655784.dkr.ecr.us-east-1.amazonaws.com/hf-trl-grpo:sagemaker-trl-dev-e63f67e\""

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this must be changed before merging

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

dwarez added 2 commits June 26, 2026 10:44
change: using zero3 instead of zero2

Signed-off-by: DWarez <dario.salvati@huggingface.co>
The workaround I did before is totally unnecessary for ml.p4d.24xlarge
therefore I think it was due to some hardware related issue (like high
memory pressure or comms bandwidth). Since we're keeping ml.p4d.24xlarge
as instance type in order to also run the reference model, we can safely
revert back to not use the workaround

Signed-off-by: DWarez <dario.salvati@huggingface.co>

@alvarobartt alvarobartt left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's crazy that in big 2026 we cannot review / comment on Jupyter Notebooks in GitHub without third-parties...

So here's the review with minor nits:

import os
from huggingface_hub import login, get_token

if not os.environ.get("HF_TOKEN"):
    login()
HF_TOKEN = get_token()
assert HF_TOKEN, "No HF token found — set HF_TOKEN or run login()"
print("HF token loaded")
  • # The training image lives in us-east-1, so keep the job, the S3 bucket and the image in one region. AFAIK AWS SageMaker AI is the one creating the bucket so no need for this comment IMO

  • TRAINING_IMAGE = "754289655784.dkr.ecr.us-east-1.amazonaws.com/hf-trl-grpo:sagemaker-trl-dev-e63f67e" is this container public at the end? Or temporarily until the official one is released?

  • For the following cell, either remove the print or keep the output in the Jupyter Notebook cell IMO

import json
from datasets import load_dataset

raw = load_dataset("Salesforce/xlam-function-calling-60k", split="train")
print(raw)
print(json.dumps({k: raw[0][k] for k in ("query", "tools", "answers")}, indent=2)[:1200])
  • With the following Generation uses the Transformers backend. you mean instead of TRL's default vLLM backend? If so, I'd add the clarification.

  • ZeRO-3 shards the model, ... with DeepSpeed ZeRO-3 shards the model, ... to make it explicit.

  • There are some occurrences of eight A100 which I'd replace with 8 x A100 40GB for consistency i.e., as A100 comes in both 40GB and 80GB flavors I'd write it that way + using the "N x A100" notation as I feel like it's more readable than "eight A100".

  • Not sure if not displayed because Jupyter Notebook rendering is broken on GitHub, but I don't see the plots, so it'd be nice to include those too for the reader to understand how you interpret those metrics.

  • Maybe add a ## References section at the end?

Great work @dwarez, I'd maybe review the nits / comments above and add a simple explanation (or even diagram) on what GRPO is and how it works, with maybe mentions to DeepSeek R1? Nonetheless the example is great, and I really think that extending our collection on training will be great 🔥

add: references

Signed-off-by: DWarez <dario.salvati@huggingface.co>
@dwarez

dwarez commented Jun 27, 2026

Copy link
Copy Markdown
Contributor Author

hi @alvarobartt, thanks a lot for the feedback 🤗

I applied your suggestions, here some other points:

  • the image is still in our private ECR, I'm waiting an answer from the infra team regarding the possibility to get a public ECR. In this way we can release this (and other) tutorials without the need to wait for DLCs releases (which should become faster when we will get access to the v2 of the build system)
  • in GRPO trainer, use_vllm defaults to false, that's why I phrased that way. I made some tests with vllm but for some reason the compilation of SmolLM3 broke and I didn't investigate much on that
  • the curve plots are visible on my and when I commit, but yeah I also don't see them when looking at the docs preview, I'm not sure why
  • I added a markdown cell for a short explanation about GRPO, let me know what you think. If that's too brief I can extend that

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants