Finetune_SmolVLA_notebook.ipynb
6.6 KB · 214 lines · plaintext Raw
1 {
2 "cells": [
3 {
4 "cell_type": "markdown",
5 "metadata": {
6 "id": "NQUk3Y0WwYZ4"
7 },
8 "source": [
9 "# 🤗 x 🦾: Training SmolVLA with LeRobot Notebook\n",
10 "\n",
11 "Welcome to the **LeRobot SmolVLA training notebook**! This notebook provides a ready-to-run setup for training imitation learning policies using the [🤗 LeRobot](https://github.com/huggingface/lerobot) library.\n",
12 "\n",
13 "In this example, we train an `SmolVLA` policy using a dataset hosted on the [Hugging Face Hub](https://huggingface.co/), and optionally track training metrics with [Weights & Biases (wandb)](https://wandb.ai/).\n",
14 "\n",
15 "## ⚙️ Requirements\n",
16 "- A Hugging Face dataset repo ID containing your training data (`--dataset.repo_id=YOUR_USERNAME/YOUR_DATASET`)\n",
17 "- Optional: A [wandb](https://wandb.ai/) account if you want to enable training visualization\n",
18 "- Recommended: GPU runtime (e.g., NVIDIA A100) for faster training\n",
19 "\n",
20 "## ⏱️ Expected Training Time\n",
21 "Training with the `SmolVLA` policy for 20,000 steps typically takes **about 5 hours on an NVIDIA A100** GPU. On less powerful GPUs or CPUs, training may take significantly longer!\n",
22 "\n",
23 "## Example Output\n",
24 "Model checkpoints, logs, and training plots will be saved to the specified `--output_dir`. If `wandb` is enabled, progress will also be visualized in your wandb project dashboard.\n"
25 ]
26 },
27 {
28 "cell_type": "markdown",
29 "metadata": {
30 "id": "MOJyX0CnwA5m"
31 },
32 "source": [
33 "## Install conda\n",
34 "This cell uses `condacolab` to bootstrap a full Conda environment inside Google Colab.\n"
35 ]
36 },
37 {
38 "cell_type": "code",
39 "execution_count": null,
40 "metadata": {
41 "id": "QlKjL1X5t_zM"
42 },
43 "outputs": [],
44 "source": [
45 "!pip install -q condacolab\n",
46 "import condacolab\n",
47 "condacolab.install()"
48 ]
49 },
50 {
51 "cell_type": "markdown",
52 "metadata": {
53 "id": "DxCc3CARwUjN"
54 },
55 "source": [
56 "## Install LeRobot\n",
57 "This cell clones the `lerobot` repository from Hugging Face, installs FFmpeg (version 7.1.1), and installs the package in editable mode.\n"
58 ]
59 },
60 {
61 "cell_type": "code",
62 "execution_count": null,
63 "metadata": {
64 "id": "dgLu7QT5tUik"
65 },
66 "outputs": [],
67 "source": [
68 "!git clone https://github.com/huggingface/lerobot.git\n",
69 "!conda install ffmpeg=7.1.1 -c conda-forge\n",
70 "!cd lerobot && pip install -e ."
71 ]
72 },
73 {
74 "cell_type": "markdown",
75 "metadata": {
76 "id": "Q8Sn2wG4wldo"
77 },
78 "source": [
79 "## Weights & Biases login\n",
80 "This cell logs you into Weights & Biases (wandb) to enable experiment tracking and logging."
81 ]
82 },
83 {
84 "cell_type": "code",
85 "execution_count": null,
86 "metadata": {
87 "id": "PolVM_movEvp"
88 },
89 "outputs": [],
90 "source": [
91 "!wandb login"
92 ]
93 },
94 {
95 "cell_type": "markdown",
96 "metadata": {
97 "id": "zTWQAgX9xseE"
98 },
99 "source": [
100 "## Install SmolVLA dependencies"
101 ]
102 },
103 {
104 "cell_type": "code",
105 "execution_count": null,
106 "metadata": {
107 "id": "DiHs0BKwxseE"
108 },
109 "outputs": [],
110 "source": [
111 "!cd lerobot && pip install -e \".[smolvla]\""
112 ]
113 },
114 {
115 "cell_type": "markdown",
116 "metadata": {
117 "id": "IkzTo4mNwxaC"
118 },
119 "source": [
120 "## Start training SmolVLA with LeRobot\n",
121 "\n",
122 "This cell runs the `train.py` script from the `lerobot` library to train a robot control policy. \n",
123 "\n",
124 "Make sure to adjust the following arguments to your setup:\n",
125 "\n",
126 "1. `--dataset.repo_id=YOUR_HF_USERNAME/YOUR_DATASET`: \n",
127 " Replace this with the Hugging Face Hub repo ID where your dataset is stored, e.g., `pepijn223/il_gym0`.\n",
128 "\n",
129 "2. `--batch_size=64`: means the model processes 64 training samples in parallel before doing one gradient update. Reduce this number if you have a GPU with low memory.\n",
130 "\n",
131 "3. `--output_dir=outputs/train/...`: \n",
132 " Directory where training logs and model checkpoints will be saved.\n",
133 "\n",
134 "4. `--job_name=...`: \n",
135 " A name for this training job, used for logging and Weights & Biases.\n",
136 "\n",
137 "5. `--policy.device=cuda`: \n",
138 " Use `cuda` if training on an NVIDIA GPU. Use `mps` for Apple Silicon, or `cpu` if no GPU is available.\n",
139 "\n",
140 "6. `--wandb.enable=true`: \n",
141 " Enables Weights & Biases for visualizing training progress. You must be logged in via `wandb login` before running this."
142 ]
143 },
144 {
145 "cell_type": "code",
146 "execution_count": null,
147 "metadata": {
148 "id": "ZO52lcQtxseE"
149 },
150 "outputs": [],
151 "source": [
152 "!cd lerobot && python lerobot/scripts/train.py \\\n",
153 " --policy.path=lerobot/smolvla_base \\\n",
154 " --dataset.repo_id=${HF_USER}/mydataset \\\n",
155 " --batch_size=64 \\\n",
156 " --steps=20000 \\\n",
157 " --output_dir=outputs/train/my_smolvla \\\n",
158 " --job_name=my_smolvla_training \\\n",
159 " --policy.device=cuda \\\n",
160 " --wandb.enable=true"
161 ]
162 },
163 {
164 "cell_type": "markdown",
165 "metadata": {
166 "id": "2PBu7izpxseF"
167 },
168 "source": [
169 "## Login into Hugging Face Hub\n",
170 "Now after training is done login into the Hugging Face hub and upload the last checkpoint"
171 ]
172 },
173 {
174 "cell_type": "code",
175 "execution_count": null,
176 "metadata": {
177 "id": "8yu5khQGIHi6"
178 },
179 "outputs": [],
180 "source": [
181 "!huggingface-cli login"
182 ]
183 },
184 {
185 "cell_type": "code",
186 "execution_count": null,
187 "metadata": {
188 "id": "zFMLGuVkH7UN"
189 },
190 "outputs": [],
191 "source": [
192 "!huggingface-cli upload ${HF_USER}/my_smolvla \\\n",
193 " /content/lerobot/outputs/train/my_smolvla/checkpoints/last/pretrained_model"
194 ]
195 }
196 ],
197 "metadata": {
198 "accelerator": "GPU",
199 "colab": {
200 "gpuType": "A100",
201 "machine_shape": "hm",
202 "provenance": []
203 },
204 "kernelspec": {
205 "display_name": "Python 3",
206 "name": "python3"
207 },
208 "language_info": {
209 "name": "python"
210 }
211 },
212 "nbformat": 4,
213 "nbformat_minor": 0
214 }