Generalization and Stability in Interpolating Neural Networks
- Hossein Taheri (UCSB)
Abstract
Neural networks are renowned for their ability to memorize datasets, often achieving near-zero training loss via gradient descent optimization. Despite this capability, they also demonstrate remarkable generalization to new data. This paper delves into studying the generalization behavior of neural networks trained with logistic loss through the lens of algorithmic stability. Our focus lies on the neural tangent regime, where network weights move a constant distance from initialization to solution to achieve minimal training loss. Our main finding reveals that under NTK-separability, optimal test loss bounds are achievable if the network width is at least poly-logarithmically large with respect to the number of training samples. This departure from existing generalization outcomes using algorithmic stability, which typically require polynomial width and yield suboptimal rates, underscores the significance of our approach. Moreover, our analysis presents improved generalization bounds and width lower bounds compared to prior works employing alternative methods such as uniform convergence via Rademacher complexity. The key to this improvement lies in leveraging the Hessian information of the objective function during gradient descent iterates. We demonstrate that neural networks of sufficiently large width trained by the logistic loss satisfy an approximate quasi-convexity property along the gradient descent path. To demonstrate the practical implications of our findings, we specialize our analysis to a XOR dataset, where we present refined width conditions.
Biography: Hossein Taheri received the B.Sc. degree in electrical engineering and mathematics from the Sharif University of Technology, Tehran, Iran, in 2018. He is currently pursuing the Ph.D. degree in electrical and computer engineering under the guidance of Christos Thrampoulidis at the University of California at Santa Barbara. His main area of research is statistical learning and optimization.