Provable Non-Convex Optimization For ML: Prateek Jain Microsoft Research India
Provable Non-Convex Optimization For ML: Prateek Jain Microsoft Research India
Optimization for ML
Prateek Jain
Microsoft Research India
http://research.microsoft.com/en-us/people/prajain/
Overview
• High-dimensional Machine Learning
• Many many parameters
• Impose structural assumptions
• Requires solving non-convex optimization
• In general NP-hard
• No provable generic optimization tools
2
Overview
• Most popular approach: convex relaxation
• Solvable in poly-time
• Guarantees under certain assumptions
• Slow in practice
Theoretically
Theoretically
Practical Practical ProvableProvable
AlgorithmsAlgorithms Algorithms
Algorithms
For High-d ML Problems
For High-d ML Problems ForProblems
For High-d ML High-d ML Problems
3
Learning in Large No. of Dimensions
(
f
) {Learning, Optimization}
0 3 0 1… 0 … … …… …2 ... 0 9
𝑑 4
Linear Model
𝑓 𝑥 = 𝑤𝑖 𝑥𝑖 = 〈𝑤, 𝑥〉
𝑖
• 𝑤: 𝑑 −dimensional vector
• No. of training samples: 𝑛 = 𝑂(𝑑)
• For bi-grams: 𝑛 = 1000𝐵 documents!
• Prediction and storage: O(d)
• Prediction time per query: 1000 secs
• Over-fitting
00 3 3 0 0… … …… … … ... …0 9... … 0 … 9… … 22 1
5
Another Example: Low-rank Matrix Completion
6
Key Issues
• Large no. of training samples required
• Large training time
• Large storage and prediction time
7
Learning with Structure
• Restrict the parameter space
• Linear classification/regression: 𝑓 𝑥 = 〈𝑤, 𝑥〉
• Restrict no. of zeros in 𝑤 to 𝑠 ≪ 𝑑
0 3 0 0 1 0 0 0 0 9
8
Learning with Structure contd…
• Matrix completion:
𝑟
≅ ×
𝑊 ≅𝑈 × 𝑉 𝑇 • W: characterized by U, V
• No. of variables:
𝑑2 • U: d1 × 𝑟 = 𝑑1 𝑟
• V: d2 × 𝑟 = 𝑑2 𝑟 9
Learning with structure Data Fidelity
Function
min 𝐿(𝑤)
𝑤
𝑠. 𝑡. 𝑤 ∈ 𝐶
• Linear classification/regression
• 𝐶 = {𝑤, ||𝑤||0 ≤ 𝑠} • Comp. Complexity: NP-Hard
• ||𝑤||0 : Non-convex
• 𝑠 log 𝑑 ≪ 𝑑
• Matrix completion
• 𝐶 = {𝑊, 𝑟𝑎𝑛𝑘 𝑊 ≤ 𝑟}
• 𝑟(𝑑1 + 𝑑2 ) ≪ 𝑑1 𝑑2 • Comp. Complexity: NP-Hard
• 𝑟𝑎𝑛𝑘(𝑊): Non-convex 10
Other Examples
• Complexity: undecidable
• Low-rank Tensor completion
• 𝑡𝑒𝑛𝑠𝑜𝑟 − 𝑟𝑎𝑛𝑘 𝑊 : Non-convex
• 𝐶 = {𝑊, 𝑡𝑒𝑛𝑠𝑜𝑟 − 𝑟𝑎𝑛𝑘 𝑊 ≤ 𝑟}
• 𝑟(𝑑1 + 𝑑2 + 𝑑3 ) ≪ 𝑑1 𝑑2 𝑑3
• Robust PCA
• 𝐶 = {𝑊, 𝑊 = 𝐿 + 𝑆, 𝑟𝑎𝑛𝑘 𝐿 ≤ 𝑟, ||𝑆||0 ≤ 𝑠}
• 𝑟 𝑑1 + 𝑑2 + 𝑠 log(𝑑1 + 𝑑2 ) ≪ 𝑑1 𝑑2
• Complexity: NP-Hard
• 𝑟𝑎𝑛𝑘 𝑊 , ||𝑆||0 : Non-convex
11
Convex Relaxations
• Linear classification/regression
• 𝐶 = {𝑤, ||𝑤||0 ≤ 𝑠} 𝐶ሚ = {𝑤, ||𝑤||1 ≤ 𝜆(𝑠)}
• ||𝑤||1 ≤ σ𝑖 𝑤𝑖
• Matrix completion
• 𝐶 = {𝑊, 𝑟𝑎𝑛𝑘 𝑊 ≤ 𝑟} 𝐶ሚ = {𝑊, ||𝑊||∗ ≤ 𝜆 𝑟 }
• ||𝑊||∗ ≤ σ𝑖 𝜎𝑖 , 𝑊 = 𝑈Σ𝑉 𝑇
12
Convex Relaxations Contd…
• Robust PCA
• 𝐶 = {𝑊, 𝑊 = 𝐿 + 𝑆, 𝑟𝑎𝑛𝑘 𝐿 ≤ 𝑟, ||𝑆||0 ≤ 𝑠}
13
Convex Relaxation
• Advantage:
• Convex optimization: Polynomial time
• Generic tools available for optimization
• Systematic analysis
• Disadvantage:
• Optimizes over a much bigger set
• Not scalable to large problems
14
This tutorial’s focus
Don’t Relax!
• Advantage: scalability
• Disadvantage: optimization and its analysis is much harder
• Local minima problems
• Two approaches:
• Projected gradient descent
• Alternating minimization
15
Approach 1: Projected Gradient Descent
min 𝐿 𝑤
𝑤
𝑠. 𝑡. 𝑤 ∈ 𝐶
• 𝑤𝑡+1 = 𝑤𝑡 − 𝜕𝑤𝑡 𝐿(𝑤𝑡 )
• 𝑤𝑡+1 = 𝑃𝐶 (𝑤𝑡+1 )
16
Efficient Projection
• Sparse linear regression/classification
• 𝐶 = {𝑤, ||𝑤||0 ≤ 𝑠}
• 𝑠𝑢𝑝𝑝 𝑃𝑟𝑜𝑗𝐶 𝑧 = {𝑖1 , … , 𝑖𝑠 }
• 𝑧𝑖1 ≥ 𝑧𝑖2 ≥ ⋯ ≥ |𝑧𝑖𝑑 |
• 𝑂(𝑑 log 𝑑)
• Generic technique
• If each individual problem is “easy”
• Generic technique, e.g., EM algorithms
Results for Several Problems
• Sparse regression [Jain et al.’14, Garg and Khandekar’09]
• Sparsity
• Robust Regression [Bhatia et al.’15]
• Sparsity+output sparsity
• Dictionary Learning [Agarwal et al.’14]
• Matrix Factorization + Sparsity
• Phase Sensing [Netrapalli et al.’13]
• System of Quadratic Equations
• Vector-value Regression [Jain & Tewari’15]
• Sparsity+positive definite matrix
19
Results Contd…
• Low-rank Matrix Regression [Jain et al.’10, Jain et al.’13]
• Low-rank structure
• Low-rank Matrix Completion [Jain & Netrapalli’15, Jain et al.’13]
• Low-rank structure
• Robust PCA [Netrapalli et al.’14]
• Low-rank ∩ Sparse Matrices
• Tensor Completion [Jain and Oh’14]
• Low-tensor rank
• Low-rank matrix approximation [Bhojanapalli et al.’15]
• Low-rank structure
20
Sparse Linear Regression
0.1
0 d
=
n 1
⋮ ⋮
0.9
𝑦 = 𝑋 𝑤
• But: 𝑛 ≪ 𝑑
• 𝑤: 𝑠 −sparse (𝑠 non-zeros)
21
Motivation: Single Pixel Camera
• ||𝑦 − 𝑋𝑤||2 = σ𝑖 𝑦𝑖 − 𝑥𝑖 , 𝑤 2
27
Non-convexity of Low-rank manifold
1 0 0.5
0.5 0 + 0.5 1 = 0.5
0 0 0
Convex Relaxation
min ||𝑦 − 𝑋𝑤||2
𝑤
𝑠. 𝑡. ||𝑤||0 ≤ 𝑠
• Relaxed Problem:
min ||𝑦 − 𝑋𝑤||2
𝑤
𝑠. 𝑡. ||𝑤||1 ≤ 𝑠
• ||𝑤||1 = σ𝑖 |𝑤𝑖 |
• Known to promote sparsity
• Pros: a) Principled approach, b) Captures correlations between features
• Cons: Slow to optimize
29
Our Approach : Projected Gradient Descent
min 𝑓 𝑤 = ||𝑦 − 𝑋𝑤||2
𝑤
𝑠. 𝑡. ||𝑤||0 ≤ 𝑠
• 𝑤𝑡+1 = 𝑤𝑡 − 𝜕𝑤𝑡 𝑓(𝑤𝑡 )
• 𝑤𝑡+1 = 𝑃𝑠 (𝑤𝑡+1 )
30
[Jain, Tewari, Kar’2014]
Projection onto 𝐿0 ball?
min ||𝑥 − 𝑧||22
𝑥
𝑠. 𝑡. ||𝑥||0 ≤ 𝑠
Important Properties
A Stronger Result?
𝑑−𝑠
||𝑃𝑠 𝑧 − 𝑧||22 ≤ ||𝑃 ∗ 𝑧 − 𝑧|| 2
𝑑 − 𝑠∗ 𝑠 2
Our Approach : Projected Gradient Descent
min 𝑓 𝑤 = ||𝑦 − 𝑋𝑤||2
𝑤
𝑠. 𝑡. ||𝑤||0 ≤ 𝑠
• 𝑤𝑡+1 = 𝑤𝑡 − 𝜕𝑤𝑡 𝑓(𝑤𝑡 )
• 𝑤𝑡+1 = 𝑃𝑠 (𝑤𝑡+1 )
34
[Jain, Tewari, Kar’2014]
Convex-projections vs Non-convex Projections
• For non-convex sets, we only have:
∀𝑌 ∈ 𝐶, ||𝑃𝑟 𝑍 − 𝑍|| ≤ ||𝑌 − 𝑍||
• 0-th order condition
• But, for projection onto convex set 𝐶:
∀𝑌 ∈ 𝐶, ||𝑍 − 𝑃𝐶 𝑍 ||2 ≤ 〈𝑌 − 𝑍, 𝑃𝐶 𝑍 − 𝑍〉
• 1-st order condition
𝑋
𝑤
Xw
Proof under RIP
• Let 𝑓 𝑤 = ||𝑋 𝑤 − 𝑤 ∗ ||22
1
• Let 𝛿3𝑠 ≤
2
• Let 𝑤𝑡+1 = 𝑃𝐶 𝑤𝑡 − 𝜂 𝑔𝑡 , 𝑔𝑡 = 𝑋 𝑇 𝑋 𝑤𝑡 − 𝑤 ∗ , 𝜂 = 1
• 𝐶: 𝐿0 ball with 𝑠 non-zeros and 𝑤 ∗ ∈ 𝐶
∗
3 ∗
||𝑤𝑡+1 − 𝑤 || ≤ ||𝑤𝑡 − 𝑤 ||
4
[Blumensath & Davies’09, Garg & Khandekar’09]
Variations
• Fully corrective version:
𝑢𝑡+1 = 𝑃𝐶 𝑤𝑡 − 𝜂 𝑔𝑡
𝑤𝑡+1 = arg min 𝑓(𝑤) , 𝑠. 𝑡. supp w = supp(u)
𝑤
• Two stage algorithms:
Summary so far…
• High-dimensional problems
• 𝑛≪𝑑
• Need to impose structure on 𝑤
• Sparsity
• Projection easy!
• Projected Gradient works (if RIP is satisfied)
• Several variants exist
Which Matrices Satisfy RIP?
1 − 𝛿𝑠 | 𝐰||2 ≤ ||𝐗𝐰||2 ≤ 1 + 𝛿𝑠 | 𝐰||2 , ||𝑤||0 ≤ 𝑠
𝑛
Popular RIP Ensembles
𝑋
𝑑
𝑛 = 𝑂(𝑠 log )
𝑠
3∗ ∗
||𝑤𝑡+1 − 𝑤 || ≤ ||𝑤𝑡 − 𝑤 ||
4
Proof?
But what if RIP is not possible?
Statistical Guarantees
𝑦𝑖 = 〈𝑥𝑖 , 𝑤 ∗ 〉 + 𝜂𝑖
• 𝑥𝑖 ∼ 𝑁(0, Σ)
• 𝜂𝑖 ∼ 𝑁(0, 𝜎 2 )
• 𝑤 ∗ : 𝑠 −sparse
𝜎 ⋅ 𝜅 ⋅ 𝑠 log 𝑑
|| 𝑤
ෝ− 𝑤 ∗ || ≤
𝑛
• 𝜅 = 𝜆1 (Σ)/𝜆𝑑 (Σ)
56
[Jain, Tewari, Kar’2014]
Proof?
1
•𝑓 𝑤 = ||𝑋 𝑤 − 𝑤 ∗ ||2
2
• 𝑋 = [𝑥1 ; 𝑥2 ; … ; 𝑥𝑛 ]
• 𝑥𝑖 ∼ 𝑁 0, Σ , 𝛼 ⋅ 𝐼𝑑×𝑑 ≼ Σ ≼ 𝐿 ⋅ 𝐼𝑑×𝑑
2
• 𝑤𝑡+1 = 𝑃𝑠 𝑤𝑡 − 𝜂 𝑔𝑡 , 𝐿 =
3𝐿
𝐿 2 ∗
•𝑠= 𝑠
𝛼
∗ 2
𝛼 ∗ 2
||𝑤𝑡+1 − 𝑤 ||2 ≤ 1− ||𝑤𝑡 − 𝑤 ||2
10 ⋅ 𝐿
Proof?
General Result for Any Function
• 𝑓: 𝑅𝑑 → 𝑅
• 𝑓: satisfies RSC/RSS, i.e.,
𝛼𝑠 ⋅ 𝐼𝑑×𝑑 ≼ 𝐻 𝑤 ≼ 𝐿𝑠 ⋅ 𝐼𝑑×𝑑 , 𝑖𝑓, ||𝑤||0 ≤ 𝑠
𝜎𝜅 𝑠 log 𝑑
|| 𝑤
ෝ− 𝑤 ∗ ||
≤𝜖+
𝑛
Non-Convex • 𝜅 = 𝜆1 (Σ)/𝜆𝑑 (Σ)
0.9
𝑦 = 𝑋 𝑤 + 𝑏
𝑦 = 𝑋𝑤 ∗ + 𝑏
Typical b:
a) Deterministic error : | 𝑤 − 𝑤 ∗ | ≤ ||𝑏||
||𝑏||
b) Gaussian error : | 𝑤 − 𝑤 ∗| ≤
𝑛
Robust Regression
• ||𝑏||0 ≤ 𝛽 ⋅ 𝑛
• We want 𝛽 to be a constant
• Entries of 𝑏 can be unbounded!
• | 𝑏| 2 can be arbitrarily large
A
≅ ×
𝑟-dim, k-sparse
vector
𝑚 𝑑×𝑟
Data Point Dictionary
Dictionary Learning
Y A X
𝑑 ≅ × 𝑟
𝑛 𝑟
• Overcomplete dictionaries: 𝑟 ≫ 𝑑
• Goal: Given 𝑌, compute 𝐴, 𝑋
• Using small number of samples 𝑛
Existing Results
• Generalization error bounds [VMB’11, MPR’12, MG’13, TRS’13]
• But assumes that the optimal solution is reached
• Do not cover exact recovery with finite many samples
• Identifiability of 𝐴, 𝑋 [HS’11]
• Require exponentially many samples
• Exact recovery [SWW’12]
• Restricted to square dictionary (𝑑 = 𝑟)
• In practice, overcomplete dictionary (𝑑 ≪ 𝑟) is more useful
Generating Model
• Generate dictionary 𝐴
• Assume 𝐴 to be incoherent, i.e., 𝐴𝑖 , 𝐴𝑗 ≤ 𝜇/ 𝑑
• 𝑟≫𝑑
• Generate random samples 𝑋 = 𝑥1 , 𝑥2 , … , 𝑥𝑛 ∈ 𝑅𝑑×𝑛
• Each 𝑥𝑖 is 𝑘-sparse
• Generate observations: 𝑌 = 𝐴𝑋
Algorithm
• Typically practical algorithm: alternating minimization
• 𝑋𝑡+1 = 𝑎𝑟𝑔𝑚𝑖𝑛𝑋 ||𝑌 − 𝐴𝑡 𝑋||2𝐹
• 𝐴𝑡+1 = 𝑎𝑟𝑔𝑚𝑖𝑛𝐴 ||𝑌 − 𝐴𝑋𝑡+1 ||2𝐹
• Initialize 𝐴0
• Using clustering+SVD method of [AAN’13] or [AGM’13]
Results [AAJNT’13]
• Assumptions:
• 𝐴 is 𝜇 − incoherent ( 𝐴𝑖 , 𝐴𝑗 ≤ 𝜇/ 𝑑, ||𝐴𝑖 || = 1)
• 1 ≤ 𝑋𝑖𝑗 ≤ 100
1
𝑑6
• Sparsity: 𝑘 ≤ 1 (better result by AGM’13)
𝜇3
• 𝑛 ≥ 𝑂(𝑟 2 log 𝑟)
1
• After log( )-steps of AltMin:
𝜖
||𝐴𝑖𝑇 − 𝐴𝑖 ||2 ≤ 𝜖
Proof Sketch
• Initialization step ensures that:
1
||𝐴𝑖
− ≤ 2𝐴𝑖0 ||
𝑘
• Lower bound on each element of 𝑋𝑖𝑗 + above bound:
• 𝑠𝑢𝑝𝑝(𝑥𝑖 ) is recovered exactly
• Robustness of compressive sensing!
• 𝐴𝑡+1 can be expressed exactly as:
• 𝐴𝑡+1 = 𝐴 + 𝐸𝑟𝑟𝑜𝑟(𝐴𝑡, 𝑋𝑡 )
• Use randomness in 𝑠𝑢𝑝𝑝(𝑋𝑡 )
Simulations
Emirically: 𝑛 = 𝑂(𝑟)
Known result: 𝑛 = 𝑂 𝑟 2 log 𝑟
Summary
• Consider high-dimensional structured problems
• Sparsity
• Block sparsity
• Tree-based sparsity
• Error sparsity
• Iterative hard thresholding style method
• Practical/easy to implement
• Fast convergence
• RIP/RSC/subGaussian data: Provable guarantees
http://research.microsoft.com/en-us/people/prajain/
Purushottam Kar Kush Bhatia
Asst. Prof.
Univ of Michigan
Next Lecture
• Low-rank Structure
• Matrix Regression
• Matrix Completion
• Robust PCA
• Low-rank Tensor Structure
• Tensor completion
Block-sparse Signals
𝐲1 = Φ1 𝐱1 , 𝐲2 = Φ2 𝒙2 , … , 𝐲𝑟 = Φ𝑟 𝐱 𝑟
• Total no. of measurements: 𝑂(𝑟 ⋅ 𝑘 ⋅ log 𝑛)
• Correlated signals: J = 𝑥1 ∪ 𝑥2 … 𝑥𝑟 ≤ 𝑘 ⋅ 𝑟
• Method--- Group norms: 𝐿2,1 or 𝐿2,∞
• Improvement in sample complexity if
𝐽 ≪𝑘⋅𝑟