TRL documentation
PAPO Trainer
PAPO Trainer
TRL supports the Perception-Aware Policy Optimization (PAPO) as described in the paper Perception-Aware Policy Optimization for Multimodal Reasoning by Zhenhailong Wang, Xuehang Guo, Sofia Stoica, Haiyang Xu, Hongru Wang, Hyeonjeong Ha, Xiusi Chen, Yangyi Chen, Ming Yan, Fei Huang, Heng Ji
The abstract from the paper is the following:
Reinforcement Learning with Verifiable Rewards (RLVR) has proven to be a highly effective strategy for endowing Large Language Models (LLMs) with robust multi-step reasoning abilities. However, its design and optimizations remain tailored to purely textual domains, resulting in suboptimal performance when applied to multimodal reasoning tasks. In particular, we observe that a major source of error in current multimodal reasoning lies in the perception of visual inputs. To address this bottleneck, we propose Perception-Aware Policy Optimization (PAPO), a simple yet effective extension of GRPO that encourages the model to learn to perceive while learning to reason, entirely from internal supervision signals. Notably, PAPO does not rely on additional data curation, external reward models, or proprietary models. Specifically, we introduce the Implicit Perception Loss in the form of a KL divergence term to the GRPO objective, which, despite its simplicity, yields significant overall improvements (4.4%) on diverse multimodal benchmarks. The improvements are more pronounced, approaching 8.0%, on tasks with high vision dependency. We also observe a substantial reduction (30.5%) in perception errors, indicating improved perceptual capabilities with PAPO. We conduct comprehensive analysis of PAPO and identify a unique loss hacking issue, which we rigorously analyze and mitigate through a Double Entropy Loss. Overall, our work introduces a deeper integration of perception-aware supervision into RLVR learning objectives and lays the groundwork for a new RL framework that encourages visually grounded reasoning. Project page: https://mikewangwzhl.github.io/PAPO.
PAPOTrainer
class trl.experimental.papo.PAPOTrainer
< source >( model: typing.Union[str, transformers.modeling_utils.PreTrainedModel] reward_funcs: typing.Union[str, transformers.modeling_utils.PreTrainedModel, typing.Callable[[list, list], list[float]], list[typing.Union[str, transformers.modeling_utils.PreTrainedModel, typing.Callable[[list, list], list[float]]]]] args: typing.Optional[trl.experimental.papo.papo_config.PAPOConfig] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, dict[str, typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset]], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.processing_utils.ProcessorMixin, NoneType] = None reward_processing_classes: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, list[transformers.tokenization_utils_base.PreTrainedTokenizerBase], NoneType] = None callbacks = None optimizers = (None, None) peft_config = None )
Parameters
- model (
Union[str, PreTrainedModel]) — Model to be trained (must be a vision-language model). - reward_funcs (
Union[RewardFunc, list[RewardFunc]]) — Reward functions for computing rewards (same as GRPO). - args (
PAPOConfig, optional, defaults toNone) — Configuration for this trainer. IfNone, a default configuration is used. - train_dataset (Dataset or IterableDataset) — Dataset to use for training. Must include “prompt” and “image” columns.
- eval_dataset — Same requirements as train_dataset.
- processing_class — Processing class (tokenizer/processor) for the model.
- reward_processing_classes — Processing classes for reward models.
- callbacks — Training callbacks.
- optimizers — Optimizer and scheduler tuple.
- peft_config — PEFT configuration if using parameter-efficient fine-tuning.
Trainer for Perception-Aware Policy Optimization (PAPO).
PAPO extends GRPO/DAPO for multimodal reasoning by adding an implicit perception loss that encourages the model to better utilize visual information. The key innovation is computing KL divergence between model outputs on original vs. corrupted (masked) images.
Two variants are supported:
- PAPO-G: PAPO + GRPO (use loss_type=“grpo”)
- PAPO-D: PAPO + DAPO (use loss_type=“dapo”)
Example:
from datasets import load_dataset
from trl import PAPOTrainer, PAPOConfig
dataset = load_dataset("your-vlm-dataset", split="train")
def reward_func(completions, **kwargs):
# Your reward function for multimodal reasoning
return [compute_reward(c) for c in completions]
# PAPO-G
config = PAPOConfig(
loss_type="grpo", # Use GRPO as base
perception_loss_weight=0.1,
mask_ratio=0.3,
)
# PAPO-G
config = PAPOConfig(
loss_type="dapo", # Use DAPO as base
perception_loss_weight=0.1,
mask_ratio=0.3,
)
trainer = PAPOTrainer(
model="Qwen/Qwen2-VL-2B-Instruct",
reward_funcs=reward_func,
args=config,
train_dataset=dataset,
)
trainer.train()train
< source >( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs: typing.Any )
Parameters
- resume_from_checkpoint (
strorbool, optional) — If astr, local path to a saved checkpoint as saved by a previous instance ofTrainer. If abooland equalsTrue, load the last checkpoint in args.output_dir as saved by a previous instance ofTrainer. If present, training will resume from the model/optimizer/scheduler states loaded here. - trial (
optuna.Trialordict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. - kwargs (
dict[str, Any], optional) — Additional keyword arguments used to hide deprecated arguments
Main training entry point.
Will save the model, so you can reload it using from_pretrained().
Will only save from the main process.
push_to_hub
< source >( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )
Parameters
- commit_message (
str, optional, defaults to"End of training") — Message to commit while pushing. - blocking (
bool, optional, defaults toTrue) — Whether the function should return only when thegit pushhas finished. - token (
str, optional, defaults toNone) — Token with write permission to overwrite Trainer’s original args. - revision (
str, optional) — The git revision to commit from. Defaults to the head of the “main” branch. - kwargs (
dict[str, Any], optional) — Additional keyword arguments passed along to~Trainer.create_model_card.
Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.
PAPOConfig
class trl.experimental.papo.PAPOConfig
< source >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: float = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: bool = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = False label_names: typing.Optional[list[str]] = None load_best_model_at_end: bool = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = None fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None parallelism_config: typing.Optional[accelerate.parallelism_config.ParallelismConfig] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch_fused' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: str = 'length' report_to: typing.Union[NoneType, str, list[str]] = None project: str = 'huggingface' trackio_space_id: typing.Optional[str] = 'trackio' ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = True gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: bool = False include_num_input_tokens_seen: typing.Union[str, bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: bool = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: bool = False average_tokens_across_devices: bool = True model_init_kwargs: typing.Union[dict, str, NoneType] = None disable_dropout: bool = False cast_lm_head_to_fp32: bool = False max_prompt_length: typing.Optional[int] = 512 num_generations: typing.Optional[int] = 8 max_completion_length: typing.Optional[int] = 256 ds3_gather_for_generation: bool = True shuffle_dataset: typing.Optional[bool] = True generation_batch_size: typing.Optional[int] = None steps_per_generation: typing.Optional[int] = None temperature: float = 1.0 top_p: float = 1.0 top_k: typing.Optional[int] = None min_p: typing.Optional[float] = None generation_kwargs: typing.Optional[dict] = None chat_template_kwargs: typing.Optional[dict] = None repetition_penalty: float = 1.0 use_transformers_paged: bool = False cache_implementation: typing.Optional[str] = None use_vllm: bool = False vllm_mode: str = 'server' vllm_model_impl: str = 'vllm' vllm_enable_sleep_mode: bool = False vllm_guided_decoding_regex: typing.Optional[str] = None vllm_server_base_url: typing.Optional[str] = None vllm_server_host: str = '0.0.0.0' vllm_server_port: int = 8000 vllm_server_timeout: float = 240.0 vllm_gpu_memory_utilization: float = 0.3 vllm_tensor_parallel_size: int = 1 beta: float = 0.0 num_iterations: int = 1 epsilon: float = 0.2 delta: typing.Optional[float] = None epsilon_high: typing.Optional[float] = None importance_sampling_level: str = 'token' reward_weights: typing.Optional[list[float]] = None scale_rewards: str = 'group' loss_type: str = 'dapo' mask_truncated_completions: bool = False sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.6 ref_model_sync_steps: int = 512 top_entropy_quantile: float = 1.0 use_liger_loss: bool = None vllm_importance_sampling_correction: bool = True vllm_importance_sampling_cap: float = 2.0 log_completions: bool = False num_completions_to_print: typing.Optional[int] = None wandb_log_unique_prompts: typing.Optional[bool] = False perception_loss_weight: float = 0.1 mask_ratio: float = 0.3 mask_type: typing.Literal['random', 'patch', 'grid'] = 'random' der_loss_weight1: float = 0.03 der_loss_weight2: float = 0.03 )
Parameters
- perception_loss_weight (
float, optional, defaults to0.1) — gamma Weight coefficient for the perception loss term. This encourages the model to be sensitive to visual changes. - mask_ratio (
float, optional, defaults to0.3) — Ratio of the image to mask when computing perception loss. - mask_type (
Literal["random", "patch", "grid"], optional, defaults to"random") — Type of masking strategy to use. - der_loss_weight1 (
float, optional, defaults to0.03) — eta1 Weight coefficient for the Double Entropy Regularization (DER) term. This term encourages confident predictions with original images (low entropy) and uncertain predictions with masked images (high entropy). - der_loss_weight2 (
float, optional, defaults to0.03) — eta2 Weight coefficient for the Double Entropy Regularization (DER) term. This term encourages confident predictions with original images (low entropy) and uncertain predictions with masked images (high entropy). - loss_type (
Literal["grpo", "dapo"], inherited from GRPOConfig) — Base loss type to use. Set to “grpo” for PAPO-G or “dapo” for PAPO-D.
Configuration class for PAPOTrainer.
PAPO (Perception-Aware Policy Optimization) extends GRPO/DAPO for multimodal reasoning by adding an implicit perception loss and double entropy regularization.