Current research projects

The goal of my research is to develop and analyze efficient algorithms for machine learning, with a focus on statistical estimation and interactive machine learning.

Non-linear prediction methods

During the Foundations of Machine Learning program at the Simons Institute for the Theory of Computing in 2017, Misha Belkin and I began discussing how to provide a generalization theory for “modern” non-linear prediction methods (like neural networks and kernel machines). This was spurred by the apparent lack of theory for explaining some statistical phenomena observed about neural networks as used in modern applications, and has led to interesting research questions for neural networks and as well as other commonly-used non-linear prediction methods.

Risk bounds for prediction rules that interpolate

Interpolation of training data is a property shared by many recent practical approaches to fitting large, complex models like deep neural networks and kernel machines. Conventional wisdom in machine learning and statistics would predict that the fitted models would not perform well out-of-sample, especially when the training data are noisy. However, in practice, these fitted models often perform well on new data. We proved that certain non-parametric interpolating methods achieve statistical consistency (or “near consistency” in high dimensions) and, in some cases, (optimal) rates of convergence. These methods take advantage of local averaging in high dimensions, and we conjecture that commonly used methods like deep neural networks and kernel machines that interpolate training data are effective for similar reasons.

Interaction of over-parameterization and inductive bias

Using a combination of empirical and theoretical analysis, we extend the classical “bias-variance” trade-off curve beyond the point of interpolation for a broad class of commonly-used machine learning models (including certain neural networks).

When the number of parameters p is below the sample size n, the test risk is governed by the usual bias-variance decomposition. As p is increased towards n, the training risk (i.e., in-sample error) is driven to zero, but the test risk shoots up towards infinity. The classical bias-variance analysis identifies a “sweet spot” value of p ∈ [0, n] at which the bias and variance are balanced to achieve low test risk. However, as p grows beyond n, the test risk again decreases, provided that the model is fit using a suitable inductive bias (e.g., least norm solution). In many (but not all) cases, the limiting risk as p → ∞ is lower than what is achieved at the “sweet spot” value of p.

We empirically show that this “double descent” risk curve manifests for many machine learning models on a variety of data sets, and also rigorously prove its emergence in some simple models with so-called “weak features”.

There is some work by Muthukumar, Vodrahalli, and Sahai and Hastie, Montanari, Rosset, and Tibshirani that also explores this phenomena (concurrent with the “Two models” paper).

Algorithmic statistics

A large chunk of my research (over the past decade) has focused on algorithms for statistical estimation, especially for hidden variable models. My research has contributed new algorithmic techniques and analyses for various estimation problems where maximum likelihood is non-trivial or intractable.

Method-of-moments via tensor decomposition

Building on work started in my doctoral research on efficient algorithms for learning hidden Markov models (HMMs), we developed a general method for parameter estimation in hidden variable models based on decompositions of higher-order moment tensors. The method is a special case of the method-of-moments in which the moment equations are approximately solved using an efficient algorithm for decomposing a moment tensor into its atomic (“rank-1”) components. We have applied this method (in its various forms) to obtain computationally efficient parameter estimation algorithms for Gaussian mixture models, topic models, and many other models.

We have also developed iterative schemes for tensor decomposition that permit the use of generic optimization methods in each step of the iterative process. These schemes are also more practical than sum-of-squares methods but they improve on the robustness of previous iterative methods (like the tensor power iteration analyzed in my previous work).

More recently, we applied tensor decomposition methods to estimation in single-index models, a popular class of semi-parametric models. Many special cases of single-index models are well-studied in many contexts (e.g., generalized linear models, phase retrieval), and numerous procedures have been proposed for estimating the parametric component of these models. We showed that several of these procedures are special cases of a more general procedure related to higher-order tensor decompositions when one expands the (unknown) link function in an orthogonal polynomial basis. Our result establishes consistent semi-parametric estimation in single-index models under very weak but natural conditions on the unknown link function. This improves on earlier works for single-index models that make ad hoc assumptions about the link function that may not be necessary in general. Our work provides a new basis on which to understand estimation for index models and related problems concerning effective dimensionality reduction.


Local optimization procedures based on the maximum likelihood principle are often preferred by practitioners for a variety of reasons, including algorithmic simplicity and interpretability of the objective function. It is therefore important to understand the behavior of such procedures. To this end, we have studied the ubiquitous Expectation-Maximization (EM) algorithm for some important Gaussian mixture models.

First, we considered uniform mixtures of two multivariate Gaussians with shared (and fixed) covariances. We proved that for almost any initialization (i.e., except on a set of measure zero that we explicitly characterize), the fixed point of EM is a consistent estimator of the model parameters. In contrast to previous works on EM, our result does not require the strong distributional or initialization assumptions. This is the first such “global” consistency result for EM for a non-trivial model. It is important to note that analogous results do not hold for even slightly richer mixture models, including uniform mixtures of three or more isotropic Gaussians, and certain non-uniform mixtures of two Gaussians.

We showed how over-parameterization can be used to overcome the failure of EM in the case of non-uniform mixtures of Gaussians. Specifically, we show that even when data is generated by a non-uniform mixture of two isotropic Gaussians where the mixing weights are known, it is computationally beneficial to use the larger statistical model where the mixing weights are treated as unknown parameters to be estimated alongside the mixture component means. For this over-parameterized model, the log-likelihood landscape no longer has the spurious fixed points that were present in the log-likelihood of the “correctly parameterized” model. This is the first theoretical explanation for the computational benefits of over-parameterization in hidden variable models. The result has implications for model selection in statistical practice.

Non-statistical algorithms

We have been developing a new algorithmic technique for certain signal recovery problems that appears to go beyond the capabilities of previous efficient algorithms, including those based on tensor decompositions.

The first problem we consider is a generalization of the phase retrieval problem: there are k unknown signals to be recovered from their (possibly noisy) linear measurements, but the correspondence between the measurements and the signals is missing. We first showed that this “correspondence retrieval problem” (as we call it) admits an estimation procedure based on tensor decompositions, but requires the sample size to grow polynomially with the number unknown signals. But, surprisingly, when the measurements are noiseless, we found an efficient algorithm for recovering the unknown signals from the minimum number of measurements (d + 1 for signals in d-dimensional space).

The algorithm is based on reducing the recovery problem to a random instance of a subset sum problem; we then use an argument similar to that of Lagarias and Odlyzko and Frieze for reducing the subset sum problem to an approximate shortest vector problem, which can be solved using the lattice basis reduction algorithm of Lenstra, Lenstra, and Lovasz. The resulting algorithm is “non-statistical” in the sense that it does not appear to permit an implementation based solely on statistical queries. (The other well-known example of such a “non-statistical” algorithm is the Gaussian elimination method for learning parity functions.)

The second problem we consider is that of linear regression in a scenario where the correspondence between the inputs and outputs is missing. This problem arises in many signal processing applications. We again showed that in a noise-free version of the problem, there is an efficient algorithm based on lattice basis reduction. All previous algorithms for this problem were either restricted to scalar inputs or had running time exponential in both the sample size and dimension. Our work also gives new results on worst-case computational complexity and the average-case information-theoretic complexity of the problem. (This work was completed while participating in the Foundations of Machine Learning program in 2017.)

Interactive learning

Interactive machine learning involves agents that both learn from data and collect data. For example, the agent may be the learning procedure used by a data scientist: it interacts with the data scientist to jointly construct an accurate classifier (e.g., active learning). As another example, the agent may be responsible for selecting the content for a news website: it interacts with visitors of the website both to recommend interesting news articles and to learn about the preferences of the visitors (e.g., contextual bandits). These scenarios extend beyond the standard frameworks for understanding machine learning algorithms, and thus require new frameworks, algorithms, and analytic techniques.

Contextual bandits

The contextual bandit problem is an online decision problem where an agent repeatedly (1) observes the current context, (2) chooses an action to take, and (3) observes and collects a reward associated with the chosen action. Importantly, the agent does not observe the rewards for actions that were not taken. The challenge is to control the agent in a way that guarantees a small difference in cumulative reward (i.e., low regret) relative to that of the best policy from a fixed reference class of functions mapping contexts to actions. The contextual bandit problem is at the heart of many applications, ranging from online advertising (selecting the ad to display to a website visitor) to personalized medicine (selecting the treatment for a patient), and can be regarded as a tractable special-case of reinforcement learning.

We have developed and analyzed new algorithms for contextual bandits based on supervised learning oracles. Previous algorithms for contextual bandits either required explicit enumeration over the reference policy class or had suboptimal regret bounds. The new algorithms accesses policies only through the supervised learning oracle, has a near-optimal regret bound, and uses the oracle sparingly. The uses of supervised learning oracles in these works are based on techniques from my earlier work on active learning, and they elucidate the role of supervised learning technology in contextual bandit problems. This general approach to algorithm design has since been used in several other works in machine learning.

Active learning

Active learning models the interaction between an annotator and a learning algorithm: the goal is to (adaptively) label only a few data points but still learn an accurate predictor. Active learning is essential for machine learning applications where unlabeled data are plentiful or easily acquired, but labels are costly to obtain (e.g., due to the required human effort). My work (which includes my doctoral research) provides theoretical foundations for active learning, which are far more complex than standard models for supervised machine learning, due to the adaptive nature by which data is acquired in active learning.

In a recent work, we showed how label querying decisions can be formulated as an optimization problem, and developed and analyzed an algorithm based on this formulation.

We also formalized another type of interaction in which the learning algorithm may request labeled examples that satisfy certain criteria (a positively-labeled example, an example that contradicts a set of candidate classifiers, etc.). We call such requests “search queries”, and developed an active learning strategy for using both search and label queries.

Beyond learning

More recently, I have been excited about the use of interaction to facilitate other tasks beyond concept learning.

During our time at the Foundations of Machine Learning program in 2017 (quite a productive semester!), Sanjoy Dasgupta and I worked on formalizing the benefit of interaction in machine teaching. We show some inherent limitations to non-interactive teaching when the teacher does not know the hypothesis class used by the learner. To remedy this, we also show that an interactive teacher, who is permitted to “quiz” the learner, can identify a “teaching set” of size comparable to that achievable by a teacher who knows the learner’s hypothesis class.

Data science

I have collaborations with researchers across the university who are applying “data science” to problems in their domains.

Information extraction and public health

Together with my Computer Science colleague, Luis Gravano, we have been working with the New York City Department of Health and Mental Hygiene (DOHMH) on using social media to identify evidence of foodborne illness incidents in New York City restaurants that warrant an investigation by the public health officials. We use machine learning methods to improve information extraction from noisy and informal text data. This project is supported by the NSF under Grant IIS-15-63785.

Fairness, accountability, & transparency

With other colleagues in Computer Science, we have applied machine learning and statistical inference to help make data-driven applications more transparent. In “Sunlight”, we showed how randomized experimental designs can be used to infer causal effects of online profile attributes in ad targeting systems. In “FairTest”, we showed how supervised machine learning techniques can be used to discover and measure unwarranted / potentially discriminatory associations encoded in predictive models.

Weak gravitational lensing

I have been collaborating with Zoltan Haiman from the Columbia Astronomy Department on applications of machine learning to cosmology. One high-level goal is to use machine learning to tackle inferential tasks in cosmology, which we anticipate will lead to a better understanding of the nature of dark energy. Towards this end, we trained deep neural networks to extract non-Gaussian information from weak lensing observations. These networks provide alternatives to previously-used descriptors (i.e., statistics) such as two-point correlations and peak counts. Using weak lensing maps obtained from a suite of n-body simulations, we found that the deep neural networks provided tighter constraints on cosmological parameters than previously proposed statistics. Our approach has since been adopted by other research groups working on cosmology.