TRL documentation
Aligning Text-to-Image Diffusion Models with Reward Backpropagation
Aligning Text-to-Image Diffusion Models with Reward Backpropagation
The why
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO. AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.

Getting started with examples/scripts/alignprop.py
The alignprop.py
script is a working example of using the AlignProp
trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (AlignPropConfig
).
Note: one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a huggingface user access token that will be used to upload the model post-finetuning to HuggingFace hub. The following bash command is to be entered to get things running
python alignprop.py --hf_user_access_token <token>
To obtain the documentation of stable_diffusion_tuning.py
, please run python stable_diffusion_tuning.py --help
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
- The configurable randomized truncation range (
--alignprop_config.truncated_rand_backprop_minmax=(0,50)
) the first number should be equal and greater than 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps) - The configurable truncation backprop absolute step (
--alignprop_config.truncated_backprop_timestep=49
) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
Setting up the image logging hook function
Expect the function to be given a dictionary with keys
['image', 'prompt', 'prompt_metadata', 'rewards']
and image
, prompt
, prompt_metadata
, rewards
are batched.
You are free to log however you want the use of wandb
or tensorboard
is recommended.
Key terms
rewards
: The rewards/score is a numerical associated with the generated image and is key to steering the RL processprompt
: The prompt is the text that is used to generate the imageprompt_metadata
: The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of aFLAVA
setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)image
: The image generated by the Stable Diffusion model
Example code for logging sampled images with wandb
is given below.
# for logging these images to wandb
def image_outputs_hook(image_data, global_step, accelerate_logger):
# For the sake of this example, we only care about the last batch
# hence we extract the last element of the list
result = {}
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
for i, image in enumerate(images):
pil = Image.fromarray(
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((256, 256))
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
accelerate_logger.log_images(
result,
step=global_step,
)
Using the finetuned model
Assuming you’ve done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipeline.to("cuda")
pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)
for prompt, image in zip(prompts,results.images):
image.save(f"dump/{prompt}.png")
Credits
This work is heavily influenced by the repo here and the associated paper Aligning Text-to-Image Diffusion Models with Reward Backpropagation by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki.
AlignPropTrainer
class trl.AlignPropTrainer
< source >( config: AlignPropConfig reward_function: typing.Callable[[torch.Tensor, tuple[str], tuple[typing.Any]], torch.Tensor] prompt_function: typing.Callable[[], tuple[str, typing.Any]] sd_pipeline: DDPOStableDiffusionPipeline image_samples_hook: typing.Optional[typing.Callable[[typing.Any, typing.Any, typing.Any], typing.Any]] = None )
Parameters
- config (
AlignPropConfig
) — Configuration object for AlignPropTrainer. Check the documentation ofPPOConfig
for more details. - reward_function (
Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]
) — Reward function to be used - prompt_function (
Callable[[], tuple[str, Any]]
) — Function to generate prompts to guide model - sd_pipeline (
DDPOStableDiffusionPipeline
) — Stable Diffusion pipeline to be used for training. - image_samples_hook (
Optional[Callable[[Any, Any, Any], Any]]
) — Hook to be called to log images
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/ As of now only Stable Diffusion based pipelines are supported
Train the model for a given number of epochs
AlignPropConfig
class trl.AlignPropConfig
< source >( exp_name: str = 'doc-buil' run_name: str = '' seed: int = 0 log_with: typing.Optional[str] = None log_image_freq: int = 1 tracker_kwargs: dict = <factory> accelerator_kwargs: dict = <factory> project_kwargs: dict = <factory> tracker_project_name: str = 'trl' logdir: str = 'logs' num_epochs: int = 100 save_freq: int = 1 num_checkpoint_limit: int = 5 mixed_precision: str = 'fp16' allow_tf32: bool = True resume_from: str = '' sample_num_steps: int = 50 sample_eta: float = 1.0 sample_guidance_scale: float = 5.0 train_batch_size: int = 1 train_use_8bit_adam: bool = False train_learning_rate: float = 0.001 train_adam_beta1: float = 0.9 train_adam_beta2: float = 0.999 train_adam_weight_decay: float = 0.0001 train_adam_epsilon: float = 1e-08 train_gradient_accumulation_steps: int = 1 train_max_grad_norm: float = 1.0 negative_prompts: typing.Optional[str] = None truncated_backprop_rand: bool = True truncated_backprop_timestep: int = 49 truncated_rand_backprop_minmax: tuple = (0, 50) push_to_hub: bool = False )
Parameters
- exp_name (
str
, optional, defaults toos.path.basename(sys.argv[0])[ -- -len(".py")]
): Name of this experiment (defaults to the file name without the extension). - run_name (
str
, optional, defaults to""
) — Name of this run. - seed (
int
, optional, defaults to0
) — Random seed for reproducibility. - log_with (
str
orNone
, optional, defaults toNone
) — Log with either"wandb"
or"tensorboard"
. Check tracking for more details. - log_image_freq (
int
, optional, defaults to1
) — Frequency for logging images. - tracker_kwargs (
dict[str, Any]
, optional, defaults to{}
) — Keyword arguments for the tracker (e.g.,wandb_project
). - accelerator_kwargs (
dict[str, Any]
, optional, defaults to{}
) — Keyword arguments for the accelerator. - project_kwargs (
dict[str, Any]
, optional, defaults to{}
) — Keyword arguments for the accelerator project config (e.g.,logging_dir
). - tracker_project_name (
str
, optional, defaults to"trl"
) — Name of project to use for tracking. - logdir (
str
, optional, defaults to"logs"
) — Top-level logging directory for checkpoint saving. - num_epochs (
int
, optional, defaults to100
) — Number of epochs to train. - save_freq (
int
, optional, defaults to1
) — Number of epochs between saving model checkpoints. - num_checkpoint_limit (
int
, optional, defaults to5
) — Number of checkpoints to keep before overwriting old ones. - mixed_precision (
str
, optional, defaults to"fp16"
) — Mixed precision training. - allow_tf32 (
bool
, optional, defaults toTrue
) — Allowtf32
on Ampere GPUs. - resume_from (
str
, optional, defaults to""
) — Path to resume training from a checkpoint. - sample_num_steps (
int
, optional, defaults to50
) — Number of sampler inference steps. - sample_eta (
float
, optional, defaults to1.0
) — Eta parameter for the DDIM sampler. - sample_guidance_scale (
float
, optional, defaults to5.0
) — Classifier-free guidance weight. - train_batch_size (
int
, optional, defaults to1
) — Batch size for training. - train_use_8bit_adam (
bool
, optional, defaults toFalse
) — Whether to use the 8bit Adam optimizer frombitsandbytes
. - train_learning_rate (
float
, optional, defaults to1e-3
) — Learning rate. - train_adam_beta1 (
float
, optional, defaults to0.9
) — Beta1 for Adam optimizer. - train_adam_beta2 (
float
, optional, defaults to0.999
) — Beta2 for Adam optimizer. - train_adam_weight_decay (
float
, optional, defaults to1e-4
) — Weight decay for Adam optimizer. - train_adam_epsilon (
float
, optional, defaults to1e-8
) — Epsilon value for Adam optimizer. - train_gradient_accumulation_steps (
int
, optional, defaults to1
) — Number of gradient accumulation steps. - train_max_grad_norm (
float
, optional, defaults to1.0
) — Maximum gradient norm for gradient clipping. - negative_prompts (
str
orNone
, optional, defaults toNone
) — Comma-separated list of prompts to use as negative examples. - truncated_backprop_rand (
bool
, optional, defaults toTrue
) — IfTrue
, randomized truncation to different diffusion timesteps is used. - truncated_backprop_timestep (
int
, optional, defaults to49
) — Absolute timestep to which the gradients are backpropagated. Used only iftruncated_backprop_rand=False
. - truncated_rand_backprop_minmax (
tuple[int, int]
, optional, defaults to(0, 50)
) — Range of diffusion timesteps for randomized truncated backpropagation. - push_to_hub (
bool
, optional, defaults toFalse
) — Whether to push the final model to the Hub.
Configuration class for the AlignPropTrainer.
Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.