Understanding Stochastic Gradient Descent (SGD) in Machine Learning
In our previous blog post – on day 4 – we have talked about using the SGD algorithm for the MNIST dataset. But what is Stochastic Gradient Descent?
Stochastic Gradient Descent (SGD) is an iterative method for optimizing an objective function that is written as a sum of differentiable functions. It’s a variant of the traditional gradient descent algorithm but with a twist: instead of computing the gradient of the whole dataset, it approximates the gradient using a single data point or a small batch of data points. This makes SGD much faster and more scalable, especially for large datasets.
Why is SGD Important?
The SGD Algorithm
The goal of SGD is to minimize an objective function \( J(\theta) \) with respect to the parameters \( \theta \). Here’s the general procedure:
- Initialize: Randomly initialize the parameters \( \theta \).
- Iterate: For each data point \( (x_i, y_i) \):
- Compute the gradient of the loss function with respect to \( \theta \) using \( (x_i, y_i) \).
- Update the parameters: \( \theta \leftarrow \theta – \eta \nabla_{\theta} J(\theta; x_i, y_i) \), where \( \eta \) is the learning rate.
- Repeat: Continue iterating until convergence or for a specified number of epochs.
A Detailed Example of SGD
Let’s walk through a simple example of SGD using a linear regression problem with just two data points. This will help illustrate each step of the algorithm.
Problem Setup
We have two data points:
(x_1, y_1) = (1, 2) (x_2, y_2) = (2, 3)
We aim to find a linear function \( y = wx + b \) that best fits these data points using SGD to minimize the Mean Squared Error (MSE) loss function.
Objective Function
The Mean Squared Error (MSE) loss function measures how well the line \( y = wx + b \) fits the data points. It is defined as:
\[ J(w, b) = \frac{1}{n} \sum_{i=1}^{n} (y_i – (wx_i + b))^2 \]
Here:
- J(w, b): The loss function we want to minimize.
- n: The number of data points.
- (x_i, y_i): The data points.
- wx_i + b: The predicted value for \( x_i \).
For our two data points, \( n = 2 \). So, the loss function becomes:
\[ J(w, b) = \frac{1}{2} \left[ (2 – (w \cdot 1 + b))^2 + (3 – (w \cdot 2 + b))^2 \right] \]
This loss function represents the average squared difference between the actual \( y \) values and the predicted \( y \) values. Minimizing this function means finding the line that best fits the data.
Gradients
To minimize the loss function, we need to compute its gradients (partial derivatives) with respect to the parameters \( w \) and \( b \). The gradients indicate the direction and rate of change of the loss function with respect to each parameter.
The general formulas for the gradients of the loss function with respect to \( w \) and \( b \) are:
\[ \frac{\partial J}{\partial w} = -\frac{2}{n} \sum_{i=1}^{n} x_i (y_i – (wx_i + b)) \]
\[ \frac{\partial J}{\partial b} = -\frac{2}{n} \sum_{i=1}^{n} (y_i – (wx_i + b)) \]
In Stochastic Gradient Descent (SGD), we update the parameters using only one data point at a time, rather than the entire dataset. So, for each individual data point, the gradients become:
\[ \frac{\partial J}{\partial w} = -2 x_i (y_i – (wx_i + b)) \]
\[ \frac{\partial J}{\partial b} = -2 (y_i – (wx_i + b)) \]
These formulas derive from the chain rule in calculus and show how changes in \( w \) and \( b \) affect the overall error.
SGD Steps
- Initialize Parameters: Start with initial guesses for \( w \) and \( b \). Let’s initialize \( w = 0 \) and \( b = 0 \). We also set a learning rate \( \eta = 0.1 \).
- Iterate Over Data Points:
- Compute the predicted value \( \hat{y} \) using the current values of \( w \) and \( b \).
- Calculate the error between the actual value \( y \) and the predicted value \( \hat{y} \).
- Compute the gradients of the loss function with respect to \( w \) and \( b \).
- Update the parameters \( w \) and \( b \) using the computed gradients and the learning rate.
First Epoch (First Pass Over the Data Points)
Data Point 1: (x_1, y_1) = (1, 2)
- Predict the Value:\[ \hat{y}_i = w \cdot x_i + b \]Using the initial values \( w = 0 \) and \( b = 0 \):\[ \hat{y}_1 = 0 \cdot 1 + 0 = 0 \]
- Calculate the Error:\[ \text{Error} = y_i – \hat{y}_i \]\[ \text{Error} = 2 – 0 = 2 \]
- Compute the Gradients:\[ \frac{\partial J}{\partial w} = -2 \cdot x_i \cdot \text{Error} \]\[ \frac{\partial J}{\partial b} = -2 \cdot \text{Error} \]Substituting the values for the first data point:\[ \frac{\partial J}{\partial w} = -2 \cdot 1 \cdot 2 = -4 \]\[ \frac{\partial J}{\partial b} = -2 \cdot 2 = -4 \]
These gradients tell us how much the parameters \( w \) and \( b \) need to change to reduce the error.
- Update the Parameters:\[ w \leftarrow w – \eta \cdot \frac{\partial J}{\partial w} \]\[ b \leftarrow b – \eta \cdot \frac{\partial J}{\partial b} \]Applying the updates:\[ w \leftarrow 0 – 0.1 \cdot (-4) = 0.4 \]\[ b \leftarrow 0 – 0.1 \cdot (-4) = 0.4 \]
The learning rate \( \eta \) controls how much we adjust the parameters in each step. A smaller learning rate means smaller adjustments, and a larger learning rate means larger adjustments.
Data Point 2: (x_2, y_2) = (2, 3)
- Predict the Value:Using the updated values \( w = 0.4 \) and \( b = 0.4 \):\[ \hat{y}_2 = w \cdot x_2 + b \]\[ \hat{y}_2 = 0.4 \cdot 2 + 0.4 = 1.2 \]
- Calculate the Error:\[ \text{Error} = y_2 – \hat{y}_2 \]\[ \text{Error} = 3 – 1.2 = 1.8 \]
- Compute the Gradients:Substituting the values for the second data point:\[ \frac{\partial J}{\partial w} = -2 \cdot 2 \cdot 1.8 = -7.2 \]\[ \frac{\partial J}{\partial b} = -2 \cdot 1.8 = -3.6 \]These gradients indicate how much the parameters \( w \) and \( b \) need to change to reduce the error further.
- Update the Parameters:Applying the updates:\[ w \leftarrow w – \eta \cdot \frac{\partial J}{\partial w} \]\[ w \leftarrow 0.4 – 0.1 \cdot (-7.2) = 1.12 \]\[ b \leftarrow b – \eta \cdot \frac{\partial J}{\partial b} \]\[ b \leftarrow 0.4 – 0.1 \cdot (-3.6) = 0.76 \]
Second Epoch (Second Pass Over the Data Points)
Data Point 1: (x_1, y_1) = (1, 2)
- Predict the Value:Using the updated values \( w = 1.12 \) and \( b = 0.76 \):\[ \hat{y}_1 = w \cdot x_1 + b \]\[ \hat{y}_1 = 1.12 \cdot 1 + 0.76 = 1.88 \]
- Calculate the Error:\[ \text{Error} = y_1 – \hat{y}_1 \]\[ \text{Error} = 2 – 1.88 = 0.12 \]
- Compute the Gradients:Substituting the values for the first data point:\[ \frac{\partial J}{\partial w} = -2 \cdot 1 \cdot 0.12 = -0.24 \]\[ \frac{\partial J}{\partial b} = -2 \cdot 0.12 = -0.24 \]
- Update the Parameters:Applying the updates:\[ w \leftarrow w – \eta \cdot \frac{\partial J}{\partial w} \]\[ w \leftarrow 1.12 – 0.1 \cdot (-0.24) = 1.144 \]\[ b \leftarrow b – \eta \cdot \frac{\partial J}{\partial b} \]\[ b \leftarrow 0.76 – 0.1 \cdot (-0.24) = 0.784 \]
Data Point 2: (x_2, y_2) = (2, 3)
- Predict the Value:Using the updated values \( w = 1.144 \) and \( b = 0.784 \):\[ \hat{y}_2 = w \cdot x_2 + b \]\[ \hat{y}_2 = 1.144 \cdot 2 + 0.784 = 3.072 \]
- Calculate the Error:\[ \text{Error} = y_2 – \hat{y}_2 \]\[ \text{Error} = 3 – 3.072 = -0.072 \]
- Compute the Gradients:Substituting the values for the second data point:\[ \frac{\partial J}{\partial w} = -2 \cdot 2 \cdot (-0.072) = 0.288 \]\[ \frac{\partial J}{\partial b} = -2 \cdot (-0.072) = 0.144 \]
- Update the Parameters:Applying the updates:\[ w \leftarrow w – \eta \cdot \frac{\partial J}{\partial w} \]\[ w \leftarrow 1.144 – 0.1 \cdot (0.288) = 1.1152 \]\[ b \leftarrow b – \eta \cdot \frac{\partial J}{\partial b} \]\[ b \leftarrow 0.784 – 0.1 \cdot (0.144) = 0.7696 \]
Summary of Updates
After two epochs (two passes over the data points), the parameters \( w \) and \( b \) are updated as follows:
- After the first epoch: \( w = 1.12 \), \( b = 0.76 \)
- After the second epoch: \( w = 1.1152 \), \( b = 0.7696 \)
By continuing this process for more epochs, the parameters will converge further to the optimal solution that minimizes the MSE loss.
Conclusion
This detailed breakdown shows each step of Stochastic Gradient Descent (SGD) for a simple linear regression problem, explaining the formulas used and their purpose. SGD updates the model parameters iteratively based on the gradients computed from individual data points, making it efficient for large-scale data and suitable for online learning. Understanding these steps and formulas helps in comprehending how SGD works to minimize the error and improve the model.
Second Epoch (Second Pass Over the Data Points)
Data Point 1: (x_1, y_1) = (1, 2)
Step | Formula | Value |
---|---|---|
Predict the Value | \(\hat{y}_1 = w \cdot x_1 + b\) | 1.88 |
Calculate the Error | \( \text{Error} = y_1 – \hat{y}_1 \) | 0.12 |
Compute the Gradient (w) | \(\frac{\partial J}{\partial w} = -2 \cdot x_1 \cdot \text{Error}\) | -0.24 |
Compute the Gradient (b) | \(\frac{\partial J}{\partial b} = -2 \cdot \text{Error}\) | -0.24 |
Update w | \(w \leftarrow w – \eta \cdot \frac{\partial J}{\partial w}\) | 1.144 |
Update b | \(b \leftarrow b – \eta \cdot \frac{\partial J}{\partial b}\) | 0.784 |
Data Point 2: (x_2, y_2) = (2, 3)
Step | Formula | Value |
---|---|---|
Predict the Value | \(\hat{y}_2 = w \cdot x_2 + b\) | 3.072 |
Calculate the Error | \( \text{Error} = y_2 – \hat{y}_2 \) | -0.072 |
Compute the Gradient (w) | \(\frac{\partial J}{\partial w} = -2 \cdot x_2 \cdot \text{Error}\) | 0.288 |
Compute the Gradient (b) | \(\frac{\partial J}{\partial b} = -2 \cdot \text{Error}\) | 0.144 |
Update w | \(w \leftarrow w – \eta \cdot \frac{\partial J}{\partial w}\) | 1.1152 |
Update b | \(b \leftarrow b – \eta \cdot \frac{\partial J}{\partial b}\) | 0.7696 |