Image
statistics image

Learning representations and associations with gradient descent

Summary
Jason Lee (Princeton University)
Sloan Center Room 380Y
Apr
16
Date(s)
Content

Machine learning has undergone a paradigm shift with the success of pretrained models. Pretraining models via gradient descent learn transferable representations that adapt to a wide swath of downstream tasks. However, significant prior theoretical work has demonstrated that in many regimes, overparameterized neural networks trained by gradient descent behave like kernel methods, and do not learn transferable representations. In this talk, we close this gap by demonstrating that there is a large class of functions which cannot be efficiently learned by kernel methods but can be easily learned with gradient descent on a neural network by learning representations that are relevant to the target task. We also demonstrate that these representations allow for efficient transfer learning, which is impossible in the kernel regime. Finally, I will demonstrate how pretraining learns associations for in-context learning with transformers. This leads to a systematic and mechanistic understanding of learning causal structures including the celebrated induction head identified by Anthropic.