1 2 3 4 5 6 7 8 9 10 11 12--- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -501,7 +501,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): } metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop("entropys") + # Keep entropys in batch so advantage_fn (e.g. Clip_B) can use it + # old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) if "rollout_log_probs" in batch.batch.keys(): # TODO: we may want to add diff of probs too.