-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Updated the gradient_based_optimization tutorial with my understand…
…ing of what is going on with an explanation at a level that an engineer with a bachelors in comp sci understands - Added/updated comments for the gradient_based_optimization tutorial - Shamelessly added my name to the bottom of that tutorial so that I have something to point to for the time spend with leadership. Feel free to delete it if my explanation is terrible :-p - Added comments to initl.py Signed-off-by: Grant Curell [email protected]
- Loading branch information
1 parent
85449c0
commit fd7c474
Showing
4 changed files
with
229 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
# Defining a simple quadratic function f(x) = x^2 | ||
def f(x): | ||
return x ** 2 | ||
|
||
# Derivative of the function f(x) = x^2 | ||
def df(x): | ||
return 2 * x | ||
|
||
# Initial parameter and learning rate | ||
x_start = 4.5 | ||
learning_rate = 0.15 | ||
# Gradient descent iterations | ||
iterations = 15 | ||
|
||
# Storing the progression of x values for plotting | ||
x_progression = [x_start] | ||
y_progression = [f(x_start)] | ||
|
||
# Performing the gradient descent | ||
for _ in range(iterations): | ||
x_gradient = df(x_progression[-1]) | ||
x_next = x_progression[-1] - learning_rate * x_gradient | ||
x_progression.append(x_next) | ||
y_progression.append(f(x_next)) | ||
|
||
# Creating a range of x values for plotting the function | ||
x_values = np.linspace(-5, 5, 100) | ||
y_values = f(x_values) | ||
|
||
# Plotting the function f(x) | ||
plt.plot(x_values, y_values, label=r'$f(x) = x^2$') | ||
|
||
# Plotting the gradient descent progression | ||
plt.scatter(x_progression, y_progression, color='red', zorder=5) | ||
plt.plot(x_progression, y_progression, linestyle='--', color='red', label='Gradient Descent') | ||
|
||
# Adding details to the plot | ||
plt.title('2D Graph Illustrating Gradient Descent') | ||
plt.xlabel('x') | ||
plt.ylabel('f(x)') | ||
plt.legend() | ||
plt.grid(True) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.