Skip to content

Commit

Permalink
update sgd section
Browse files Browse the repository at this point in the history
  • Loading branch information
nsreddy16 committed Oct 21, 2024
1 parent 3bc655c commit 9be6715
Show file tree
Hide file tree
Showing 85 changed files with 904 additions and 974 deletions.
28 changes: 14 additions & 14 deletions docs/constant_model_loss_transformations/loss_transformations.html
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ <h3 data-number="11.1.2" class="anchored" data-anchor-id="comparing-two-differen
</table>
<p>(Notice how the points for our SLR scatter plot are visually not a great linear fit. We’ll come back to this).</p>
<p>The code for generating the graphs and models is included below, but we won’t go over it in too much depth.</p>
<div id="5b2235c2" class="cell" data-execution_count="1">
<div id="16bb6d33" class="cell" data-execution_count="1">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
Expand All @@ -492,7 +492,7 @@ <h3 data-number="11.1.2" class="anchored" data-anchor-id="comparing-two-differen
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>data_linear <span class="op">=</span> dugongs[[<span class="st">"Length"</span>, <span class="st">"Age"</span>]]</span></code><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></pre></div>
</details>
</div>
<div id="8bc9c50a" class="cell" data-execution_count="2">
<div id="554d069c" class="cell" data-execution_count="2">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Big font helper</span></span>
Expand All @@ -514,7 +514,7 @@ <h3 data-number="11.1.2" class="anchored" data-anchor-id="comparing-two-differen
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>plt.style.use(<span class="st">"default"</span>) <span class="co"># Revert style to default mpl</span></span></code><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></pre></div>
</details>
</div>
<div id="6bf5646d" class="cell" data-execution_count="3">
<div id="a66b6837" class="cell" data-execution_count="3">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Constant Model + MSE</span></span>
Expand Down Expand Up @@ -547,7 +547,7 @@ <h3 data-number="11.1.2" class="anchored" data-anchor-id="comparing-two-differen
</div>
</div>
</div>
<div id="93197b9e" class="cell" data-execution_count="4">
<div id="04eb9823" class="cell" data-execution_count="4">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="co"># SLR + MSE</span></span>
Expand Down Expand Up @@ -610,7 +610,7 @@ <h3 data-number="11.1.2" class="anchored" data-anchor-id="comparing-two-differen
</div>
</div>
</div>
<div id="58131efb" class="cell" data-execution_count="5">
<div id="d2609d98" class="cell" data-execution_count="5">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Predictions</span></span>
Expand All @@ -622,7 +622,7 @@ <h3 data-number="11.1.2" class="anchored" data-anchor-id="comparing-two-differen
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>yhats_linear <span class="op">=</span> [theta_0_hat <span class="op">+</span> theta_1_hat <span class="op">*</span> x <span class="cf">for</span> x <span class="kw">in</span> xs]</span></code><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></pre></div>
</details>
</div>
<div id="f6e42346" class="cell" data-execution_count="6">
<div id="1164e17e" class="cell" data-execution_count="6">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Constant Model Rug Plot</span></span>
Expand Down Expand Up @@ -652,7 +652,7 @@ <h3 data-number="11.1.2" class="anchored" data-anchor-id="comparing-two-differen
</div>
</div>
</div>
<div id="fd949995" class="cell" data-execution_count="7">
<div id="7143fdf6" class="cell" data-execution_count="7">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># SLR model scatter plot </span></span>
Expand Down Expand Up @@ -766,15 +766,15 @@ <h2 data-number="11.3" class="anchored" data-anchor-id="summary-loss-optimizatio
<h2 data-number="11.4" class="anchored" data-anchor-id="comparing-loss-functions"><span class="header-section-number">11.4</span> Comparing Loss Functions</h2>
<p>We’ve now tried our hand at fitting a model under both MSE and MAE cost functions. How do the two results compare?</p>
<p>Let’s consider a dataset where each entry represents the number of drinks sold at a bubble tea store each day. We’ll fit a constant model to predict the number of drinks that will be sold tomorrow.</p>
<div id="2b1da68d" class="cell" data-execution_count="8">
<div id="4c550f61" class="cell" data-execution_count="8">
<div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a>drinks <span class="op">=</span> np.array([<span class="dv">20</span>, <span class="dv">21</span>, <span class="dv">22</span>, <span class="dv">29</span>, <span class="dv">33</span>])</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>drinks</span></code><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></pre></div>
<div class="cell-output cell-output-display" data-execution_count="8">
<pre><code>array([20, 21, 22, 29, 33])</code></pre>
</div>
</div>
<p>From our derivations above, we know that the optimal model parameter under MSE cost is the mean of the dataset. Under MAE cost, the optimal parameter is the median of the dataset.</p>
<div id="420bce20" class="cell" data-execution_count="9">
<div id="dfadd5d4" class="cell" data-execution_count="9">
<div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a>np.mean(drinks), np.median(drinks)</span></code><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></pre></div>
<div class="cell-output cell-output-display" data-execution_count="9">
<pre><code>(np.float64(25.0), np.float64(22.0))</code></pre>
Expand All @@ -784,7 +784,7 @@ <h2 data-number="11.4" class="anchored" data-anchor-id="comparing-loss-functions
<p><img src="images/error.png" alt="error" width="600"></p>
<p>Notice that the MSE above is a <strong>smooth</strong> function – it is differentiable at all points, making it easy to minimize using numerical methods. The MAE, in contrast, is not differentiable at each of its “kinks.” We’ll explore how the smoothness of the cost function can impact our ability to apply numerical optimization in a few weeks.</p>
<p>How do outliers affect each cost function? Imagine we replace the largest value in the dataset with 1000. The mean of the data increases substantially, while the median is nearly unaffected.</p>
<div id="f63ccf25" class="cell" data-execution_count="10">
<div id="e4129863" class="cell" data-execution_count="10">
<div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a>drinks_with_outlier <span class="op">=</span> np.append(drinks, <span class="dv">1033</span>)</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>display(drinks_with_outlier)</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>np.mean(drinks_with_outlier), np.median(drinks_with_outlier)</span></code><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></pre></div>
Expand All @@ -798,7 +798,7 @@ <h2 data-number="11.4" class="anchored" data-anchor-id="comparing-loss-functions
<p><img src="images/outliers.png" alt="outliers" width="700"></p>
<p>This means that under the MSE, the optimal model parameter <span class="math inline">\(\hat{\theta}\)</span> is strongly affected by the presence of outliers. Under the MAE, the optimal parameter is not as influenced by outlying data. We can generalize this by saying that the MSE is <strong>sensitive</strong> to outliers, while the MAE is <strong>robust</strong> to outliers.</p>
<p>Let’s try another experiment. This time, we’ll add an additional, non-outlying datapoint to the data.</p>
<div id="e2427f8c" class="cell" data-execution_count="11">
<div id="1c99a887" class="cell" data-execution_count="11">
<div class="sourceCode cell-code" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a>drinks_with_additional_observation <span class="op">=</span> np.append(drinks, <span class="dv">35</span>)</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>drinks_with_additional_observation</span></code><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></pre></div>
<div class="cell-output cell-output-display" data-execution_count="11">
Expand Down Expand Up @@ -870,7 +870,7 @@ <h2 data-number="11.5" class="anchored" data-anchor-id="transformations-to-fit-l
</ul>
<p>Other goals in addition to linearity are possible, for example, making data appear more symmetric. Linearity allows us to fit lines to the transformed data.</p>
<p>Let’s revisit our dugongs example. The lengths and ages are plotted below:</p>
<div id="df9473bd" class="cell" data-execution_count="12">
<div id="e9258db0" class="cell" data-execution_count="12">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="co"># `corrcoef` computes the correlation coefficient between two variables</span></span>
Expand Down Expand Up @@ -902,7 +902,7 @@ <h2 data-number="11.5" class="anchored" data-anchor-id="transformations-to-fit-l
<p>Looking at the plot on the left, we see that there is a slight curvature to the data points. Plotting the SLR curve on the right results in a poor fit.</p>
<p>For SLR to perform well, we’d like there to be a rough linear trend relating <code>"Age"</code> and <code>"Length"</code>. What is making the raw data deviate from a linear relationship? Notice that the data points with <code>"Length"</code> greater than 2.6 have disproportionately high values of <code>"Age"</code> relative to the rest of the data. If we could manipulate these data points to have lower <code>"Age"</code> values, we’d “shift” these points downwards and reduce the curvature in the data. Applying a logarithmic transformation to <span class="math inline">\(y_i\)</span> (that is, taking <span class="math inline">\(\log(\)</span> <code>"Age"</code> <span class="math inline">\()\)</span> ) would achieve just that.</p>
<p>An important word on <span class="math inline">\(\log\)</span>: in Data 100 (and most upper-division STEM courses), <span class="math inline">\(\log\)</span> denotes the natural logarithm with base <span class="math inline">\(e\)</span>. The base-10 logarithm, where relevant, is indicated by <span class="math inline">\(\log_{10}\)</span>.</p>
<div id="c90a089f" class="cell" data-execution_count="13">
<div id="ba4042f3" class="cell" data-execution_count="13">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a>z <span class="op">=</span> np.log(y)</span>
Expand Down Expand Up @@ -937,7 +937,7 @@ <h2 data-number="11.5" class="anchored" data-anchor-id="transformations-to-fit-l
<p><span class="math display">\[\log{(y)} = \theta_0 + \theta_1 x\]</span> <span class="math display">\[y = e^{\theta_0 + \theta_1 x}\]</span> <span class="math display">\[y = (e^{\theta_0})e^{\theta_1 x}\]</span> <span class="math display">\[y_i = C e^{k x}\]</span></p>
<p>For some constants <span class="math inline">\(C\)</span> and <span class="math inline">\(k\)</span>.</p>
<p><span class="math inline">\(y\)</span> is an <em>exponential</em> function of <span class="math inline">\(x\)</span>. Applying an exponential fit to the untransformed variables corroborates this finding.</p>
<div id="39aaefe9" class="cell" data-execution_count="14">
<div id="182e7e2b" class="cell" data-execution_count="14">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a>plt.figure(dpi<span class="op">=</span><span class="dv">120</span>, figsize<span class="op">=</span>(<span class="dv">4</span>, <span class="dv">3</span>))</span>
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 9be6715

Please sign in to comment.