Analyses of Policy Gradient for Language Model Finetuning and Optimal Control
- Noam Razin (Tel Aviv University)
Abstract
Gradient-based methods are the workhorse behind modern machine learning. While they have been extensively studied in the basic framework of supervised learning, they are far less understood in the framework of optimal control, which in its broadest form is equivalent to reinforcement learning. There, algorithms that learn a policy via gradient updates are known as policy gradient methods. In this talk, I will present two recent works analyzing the optimization dynamics and implicit bias of policy gradient, in different contexts. The first work identifies a vanishing gradients problem that occurs when using policy gradient to finetune language models. I will demonstrate the detrimental effects of this phenomenon and present possible solutions. The second work characterizes how the implicit bias of policy gradient affects extrapolation to initial states unseen in training, focusing on the fundamental Linear Quadratic Regulator (LQR) control problem. Overall, our results highlight that the optimization dynamics and implicit bias of policy gradient can substantially differ from those of gradient-based methods in supervised learning, hence require dedicated study.
Works covered in the talk were in collaboration with Nadav Cohen, Yotam Alexander, Edo Cohen-Karlik, Raja Giryes, Amir Globerson, Hattie Zhou, Omid Saremi, Vimal Thilak, Arwen Bradley, Preetum Nakkiran, Joshua Susskind, and Etai Littwin.