Machine Learning Overview

Day 5 _ Mathematical Explanation behind SGD Algorithm in Machine Learning






Understanding Stochastic Gradient Descent (SGD) in Machine Learning


Understanding Stochastic Gradient Descent (SGD) in Machine Learning

In our previous blog post – on day 4 – we have talked about using the SGD algorithm for the MNIST dataset. But what is Stochastic Gradient Descent?

Stochastic Gradient Descent (SGD) is an iterative method for optimizing an objective function that is written as a sum of differentiable functions. It’s a variant of the traditional gradient descent algorithm but with a twist: instead of computing the gradient of the whole dataset, it approximates the gradient using a single data point or a small batch of data points. This makes SGD much faster and more scalable, especially for large datasets.

Why is SGD Important?

  • Efficiency: By updating the parameters using only a subset of data, SGD reduces computation time, making it faster than batch gradient descent for large datasets.
  • Online Learning: SGD can be used in online learning scenarios where the model is updated continuously as new data comes in.
  • Convergence: Although SGD introduces more noise into the optimization process, this can help in escaping local minima and finding a better global minimum.
  • The SGD Algorithm

    The goal of SGD is to minimize an objective function \( J(\theta) \) with respect to the parameters \( \theta \). Here’s the general procedure:

    1. Initialize: Randomly initialize the parameters \( \theta \).
    2. Iterate: For each data point \( (x_i, y_i) \):
      • Compute the gradient of the loss function with respect to \( \theta \) using \( (x_i, y_i) \).
      • Update the parameters: \( \theta \leftarrow \theta – \eta \nabla_{\theta} J(\theta; x_i, y_i) \), where \( \eta \) is the learning rate.
    3. Repeat: Continue iterating until convergence or for a specified number of epochs.

    A Detailed Example of SGD

    Let’s walk through a simple example of SGD using a linear regression problem with just two data points. This will help illustrate each step of the algorithm.

    Problem Setup

    We have two data points:

    (x_1, y_1) = (1, 2)
    (x_2, y_2) = (2, 3)

    We aim to find a linear function \( y = wx + b \) that best fits these data points using SGD to minimize the Mean Squared Error (MSE) loss function.

    Objective Function

    The Mean Squared Error (MSE) loss function measures how well the line \( y = wx + b \) fits the data points. It is defined as:

    \[ J(w, b) = \frac{1}{n} \sum_{i=1}^{n} (y_i – (wx_i + b))^2 \]

    Here:

    • J(w, b): The loss function we want to minimize.
    • n: The number of data points.
    • (x_i, y_i): The data points.
    • wx_i + b: The predicted value for \( x_i \).

    For our two data points, \( n = 2 \). So, the loss function becomes:

    \[ J(w, b) = \frac{1}{2} \left[ (2 – (w \cdot 1 + b))^2 + (3 – (w \cdot 2 + b))^2 \right] \]

    This loss function represents the average squared difference between the actual \( y \) values and the predicted \( y \) values. Minimizing this function means finding the line that best fits the data.

    Gradients

    To minimize the loss function, we need to compute its gradients (partial derivatives) with respect to the parameters \( w \) and \( b \). The gradients indicate the direction and rate of change of the loss function with respect to each parameter.

    The general formulas for the gradients of the loss function with respect to \( w \) and \( b \) are:

    \[ \frac{\partial J}{\partial w} = -\frac{2}{n} \sum_{i=1}^{n} x_i (y_i – (wx_i + b)) \]

    \[ \frac{\partial J}{\partial b} = -\frac{2}{n} \sum_{i=1}^{n} (y_i – (wx_i + b)) \]

    In Stochastic Gradient Descent (SGD), we update the parameters using only one data point at a time, rather than the entire dataset. So, for each individual data point, the gradients become:

    \[ \frac{\partial J}{\partial w} = -2 x_i (y_i – (wx_i + b)) \]

    \[ \frac{\partial J}{\partial b} = -2 (y_i – (wx_i + b)) \]

    These formulas derive from the chain rule in calculus and show how changes in \( w \) and \( b \) affect the overall error.

    SGD Steps

    1. Initialize Parameters: Start with initial guesses for \( w \) and \( b \). Let’s initialize \( w = 0 \) and \( b = 0 \). We also set a learning rate \( \eta = 0.1 \).
    2. Iterate Over Data Points:
      • Compute the predicted value \( \hat{y} \) using the current values of \( w \) and \( b \).
      • Calculate the error between the actual value \( y \) and the predicted value \( \hat{y} \).
      • Compute the gradients of the loss function with respect to \( w \) and \( b \).
      • Update the parameters \( w \) and \( b \) using the computed gradients and the learning rate.

    First Epoch (First Pass Over the Data Points)

    Data Point 1: (x_1, y_1) = (1, 2)

    1. Predict the Value:\[ \hat{y}_i = w \cdot x_i + b \]Using the initial values \( w = 0 \) and \( b = 0 \):\[ \hat{y}_1 = 0 \cdot 1 + 0 = 0 \]
    2. Calculate the Error:\[ \text{Error} = y_i – \hat{y}_i \]\[ \text{Error} = 2 – 0 = 2 \]
    3. Compute the Gradients:\[ \frac{\partial J}{\partial w} = -2 \cdot x_i \cdot \text{Error} \]\[ \frac{\partial J}{\partial b} = -2 \cdot \text{Error} \]Substituting the values for the first data point:\[ \frac{\partial J}{\partial w} = -2 \cdot 1 \cdot 2 = -4 \]\[ \frac{\partial J}{\partial b} = -2 \cdot 2 = -4 \]

      These gradients tell us how much the parameters \( w \) and \( b \) need to change to reduce the error.

    4. Update the Parameters:\[ w \leftarrow w – \eta \cdot \frac{\partial J}{\partial w} \]\[ b \leftarrow b – \eta \cdot \frac{\partial J}{\partial b} \]Applying the updates:\[ w \leftarrow 0 – 0.1 \cdot (-4) = 0.4 \]\[ b \leftarrow 0 – 0.1 \cdot (-4) = 0.4 \]

      The learning rate \( \eta \) controls how much we adjust the parameters in each step. A smaller learning rate means smaller adjustments, and a larger learning rate means larger adjustments.

    Data Point 2: (x_2, y_2) = (2, 3)

    1. Predict the Value:Using the updated values \( w = 0.4 \) and \( b = 0.4 \):\[ \hat{y}_2 = w \cdot x_2 + b \]\[ \hat{y}_2 = 0.4 \cdot 2 + 0.4 = 1.2 \]
    2. Calculate the Error:\[ \text{Error} = y_2 – \hat{y}_2 \]\[ \text{Error} = 3 – 1.2 = 1.8 \]
    3. Compute the Gradients:Substituting the values for the second data point:\[ \frac{\partial J}{\partial w} = -2 \cdot 2 \cdot 1.8 = -7.2 \]\[ \frac{\partial J}{\partial b} = -2 \cdot 1.8 = -3.6 \]These gradients indicate how much the parameters \( w \) and \( b \) need to change to reduce the error further.
    4. Update the Parameters:Applying the updates:\[ w \leftarrow w – \eta \cdot \frac{\partial J}{\partial w} \]\[ w \leftarrow 0.4 – 0.1 \cdot (-7.2) = 1.12 \]\[ b \leftarrow b – \eta \cdot \frac{\partial J}{\partial b} \]\[ b \leftarrow 0.4 – 0.1 \cdot (-3.6) = 0.76 \]

    Second Epoch (Second Pass Over the Data Points)

    Data Point 1: (x_1, y_1) = (1, 2)

    1. Predict the Value:Using the updated values \( w = 1.12 \) and \( b = 0.76 \):\[ \hat{y}_1 = w \cdot x_1 + b \]\[ \hat{y}_1 = 1.12 \cdot 1 + 0.76 = 1.88 \]
    2. Calculate the Error:\[ \text{Error} = y_1 – \hat{y}_1 \]\[ \text{Error} = 2 – 1.88 = 0.12 \]
    3. Compute the Gradients:Substituting the values for the first data point:\[ \frac{\partial J}{\partial w} = -2 \cdot 1 \cdot 0.12 = -0.24 \]\[ \frac{\partial J}{\partial b} = -2 \cdot 0.12 = -0.24 \]
    4. Update the Parameters:Applying the updates:\[ w \leftarrow w – \eta \cdot \frac{\partial J}{\partial w} \]\[ w \leftarrow 1.12 – 0.1 \cdot (-0.24) = 1.144 \]\[ b \leftarrow b – \eta \cdot \frac{\partial J}{\partial b} \]\[ b \leftarrow 0.76 – 0.1 \cdot (-0.24) = 0.784 \]

    Data Point 2: (x_2, y_2) = (2, 3)

    1. Predict the Value:Using the updated values \( w = 1.144 \) and \( b = 0.784 \):\[ \hat{y}_2 = w \cdot x_2 + b \]\[ \hat{y}_2 = 1.144 \cdot 2 + 0.784 = 3.072 \]
    2. Calculate the Error:\[ \text{Error} = y_2 – \hat{y}_2 \]\[ \text{Error} = 3 – 3.072 = -0.072 \]
    3. Compute the Gradients:Substituting the values for the second data point:\[ \frac{\partial J}{\partial w} = -2 \cdot 2 \cdot (-0.072) = 0.288 \]\[ \frac{\partial J}{\partial b} = -2 \cdot (-0.072) = 0.144 \]
    4. Update the Parameters:Applying the updates:\[ w \leftarrow w – \eta \cdot \frac{\partial J}{\partial w} \]\[ w \leftarrow 1.144 – 0.1 \cdot (0.288) = 1.1152 \]\[ b \leftarrow b – \eta \cdot \frac{\partial J}{\partial b} \]\[ b \leftarrow 0.784 – 0.1 \cdot (0.144) = 0.7696 \]

    Summary of Updates

    After two epochs (two passes over the data points), the parameters \( w \) and \( b \) are updated as follows:

    • After the first epoch: \( w = 1.12 \), \( b = 0.76 \)
    • After the second epoch: \( w = 1.1152 \), \( b = 0.7696 \)

    By continuing this process for more epochs, the parameters will converge further to the optimal solution that minimizes the MSE loss.

    Conclusion

    This detailed breakdown shows each step of Stochastic Gradient Descent (SGD) for a simple linear regression problem, explaining the formulas used and their purpose. SGD updates the model parameters iteratively based on the gradients computed from individual data points, making it efficient for large-scale data and suitable for online learning. Understanding these steps and formulas helps in comprehending how SGD works to minimize the error and improve the model.

    Second Epoch (Second Pass Over the Data Points)

    Data Point 1: (x_1, y_1) = (1, 2)

    Step Formula Value
    Predict the Value \(\hat{y}_1 = w \cdot x_1 + b\) 1.88
    Calculate the Error \( \text{Error} = y_1 – \hat{y}_1 \) 0.12
    Compute the Gradient (w) \(\frac{\partial J}{\partial w} = -2 \cdot x_1 \cdot \text{Error}\) -0.24
    Compute the Gradient (b) \(\frac{\partial J}{\partial b} = -2 \cdot \text{Error}\) -0.24
    Update w \(w \leftarrow w – \eta \cdot \frac{\partial J}{\partial w}\) 1.144
    Update b \(b \leftarrow b – \eta \cdot \frac{\partial J}{\partial b}\) 0.784

    Data Point 2: (x_2, y_2) = (2, 3)

    Step Formula Value
    Predict the Value \(\hat{y}_2 = w \cdot x_2 + b\) 3.072
    Calculate the Error \( \text{Error} = y_2 – \hat{y}_2 \) -0.072
    Compute the Gradient (w) \(\frac{\partial J}{\partial w} = -2 \cdot x_2 \cdot \text{Error}\) 0.288
    Compute the Gradient (b) \(\frac{\partial J}{\partial b} = -2 \cdot \text{Error}\) 0.144
    Update w \(w \leftarrow w – \eta \cdot \frac{\partial J}{\partial w}\) 1.1152
    Update b \(b \leftarrow b – \eta \cdot \frac{\partial J}{\partial b}\) 0.7696