diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index c7f92b86..0f4f17b3 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -84,7 +84,7 @@ def reparametrize( module: nn.Module, named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]], allow_missing: bool = False, -) -> Generator[nn.Module, None, None]: +) -> Generator[nn.Module]: """Reparameterize the module parameters and/or buffers.""" if not isinstance(named_tensors, dict): named_tensors = dict(named_tensors) diff --git a/tutorials/1_Functional_Optimizer.ipynb b/tutorials/1_Functional_Optimizer.ipynb index afc55f38..3f21465b 100644 --- a/tutorials/1_Functional_Optimizer.ipynb +++ b/tutorials/1_Functional_Optimizer.ipynb @@ -1,588 +1,539 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# TorchOpt as Functional Optimizer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/1_Functional_Optimizer.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this tutorial, we will introduce how TorchOpt can be treated as functional optimizer to conduct normal optimization with functional programming style. We will also illustrate how to conduct differentiable optimization with functional programming in PyTorch." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Basic API\n", - "\n", - "In this first part, we will illustrate how TorchOpt can be used as a functional optimizer. We compare it with different API in [JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org) to help understand the similarity and dissimilarity. We use simple network, Adam optimizer and MSE loss objective." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from collections import OrderedDict\n", - "\n", - "import functorch\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import optax\n", - "import torch\n", - "import torch.autograd\n", - "import torch.nn as nn\n", - "\n", - "import torchopt\n", - "\n", - "\n", - "class Net(nn.Module):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, 1, bias=True)\n", - " nn.init.ones_(self.fc.weight)\n", - " nn.init.zeros_(self.fc.bias)\n", - "\n", - " def forward(self, x):\n", - " return self.fc(x)\n", - "\n", - "\n", - "def mse(inputs, targets):\n", - " return ((inputs - targets) ** 2).mean()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.1 Original JAX implementation\n", - "\n", - "The first example is JAX implementation coupled with [Optax](https://github.com/deepmind/optax), which belongs to functional programming style." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "def origin_jax():\n", - " batch_size = 1\n", - " dim = 1\n", - " params = OrderedDict([('weight', jnp.ones((dim, 1))), ('bias', jnp.zeros((1,)))])\n", - "\n", - " def model(params, x):\n", - " return jnp.matmul(x, params['weight']) + params['bias']\n", - "\n", - " # Obtain the `opt_state` that contains statistics for the optimizer\n", - " learning_rate = 1.0\n", - " optimizer = optax.adam(learning_rate)\n", - " opt_state = optimizer.init(params)\n", - "\n", - " def compute_loss(params, x, y):\n", - " pred = model(params, x)\n", - " return mse(pred, y)\n", - "\n", - " xs = 2 * jnp.ones((batch_size, dim))\n", - " ys = jnp.ones((batch_size, 1))\n", - "\n", - " grads = jax.grad(compute_loss)(params, xs, ys)\n", - " updates, opt_state = optimizer.update(grads, opt_state)\n", - "\n", - " print('Parameters before update:', params)\n", - " params = optax.apply_updates(params, updates)\n", - " print('Parameters after update:', params)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "OrderedDict([\n", - " ('weight', DeviceArray([[1.]], dtype=float32)),\n", - " ('bias', DeviceArray([0.], dtype=float32))\n", - "])\n", - "Parameters after update:\n", - "OrderedDict([\n", - " ('weight', DeviceArray([[6.735325e-06]], dtype=float32)),\n", - " ('bias', DeviceArray([-0.99999326], dtype=float32))\n", - "])\n" - ] - } - ], - "source": [ - "origin_jax()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.2 `functorch` with TorchOpt\n", - "\n", - "The second example is [`functorch`](https://pytorch.org/functorch) coupled with TorchOpt. It basically follows the same structure with the JAX example." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def interact_with_functorch():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - " model, params = functorch.make_functional(net) # get the functional version of the model\n", - "\n", - " # Obtain the `opt_state` that contains statistics for the optimizer\n", - " learning_rate = 1.0\n", - " optimizer = torchopt.adam(learning_rate)\n", - " opt_state = optimizer.init(params)\n", - "\n", - " xs = 2 * torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = model(params, xs)\n", - " loss = mse(pred, ys)\n", - "\n", - " grads = torch.autograd.grad(loss, params)\n", - " updates, opt_state = optimizer.update(grads, opt_state)\n", - "\n", - " print('Parameters before update:', params)\n", - " params = torchopt.apply_updates(params, updates)\n", - " print('Parameters after update:', params)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "(\n", - " Parameter containing: tensor([[1.]], requires_grad=True),\n", - " Parameter containing: tensor([0.], requires_grad=True)\n", - ")\n", - "Parameters after update:\n", - "(\n", - " Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", - " Parameter containing: tensor([-1.0000], requires_grad=True)\n", - ")\n" - ] - } - ], - "source": [ - "interact_with_functorch()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "TorchOpt also offers a wrapper `torchopt.FuncOptimizer` to make it easier to maintain the optimizer states." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "def interact_with_functorch_with_wrapper():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - " model, params = functorch.make_functional(net) # get the functional version of the model\n", - "\n", - " learning_rate = 1.0\n", - " optimizer = torchopt.FuncOptimizer(torchopt.adam(learning_rate))\n", - "\n", - " xs = 2 * torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = model(params, xs)\n", - " loss = mse(pred, ys)\n", - "\n", - " print('Parameters before update:', params)\n", - " params = optimizer.step(loss, params)\n", - " print('Parameters after update:', params)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "(\n", - " Parameter containing: tensor([[1.]], requires_grad=True),\n", - " Parameter containing: tensor([0.], requires_grad=True)\n", - ")\n", - "Parameters after update:\n", - "(\n", - " tensor([[6.6757e-06]], grad_fn=),\n", - " tensor([-1.0000], grad_fn=)\n", - ")\n" - ] - } - ], - "source": [ - "interact_with_functorch_with_wrapper()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.3 Full TorchOpt\n", - "\n", - "`torchopt.Optimizer` is the base class for our PyTorch-like optimizer. Combined with the functional optimizer `torchopt.sgd` and `torchopt.adam`, we can define our high-level API `torchopt.SGD` and `torchopt.Adam`. The third example is to illustrate that TorchOpt can also directly replace `torch.optim` with exactly the same usage. Note the API difference happens between `torchopt.adam()` and `torchopt.Adam()`." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "def full_torchopt():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - "\n", - " learning_rate = 1.0\n", - " # High-level API\n", - " optim = torchopt.Adam(net.parameters(), lr=learning_rate)\n", - " # Low-level API\n", - " optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate))\n", - "\n", - " xs = 2 * torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = net(xs)\n", - " loss = mse(pred, ys)\n", - "\n", - " print('Parameters before update:', dict(net.named_parameters()))\n", - " optim.zero_grad()\n", - " loss.backward()\n", - " optim.step()\n", - " print('Parameters after update:', dict(net.named_parameters()))" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "{\n", - " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", - "}\n", - "Parameters after update:\n", - "{\n", - " 'fc.weight': Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", - "}\n" - ] - } - ], - "source": [ - "full_torchopt()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.4 Original PyTorch\n", - "\n", - "The final example is to original PyTorch example with `torch.optim`." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "def origin_torch():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - "\n", - " learning_rate = 1.0\n", - " optim = torch.optim.Adam(net.parameters(), lr=learning_rate)\n", - "\n", - " xs = 2 * torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = net(xs)\n", - " loss = mse(pred, ys)\n", - "\n", - " print('Parameters before update:', dict(net.named_parameters()))\n", - " optim.zero_grad()\n", - " loss.backward()\n", - " optim.step()\n", - " print('Parameters after update:', dict(net.named_parameters()))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "{\n", - " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", - "}\n", - "Parameters after update:\n", - "{\n", - " 'fc.weight': Parameter containing: tensor([[1.1921e-07]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", - "}\n" - ] - } - ], - "source": [ - "origin_torch()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Differentiable Optimization with Functional Optimizer\n", - "\n", - "Coupled with functional optimizer, you can conduct differentiable optimization by setting the `inplace` flag as `False` in update and `apply_updates` function. (which might be helpful for meta-learning algorithm implementation with functional programming style). \n", - "\n", - "Note that `torchopt.SGD` and `torchopt.Adam` do not support differentiable optimization. Refer to the Meta-Optimizer notebook for PyTorch-like differentiable optimizers." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "def differentiable():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - " model, params = functorch.make_functional(net) # get the functional version of the model\n", - "\n", - " # Meta-parameter\n", - " meta_param = nn.Parameter(torch.ones(1))\n", - "\n", - " # SGD example\n", - " learning_rate = 1.0\n", - " optimizer = torchopt.sgd(learning_rate)\n", - " opt_state = optimizer.init(params)\n", - "\n", - " xs = torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = model(params, xs)\n", - " # Where meta_param is used\n", - " pred = pred + meta_param\n", - " loss = mse(pred, ys)\n", - "\n", - " grads = torch.autograd.grad(loss, params, create_graph=True)\n", - " updates, opt_state = optimizer.update(grads, opt_state, inplace=False)\n", - " # Update parameters with single step SGD update\n", - " params = torchopt.apply_updates(params, updates, inplace=False)\n", - "\n", - " pred = model(params, xs)\n", - " loss = mse(pred, ys)\n", - " loss.backward()\n", - "\n", - " print('Gradient for the meta-parameter:', meta_param.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Gradient for the meta-parameter: tensor([32.])\n" - ] - } - ], - "source": [ - "differentiable()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.1 Track the Gradient of Momentum\n", - "\n", - "Note that most modern optimizers involve momentum term in the gradient update (basically only SGD with `momentum = 0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through momentum term. The default option is `moment_requires_grad=True`." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "optim = torchopt.adam(lr=1.0, moment_requires_grad=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "optim = torchopt.adam(lr=1.0, moment_requires_grad=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "optim = torchopt.sgd(lr=1.0, momentum=0.8, moment_requires_grad=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Accelerated Optimizer\n", - "\n", - "Users can use accelerated optimizer by setting the `use_accelerated_op` as `True`. Currently we only support the Adam optimizer." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Check whether the `accelerated_op` is available:" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] + "cells" : [ + {"cell_type" : "markdown", "metadata" : {}, "source" : ["# TorchOpt as Functional Optimizer"]}, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["[](https://" + "colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/" + "1_Functional_Optimizer.ipynb)"] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["In this tutorial, we will introduce how TorchOpt can be treated as functional " + "optimizer to conduct normal optimization with functional programming style. We " + "will also illustrate how to conduct differentiable optimization with functional " + "programming in PyTorch."] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "## 1. Basic API\n", + "\n", + "In this first part, we will illustrate how TorchOpt can be used as a functional " + "optimizer. We compare it with different API in [JAX](https://github.com/google/jax) and " + "[PyTorch](https://pytorch.org) to help understand the similarity and dissimilarity. We " + "use simple network, Adam optimizer and MSE loss objective." + ] + }, + { + "cell_type" : "code", + "execution_count" : 1, + "metadata" : {}, + "outputs" : [], + "source" : [ + "from collections import OrderedDict\n", + "\n", + "import functorch\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import optax\n", + "import torch\n", + "import torch.autograd\n", + "import torch.nn as nn\n", + "\n", + "import torchopt\n", + "\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1, bias=True)\n", + " nn.init.ones_(self.fc.weight)\n", + " nn.init.zeros_(self.fc.bias)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "def mse(inputs, targets):\n", + " return ((inputs - targets) ** 2).mean()" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "### 1.1 Original JAX implementation\n", + "\n", + "The first example is JAX implementation coupled with " + "[Optax](https://github.com/deepmind/optax), which belongs to functional programming style." + ] + }, + { + "cell_type" : "code", + "execution_count" : 2, + "metadata" : {}, + "outputs" : [], + "source" : [ + "def origin_jax():\n", + " batch_size = 1\n", + " dim = 1\n", + " params = OrderedDict([('weight', jnp.ones((dim, 1))), ('bias', jnp.zeros((1,)))])\n", + "\n", + " def model(params, x):\n", + " return jnp.matmul(x, params['weight']) + params['bias']\n", + "\n", + " # Obtain the `opt_state` that contains statistics for the optimizer\n", + " learning_rate = 1.0\n", + " optimizer = optax.adam(learning_rate)\n", + " opt_state = optimizer.init(params)\n", + "\n", + " def compute_loss(params, x, y):\n", + " pred = model(params, x)\n", + " return mse(pred, y)\n", + "\n", + " xs = 2 * jnp.ones((batch_size, dim))\n", + " ys = jnp.ones((batch_size, 1))\n", + "\n", + " grads = jax.grad(compute_loss)(params, xs, ys)\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + "\n", + " print('Parameters before update:', params)\n", + " params = optax.apply_updates(params, updates)\n", + " print('Parameters after update:', params)" + ] + }, + { + "cell_type" : "code", + "execution_count" : 3, + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : [ + "Parameters before update:\n", + "OrderedDict([\n", + " ('weight', DeviceArray([[1.]], dtype=float32)),\n", + " ('bias', DeviceArray([0.], dtype=float32))\n", + "])\n", + "Parameters after update:\n", + "OrderedDict([\n", + " ('weight', DeviceArray([[6.735325e-06]], dtype=float32)),\n", + " ('bias', DeviceArray([-0.99999326], dtype=float32))\n", + "])\n" + ] + } ], + "source" : ["origin_jax()"] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "### 1.2 `functorch` with TorchOpt\n", + "\n", + "The second example is [`functorch`](https://pytorch.org/functorch) coupled with TorchOpt. " + "It basically follows the same structure with the JAX example." + ] + }, + { + "cell_type" : "code", + "execution_count" : 4, + "metadata" : {}, + "outputs" : [], + "source" : [ + "def interact_with_functorch():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + " model, params = functorch.make_functional(net) # get the functional version of the " + "model\n", + "\n", + " # Obtain the `opt_state` that contains statistics for the optimizer\n", + " learning_rate = 1.0\n", + " optimizer = torchopt.adam(learning_rate)\n", + " opt_state = optimizer.init(params)\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = model(params, xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " grads = torch.autograd.grad(loss, params)\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + "\n", + " print('Parameters before update:', params)\n", + " params = torchopt.apply_updates(params, updates)\n", + " print('Parameters after update:', params)" + ] + }, + { + "cell_type" : "code", + "execution_count" : 5, + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : [ + "Parameters before update:\n", + "(\n", + " Parameter containing: tensor([[1.]], requires_grad=True),\n", + " Parameter containing: tensor([0.], requires_grad=True)\n", + ")\n", + "Parameters after update:\n", + "(\n", + " Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", + " Parameter containing: tensor([-1.0000], requires_grad=True)\n", + ")\n" + ] + } ], + "source" : ["interact_with_functorch()"] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["TorchOpt also offers a wrapper `torchopt.FuncOptimizer` to make it easier to " + "maintain the optimizer states."] + }, + { + "cell_type" : "code", + "execution_count" : 6, + "metadata" : {}, + "outputs" : [], + "source" : [ + "def interact_with_functorch_with_wrapper():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + " model, params = functorch.make_functional(net) # get the functional version of the " + "model\n", + "\n", + " learning_rate = 1.0\n", + " optimizer = torchopt.FuncOptimizer(torchopt.adam(learning_rate))\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = model(params, xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " print('Parameters before update:', params)\n", + " params = optimizer.step(loss, params)\n", + " print('Parameters after update:', params)" + ] + }, + { + "cell_type" : "code", + "execution_count" : 7, + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : [ + "Parameters before update:\n", + "(\n", + " Parameter containing: tensor([[1.]], requires_grad=True),\n", + " Parameter containing: tensor([0.], requires_grad=True)\n", + ")\n", + "Parameters after update:\n", + "(\n", + " tensor([[6.6757e-06]], grad_fn=),\n", + " tensor([-1.0000], grad_fn=)\n", + ")\n" + ] + } ], + "source" : ["interact_with_functorch_with_wrapper()"] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "### 1.3 Full TorchOpt\n", + "\n", + "`torchopt.Optimizer` is the base class for our PyTorch-like optimizer. Combined with the " + "functional optimizer `torchopt.sgd` and `torchopt.adam`, we can define our high-level API " + "`torchopt.SGD` and `torchopt.Adam`. The third example is to illustrate that TorchOpt can " + "also directly replace `torch.optim` with exactly the same usage. Note the API difference " + "happens between `torchopt.adam()` and `torchopt.Adam()`." + ] + }, + { + "cell_type" : "code", + "execution_count" : 8, + "metadata" : {}, + "outputs" : [], + "source" : [ + "def full_torchopt():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + "\n", + " learning_rate = 1.0\n", + " # High-level API\n", + " optim = torchopt.Adam(net.parameters(), lr=learning_rate)\n", + " # Low-level API\n", + " optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate))\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = net(xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " print('Parameters before update:', dict(net.named_parameters()))\n", + " optim.zero_grad()\n", + " loss.backward()\n", + " optim.step()\n", + " print('Parameters after update:', dict(net.named_parameters()))" + ] + }, + { + "cell_type" : "code", + "execution_count" : 9, + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : [ + "Parameters before update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", + "}\n", + "Parameters after update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", + "}\n" + ] + } ], + "source" : ["full_torchopt()"] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "### 1.4 Original PyTorch\n", + "\n", + "The final example is to original PyTorch example with `torch.optim`." + ] + }, + { + "cell_type" : "code", + "execution_count" : 10, + "metadata" : {}, + "outputs" : [], + "source" : [ + "def origin_torch():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + "\n", + " learning_rate = 1.0\n", + " optim = torch.optim.Adam(net.parameters(), lr=learning_rate)\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = net(xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " print('Parameters before update:', dict(net.named_parameters()))\n", + " optim.zero_grad()\n", + " loss.backward()\n", + " optim.step()\n", + " print('Parameters after update:', dict(net.named_parameters()))" + ] + }, + { + "cell_type" : "code", + "execution_count" : 11, + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : [ + "Parameters before update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", + "}\n", + "Parameters after update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[1.1921e-07]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", + "}\n" + ] + } ], + "source" : ["origin_torch()"] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "## 2. Differentiable Optimization with Functional Optimizer\n", + "\n", + "Coupled with functional optimizer, you can conduct differentiable optimization by setting " + "the `inplace` flag as `False` in update and `apply_updates` function. (which might be " + "helpful for meta-learning algorithm implementation with functional programming style). \n", + "\n", + "Note that `torchopt.SGD` and `torchopt.Adam` do not support differentiable optimization. " + "Refer to the Meta-Optimizer notebook for PyTorch-like differentiable optimizers." + ] + }, + { + "cell_type" : "code", + "execution_count" : 12, + "metadata" : {}, + "outputs" : [], + "source" : [ + "def differentiable():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + " model, params = functorch.make_functional(net) # get the functional version of the " + "model\n", + "\n", + " # Meta-parameter\n", + " meta_param = nn.Parameter(torch.ones(1))\n", + "\n", + " # SGD example\n", + " learning_rate = 1.0\n", + " optimizer = torchopt.sgd(learning_rate)\n", + " opt_state = optimizer.init(params)\n", + "\n", + " xs = torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = model(params, xs)\n", + " # Where meta_param is used\n", + " pred = pred + meta_param\n", + " loss = mse(pred, ys)\n", + "\n", + " grads = torch.autograd.grad(loss, params, create_graph=True)\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=False)\n", + " # Update parameters with single step SGD update\n", + " params = torchopt.apply_updates(params, updates, inplace=False)\n", + "\n", + " pred = model(params, xs)\n", + " loss = mse(pred, ys)\n", + " loss.backward()\n", + "\n", + " print('Gradient for the meta-parameter:', meta_param.grad)" + ] + }, + { + "cell_type" : "code", + "execution_count" : 13, + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : ["Gradient for the meta-parameter: tensor([32.])\n"] + } ], + "source" : ["differentiable()"] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "### 2.1 Track the Gradient of Momentum\n", + "\n", + "Note that most modern optimizers involve momentum term in the gradient update (basically " + "only SGD with `momentum = 0` does not involve). We provide an option for user to choose " + "whether to also track the meta-gradient through momentum term. The default option is " + "`moment_requires_grad=True`." + ] + }, + { + "cell_type" : "code", + "execution_count" : 14, + "metadata" : {}, + "outputs" : [], + "source" : ["optim = torchopt.adam(lr=1.0, moment_requires_grad=False)"] + }, + { + "cell_type" : "code", + "execution_count" : 15, + "metadata" : {}, + "outputs" : [], + "source" : ["optim = torchopt.adam(lr=1.0, moment_requires_grad=True)"] + }, + { + "cell_type" : "code", + "execution_count" : 16, + "metadata" : {}, + "outputs" : [], + "source" : ["optim = torchopt.sgd(lr=1.0, momentum=0.8, moment_requires_grad=True)"] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "## 3. Accelerated Optimizer\n", + "\n", + "Users can use accelerated optimizer by setting the `use_accelerated_op` as `True`. " + "Currently we only support the Adam optimizer." + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Check whether the `accelerated_op` is available:"] + }, + { + "cell_type" : "code", + "execution_count" : 17, + "metadata" : {}, + "outputs" : [ {"name" : "stdout", "output_type" : "stream", "text" : ["True\n"]} ], + "source" : ["torchopt.accelerated_op_available(torch.device('cpu'))"] + }, + { + "cell_type" : "code", + "execution_count" : 18, + "metadata" : {}, + "outputs" : [ {"name" : "stdout", "output_type" : "stream", "text" : ["True\n"]} ], + "source" : ["torchopt.accelerated_op_available(torch.device('cuda'))"] + }, + { + "cell_type" : "code", + "execution_count" : 19, + "metadata" : {}, + "outputs" : [], + "source" : [ + "net = Net(1).cuda()\n", + "optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True)" + ] + }, + { + "cell_type" : "code", + "execution_count" : 20, + "metadata" : {}, + "outputs" : [], + "source" : ["optim = torchopt.adam(lr=1.0, use_accelerated_op=True)"] } - ], - "source": [ - "torchopt.accelerated_op_available(torch.device('cpu'))" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] + ], + "metadata" : { + "kernelspec" : + {"display_name" : "Python 3 (ipykernel)", "language" : "python", "name" : "python3"}, + "language_info" : { + "codemirror_mode" : {"name" : "ipython", "version" : 3}, + "file_extension" : ".py", + "mimetype" : "text/x-python", + "name" : "python", + "nbconvert_exporter" : "python", + "pygments_lexer" : "ipython3", + "version" : "3.9.15" + }, + "vscode" : { + "interpreter" : + {"hash" : "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da"} } - ], - "source": [ - "torchopt.accelerated_op_available(torch.device('cuda'))" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "net = Net(1).cuda()\n", - "optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "optim = torchopt.adam(lr=1.0, use_accelerated_op=True)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" }, - "vscode": { - "interpreter": { - "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat" : 4, + "nbformat_minor" : 4 } diff --git a/tutorials/2_Visualization.ipynb b/tutorials/2_Visualization.ipynb index dd58c48d..17c85a7b 100644 --- a/tutorials/2_Visualization.ipynb +++ b/tutorials/2_Visualization.ipynb @@ -1,216 +1,616 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Visualization in TorchOpt" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/2_Visualization.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In [PyTorch](https://pytorch.org), if the attribute `requires_grad` of a tensor is `True`, the computation graph will be created if we use the tensor to do any operations. The computation graph is implemented like link-list -- `Tensor`s are nodes and they are linked by their attribute `gran_fn`. [PyTorchViz](https://github.com/szagoruyko/pytorchviz) is a Python package that uses [Graphviz](https://graphviz.org) as a backend for plotting computation graphs. TorchOpt use PyTorchViz as the blueprint and provide more easy-to-use visualization functions on the premise of supporting all its functions." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's start with a simple multiplication computation graph. We declared the variable `x` with flag `requires_grad=True` and compute `y = 2 * x`. Then we visualize the computation graph of `y`." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ + "cells" : [ + {"cell_type" : "markdown", "metadata" : {}, "source" : ["# Visualization in TorchOpt"]}, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["[](https://" + "colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/" + "2_Visualization.ipynb)"] }, { - "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140534064715952\n\ny\n()\n\n\n\n140534064838304\n\nMulBackward0\n\n\n\n140534064838304->140534064715952\n\n\n\n\n\n140534064837776\n\nAccumulateGrad\n\n\n\n140534064837776->140534064838304\n\n\n\n\n\n140534064714832\n\nx\n()\n\n\n\n140534064714832->140534064837776\n\n\n\n\n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from IPython.display import display\n", - "\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "import torchopt\n", - "\n", - "\n", - "x = torch.tensor(1.0, requires_grad=True)\n", - "y = 2 * x\n", - "display(torchopt.visual.make_dot(y, params={'x': x, 'y': y}))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The figure shows `y` is connected by the multiplication edge. The gradient of `y` will flow through the multiplication backward function then accumulated on `x`. Note that we pass a dictionary for adding node labels." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then let's plot a neural network. Note that we can pass the generator returned by method `named_parameters` for adding node labels." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["In [PyTorch](https://pytorch.org), if the attribute `requires_grad` of a tensor " + "is `True`, the computation graph will be created if we use the tensor to do any " + "operations. The computation graph is implemented like link-list -- `Tensor`s " + "are nodes and they are linked by their attribute `gran_fn`. " + "[PyTorchViz](https://github.com/szagoruyko/pytorchviz) is a Python package that " + "uses [Graphviz](https://graphviz.org) as a backend for plotting computation " + "graphs. TorchOpt use PyTorchViz as the blueprint and provide more easy-to-use " + "visualization functions on the premise of supporting all its functions."] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Let's start with a simple multiplication computation graph. We declared the " + "variable `x` with flag `requires_grad=True` and compute `y = 2 * x`. Then we " + "visualize the computation graph of `y`."] }, { - "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140534659780336\n\nloss\n()\n\n\n\n140531595570768\n\nMseLossBackward0\n\n\n\n140531595570768->140534659780336\n\n\n\n\n\n140531595570576\n\nAddmmBackward0\n\n\n\n140531595570576->140531595570768\n\n\n\n\n\n140531595570528\n\nAccumulateGrad\n\n\n\n140531595570528->140531595570576\n\n\n\n\n\n140531595583632\n\nfc.bias\n(1)\n\n\n\n140531595583632->140531595570528\n\n\n\n\n\n140531595571104\n\nTBackward0\n\n\n\n140531595571104->140531595570576\n\n\n\n\n\n140531595570432\n\nAccumulateGrad\n\n\n\n140531595570432->140531595571104\n\n\n\n\n\n140531595582816\n\nfc.weight\n(1, 5)\n\n\n\n140531595582816->140531595570432\n\n\n\n\n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "class Net(nn.Module):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, 1, bias=True)\n", - "\n", - " def forward(self, x):\n", - " return self.fc(x)\n", - "\n", - "\n", - "dim = 5\n", - "batch_size = 2\n", - "net = Net(dim)\n", - "xs = torch.ones((batch_size, dim))\n", - "ys = torch.ones((batch_size, 1))\n", - "pred = net(xs)\n", - "loss = F.mse_loss(pred, ys)\n", - "\n", - "display(torchopt.visual.make_dot(loss, params=(net.named_parameters(), {'loss': loss})))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The computation graph of meta-learning algorithms will be much more complex. Our visualization tool allows users take as input the extracted network state for better visualization." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ + "cell_type" : "code", + "execution_count" : 1, + "metadata" : {}, + "outputs" : [ + { + "name" : "stdout", + "output_type" : "stream", + "text" : ["\n"] + }, + { + "data" : { + "image/svg+xml" : + "\n\n\n\n\n\n%3\n\n\n\n140534064715952\n\ny\n()\n\n\n\n140534064838304\n\nMulBackward0\n\n\n\n140534064838304->140534064715952\n\n\n\n\n\n140534064837776\n\nAccumulateGrad\n\n\n\n140534064837776->140534064838304\n\n\n\n\n\n140534064714832\n\nx\n()\n\n\n\n140534064714832->140534064837776\n\n\n\n\n\n" + }, + "metadata" : {}, + "output_type" : "display_data" + } + ], + "source" : [ + "from IPython.display import display\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torchopt\n", + "\n", + "\n", + "x = torch.tensor(1.0, requires_grad=True)\n", + "y = 2 * x\n", + "display(torchopt.visual.make_dot(y, params={'x': x, 'y': y}))" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["The figure shows `y` is connected by the multiplication edge. The gradient of " + "`y` will flow through the multiplication backward function then accumulated on " + "`x`. Note that we pass a dictionary for adding node labels."] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Then let's plot a neural network. Note that we can pass the generator returned " + "by method `named_parameters` for adding node labels."] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] + "cell_type" : "code", + "execution_count" : 2, + "metadata" : {}, + "outputs" : [ + { + "name" : "stdout", + "output_type" : "stream", + "text" : ["\n"] + }, + { + "data" : { + "image/svg+xml" : + "\n\n\n\n\n\n%3\n\n\n\n140534659780336\n\nloss\n()\n\n\n\n140531595570768\n\nMseLossBackward0\n\n\n\n140531595570768->140534659780336\n\n\n\n\n\n140531595570576\n\nAddmmBackward0\n\n\n\n140531595570576->140531595570768\n\n\n\n\n\n140531595570528\n\nAccumulateGrad\n\n\n\n140531595570528->140531595570576\n\n\n\n\n\n140531595583632\n\nfc.bias\n(1)\n\n\n\n140531595583632->140531595570528\n\n\n\n\n\n140531595571104\n\nTBackward0\n\n\n\n140531595571104->140531595570576\n\n\n\n\n\n140531595570432\n\nAccumulateGrad\n\n\n\n140531595570432->140531595571104\n\n\n\n\n\n140531595582816\n\nfc.weight\n(1, 5)\n\n\n\n140531595582816->140531595570432\n\n\n\n\n\n" + }, + "metadata" : {}, + "output_type" : "display_data" + } + ], + "source" : [ + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1, bias=True)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "dim = 5\n", + "batch_size = 2\n", + "net = Net(dim)\n", + "xs = torch.ones((batch_size, dim))\n", + "ys = torch.ones((batch_size, 1))\n", + "pred = net(xs)\n", + "loss = F.mse_loss(pred, ys)\n", + "\n", + "display(torchopt.visual.make_dot(loss, params=(net.named_parameters(), {'loss': loss})))" + ] }, { - "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140531595614064\n\nloss\n()\n\n\n\n140531595567168\n\nMseLossBackward0\n\n\n\n140531595567168->140531595614064\n\n\n\n\n\n140531595569232\n\nAddBackward0\n\n\n\n140531595569232->140531595567168\n\n\n\n\n\n140531595568800\n\nAddmmBackward0\n\n\n\n140531595568800->140531595569232\n\n\n\n\n\n140534660247264\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140534660247264->140531595568800\n\n\n\n\n\n140534553595376\n\nAccumulateGrad\n\n\n\n140534553595376->140534660247264\n\n\n\n\n\n140534553592832\n\nAddmmBackward0\n\n\n\n140534553595376->140534553592832\n\n\n\n\n\n140534064448352\n\nstep0.fc.bias\n(1)\n\n\n\n140534064448352->140534553595376\n\n\n\n\n\n140534553595616\n\nMulBackward0\n\n\n\n140534553595616->140534660247264\n\n\n\n\n\n140534553594848\n\nViewBackward0\n\n\n\n140534553594848->140534553595616\n\n\n\n\n\n140534553594992\n\nSumBackward1\n\n\n\n140534553594992->140534553594848\n\n\n\n\n\n140534553594800\n\nMseLossBackwardBackward0\n\n\n\n140534553594800->140534553594992\n\n\n\n\n\n140531595617904\n\nTBackward0\n\n\n\n140534553594800->140531595617904\n\n\n\n\n\n140534553593072\n\nAddBackward0\n\n\n\n140534553593072->140534553594800\n\n\n\n\n\n140534553592832->140534553593072\n\n\n\n\n\n140534553593456\n\nTBackward0\n\n\n\n140534553593456->140534553592832\n\n\n\n\n\n140534553593888\n\nAccumulateGrad\n\n\n\n140534553593888->140534553593456\n\n\n\n\n\n140531595572368\n\nAddBackward0\nstep1.fc.weight\n(1, 5)\n\n\n\n140534553593888->140531595572368\n\n\n\n\n\n140531595612944\n\nstep0.fc.weight\n(1, 5)\n\n\n\n140531595612944->140534553593888\n\n\n\n\n\n140531595567888\n\nAccumulateGrad\n\n\n\n140531595567888->140531595569232\n\n\n\n\n\n140531595567888->140534553593072\n\n\n\n\n\n140531595613184\n\nmeta_param\n()\n\n\n\n140531595613184->140531595567888\n\n\n\n\n\n140534553594272\n\nTBackward0\n\n\n\n140534553594272->140531595568800\n\n\n\n\n\n140531595572368->140534553594272\n\n\n\n\n\n140534553593504\n\nMulBackward0\n\n\n\n140534553593504->140531595572368\n\n\n\n\n\n140534553592976\n\nTBackward0\n\n\n\n140534553592976->140534553593504\n\n\n\n\n\n140534553593216\n\nTBackward0\n\n\n\n140534553593216->140534553592976\n\n\n\n\n\n140534553593552\n\nMmBackward0\n\n\n\n140534553593552->140534553593216\n\n\n\n\n\n140531595617904->140534553593552\n\n\n\n\n\n" - }, - "metadata": {}, - "output_type": "display_data" + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["The computation graph of meta-learning algorithms will be much more complex. " + "Our visualization tool allows users take as input the extracted network state " + "for better visualization."] + }, + { + "cell_type" : "code", + "execution_count" : 3, + "metadata" : {}, + "outputs" : [ + { + "name" : "stdout", + "output_type" : "stream", + "text" : ["\n"] + }, + { + "data" : { + "image/svg+xml" : + "\n\n\n\n\n\n%3\n\n\n\n140531595614064\n\nloss\n()\n\n\n\n140531595567168\n\nMseLossBackward0\n\n\n\n140531595567168->140531595614064\n\n\n\n\n\n140531595569232\n\nAddBackward0\n\n\n\n140531595569232->140531595567168\n\n\n\n\n\n140531595568800\n\nAddmmBackward0\n\n\n\n140531595568800->140531595569232\n\n\n\n\n\n140534660247264\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140534660247264->140531595568800\n\n\n\n\n\n140534553595376\n\nAccumulateGrad\n\n\n\n140534553595376->140534660247264\n\n\n\n\n\n140534553592832\n\nAddmmBackward0\n\n\n\n140534553595376->140534553592832\n\n\n\n\n\n140534064448352\n\nstep0.fc.bias\n(1)\n\n\n\n140534064448352->140534553595376\n\n\n\n\n\n140534553595616\n\nMulBackward0\n\n\n\n140534553595616->140534660247264\n\n\n\n\n\n140534553594848\n\nViewBackward0\n\n\n\n140534553594848->140534553595616\n\n\n\n\n\n140534553594992\n\nSumBackward1\n\n\n\n140534553594992->140534553594848\n\n\n\n\n\n140534553594800\n\nMseLossBackwardBackward0\n\n\n\n140534553594800->140534553594992\n\n\n\n\n\n140531595617904\n\nTBackward0\n\n\n\n140534553594800->140531595617904\n\n\n\n\n\n140534553593072\n\nAddBackward0\n\n\n\n140534553593072->140534553594800\n\n\n\n\n\n140534553592832->140534553593072\n\n\n\n\n\n140534553593456\n\nTBackward0\n\n\n\n140534553593456->140534553592832\n\n\n\n\n\n140534553593888\n\nAccumulateGrad\n\n\n\n140534553593888->140534553593456\n\n\n\n\n\n140531595572368\n\nAddBackward0\nstep1.fc.weight\n(1, " + "5)\n\n\n\n140534553593888->140531595572368\n\n\n\n\n\n140531595612944\n\nstep0.fc.weight\n(1, " + "5)\n\n\n\n140531595612944->140534553593888\n\n\n\n\n\n140531595567888\n\nAccumulateGrad\n\n\n\n140531595567888->140531595569232\n\n\n\n\n\n140531595567888->140534553593072\n\n\n\n\n\n140531595613184\n\nmeta_param\n()\n\n\n\n140531595613184->140531595567888\n\n\n\n\n\n140534553594272\n\nTBackward0\n\n\n\n140534553594272->140531595568800\n\n\n\n\n\n140531595572368->140534553594272\n\n\n\n\n\n140534553593504\n\nMulBackward0\n\n\n\n140534553593504->140531595572368\n\n\n\n\n\n140534553592976\n\nTBackward0\n\n\n\n140534553592976->140534553593504\n\n\n\n\n\n140534553593216\n\nTBackward0\n\n\n\n140534553593216->140534553592976\n\n\n\n\n\n140534553593552\n\nMmBackward0\n\n\n\n140534553593552->140534553593216\n\n\n\n\n\n140531595617904->140534553593552\n\n\n\n\n\n" + }, + "metadata" : {}, + "output_type" : "display_data" + } + ], + "source" : [ + "class MetaNet(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1, bias=True)\n", + "\n", + " def forward(self, x, meta_param):\n", + " return self.fc(x) + meta_param\n", + "\n", + "\n", + "dim = 5\n", + "batch_size = 2\n", + "net = MetaNet(dim)\n", + "\n", + "xs = torch.ones((batch_size, dim))\n", + "ys = torch.ones((batch_size, 1))\n", + "\n", + "optimizer = torchopt.MetaSGD(net, lr=1e-3)\n", + "meta_param = torch.tensor(1.0, requires_grad=True)\n", + "\n", + "# Set enable_visual\n", + "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, " + "visual_prefix='step0.')\n", + "\n", + "pred = net(xs, meta_param)\n", + "loss = F.mse_loss(pred, ys)\n", + "optimizer.step(loss)\n", + "\n", + "# Set enable_visual\n", + "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, " + "visual_prefix='step1.')\n", + "\n", + "pred = net(xs, meta_param)\n", + "loss = F.mse_loss(pred, torch.ones_like(pred))\n", + "\n", + "# Draw computation graph\n", + "display(\n", + " torchopt.visual.make_dot(\n", + " loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]\n", + " )\n", + ")" + ] + } + ], + "metadata" : { + "kernelspec" : + {"display_name" : "Python 3 (ipykernel)", "language" : "python", "name" : "python3"}, + "language_info" : { + "codemirror_mode" : {"name" : "ipython", "version" : 3}, + "file_extension" : ".py", + "mimetype" : "text/x-python", + "name" : "python", + "nbconvert_exporter" : "python", + "pygments_lexer" : "ipython3", + "version" : "3.9.15" + }, + "vscode" : { + "interpreter" : + {"hash" : "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da"} } - ], - "source": [ - "class MetaNet(nn.Module):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, 1, bias=True)\n", - "\n", - " def forward(self, x, meta_param):\n", - " return self.fc(x) + meta_param\n", - "\n", - "\n", - "dim = 5\n", - "batch_size = 2\n", - "net = MetaNet(dim)\n", - "\n", - "xs = torch.ones((batch_size, dim))\n", - "ys = torch.ones((batch_size, 1))\n", - "\n", - "optimizer = torchopt.MetaSGD(net, lr=1e-3)\n", - "meta_param = torch.tensor(1.0, requires_grad=True)\n", - "\n", - "# Set enable_visual\n", - "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", - "\n", - "pred = net(xs, meta_param)\n", - "loss = F.mse_loss(pred, ys)\n", - "optimizer.step(loss)\n", - "\n", - "# Set enable_visual\n", - "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", - "\n", - "pred = net(xs, meta_param)\n", - "loss = F.mse_loss(pred, torch.ones_like(pred))\n", - "\n", - "# Draw computation graph\n", - "display(\n", - " torchopt.visual.make_dot(\n", - " loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]\n", - " )\n", - ")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" }, - "vscode": { - "interpreter": { - "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat" : 4, + "nbformat_minor" : 4 } diff --git a/tutorials/3_Meta_Optimizer.ipynb b/tutorials/3_Meta_Optimizer.ipynb index 69be77ed..23d7a575 100644 --- a/tutorials/3_Meta_Optimizer.ipynb +++ b/tutorials/3_Meta_Optimizer.ipynb @@ -1,713 +1,1584 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# TorchOpt as Meta-Optimizer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/3_Meta_Optimizer.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this tutorial, we will show how to treat TorchOpt as a differentiable optimizer with traditional PyTorch optimization API. In addition, we also provide many other API for easy meta-learning algorithm implementations." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Basic API for Differentiable Optimizer\n", - "\n", - "`MetaOptimizer` is the main class for our differentiable optimizer. Combined with the functional optimizer `torchopt.sgd` and `torchopt.adam` mentioned in the tutorial 1, we can define our high-level API `torchopt.MetaSGD` and `torchopt.MetaAdam`. We will discuss how this combination happens with `torchopt.chain` in Section 3. Let us consider the problem below." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Assume a tensor $x$ is a meta-parameter and $a$ is a normal parameters (such as network parameters). We have inner loss $\\mathcal{L}^{\\textrm{in}} = a_0 \\cdot x^2$ and we update $a$ use the gradient $\\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = x^2$ and $a_1 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x^2$. Then we compute the outer loss $\\mathcal{L}^{\\textrm{out}} = a_1 \\cdot x^2$. So the gradient of outer loss to $x$ would be:\n", - "\n", - "$$\n", - "\\begin{split}\n", - " \\frac{\\partial \\mathcal{L}^{\\textrm{out}}}{\\partial x}\n", - " & = \\frac{\\partial (a_1 \\cdot x^2)}{\\partial x} \\\\\n", - " & = \\frac{\\partial a_1}{\\partial x} \\cdot x^2 + a_1 \\cdot \\frac{\\partial (x^2)}{\\partial x} \\\\\n", - " & = \\frac{\\partial (a_0 - \\eta \\, x^2)}{\\partial x} \\cdot x^2 + (a_0 - \\eta \\, x^2) \\cdot 2 x \\\\\n", - " & = (- \\eta \\cdot 2 x) \\cdot x^2 + (a_0 - \\eta \\, x^2) \\cdot 2 x \\\\\n", - " & = - 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x\n", - "\\end{split}\n", - "$$" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Given the analytical solution above. Let's try to verify it with TorchOpt. Define the net work first." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from IPython.display import display\n", - "\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "import torchopt\n", - "\n", - "\n", - "class Net(nn.Module):\n", - " def __init__(self):\n", - " super().__init__()\n", - " self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", - "\n", - " def forward(self, x):\n", - " return self.a * (x**2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then we declare the network (parameterized by `a`) and the meta-parameter `x`. Do not forget to set flag `requires_grad=True` for `x`." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next we declare the meta-optimizer. Here we show two equivalent ways of defining the meta-optimizer. " - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# Low-level API\n", - "optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0))\n", - "\n", - "# High-level API\n", - "optim = torchopt.MetaSGD(net, lr=1.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The meta-optimizer takes the network as input and use method `step` to update the network (parameterized by `a`). Finally, we show how a bi-level process works." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x.grad = tensor(-28.)\n" - ] - } - ], - "source": [ - "inner_loss = net(x)\n", - "optim.step(inner_loss)\n", - "\n", - "outer_loss = net(x)\n", - "outer_loss.backward()\n", - "# x.grad = - 4 * lr * x^3 + 2 * a_0 * x\n", - "# = - 4 * 1 * 2^3 + 2 * 1 * 2\n", - "# = -32 + 4\n", - "# = -28\n", - "print(f'x.grad = {x.grad!r}')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.1 Track the Gradient of Momentum\n", - "\n", - "Note that most modern optimizers involve moment term in the gradient update (basically only SGD with `momentum=0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through moment term. The default option is `moment_requires_grad=True`." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "- When you do not track the meta-gradient through moment (`moment_requires_grad=False`)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140447553047184\n\nouter_loss\n()\n\n\n\n140447553041216\n\nMseLossBackward0\n\n\n\n140447553041216->140447553047184\n\n\n\n\n\n140447553042896\n\nMulBackward0\n\n\n\n140447553042896->140447553041216\n\n\n\n\n\n140447553019088\n\nAddBackward0\nstep1.a\n()\n\n\n\n140447553019088->140447553042896\n\n\n\n\n\n140447553041072\n\nAccumulateGrad\n\n\n\n140447553041072->140447553019088\n\n\n\n\n\n140447553043664\n\nMulBackward0\n\n\n\n140447553041072->140447553043664\n\n\n\n\n\n140447553045344\n\nstep0.a\n()\n\n\n\n140447553045344->140447553041072\n\n\n\n\n\n140447553041120\n\nMulBackward0\n\n\n\n140447553041120->140447553019088\n\n\n\n\n\n140447553043040\n\nDivBackward0\n\n\n\n140447553043040->140447553041120\n\n\n\n\n\n140447553043184\n\nDivBackward0\n\n\n\n140447553043184->140447553043040\n\n\n\n\n\n140447553043328\n\nAddBackward0\n\n\n\n140447553043328->140447553043184\n\n\n\n\n\n140447553043424\n\nMulBackward0\n\n\n\n140447553043424->140447553043328\n\n\n\n\n\n140447553043856\n\nAddcmulBackward0\n\n\n\n140447553043424->140447553043856\n\n\n\n\n\n140447553043424->140447553043856\n\n\n\n\n\n140447553043520\n\nMseLossBackwardBackward0\n\n\n\n140447553043520->140447553043424\n\n\n\n\n\n140447553043664->140447553043520\n\n\n\n\n\n140447553043472\n\nPowBackward0\n\n\n\n140447553043472->140447553043424\n\n\n\n\n\n140447553043472->140447553043664\n\n\n\n\n\n140447553043808\n\nAccumulateGrad\n\n\n\n140447553043808->140447553043472\n\n\n\n\n\n140447553041264\n\nPowBackward0\n\n\n\n140447553043808->140447553041264\n\n\n\n\n\n140447553045584\n\nx\n()\n\n\n\n140447553045584->140447553043808\n\n\n\n\n\n140447553043136\n\nAddBackward0\n\n\n\n140447553043136->140447553043040\n\n\n\n\n\n140447553043232\n\nSqrtBackward0\n\n\n\n140447553043232->140447553043136\n\n\n\n\n\n140447553043760\n\nAddBackward0\n\n\n\n140447553043760->140447553043232\n\n\n\n\n\n140447553043904\n\nDivBackward0\n\n\n\n140447553043904->140447553043760\n\n\n\n\n\n140447553043856->140447553043904\n\n\n\n\n\n140447553041264->140447553042896\n\n\n\n\n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", - "y = torch.tensor(1.0)\n", - "\n", - "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=False)\n", - "\n", - "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", - "inner_loss = F.mse_loss(net(x), y)\n", - "optim.step(inner_loss)\n", - "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", - "\n", - "outer_loss = F.mse_loss(net(x), y)\n", - "display(\n", - " torchopt.visual.make_dot(\n", - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "- When you track the meta-gradient through moment (`moment_requires_grad=True`, default for `torchopt.MetaAdam`)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140447553148704\n\nouter_loss\n()\n\n\n\n140447553041024\n\nMseLossBackward0\n\n\n\n140447553041024->140447553148704\n\n\n\n\n\n140447553043424\n\nMulBackward0\n\n\n\n140447553043424->140447553041024\n\n\n\n\n\n140450536407152\n\nAddBackward0\nstep1.a\n()\n\n\n\n140450536407152->140447553043424\n\n\n\n\n\n140447553041264\n\nAccumulateGrad\n\n\n\n140447553041264->140450536407152\n\n\n\n\n\n140447553019232\n\nMulBackward0\n\n\n\n140447553041264->140447553019232\n\n\n\n\n\n140447553148064\n\nstep0.a\n()\n\n\n\n140447553148064->140447553041264\n\n\n\n\n\n140447553041216\n\nMulBackward0\n\n\n\n140447553041216->140450536407152\n\n\n\n\n\n140447553041312\n\nDivBackward0\n\n\n\n140447553041312->140447553041216\n\n\n\n\n\n140447553041408\n\nDivBackward0\n\n\n\n140447553041408->140447553041312\n\n\n\n\n\n140447553043376\n\nAddBackward0\n\n\n\n140447553043376->140447553041408\n\n\n\n\n\n140447553041168\n\nMulBackward0\n\n\n\n140447553041168->140447553043376\n\n\n\n\n\n140447553042272\n\nAccumulateGrad\n\n\n\n140447553042272->140447553041168\n\n\n\n\n\n140450290826352\n\n()\n\n\n\n140450290826352->140447553042272\n\n\n\n\n\n140447553044432\n\nMulBackward0\n\n\n\n140447553044432->140447553043376\n\n\n\n\n\n140447553018320\n\nAddcmulBackward0\n\n\n\n140447553044432->140447553018320\n\n\n\n\n\n140447553044432->140447553018320\n\n\n\n\n\n140447553042080\n\nMseLossBackwardBackward0\n\n\n\n140447553042080->140447553044432\n\n\n\n\n\n140447553019232->140447553042080\n\n\n\n\n\n140447553019088\n\nPowBackward0\n\n\n\n140447553019088->140447553044432\n\n\n\n\n\n140447553019088->140447553019232\n\n\n\n\n\n140447553018464\n\nAccumulateGrad\n\n\n\n140447553018464->140447553019088\n\n\n\n\n\n140447553043328\n\nPowBackward0\n\n\n\n140447553018464->140447553043328\n\n\n\n\n\n140447553148144\n\nx\n()\n\n\n\n140447553148144->140447553018464\n\n\n\n\n\n140447553041456\n\nAddBackward0\n\n\n\n140447553041456->140447553041312\n\n\n\n\n\n140447553041360\n\nSqrtBackward0\n\n\n\n140447553041360->140447553041456\n\n\n\n\n\n140447553015920\n\nAddBackward0\n\n\n\n140447553015920->140447553041360\n\n\n\n\n\n140447553018560\n\nDivBackward0\n\n\n\n140447553018560->140447553015920\n\n\n\n\n\n140447553018320->140447553018560\n\n\n\n\n\n140447553018272\n\nMulBackward0\n\n\n\n140447553018272->140447553018320\n\n\n\n\n\n140447553018944\n\nAccumulateGrad\n\n\n\n140447553018944->140447553018272\n\n\n\n\n\n140450290824272\n\n()\n\n\n\n140450290824272->140447553018944\n\n\n\n\n\n140447553043328->140447553043424\n\n\n\n\n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", - "y = torch.tensor(1.0)\n", - "\n", - "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True)\n", - "\n", - "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", - "inner_loss = F.mse_loss(net(x), y)\n", - "optim.step(inner_loss)\n", - "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", - "\n", - "outer_loss = F.mse_loss(net(x), y)\n", - "display(\n", - " torchopt.visual.make_dot(\n", - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can see that the additional moment terms are added into the computational graph when we set `moment_requires_grad=True`." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Extract and Recover" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.1 Basic API\n", - "\n", - "We observe that how to reinitialize the inner-loop parameter in a new bi-level process vary in different meta-learning algorithms. For instance, in algorithm like Model-Agnostic Meta-Learning (MAML) ([arXiv:1703.03400](https://arxiv.org/abs/1703.03400)), every time a new task comes, we need to reset the parameters to the initial ones. In other cases such as Meta-Gradient Reinforcement Learning (MGRL) ([arXiv:1805.09801](https://arxiv.org/abs/1805.09801)), the inner-loop network parameter just inherit previous updated parameter to continue the new bi-level process.\n", - "\n", - "We provide the `torchopt.extract_state_dict` and `torchopt.recover_state_dict` functions to extract and restore the state of network and optimizer. By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). You can also set `by='copy'` to extract the copy of the state dictionary or set `by='deepcopy'` to have a detached copy." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "a = tensor(-1.0000, grad_fn=)\n", - "a = tensor(-1.0000, grad_fn=)\n" - ] - } - ], - "source": [ - "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", - "\n", - "optim = torchopt.MetaAdam(net, lr=1.0)\n", - "\n", - "# Get the reference of state dictionary\n", - "init_net_state = torchopt.extract_state_dict(net, by='reference')\n", - "init_optim_state = torchopt.extract_state_dict(optim, by='reference')\n", - "# If set `detach_buffers=True`, the parameters are referenced as references while buffers are detached copies\n", - "init_net_state = torchopt.extract_state_dict(net, by='reference', detach_buffers=True)\n", - "\n", - "# Set `copy` to get the copy of the state dictionary\n", - "init_net_state_copy = torchopt.extract_state_dict(net, by='copy')\n", - "init_optim_state_copy = torchopt.extract_state_dict(optim, by='copy')\n", - "\n", - "# Set `deepcopy` to get the detached copy of state dictionary\n", - "init_net_state_deepcopy = torchopt.extract_state_dict(net, by='deepcopy')\n", - "init_optim_state_deepcopy = torchopt.extract_state_dict(optim, by='deepcopy')\n", - "\n", - "# Conduct 2 inner-loop optimization\n", - "for i in range(2):\n", - " inner_loss = net(x)\n", - " optim.step(inner_loss)\n", - "\n", - "print(f'a = {net.a!r}')\n", - "\n", - "# Recover and reconduct 2 inner-loop optimization\n", - "torchopt.recover_state_dict(net, init_net_state)\n", - "torchopt.recover_state_dict(optim, init_optim_state)\n", - "\n", - "for i in range(2):\n", - " inner_loss = net(x)\n", - " optim.step(inner_loss)\n", - "\n", - "print(f'a = {net.a!r}') # the same result" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.2 Multi-task Example with `extract_state_dict` and `recover_state_dict`\n", - "\n", - "Let's move to another more complex setting. Meta-Learning algorithms always fix network on several different tasks and accumulate outer loss of each task to the meta-gradient.\n", - "\n", - "Assume $x$ is a meta-parameter and $a$ is a normal parameter. We firstly update $a$ use inner loss $\\mathcal{L}_1^{\\textrm{in}} = a_0 \\cdot x^2$ to $a_1$. Then we use $a_1$ to compute the outer loss $\\mathcal{L}_1^{\\textrm{out}} = a_1 \\cdot x^2$ and backpropagate it. Then we use $a_0$ to compute the inner loss $\\mathcal{L}_2^{\\textrm{in}} = a_0 \\cdot x$ and update $a_0$ to $a_2 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}_2^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x$. Then we compute outer loss $\\mathcal{L}_2^{\\textrm{out}} = a_2 \\cdot x$ and backpropagate it. So the accumulated meta-gradient would be:\n", - "\n", - "$$\n", - "\\begin{split}\n", - " \\frac{\\partial \\mathcal{L}_1^{\\textrm{out}}}{\\partial x} + \\frac{\\partial \\mathcal{L}_2^{\\textrm{out}}}{\\partial x}\n", - " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + \\frac{\\partial (a_2 \\cdot x)}{\\partial x} \\\\\n", - " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + (\\frac{\\partial a_2}{\\partial x} \\cdot x + a_2) \\\\\n", - " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + [\\frac{\\partial (a_0 - \\eta \\, x)}{\\partial x} \\cdot x + (a_0 - \\eta \\, x)] \\\\\n", - " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + [(- \\eta) \\cdot x + (a_0 - \\eta \\, x)] \\\\\n", - " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + (- 2 \\, \\eta \\, x + a_0)\n", - "\\end{split}\n", - "$$" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's define the network and variables first." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "class Net2Tasks(nn.Module):\n", - " def __init__(self):\n", - " super().__init__()\n", - " self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", - "\n", - " def task1(self, x):\n", - " return self.a * x**2\n", - "\n", - " def task2(self, x):\n", - " return self.a * x\n", - "\n", - "\n", - "net = Net2Tasks()\n", - "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", - "\n", - "optim = torchopt.MetaSGD(net, lr=1.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Once we call `step` method of `MetaOptimizer`, the parameters of the network would be changed. We should use `torchopt.extract_state_dict` to extract state and use `torchopt.recover_state_dict` to recover the state. Note that if we use optimizers that have momentum buffers, we should also extract and recover them, vanilla SGD does not have momentum buffers so code `init_optim_state = torchopt.extract_state_dict(optim)` and `torchopt.recover_state_dict(optim, init_optim_state)` have no effect." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "init_optim_state = ((EmptyState(),),)\n", - "Task 1: x.grad = tensor(-28.)\n", - "Accumulated: x.grad = tensor(-31.)\n" - ] - } - ], - "source": [ - "# Get the reference of state dictionary\n", - "init_net_state = torchopt.extract_state_dict(net, by='reference')\n", - "init_optim_state = torchopt.extract_state_dict(optim, by='reference')\n", - "# The `state_dict` is empty for vanilla SGD optimizer\n", - "print(f'init_optim_state = {init_optim_state!r}')\n", - "\n", - "inner_loss_1 = net.task1(x)\n", - "optim.step(inner_loss_1)\n", - "outer_loss_1 = net.task1(x)\n", - "outer_loss_1.backward()\n", - "print(f'Task 1: x.grad = {x.grad!r}')\n", - "\n", - "torchopt.recover_state_dict(net, init_net_state)\n", - "torchopt.recover_state_dict(optim, init_optim_state)\n", - "inner_loss_2 = net.task2(x)\n", - "optim.step(inner_loss_2)\n", - "outer_loss_2 = net.task2(x)\n", - "outer_loss_2.backward()\n", - "\n", - "# `extract_state_dict`` extracts the reference so gradient accumulates\n", - "# x.grad = (- 4 * lr * x^3 + 2 * a_0 * x) + (- 2 * lr * x + a_0)\n", - "# = (- 4 * 1 * 2^3 + 2 * 1 * 2) + (- 2 * 1 * 2 + 1)\n", - "# = -28 - 3\n", - "# = -31\n", - "print(f'Accumulated: x.grad = {x.grad!r}')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Gradient Transformation in `MetaOptimizer`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can also use some gradient normalization tricks in our `MetaOptimizer`. In fact `MetaOptimizer` decedents like `MetaSGD` are specializations of `MetaOptimizer`. Specifically, `MetaSGD(net, lr=1.)` is `MetaOptimizer(net, alias.sgd(lr=1., moment_requires_grad=True))`, where flag `moment_requires_grad=True` means the momentums are created with flag `requires_grad=True` so the momentums will also be the part of the computation graph.\n", - "\n", - "In the designing of TorchOpt, we treat these functions as derivations of `combine.chain`. So we can build our own chain like `combine.chain(clip.clip_grad_norm(max_norm=1.), sgd(lr=1., requires_grad=True))` to clip the gradient and update parameters using `sgd`.\n", - "\n", - "$$\n", - "\\begin{aligned}\n", - " \\frac{\\partial \\mathcal{L}^{\\textrm{out}}}{\\partial x}\n", - " & = \\frac{\\partial (a_1 \\cdot x^2)}{\\partial x} \\\\\n", - " & = \\frac{\\partial a_1}{\\partial x} \\cdot x^2 + a_1 \\cdot \\frac{\\partial (x^2)}{\\partial x} \\\\\n", - " & = \\frac{\\partial (a_0 - \\eta \\, g)}{\\partial x} \\cdot x^2 + (a_0 - \\eta \\, g) \\cdot 2 x & \\qquad (g \\propto \\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = x^2, \\ {\\lVert g \\rVert}_2 \\le G_{\\max}) \\\\\n", - " & = \\frac{\\partial (a_0 - \\eta \\, \\beta^{-1} \\, x^2)}{\\partial x} \\cdot x^2 + (a_0 - \\eta \\, \\beta^{-1} \\, x^2) \\cdot 2 x & \\qquad (g = \\beta^{-1} \\, x^2, \\ \\beta > 0, \\ {\\lVert g \\rVert}_2 \\le G_{\\max}) \\\\\n", - " & = (- \\beta^{-1} \\, \\eta \\cdot 2 x) \\cdot x^2 + (a_0 - \\beta^{-1} \\, \\eta \\, x^2) \\cdot 2 x \\\\\n", - " & = - 4 \\, \\beta^{-1} \\, \\eta \\, x^3 + 2 \\, a_0 \\, x\n", - "\\end{aligned}\n", - "$$" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x.grad = tensor(-12.0000)\n" - ] - } - ], - "source": [ - "net = Net()\n", - "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", - "\n", - "optim_impl = torchopt.combine.chain(\n", - " torchopt.clip.clip_grad_norm(max_norm=2.0),\n", - " torchopt.sgd(lr=1.0, moment_requires_grad=True),\n", - ")\n", - "optim = torchopt.MetaOptimizer(net, optim_impl)\n", - "\n", - "inner_loss = net(x)\n", - "optim.step(inner_loss)\n", - "\n", - "outer_loss = net(x)\n", - "outer_loss.backward()\n", - "# Since `max_norm` is 2 and the gradient is x^2, so the scale = x^2 / 2 = 2^2 / 2 = 2\n", - "# x.grad = - 4 * lr * x^3 / scale + 2 * a_0 * x\n", - "# = - 4 * 1 * 2^3 / 2 + 2 * 1 * 2\n", - "# = -16 + 4\n", - "# = -12\n", - "print(f'x.grad = {x.grad!r}')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Learning Rate Scheduler\n", - "\n", - "TorchOpt also provides implementation of learning rate scheduler, which can be used as:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "functional_adam = torchopt.adam(\n", - " lr=torchopt.schedule.linear_schedule(\n", - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", - " )\n", - ")\n", - "\n", - "adam = torchopt.Adam(\n", - " net.parameters(),\n", - " lr=torchopt.schedule.linear_schedule(\n", - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", - " ),\n", - ")\n", - "\n", - "meta_adam = torchopt.MetaAdam(\n", - " net,\n", - " lr=torchopt.schedule.linear_schedule(\n", - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Accelerated Optimizer\n", - "\n", - "Users can use accelerated optimizer by setting the `use_accelerated_op=True`. Currently we only support the Adam optimizer." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Check whether the `accelerated_op` is available:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] - } - ], - "source": [ - "torchopt.accelerated_op_available(torch.device('cpu'))" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] + "cells" : [ + {"cell_type" : "markdown", "metadata" : {}, "source" : ["# TorchOpt as Meta-Optimizer"]}, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["[](https://" + "colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/" + "3_Meta_Optimizer.ipynb)"] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["In this tutorial, we will show how to treat TorchOpt as a differentiable " + "optimizer with traditional PyTorch optimization API. In addition, we also " + "provide many other API for easy meta-learning algorithm implementations."] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "## 1. Basic API for Differentiable Optimizer\n", + "\n", + "`MetaOptimizer` is the main class for our differentiable optimizer. Combined with the " + "functional optimizer `torchopt.sgd` and `torchopt.adam` mentioned in the tutorial 1, we " + "can define our high-level API `torchopt.MetaSGD` and `torchopt.MetaAdam`. We will discuss " + "how this combination happens with `torchopt.chain` in Section 3. Let us consider the " + "problem below." + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "Assume a tensor $x$ is a meta-parameter and $a$ is a normal parameters (such as network " + "parameters). We have inner loss $\\mathcal{L}^{\\textrm{in}} = a_0 \\cdot x^2$ and we " + "update $a$ use the gradient $\\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} " + "= x^2$ and $a_1 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial " + "a_0} = a_0 - \\eta \\, x^2$. Then we compute the outer loss $\\mathcal{L}^{\\textrm{out}} " + "= a_1 \\cdot x^2$. So the gradient of outer loss to $x$ would be:\n", + "\n", + "$$\n", + "\\begin{split}\n", + " \\frac{\\partial \\mathcal{L}^{\\textrm{out}}}{\\partial x}\n", + " & = \\frac{\\partial (a_1 \\cdot x^2)}{\\partial x} \\\\\n", + " & = \\frac{\\partial a_1}{\\partial x} \\cdot x^2 + a_1 \\cdot \\frac{\\partial " + "(x^2)}{\\partial x} \\\\\n", + " & = \\frac{\\partial (a_0 - \\eta \\, x^2)}{\\partial x} \\cdot x^2 + (a_0 - \\eta " + "\\, x^2) \\cdot 2 x \\\\\n", + " & = (- \\eta \\cdot 2 x) \\cdot x^2 + (a_0 - \\eta \\, x^2) \\cdot 2 x \\\\\n", + " & = - 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x\n", + "\\end{split}\n", + "$$" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Given the analytical solution above. Let's try to verify it with TorchOpt. " + "Define the net work first."] + }, + { + "cell_type" : "code", + "execution_count" : 1, + "metadata" : {}, + "outputs" : [], + "source" : [ + "from IPython.display import display\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torchopt\n", + "\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", + "\n", + " def forward(self, x):\n", + " return self.a * (x**2)" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Then we declare the network (parameterized by `a`) and the meta-parameter `x`. " + "Do not forget to set flag `requires_grad=True` for `x`."] + }, + { + "cell_type" : "code", + "execution_count" : 2, + "metadata" : {}, + "outputs" : [], + "source" : [ "net = Net()\n", "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)" ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Next we declare the meta-optimizer. Here we show two equivalent ways of " + "defining the meta-optimizer. "] + }, + { + "cell_type" : "code", + "execution_count" : 3, + "metadata" : {}, + "outputs" : [], + "source" : [ + "# Low-level API\n", + "optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0))\n", + "\n", + "# High-level API\n", + "optim = torchopt.MetaSGD(net, lr=1.0)" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : + ["The meta-optimizer takes the network as input and use method `step` to update the " + "network (parameterized by `a`). Finally, we show how a bi-level process works."] + }, + { + "cell_type" : "code", + "execution_count" : 4, + "metadata" : {}, + "outputs" : + [ {"name" : "stdout", "output_type" : "stream", "text" : ["x.grad = tensor(-28.)\n"]} ], + "source" : [ + "inner_loss = net(x)\n", + "optim.step(inner_loss)\n", + "\n", + "outer_loss = net(x)\n", + "outer_loss.backward()\n", + "# x.grad = - 4 * lr * x^3 + 2 * a_0 * x\n", + "# = - 4 * 1 * 2^3 + 2 * 1 * 2\n", + "# = -32 + 4\n", + "# = -28\n", + "print(f'x.grad = {x.grad!r}')" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "### 1.1 Track the Gradient of Momentum\n", + "\n", + "Note that most modern optimizers involve moment term in the gradient update (basically " + "only SGD with `momentum=0` does not involve). We provide an option for user to choose " + "whether to also track the meta-gradient through moment term. The default option is " + "`moment_requires_grad=True`." + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["- When you do not track the meta-gradient through moment " + "(`moment_requires_grad=False`)"] + }, + { + "cell_type" : "code", + "execution_count" : 5, + "metadata" : {}, + "outputs" : [ + { + "name" : "stdout", + "output_type" : "stream", + "text" : ["\n"] + }, + { + "data" : { + "image/svg+xml" : + "\n\n\n\n\n\n%3\n\n\n\n140447553047184\n\nouter_loss\n()\n\n\n\n140447553041216\n\nMseLossBackward0\n\n\n\n140447553041216->140447553047184\n\n\n\n\n\n140447553042896\n\nMulBackward0\n\n\n\n140447553042896->140447553041216\n\n\n\n\n\n140447553019088\n\nAddBackward0\nstep1.a\n()\n\n\n\n140447553019088->140447553042896\n\n\n\n\n\n140447553041072\n\nAccumulateGrad\n\n\n\n140447553041072->140447553019088\n\n\n\n\n\n140447553043664\n\nMulBackward0\n\n\n\n140447553041072->140447553043664\n\n\n\n\n\n140447553045344\n\nstep0.a\n()\n\n\n\n140447553045344->140447553041072\n\n\n\n\n\n140447553041120\n\nMulBackward0\n\n\n\n140447553041120->140447553019088\n\n\n\n\n\n140447553043040\n\nDivBackward0\n\n\n\n140447553043040->140447553041120\n\n\n\n\n\n140447553043184\n\nDivBackward0\n\n\n\n140447553043184->140447553043040\n\n\n\n\n\n140447553043328\n\nAddBackward0\n\n\n\n140447553043328->140447553043184\n\n\n\n\n\n140447553043424\n\nMulBackward0\n\n\n\n140447553043424->140447553043328\n\n\n\n\n\n140447553043856\n\nAddcmulBackward0\n\n\n\n140447553043424->140447553043856\n\n\n\n\n\n140447553043424->140447553043856\n\n\n\n\n\n140447553043520\n\nMseLossBackwardBackward0\n\n\n\n140447553043520->140447553043424\n\n\n\n\n\n140447553043664->140447553043520\n\n\n\n\n\n140447553043472\n\nPowBackward0\n\n\n\n140447553043472->140447553043424\n\n\n\n\n\n140447553043472->140447553043664\n\n\n\n\n\n140447553043808\n\nAccumulateGrad\n\n\n\n140447553043808->140447553043472\n\n\n\n\n\n140447553041264\n\nPowBackward0\n\n\n\n140447553043808->140447553041264\n\n\n\n\n\n140447553045584\n\nx\n()\n\n\n\n140447553045584->140447553043808\n\n\n\n\n\n140447553043136\n\nAddBackward0\n\n\n\n140447553043136->140447553043040\n\n\n\n\n\n140447553043232\n\nSqrtBackward0\n\n\n\n140447553043232->140447553043136\n\n\n\n\n\n140447553043760\n\nAddBackward0\n\n\n\n140447553043760->140447553043232\n\n\n\n\n\n140447553043904\n\nDivBackward0\n\n\n\n140447553043904->140447553043760\n\n\n\n\n\n140447553043856->140447553043904\n\n\n\n\n\n140447553041264->140447553042896\n\n\n\n\n\n" + }, + "metadata" : {}, + "output_type" : "display_data" + } + ], + "source" : [ + "net = Net()\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", + "y = torch.tensor(1.0)\n", + "\n", + "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=False)\n", + "\n", + "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, " + "visual_prefix='step0.')\n", + "inner_loss = F.mse_loss(net(x), y)\n", + "optim.step(inner_loss)\n", + "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, " + "visual_prefix='step1.')\n", + "\n", + "outer_loss = F.mse_loss(net(x), y)\n", + "display(\n", + " torchopt.visual.make_dot(\n", + " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': " + "outer_loss}]\n", + " )\n", + ")" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["- When you track the meta-gradient through moment (`moment_requires_grad=True`, " + "default for `torchopt.MetaAdam`)"] + }, + { + "cell_type" : "code", + "execution_count" : 6, + "metadata" : {}, + "outputs" : [ + { + "name" : "stdout", + "output_type" : "stream", + "text" : ["\n"] + }, + { + "data" : { + "image/svg+xml" : + "\n\n\n\n\n\n%3\n\n\n\n140447553148704\n\nouter_loss\n()\n\n\n\n140447553041024\n\nMseLossBackward0\n\n\n\n140447553041024->140447553148704\n\n\n\n\n\n140447553043424\n\nMulBackward0\n\n\n\n140447553043424->140447553041024\n\n\n\n\n\n140450536407152\n\nAddBackward0\nstep1.a\n()\n\n\n\n140450536407152->140447553043424\n\n\n\n\n\n140447553041264\n\nAccumulateGrad\n\n\n\n140447553041264->140450536407152\n\n\n\n\n\n140447553019232\n\nMulBackward0\n\n\n\n140447553041264->140447553019232\n\n\n\n\n\n140447553148064\n\nstep0.a\n()\n\n\n\n140447553148064->140447553041264\n\n\n\n\n\n140447553041216\n\nMulBackward0\n\n\n\n140447553041216->140450536407152\n\n\n\n\n\n140447553041312\n\nDivBackward0\n\n\n\n140447553041312->140447553041216\n\n\n\n\n\n140447553041408\n\nDivBackward0\n\n\n\n140447553041408->140447553041312\n\n\n\n\n\n140447553043376\n\nAddBackward0\n\n\n\n140447553043376->140447553041408\n\n\n\n\n\n140447553041168\n\nMulBackward0\n\n\n\n140447553041168->140447553043376\n\n\n\n\n\n140447553042272\n\nAccumulateGrad\n\n\n\n140447553042272->140447553041168\n\n\n\n\n\n140450290826352\n\n()\n\n\n\n140450290826352->140447553042272\n\n\n\n\n\n140447553044432\n\nMulBackward0\n\n\n\n140447553044432->140447553043376\n\n\n\n\n\n140447553018320\n\nAddcmulBackward0\n\n\n\n140447553044432->140447553018320\n\n\n\n\n\n140447553044432->140447553018320\n\n\n\n\n\n140447553042080\n\nMseLossBackwardBackward0\n\n\n\n140447553042080->140447553044432\n\n\n\n\n\n140447553019232->140447553042080\n\n\n\n\n\n140447553019088\n\nPowBackward0\n\n\n\n140447553019088->140447553044432\n\n\n\n\n\n140447553019088->140447553019232\n\n\n\n\n\n140447553018464\n\nAccumulateGrad\n\n\n\n140447553018464->140447553019088\n\n\n\n\n\n140447553043328\n\nPowBackward0\n\n\n\n140447553018464->140447553043328\n\n\n\n\n\n140447553148144\n\nx\n()\n\n\n\n140447553148144->140447553018464\n\n\n\n\n\n140447553041456\n\nAddBackward0\n\n\n\n140447553041456->140447553041312\n\n\n\n\n\n140447553041360\n\nSqrtBackward0\n\n\n\n140447553041360->140447553041456\n\n\n\n\n\n140447553015920\n\nAddBackward0\n\n\n\n140447553015920->140447553041360\n\n\n\n\n\n140447553018560\n\nDivBackward0\n\n\n\n140447553018560->140447553015920\n\n\n\n\n\n140447553018320->140447553018560\n\n\n\n\n\n140447553018272\n\nMulBackward0\n\n\n\n140447553018272->140447553018320\n\n\n\n\n\n140447553018944\n\nAccumulateGrad\n\n\n\n140447553018944->140447553018272\n\n\n\n\n\n140450290824272\n\n()\n\n\n\n140450290824272->140447553018944\n\n\n\n\n\n140447553043328->140447553043424\n\n\n\n\n\n" + }, + "metadata" : {}, + "output_type" : "display_data" + } + ], + "source" : [ + "net = Net()\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", + "y = torch.tensor(1.0)\n", + "\n", + "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True)\n", + "\n", + "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, " + "visual_prefix='step0.')\n", + "inner_loss = F.mse_loss(net(x), y)\n", + "optim.step(inner_loss)\n", + "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, " + "visual_prefix='step1.')\n", + "\n", + "outer_loss = F.mse_loss(net(x), y)\n", + "display(\n", + " torchopt.visual.make_dot(\n", + " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': " + "outer_loss}]\n", + " )\n", + ")" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["We can see that the additional moment terms are added into the computational " + "graph when we set `moment_requires_grad=True`."] + }, + {"cell_type" : "markdown", "metadata" : {}, "source" : ["## 2. Extract and Recover"]}, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "### 2.1 Basic API\n", + "\n", + "We observe that how to reinitialize the inner-loop parameter in a new bi-level process " + "vary in different meta-learning algorithms. For instance, in algorithm like " + "Model-Agnostic Meta-Learning (MAML) " + "([arXiv:1703.03400](https://arxiv.org/abs/1703.03400)), every time a new task comes, we " + "need to reset the parameters to the initial ones. In other cases such as Meta-Gradient " + "Reinforcement Learning (MGRL) ([arXiv:1805.09801](https://arxiv.org/abs/1805.09801)), the " + "inner-loop network parameter just inherit previous updated parameter to continue the new " + "bi-level process.\n", + "\n", + "We provide the `torchopt.extract_state_dict` and `torchopt.recover_state_dict` functions " + "to extract and restore the state of network and optimizer. By default, the extracted " + "state dictionary is a reference (this design is for accumulating gradient of multi-task " + "batch training, MAML for example). You can also set `by='copy'` to extract the copy of " + "the state dictionary or set `by='deepcopy'` to have a detached copy." + ] + }, + { + "cell_type" : "code", + "execution_count" : 7, + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : [ + "a = tensor(-1.0000, grad_fn=)\n", + "a = tensor(-1.0000, grad_fn=)\n" + ] + } ], + "source" : [ + "net = Net()\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", + "\n", + "optim = torchopt.MetaAdam(net, lr=1.0)\n", + "\n", + "# Get the reference of state dictionary\n", + "init_net_state = torchopt.extract_state_dict(net, by='reference')\n", + "init_optim_state = torchopt.extract_state_dict(optim, by='reference')\n", + "# If set `detach_buffers=True`, the parameters are referenced as references while buffers " + "are detached copies\n", + "init_net_state = torchopt.extract_state_dict(net, by='reference', detach_buffers=True)\n", + "\n", + "# Set `copy` to get the copy of the state dictionary\n", + "init_net_state_copy = torchopt.extract_state_dict(net, by='copy')\n", + "init_optim_state_copy = torchopt.extract_state_dict(optim, by='copy')\n", + "\n", + "# Set `deepcopy` to get the detached copy of state dictionary\n", + "init_net_state_deepcopy = torchopt.extract_state_dict(net, by='deepcopy')\n", + "init_optim_state_deepcopy = torchopt.extract_state_dict(optim, by='deepcopy')\n", + "\n", + "# Conduct 2 inner-loop optimization\n", + "for i in range(2):\n", + " inner_loss = net(x)\n", + " optim.step(inner_loss)\n", + "\n", + "print(f'a = {net.a!r}')\n", + "\n", + "# Recover and reconduct 2 inner-loop optimization\n", + "torchopt.recover_state_dict(net, init_net_state)\n", + "torchopt.recover_state_dict(optim, init_optim_state)\n", + "\n", + "for i in range(2):\n", + " inner_loss = net(x)\n", + " optim.step(inner_loss)\n", + "\n", + "print(f'a = {net.a!r}') # the same result" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "### 2.2 Multi-task Example with `extract_state_dict` and `recover_state_dict`\n", + "\n", + "Let's move to another more complex setting. Meta-Learning algorithms always fix network " + "on several different tasks and accumulate outer loss of each task to the meta-gradient.\n", + "\n", + "Assume $x$ is a meta-parameter and $a$ is a normal parameter. We firstly update $a$ use " + "inner loss $\\mathcal{L}_1^{\\textrm{in}} = a_0 \\cdot x^2$ to $a_1$. Then we use $a_1$ " + "to compute the outer loss $\\mathcal{L}_1^{\\textrm{out}} = a_1 \\cdot x^2$ and " + "backpropagate it. Then we use $a_0$ to compute the inner loss " + "$\\mathcal{L}_2^{\\textrm{in}} = a_0 \\cdot x$ and update $a_0$ to $a_2 = a_0 - \\eta \\, " + "\\frac{\\partial \\mathcal{L}_2^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x$. Then " + "we compute outer loss $\\mathcal{L}_2^{\\textrm{out}} = a_2 \\cdot x$ and backpropagate " + "it. So the accumulated meta-gradient would be:\n", + "\n", + "$$\n", + "\\begin{split}\n", + " \\frac{\\partial \\mathcal{L}_1^{\\textrm{out}}}{\\partial x} + \\frac{\\partial " + "\\mathcal{L}_2^{\\textrm{out}}}{\\partial x}\n", + " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + \\frac{\\partial (a_2 \\cdot " + "x)}{\\partial x} \\\\\n", + " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + (\\frac{\\partial a_2}{\\partial x} " + "\\cdot x + a_2) \\\\\n", + " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + [\\frac{\\partial (a_0 - \\eta \\, " + "x)}{\\partial x} \\cdot x + (a_0 - \\eta \\, x)] \\\\\n", + " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + [(- \\eta) \\cdot x + (a_0 - \\eta " + "\\, x)] \\\\\n", + " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + (- 2 \\, \\eta \\, x + a_0)\n", + "\\end{split}\n", + "$$" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Let's define the network and variables first."] + }, + { + "cell_type" : "code", + "execution_count" : 8, + "metadata" : {}, + "outputs" : [], + "source" : [ + "class Net2Tasks(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", + "\n", + " def task1(self, x):\n", + " return self.a * x**2\n", + "\n", + " def task2(self, x):\n", + " return self.a * x\n", + "\n", + "\n", + "net = Net2Tasks()\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", + "\n", + "optim = torchopt.MetaSGD(net, lr=1.0)" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Once we call `step` method of `MetaOptimizer`, the parameters of the network " + "would be changed. We should use `torchopt.extract_state_dict` to extract state " + "and use `torchopt.recover_state_dict` to recover the state. Note that if we use " + "optimizers that have momentum buffers, we should also extract and recover them, " + "vanilla SGD does not have momentum buffers so code `init_optim_state = " + "torchopt.extract_state_dict(optim)` and `torchopt.recover_state_dict(optim, " + "init_optim_state)` have no effect."] + }, + { + "cell_type" : "code", + "execution_count" : 9, + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : [ + "init_optim_state = ((EmptyState(),),)\n", + "Task 1: x.grad = tensor(-28.)\n", + "Accumulated: x.grad = tensor(-31.)\n" + ] + } ], + "source" : [ + "# Get the reference of state dictionary\n", + "init_net_state = torchopt.extract_state_dict(net, by='reference')\n", + "init_optim_state = torchopt.extract_state_dict(optim, by='reference')\n", + "# The `state_dict` is empty for vanilla SGD optimizer\n", + "print(f'init_optim_state = {init_optim_state!r}')\n", + "\n", + "inner_loss_1 = net.task1(x)\n", + "optim.step(inner_loss_1)\n", + "outer_loss_1 = net.task1(x)\n", + "outer_loss_1.backward()\n", + "print(f'Task 1: x.grad = {x.grad!r}')\n", + "\n", + "torchopt.recover_state_dict(net, init_net_state)\n", + "torchopt.recover_state_dict(optim, init_optim_state)\n", + "inner_loss_2 = net.task2(x)\n", + "optim.step(inner_loss_2)\n", + "outer_loss_2 = net.task2(x)\n", + "outer_loss_2.backward()\n", + "\n", + "# `extract_state_dict`` extracts the reference so gradient accumulates\n", + "# x.grad = (- 4 * lr * x^3 + 2 * a_0 * x) + (- 2 * lr * x + a_0)\n", + "# = (- 4 * 1 * 2^3 + 2 * 1 * 2) + (- 2 * 1 * 2 + 1)\n", + "# = -28 - 3\n", + "# = -31\n", + "print(f'Accumulated: x.grad = {x.grad!r}')" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["## 3. Gradient Transformation in `MetaOptimizer`"] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "We can also use some gradient normalization tricks in our `MetaOptimizer`. In fact " + "`MetaOptimizer` decedents like `MetaSGD` are specializations of `MetaOptimizer`. " + "Specifically, `MetaSGD(net, lr=1.)` is `MetaOptimizer(net, alias.sgd(lr=1., " + "moment_requires_grad=True))`, where flag `moment_requires_grad=True` means the momentums " + "are created with flag `requires_grad=True` so the momentums will also be the part of the " + "computation graph.\n", + "\n", + "In the designing of TorchOpt, we treat these functions as derivations of `combine.chain`. " + "So we can build our own chain like `combine.chain(clip.clip_grad_norm(max_norm=1.), " + "sgd(lr=1., requires_grad=True))` to clip the gradient and update parameters using " + "`sgd`.\n", + "\n", + "$$\n", + "\\begin{aligned}\n", + " \\frac{\\partial \\mathcal{L}^{\\textrm{out}}}{\\partial x}\n", + " & = \\frac{\\partial (a_1 \\cdot x^2)}{\\partial x} \\\\\n", + " & = \\frac{\\partial a_1}{\\partial x} \\cdot x^2 + a_1 \\cdot \\frac{\\partial " + "(x^2)}{\\partial x} \\\\\n", + " & = \\frac{\\partial (a_0 - \\eta \\, g)}{\\partial x} \\cdot x^2 + (a_0 - \\eta \\, " + "g) \\cdot 2 x & \\qquad (g \\propto \\frac{\\partial " + "\\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = x^2, \\ {\\lVert g \\rVert}_2 \\le " + "G_{\\max}) \\\\\n", + " & = \\frac{\\partial (a_0 - \\eta \\, \\beta^{-1} \\, x^2)}{\\partial x} \\cdot x^2 + " + "(a_0 - \\eta \\, \\beta^{-1} \\, x^2) \\cdot 2 x & \\qquad (g = \\beta^{-1} \\, x^2, \\ " + " \\beta > 0, \\ {\\lVert g \\rVert}_2 \\le G_{\\max}) \\\\\n", + " & = (- \\beta^{-1} \\, \\eta \\cdot 2 x) \\cdot x^2 + (a_0 - \\beta^{-1} \\, \\eta " + "\\, x^2) \\cdot 2 x \\\\\n", + " & = - 4 \\, \\beta^{-1} \\, \\eta \\, x^3 + 2 \\, a_0 \\, x\n", + "\\end{aligned}\n", + "$$" + ] + }, + { + "cell_type" : "code", + "execution_count" : 10, + "metadata" : {}, + "outputs" : [ + {"name" : "stdout", "output_type" : "stream", "text" : ["x.grad = tensor(-12.0000)\n"]} + ], + "source" : [ + "net = Net()\n", + "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", + "\n", + "optim_impl = torchopt.combine.chain(\n", + " torchopt.clip.clip_grad_norm(max_norm=2.0),\n", + " torchopt.sgd(lr=1.0, moment_requires_grad=True),\n", + ")\n", + "optim = torchopt.MetaOptimizer(net, optim_impl)\n", + "\n", + "inner_loss = net(x)\n", + "optim.step(inner_loss)\n", + "\n", + "outer_loss = net(x)\n", + "outer_loss.backward()\n", + "# Since `max_norm` is 2 and the gradient is x^2, so the scale = x^2 / 2 = 2^2 / 2 = 2\n", + "# x.grad = - 4 * lr * x^3 / scale + 2 * a_0 * x\n", + "# = - 4 * 1 * 2^3 / 2 + 2 * 1 * 2\n", + "# = -16 + 4\n", + "# = -12\n", + "print(f'x.grad = {x.grad!r}')" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "## 4. Learning Rate Scheduler\n", + "\n", + "TorchOpt also provides implementation of learning rate scheduler, which can be used as:" + ] + }, + { + "cell_type" : "code", + "execution_count" : 11, + "metadata" : {}, + "outputs" : [], + "source" : [ + "functional_adam = torchopt.adam(\n", + " lr=torchopt.schedule.linear_schedule(\n", + " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " )\n", + ")\n", + "\n", + "adam = torchopt.Adam(\n", + " net.parameters(),\n", + " lr=torchopt.schedule.linear_schedule(\n", + " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " ),\n", + ")\n", + "\n", + "meta_adam = torchopt.MetaAdam(\n", + " net,\n", + " lr=torchopt.schedule.linear_schedule(\n", + " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " ),\n", + ")" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "## 5. Accelerated Optimizer\n", + "\n", + "Users can use accelerated optimizer by setting the `use_accelerated_op=True`. Currently " + "we only support the Adam optimizer." + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Check whether the `accelerated_op` is available:"] + }, + { + "cell_type" : "code", + "execution_count" : 12, + "metadata" : {}, + "outputs" : [ {"name" : "stdout", "output_type" : "stream", "text" : ["True\n"]} ], + "source" : ["torchopt.accelerated_op_available(torch.device('cpu'))"] + }, + { + "cell_type" : "code", + "execution_count" : 13, + "metadata" : {}, + "outputs" : [ {"name" : "stdout", "output_type" : "stream", "text" : ["True\n"]} ], + "source" : ["torchopt.accelerated_op_available(torch.device('cuda'))"] + }, + { + "cell_type" : "code", + "execution_count" : 14, + "metadata" : {}, + "outputs" : [ + { + "name" : "stdout", + "output_type" : "stream", + "text" : ["\n"] + }, + { + "data" : { + "image/svg+xml" : + "\n\n\n\n\n\n%3\n\n\n\n140450290825712\n\nouter_loss\n()\n\n\n\n140450533650240\n\nMseLossBackward0\n\n\n\n140450533650240->140450290825712\n\n\n\n\n\n140450533648560\n\nMulBackward0\n\n\n\n140450533648560->140450533650240\n\n\n\n\n\n140450533647456\n\nAddBackward0\nstep1.a\n()\n\n\n\n140450533647456->140450533648560\n\n\n\n\n\n140447435136640\n\nAccumulateGrad\n\n\n\n140447435136640->140450533647456\n\n\n\n\n\n140450533648416\n\nMulBackward0\n\n\n\n140447435136640->140450533648416\n\n\n\n\n\n140447435236512\n\nstep0.a\n()\n\n\n\n140447435236512->140447435136640\n\n\n\n\n\n140447435136688\n\nMulBackward0\n\n\n\n140447435136688->140450533647456\n\n\n\n\n\n140447554132144\n\nUpdatesOpBackward\n\n\n\n140447554132144->140447435136688\n\n\n\n\n\n140447554131664\n\nMuOpBackward\n\n\n\n140447554131664->140447554132144\n\n\n\n\n\n140447435134816\n\nMulBackward0\n\n\n\n140447435134816->140447554131664\n\n\n\n\n\n140447554131904\n\nNuOpBackward\n\n\n\n140447435134816->140447554131904\n\n\n\n\n\n140450533648992\n\nMseLossBackwardBackward0\n\n\n\n140450533648992->140447435134816\n\n\n\n\n\n140450533648416->140450533648992\n\n\n\n\n\n140450533646448\n\nPowBackward0\n\n\n\n140450533646448->140447435134816\n\n\n\n\n\n140450533646448->140450533648416\n\n\n\n\n\n140447553018176\n\nAccumulateGrad\n\n\n\n140447553018176->140450533646448\n\n\n\n\n\n140447435135536\n\nPowBackward0\n\n\n\n140447553018176->140447435135536\n\n\n\n\n\n140447553045424\n\nx\n()\n\n\n\n140447553045424->140447553018176\n\n\n\n\n\n140447435136592\n\nAccumulateGrad\n\n\n\n140447435136592->140447554131664\n\n\n\n\n\n140447552973856\n\n()\n\n\n\n140447552973856->140447554131664\n\n\n\n\n\n140447552973856->140447435136592\n\n\n\n\n\n140447553044544\n\n()\n\n\n\n140447553044544->140447554131664\n\n\n\n\n\n140447553044544->140447554131904\n\n\n\n\n\n140447554131904->140447554132144\n\n\n\n\n\n140450533648896\n\nAccumulateGrad\n\n\n\n140450533648896->140447554131904\n\n\n\n\n\n140447435236752\n\n()\n\n\n\n140447435236752->140447554131904\n\n\n\n\n\n140447435236752->140450533648896\n\n\n\n\n\n140447553045904\n\n()\n\n\n\n140447553045904->140447554132144\n\n\n\n\n\n140447435237152\n\n()\n\n\n\n140447435237152->140447554132144\n\n\n\n\n\n140447435237232\n\n()\n\n\n\n140447435237232->140447554132144\n\n\n\n\n\n140447435135536->140450533648560\n\n\n\n\n\n" + }, + "metadata" : {}, + "output_type" : "display_data" + } + ], + "source" : [ + "net = Net().to(device='cuda')\n", + "x = nn.Parameter(torch.tensor(2.0, device=torch.device('cuda')), requires_grad=True)\n", + "y = torch.tensor(1.0, device=torch.device('cuda'))\n", + "\n", + "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True, " + "use_accelerated_op=True)\n", + "\n", + "net_state_0 = torchopt.extract_state_dict(\n", + " net, by='reference', enable_visual=True, visual_prefix='step0.'\n", + ")\n", + "inner_loss = F.mse_loss(net(x), y)\n", + "optim.step(inner_loss)\n", + "net_state_1 = torchopt.extract_state_dict(\n", + " net, by='reference', enable_visual=True, visual_prefix='step1.'\n", + ")\n", + "\n", + "outer_loss = F.mse_loss(net(x), y)\n", + "display(\n", + " torchopt.visual.make_dot(\n", + " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': " + "outer_loss}]\n", + " )\n", + ")" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "## 6. Known Issues\n", + "\n", + "Here we record some common issues faced by users when using the meta-optimizer." + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "**1. Get `NaN` error when using `MetaAdam` or other meta-optimizers.**\n", + "\n", + "The `NaN` error is because of the numerical instability of the `Adam` in meta-learning. " + "There exist an `sqrt` operation in `Adam`'s computation process. Backpropogating through " + "the `Adam` operator introduces the second derivation of the `sqrt` operation, which is " + "not numerical stable, i.e. ${\\left. \\frac{d^2 \\sqrt{x}}{{dx}^2} \\right\\rvert}_{x = " + "0} = \\texttt{NaN}$. You can also refer to issue " + "[facebookresearch/higher#125](https://github.com/facebookresearch/higher/issues/125).\n", + "\n", + "For this problem, TorchOpt have two recommended solutions.\n", + "\n", + "* Put the `sqrt` operation into the whole equation, and compute the derivation of the " + "output to the input manually. The second derivation of the `sqrt` operation will be " + "eliminated. You can achieve this by setting the flag `use_accelerated_op=True`, you can " + "follow the instructions in notebook [Functional Optimizer](1_Functional_Optimizer.ipynb) " + "and Meta-Optimizer." + ] + }, + { + "cell_type" : "code", + "execution_count" : 15, + "metadata" : {}, + "outputs" : [], + "source" : ["inner_optim = torchopt.MetaAdam(net, lr=1.0, use_accelerated_op=True)"] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["* Register hook to the first-order gradients. During the backpropagation, the " + "NaN gradients will be set to 0, which will have a similar effect to the first " + "solution but much slower. "] + }, + { + "cell_type" : "code", + "execution_count" : 16, + "metadata" : {}, + "outputs" : [], + "source" : [ + "impl = torchopt.chain(torchopt.hook.register_hook(torchopt.hook.zero_nan_hook), " + "torchopt.adam(1e-1))\n", + "inner_optim = torchopt.MetaOptimizer(net, impl)" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "**2. Get `Trying to backward through the graph a second time` error when conducting " + "multiple meta-optimization.**\n", + "\n", + "Please refer to the tutorial notebook [Stop Gradient](4_Stop_Gradient.ipynb) for more " + "guidance." + ] } - ], - "source": [ - "torchopt.accelerated_op_available(torch.device('cuda'))" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140450290825712\n\nouter_loss\n()\n\n\n\n140450533650240\n\nMseLossBackward0\n\n\n\n140450533650240->140450290825712\n\n\n\n\n\n140450533648560\n\nMulBackward0\n\n\n\n140450533648560->140450533650240\n\n\n\n\n\n140450533647456\n\nAddBackward0\nstep1.a\n()\n\n\n\n140450533647456->140450533648560\n\n\n\n\n\n140447435136640\n\nAccumulateGrad\n\n\n\n140447435136640->140450533647456\n\n\n\n\n\n140450533648416\n\nMulBackward0\n\n\n\n140447435136640->140450533648416\n\n\n\n\n\n140447435236512\n\nstep0.a\n()\n\n\n\n140447435236512->140447435136640\n\n\n\n\n\n140447435136688\n\nMulBackward0\n\n\n\n140447435136688->140450533647456\n\n\n\n\n\n140447554132144\n\nUpdatesOpBackward\n\n\n\n140447554132144->140447435136688\n\n\n\n\n\n140447554131664\n\nMuOpBackward\n\n\n\n140447554131664->140447554132144\n\n\n\n\n\n140447435134816\n\nMulBackward0\n\n\n\n140447435134816->140447554131664\n\n\n\n\n\n140447554131904\n\nNuOpBackward\n\n\n\n140447435134816->140447554131904\n\n\n\n\n\n140450533648992\n\nMseLossBackwardBackward0\n\n\n\n140450533648992->140447435134816\n\n\n\n\n\n140450533648416->140450533648992\n\n\n\n\n\n140450533646448\n\nPowBackward0\n\n\n\n140450533646448->140447435134816\n\n\n\n\n\n140450533646448->140450533648416\n\n\n\n\n\n140447553018176\n\nAccumulateGrad\n\n\n\n140447553018176->140450533646448\n\n\n\n\n\n140447435135536\n\nPowBackward0\n\n\n\n140447553018176->140447435135536\n\n\n\n\n\n140447553045424\n\nx\n()\n\n\n\n140447553045424->140447553018176\n\n\n\n\n\n140447435136592\n\nAccumulateGrad\n\n\n\n140447435136592->140447554131664\n\n\n\n\n\n140447552973856\n\n()\n\n\n\n140447552973856->140447554131664\n\n\n\n\n\n140447552973856->140447435136592\n\n\n\n\n\n140447553044544\n\n()\n\n\n\n140447553044544->140447554131664\n\n\n\n\n\n140447553044544->140447554131904\n\n\n\n\n\n140447554131904->140447554132144\n\n\n\n\n\n140450533648896\n\nAccumulateGrad\n\n\n\n140450533648896->140447554131904\n\n\n\n\n\n140447435236752\n\n()\n\n\n\n140447435236752->140447554131904\n\n\n\n\n\n140447435236752->140450533648896\n\n\n\n\n\n140447553045904\n\n()\n\n\n\n140447553045904->140447554132144\n\n\n\n\n\n140447435237152\n\n()\n\n\n\n140447435237152->140447554132144\n\n\n\n\n\n140447435237232\n\n()\n\n\n\n140447435237232->140447554132144\n\n\n\n\n\n140447435135536->140450533648560\n\n\n\n\n\n" - }, - "metadata": {}, - "output_type": "display_data" + ], + "metadata" : { + "kernelspec" : + {"display_name" : "Python 3 (ipykernel)", "language" : "python", "name" : "python3"}, + "language_info" : { + "codemirror_mode" : {"name" : "ipython", "version" : 3}, + "file_extension" : ".py", + "mimetype" : "text/x-python", + "name" : "python", + "nbconvert_exporter" : "python", + "pygments_lexer" : "ipython3", + "version" : "3.9.15" + }, + "vscode" : { + "interpreter" : + {"hash" : "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da"} } - ], - "source": [ - "net = Net().to(device='cuda')\n", - "x = nn.Parameter(torch.tensor(2.0, device=torch.device('cuda')), requires_grad=True)\n", - "y = torch.tensor(1.0, device=torch.device('cuda'))\n", - "\n", - "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True, use_accelerated_op=True)\n", - "\n", - "net_state_0 = torchopt.extract_state_dict(\n", - " net, by='reference', enable_visual=True, visual_prefix='step0.'\n", - ")\n", - "inner_loss = F.mse_loss(net(x), y)\n", - "optim.step(inner_loss)\n", - "net_state_1 = torchopt.extract_state_dict(\n", - " net, by='reference', enable_visual=True, visual_prefix='step1.'\n", - ")\n", - "\n", - "outer_loss = F.mse_loss(net(x), y)\n", - "display(\n", - " torchopt.visual.make_dot(\n", - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Known Issues\n", - "\n", - "Here we record some common issues faced by users when using the meta-optimizer." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**1. Get `NaN` error when using `MetaAdam` or other meta-optimizers.**\n", - "\n", - "The `NaN` error is because of the numerical instability of the `Adam` in meta-learning. There exist an `sqrt` operation in `Adam`'s computation process. Backpropogating through the `Adam` operator introduces the second derivation of the `sqrt` operation, which is not numerical stable, i.e. ${\\left. \\frac{d^2 \\sqrt{x}}{{dx}^2} \\right\\rvert}_{x = 0} = \\texttt{NaN}$. You can also refer to issue [facebookresearch/higher#125](https://github.com/facebookresearch/higher/issues/125).\n", - "\n", - "For this problem, TorchOpt have two recommended solutions.\n", - "\n", - "* Put the `sqrt` operation into the whole equation, and compute the derivation of the output to the input manually. The second derivation of the `sqrt` operation will be eliminated. You can achieve this by setting the flag `use_accelerated_op=True`, you can follow the instructions in notebook [Functional Optimizer](1_Functional_Optimizer.ipynb) and Meta-Optimizer." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "inner_optim = torchopt.MetaAdam(net, lr=1.0, use_accelerated_op=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "* Register hook to the first-order gradients. During the backpropagation, the NaN gradients will be set to 0, which will have a similar effect to the first solution but much slower. " - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "impl = torchopt.chain(torchopt.hook.register_hook(torchopt.hook.zero_nan_hook), torchopt.adam(1e-1))\n", - "inner_optim = torchopt.MetaOptimizer(net, impl)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**2. Get `Trying to backward through the graph a second time` error when conducting multiple meta-optimization.**\n", - "\n", - "Please refer to the tutorial notebook [Stop Gradient](4_Stop_Gradient.ipynb) for more guidance." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" }, - "vscode": { - "interpreter": { - "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat" : 4, + "nbformat_minor" : 4 } diff --git a/tutorials/4_Stop_Gradient.ipynb b/tutorials/4_Stop_Gradient.ipynb index d8c24bc6..e497a9d1 100644 --- a/tutorials/4_Stop_Gradient.ipynb +++ b/tutorials/4_Stop_Gradient.ipynb @@ -1,521 +1,1888 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# `torchopt.stop_gradient` in Meta-Learning" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/4_Stop_Gradient.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this tutorial, we will illustrate the usage of `torchopt.stop_gradient` with a meta-learning example. We use `torchopt.visual` to help us visualize what is going on in automatic differentiation. Firstly, we define a simple network and the objective function for inner- and outer- optimization." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from IPython.display import display\n", - "\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "import torchopt\n", - "\n", - "\n", - "class Net(nn.Module):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, 1, bias=True)\n", - "\n", - " def forward(self, x):\n", - " return self.fc(x)\n", - "\n", - "\n", - "loss_fn = F.mse_loss" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We define the input `x` and output `y`. `y` will be served as the regression target in the following code." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "batch_size = 64\n", - "dim = 16\n", - "\n", - "x = torch.randn((batch_size, dim))\n", - "y = torch.zeros((batch_size, 1))\n", - "net = Net(dim)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let us define the meta-parameter, we use `MetaSGD` as the inner-loop optimizer and `Adam` as the outer-loop optimizer. " - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "meta_parameter = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", - "\n", - "optim = torchopt.MetaSGD(net, lr=1e-1)\n", - "meta_optim = torch.optim.Adam([meta_parameter], lr=1e-1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Define the inner-loop optimization and visualize the inner-loop forward gradient flow." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ + "cells" : [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "inner loss: 0.3472\n", - "\n" - ] + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["# `torchopt.stop_gradient` in Meta-Learning"] }, { - "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140025091550880\n\ninner_loss\n()\n\n\n\n140028156253184\n\nMseLossBackward0\n\n\n\n140028156253184->140025091550880\n\n\n\n\n\n140028156436736\n\nAddmmBackward0\n\n\n\n140028156436736->140028156253184\n\n\n\n\n\n140025091526416\n\nAccumulateGrad\n\n\n\n140025091526416->140028156436736\n\n\n\n\n\n140028155952000\n\nstep0.fc.bias\n(1)\n\n\n\n140028155952000->140025091526416\n\n\n\n\n\n140025091525408\n\nTBackward0\n\n\n\n140025091525408->140028156436736\n\n\n\n\n\n140025091526224\n\nAccumulateGrad\n\n\n\n140025091526224->140025091525408\n\n\n\n\n\n140028155952880\n\nstep0.fc.weight\n(1, 16)\n\n\n\n140028155952880->140025091526224\n\n\n\n\n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "init_net_state = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", - "\n", - "# inner loss\n", - "inner_loss = loss_fn(net(x), y)\n", - "\n", - "print(f'inner loss: {inner_loss:.4f}')\n", - "display(torchopt.visual.make_dot(inner_loss, params=(init_net_state, {'inner_loss': inner_loss})))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Conduct inner-loop optimization with `MetaSGD`, here the meta-parameter is served as a factor controlling the scale of inner-loop loss." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "# inner-step optimization\n", - "loss = inner_loss * meta_parameter\n", - "optim.step(loss)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We compute the outer loss and draw the full computation graph of the first bi-level process. In this graph, three main parts are included.\n", - "\n", - "- Inner-loop: forward process and inner-loss calculation\n", - "- Inner-loop optimization: `MetaSGD` optimization step given inner-loss\n", - "- Outer-loop: forward process and outer-loss calculation" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["[](https://" + "colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/" + "4_Stop_Gradient.ipynb)"] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "outer loss: 0.2039\n", - "\n" - ] + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["In this tutorial, we will illustrate the usage of `torchopt.stop_gradient` with " + "a meta-learning example. We use `torchopt.visual` to help us visualize what is " + "going on in automatic differentiation. Firstly, we define a simple network and " + "the objective function for inner- and outer- optimization."] }, { - "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140027829238416\n\nouter_loss\n()\n\n\n\n140025091525072\n\nMseLossBackward0\n\n\n\n140025091525072->140027829238416\n\n\n\n\n\n140025091525216\n\nAddmmBackward0\n\n\n\n140025091525216->140025091525072\n\n\n\n\n\n140025091526128\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140025091526128->140025091525216\n\n\n\n\n\n140025091526416\n\nAccumulateGrad\n\n\n\n140025091526416->140025091526128\n\n\n\n\n\n140028156436736\n\nAddmmBackward0\n\n\n\n140025091526416->140028156436736\n\n\n\n\n\n140028155952000\n\nstep0.fc.bias\n(1)\n\n\n\n140028155952000->140025091526416\n\n\n\n\n\n140025091524976\n\nMulBackward0\n\n\n\n140025091524976->140025091526128\n\n\n\n\n\n140025091526560\n\nViewBackward0\n\n\n\n140025091526560->140025091524976\n\n\n\n\n\n140025091525456\n\nSumBackward1\n\n\n\n140025091525456->140025091526560\n\n\n\n\n\n140025091524112\n\nMseLossBackwardBackward0\n\n\n\n140025091524112->140025091525456\n\n\n\n\n\n140024973742672\n\nTBackward0\n\n\n\n140025091524112->140024973742672\n\n\n\n\n\n140024973742288\n\nMulBackward0\n\n\n\n140024973742288->140025091524112\n\n\n\n\n\n140024973742384\n\nAccumulateGrad\n\n\n\n140024973742384->140024973742288\n\n\n\n\n\n140025091549440\n\nmeta_parameter\n()\n\n\n\n140025091549440->140024973742384\n\n\n\n\n\n140028156436736->140025091524112\n\n\n\n\n\n140025091525408\n\nTBackward0\n\n\n\n140025091525408->140028156436736\n\n\n\n\n\n140025091526224\n\nAccumulateGrad\n\n\n\n140025091526224->140025091525408\n\n\n\n\n\n140025091524928\n\nAddBackward0\nstep1.fc.weight\n(1, 16)\n\n\n\n140025091526224->140025091524928\n\n\n\n\n\n140028155952880\n\nstep0.fc.weight\n(1, 16)\n\n\n\n140028155952880->140025091526224\n\n\n\n\n\n140025091524448\n\nTBackward0\n\n\n\n140025091524448->140025091525216\n\n\n\n\n\n140025091524928->140025091524448\n\n\n\n\n\n140025091525600\n\nMulBackward0\n\n\n\n140025091525600->140025091524928\n\n\n\n\n\n140024973742144\n\nTBackward0\n\n\n\n140024973742144->140025091525600\n\n\n\n\n\n140024973742576\n\nTBackward0\n\n\n\n140024973742576->140024973742144\n\n\n\n\n\n140024973742480\n\nMmBackward0\n\n\n\n140024973742480->140024973742576\n\n\n\n\n\n140024973742672->140024973742480\n\n\n\n\n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Extract `state_dict`` for updated network\n", - "one_step_net_state = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", - "one_step_optim_state = torchopt.extract_state_dict(optim)\n", - "\n", - "# Calculate outer loss\n", - "outer_loss = loss_fn(net(x), y)\n", - "print(f'outer loss: {outer_loss:.4f}')\n", - "display(\n", - " torchopt.visual.make_dot(\n", - " outer_loss,\n", - " params=(\n", - " init_net_state,\n", - " one_step_net_state,\n", - " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", - " ),\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then we backward the loss to conduct outer-loop meta-optimization." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ + "cell_type" : "code", + "execution_count" : 1, + "metadata" : {}, + "outputs" : [], + "source" : [ + "from IPython.display import display\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torchopt\n", + "\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1, bias=True)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "loss_fn = F.mse_loss" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "meta_parameter.grad = tensor(-0.1205)\n", - "meta_parameter = Parameter containing:\n", - "tensor(1.1000, requires_grad=True)\n" - ] - } - ], - "source": [ - "meta_optim.zero_grad()\n", - "outer_loss.backward()\n", - "print(f'meta_parameter.grad = {meta_parameter.grad!r}')\n", - "meta_optim.step()\n", - "print(f'meta_parameter = {meta_parameter!r}')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We have already conducted one bi-level optimization and optimize our meta-parameters. When you want to conduct the second bi-level optimization, you need to be careful whether you need to use the `stop_gradient` function. For example, if your new inner-loop parameters directly inherits previous inner-loop parameters (which is a common strategy in many meta-learning algorithms like Meta-Gradient Reinforcement Learning (MGRL) ([arXiv:1805.09801](https://arxiv.org/abs/1805.09801))), you might need `stop_gradient` function." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In general, the backpropagation only frees saved tensors (often used as auxiliary data for computing the gradient) but the computation graph remains. Once the outer iteration is finished, if you want to use any intermediate network parameters produced by the inner loop for the next bi-level iteration, you should detach them from the computation graph.\n", - "\n", - "There are two main reasons:\n", - "\n", - "- The network parameters are still connected to the previous computation graph (`.grad_fn` is not `None`). If later the gradient backpropagate to these parameters, the PyTorch backward engine will try to backpropagate through the previous computation graph. This will raise a `RuntimeError`: Trying to backward through the graph a second time...\n", - "- If we do not detach the computation graph, the computation graph connected to these parameters can not be freed by GC (Garbage Collector) until these parameters are collected by GC." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let us see what will happen if we do not use the `stop_gradient` function before we conduct the second bi-level process." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["We define the input `x` and output `y`. `y` will be served as the regression " + "target in the following code."] + }, + { + "cell_type" : "code", + "execution_count" : 2, + "metadata" : {}, + "outputs" : [], + "source" : [ + "batch_size = 64\n", + "dim = 16\n", + "\n", + "x = torch.randn((batch_size, dim))\n", + "y = torch.zeros((batch_size, 1))\n", + "net = Net(dim)" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Let us define the meta-parameter, we use `MetaSGD` as the inner-loop optimizer " + "and `Adam` as the outer-loop optimizer. "] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] + "cell_type" : "code", + "execution_count" : 3, + "metadata" : {}, + "outputs" : [], + "source" : [ + "meta_parameter = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", + "\n", + "optim = torchopt.MetaSGD(net, lr=1e-1)\n", + "meta_optim = torch.optim.Adam([meta_parameter], lr=1e-1)" + ] }, { - "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140024973755152\n\nouter_loss\n()\n\n\n\n140027829363232\n\nMseLossBackward0\n\n\n\n140027829363232->140024973755152\n\n\n\n\n\n140027829363616\n\nAddmmBackward0\n\n\n\n140027829363616->140027829363232\n\n\n\n\n\n140027829366544\n\nAddBackward0\n\n\n\n140027829366544->140027829363616\n\n\n\n\n\n140025091526128\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140025091526128->140027829366544\n\n\n\n\n\n140025091725152\n\nAddmmBackward0\n\n\n\n140025091526128->140025091725152\n\n\n\n\n\n140025091526416\n\nAccumulateGrad\n\n\n\n140025091526416->140025091526128\n\n\n\n\n\n140028156436736\n\nAddmmBackward0\n\n\n\n140025091526416->140028156436736\n\n\n\n\n\n140028155952000\n\nstep0.fc.bias\n(1)\n\n\n\n140028155952000->140025091526416\n\n\n\n\n\n140025091524976\n\nMulBackward0\n\n\n\n140025091524976->140025091526128\n\n\n\n\n\n140025091526560\n\nViewBackward0\n\n\n\n140025091526560->140025091524976\n\n\n\n\n\n140025091525456\n\nSumBackward1\n\n\n\n140025091525456->140025091526560\n\n\n\n\n\n140025091524112\n\nMseLossBackwardBackward0\n\n\n\n140025091524112->140025091525456\n\n\n\n\n\n140024973742672\n\nTBackward0\n\n\n\n140025091524112->140024973742672\n\n\n\n\n\n140024973742288\n\nMulBackward0\n\n\n\n140024973742288->140025091524112\n\n\n\n\n\n140024973742384\n\nAccumulateGrad\n\n\n\n140024973742384->140024973742288\n\n\n\n\n\n140025091726064\n\nMulBackward0\n\n\n\n140024973742384->140025091726064\n\n\n\n\n\n140025091549440\n\nmeta_parameter\n()\n\n\n\n140025091549440->140024973742384\n\n\n\n\n\n140028156436736->140025091524112\n\n\n\n\n\n140025091525408\n\nTBackward0\n\n\n\n140025091525408->140028156436736\n\n\n\n\n\n140025091526224\n\nAccumulateGrad\n\n\n\n140025091526224->140025091525408\n\n\n\n\n\n140025091524928\n\nAddBackward0\nstep1.fc.weight\n(1, 16)\n\n\n\n140025091526224->140025091524928\n\n\n\n\n\n140028155952880\n\nstep0.fc.weight\n(1, 16)\n\n\n\n140028155952880->140025091526224\n\n\n\n\n\n140025091726784\n\nMulBackward0\n\n\n\n140025091726784->140027829366544\n\n\n\n\n\n140025091726688\n\nViewBackward0\n\n\n\n140025091726688->140025091726784\n\n\n\n\n\n140025091725680\n\nSumBackward1\n\n\n\n140025091725680->140025091726688\n\n\n\n\n\n140025091726112\n\nMseLossBackwardBackward0\n\n\n\n140025091726112->140025091725680\n\n\n\n\n\n140025091726880\n\nTBackward0\n\n\n\n140025091726112->140025091726880\n\n\n\n\n\n140025091726064->140025091726112\n\n\n\n\n\n140025091725152->140025091726112\n\n\n\n\n\n140025091725824\n\nTBackward0\n\n\n\n140025091725824->140025091725152\n\n\n\n\n\n140025091524928->140025091725824\n\n\n\n\n\n140025091726016\n\nAddBackward0\n\n\n\n140025091524928->140025091726016\n\n\n\n\n\n140025091525600\n\nMulBackward0\n\n\n\n140025091525600->140025091524928\n\n\n\n\n\n140024973742144\n\nTBackward0\n\n\n\n140024973742144->140025091525600\n\n\n\n\n\n140024973742576\n\nTBackward0\n\n\n\n140024973742576->140024973742144\n\n\n\n\n\n140024973742480\n\nMmBackward0\n\n\n\n140024973742480->140024973742576\n\n\n\n\n\n140024973742672->140024973742480\n\n\n\n\n\n140027829365632\n\nTBackward0\n\n\n\n140027829365632->140027829363616\n\n\n\n\n\n140025091726016->140027829365632\n\n\n\n\n\n140025091726544\n\nMulBackward0\n\n\n\n140025091726544->140025091726016\n\n\n\n\n\n140025091726448\n\nTBackward0\n\n\n\n140025091726448->140025091726544\n\n\n\n\n\n140025091725584\n\nTBackward0\n\n\n\n140025091725584->140025091726448\n\n\n\n\n\n140025091727024\n\nMmBackward0\n\n\n\n140025091727024->140025091725584\n\n\n\n\n\n140025091726880->140025091727024\n\n\n\n\n\n" - }, - "metadata": {}, - "output_type": "display_data" + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Define the inner-loop optimization and visualize the inner-loop forward " + "gradient flow."] }, { - "data": { - "text/html": [ - "
╭─────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────╮\n",
-       " /tmp/ipykernel_3962266/4178930003.py:21 in <module>                                                             \n",
-       "                                                                                                                 \n",
-       " [Errno 2] No such file or directory: '/tmp/ipykernel_3962266/4178930003.py'                                     \n",
-       "                                                                                                                 \n",
-       " /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/_tensor.py:487 in backward           \n",
-       "                                                                                                                 \n",
-       "    484 │   │   │   │   create_graph=create_graph,                                                               \n",
-       "    485 │   │   │   │   inputs=inputs,                                                                           \n",
-       "    486 │   │   │   )                                                                                            \n",
-       "  487 │   │   torch.autograd.backward(                                                                         \n",
-       "    488 │   │   │   self, gradient, retain_graph, create_graph, inputs=inputs                                    \n",
-       "    489 │   │   )                                                                                                \n",
-       "    490                                                                                                          \n",
-       "                                                                                                                 \n",
-       " ╭───────────────────────── locals ──────────────────────────╮                                                   \n",
-       "  create_graph = False                                                                                         \n",
-       "      gradient = None                                                                                          \n",
-       "        inputs = None                                                                                          \n",
-       "  retain_graph = None                                                                                          \n",
-       "          self = tensor(0.1203, grad_fn=<MseLossBackward0>)                                                    \n",
-       " ╰───────────────────────────────────────────────────────────╯                                                   \n",
-       "                                                                                                                 \n",
-       " /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/autograd/__init__.py:197 in backward \n",
-       "                                                                                                                 \n",
-       "   194 │   # The reason we repeat same the comment below is that                                                 \n",
-       "   195 │   # some Python versions print out the first line of a multi-line function                              \n",
-       "   196 │   # calls in the traceback and some print out the last line                                             \n",
-       " 197 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the ba                   \n",
-       "   198 │   │   tensors, grad_tensors_, retain_graph, create_graph, inputs,                                       \n",
-       "   199 │   │   allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to r                   \n",
-       "   200                                                                                                           \n",
-       "                                                                                                                 \n",
-       " ╭──────────────────────────── locals ────────────────────────────╮                                              \n",
-       "    create_graph = False                                                                                       \n",
-       "    grad_tensors = None                                                                                        \n",
-       "   grad_tensors_ = (tensor(1.),)                                                                               \n",
-       "  grad_variables = None                                                                                        \n",
-       "          inputs = ()                                                                                          \n",
-       "    retain_graph = False                                                                                       \n",
-       "         tensors = (tensor(0.1203, grad_fn=<MseLossBackward0>),)                                               \n",
-       " ╰────────────────────────────────────────────────────────────────╯                                              \n",
-       "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
-       "RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have \n",
-       "already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().\n",
-       "Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved \n",
-       "tensors after calling backward.\n",
-       "
\n" + "cell_type" : "code", + "execution_count" : 4, + "metadata" : {}, + "outputs" : [ + { + "name" : "stdout", + "output_type" : "stream", + "text" : + [ "inner loss: 0.3472\n", "\n" ] + }, + { + "data" : { + "image/svg+xml" : + "\n\n\n\n\n\n%3\n\n\n\n140025091550880\n\ninner_loss\n()\n\n\n\n140028156253184\n\nMseLossBackward0\n\n\n\n140028156253184->140025091550880\n\n\n\n\n\n140028156436736\n\nAddmmBackward0\n\n\n\n140028156436736->140028156253184\n\n\n\n\n\n140025091526416\n\nAccumulateGrad\n\n\n\n140025091526416->140028156436736\n\n\n\n\n\n140028155952000\n\nstep0.fc.bias\n(1)\n\n\n\n140028155952000->140025091526416\n\n\n\n\n\n140025091525408\n\nTBackward0\n\n\n\n140025091525408->140028156436736\n\n\n\n\n\n140025091526224\n\nAccumulateGrad\n\n\n\n140025091526224->140025091525408\n\n\n\n\n\n140028155952880\n\nstep0.fc.weight\n(1, " + "16)\n\n\n\n140028155952880->140025091526224\n\n\n\n\n\n" + }, + "metadata" : {}, + "output_type" : "display_data" + } ], - "text/plain": [ - "\u001b[31m╭─\u001b[0m\u001b[31m────────────────────────────────────── \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m ──────────────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2;33m/tmp/ipykernel_3962266/\u001b[0m\u001b[1;33m4178930003.py\u001b[0m:\u001b[94m21\u001b[0m in \u001b[92m\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[3;31m[Errno 2] No such file or directory: '/tmp/ipykernel_3962266/4178930003.py'\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/\u001b[0m\u001b[1;33m_tensor.py\u001b[0m:\u001b[94m487\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2m 484 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mcreate_graph=create_graph, \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2m 485 \u001b[0m\u001b[2m│ │ │ │ \u001b[0minputs=inputs, \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2m 486 \u001b[0m\u001b[2m│ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 487 \u001b[2m│ │ \u001b[0mtorch.autograd.backward( \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2m 488 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m, gradient, retain_graph, create_graph, inputs=inputs \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2m 489 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2m 490 \u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m╭─\u001b[0m\u001b[33m──────────────────────── locals ─────────────────────────\u001b[0m\u001b[33m─╮\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m create_graph = \u001b[94mFalse\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m gradient = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m inputs = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m retain_graph = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m self = \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m0.1203\u001b[0m, \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mMseLossBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m╰───────────────────────────────────────────────────────────╯\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/autograd/\u001b[0m\u001b[1;33m__init__.py\u001b[0m:\u001b[94m197\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2m194 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# The reason we repeat same the comment below is that\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2m195 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# some Python versions print out the first line of a multi-line function\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2m196 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# calls in the traceback and some print out the last line\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m197 \u001b[2m│ \u001b[0mVariable._execution_engine.run_backward( \u001b[2m# Calls into the C++ engine to run the ba\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2m198 \u001b[0m\u001b[2m│ │ \u001b[0mtensors, grad_tensors_, retain_graph, create_graph, inputs, \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2m199 \u001b[0m\u001b[2m│ │ \u001b[0mallow_unreachable=\u001b[94mTrue\u001b[0m, accumulate_grad=\u001b[94mTrue\u001b[0m) \u001b[2m# Calls into the C++ engine to r\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[2m200 \u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m╭─\u001b[0m\u001b[33m─────────────────────────── locals ───────────────────────────\u001b[0m\u001b[33m─╮\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m create_graph = \u001b[94mFalse\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_tensors = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_tensors_ = \u001b[1m(\u001b[0m\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m1\u001b[0m.\u001b[1m)\u001b[0m,\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_variables = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m inputs = \u001b[1m(\u001b[0m\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m retain_graph = \u001b[94mFalse\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m tensors = \u001b[1m(\u001b[0m\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m0.1203\u001b[0m, \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mMseLossBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m,\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m│\u001b[0m \u001b[33m╰────────────────────────────────────────────────────────────────╯\u001b[0m \u001b[31m│\u001b[0m\n", - "\u001b[31m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", - "\u001b[1;91mRuntimeError: \u001b[0mTrying to backward through the graph a second time \u001b[1m(\u001b[0mor directly access saved tensors after they have \n", - "already been freed\u001b[1m)\u001b[0m. Saved intermediate values of the graph are freed when you call \u001b[1;35m.backward\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m or \u001b[1;35mautograd.grad\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m.\n", - "Specify \u001b[33mretain_graph\u001b[0m=\u001b[3;92mTrue\u001b[0m if you need to backward through the graph a second time or if you need to access saved \n", - "tensors after calling backward.\n" + "source" : [ + "init_net_state = torchopt.extract_state_dict(net, enable_visual=True, " + "visual_prefix='step0.')\n", + "\n", + "# inner loss\n", + "inner_loss = loss_fn(net(x), y)\n", + "\n", + "print(f'inner loss: {inner_loss:.4f}')\n", + "display(torchopt.visual.make_dot(inner_loss, params=(init_net_state, {'inner_loss': " + "inner_loss})))" ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Inner update with attached computation graph\n", - "inner_loss = loss_fn(net(x), y)\n", - "loss = inner_loss * meta_parameter\n", - "optim.step(loss)\n", - "\n", - "# Outer forward process\n", - "outer_loss = loss_fn(net(x), y)\n", - "display(\n", - " torchopt.visual.make_dot(\n", - " outer_loss,\n", - " params=(\n", - " init_net_state,\n", - " one_step_net_state,\n", - " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", - " ),\n", - " )\n", - ")\n", - "\n", - "# Outer update\n", - "meta_optim.zero_grad()\n", - "outer_loss.backward()\n", - "meta_optim.step()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "From the graph we can see, directly conducting the second bi-level process links the graph of first and second bi-level process together. We should manually stop gradient with `torchopt.stop_gradient`. `torchopt.stop_gradient` will detach the node of gradient graph and make it become a leaf node. It allows the input of network, optimizer, or state dictionary and the gradient operation happens in an in-place manner.\n", - "\n", - "Let's use `recover_state_dict` to come back to one-step updated states." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "# Reset to previous one-step updated states\n", - "torchopt.recover_state_dict(net, one_step_net_state)\n", - "torchopt.recover_state_dict(optim, one_step_optim_state)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And finally, Let's conduct the stop-gradient operation before the second meta-optimization step. " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Conduct inner-loop optimization with `MetaSGD`, here the meta-parameter is " + "served as a factor controlling the scale of inner-loop loss."] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "meta_parameter.grad = tensor(-0.0635)\n", - "meta_parameter = Parameter containing:\n", - "tensor(1.1940, requires_grad=True)\n", - "\n" - ] + "cell_type" : "code", + "execution_count" : 5, + "metadata" : {}, + "outputs" : [], + "source" : [ + "# inner-step optimization\n", + "loss = inner_loss * meta_parameter\n", + "optim.step(loss)" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "We compute the outer loss and draw the full computation graph of the first bi-level " + "process. In this graph, three main parts are included.\n", + "\n", + "- Inner-loop: forward process and inner-loss calculation\n", + "- Inner-loop optimization: `MetaSGD` optimization step given inner-loss\n", + "- Outer-loop: forward process and outer-loss calculation" + ] + }, + { + "cell_type" : "code", + "execution_count" : 6, + "metadata" : {}, + "outputs" : [ + { + "name" : "stdout", + "output_type" : "stream", + "text" : + [ "outer loss: 0.2039\n", "\n" ] + }, + { + "data" : { + "image/svg+xml" : + "\n\n\n\n\n\n%3\n\n\n\n140027829238416\n\nouter_loss\n()\n\n\n\n140025091525072\n\nMseLossBackward0\n\n\n\n140025091525072->140027829238416\n\n\n\n\n\n140025091525216\n\nAddmmBackward0\n\n\n\n140025091525216->140025091525072\n\n\n\n\n\n140025091526128\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140025091526128->140025091525216\n\n\n\n\n\n140025091526416\n\nAccumulateGrad\n\n\n\n140025091526416->140025091526128\n\n\n\n\n\n140028156436736\n\nAddmmBackward0\n\n\n\n140025091526416->140028156436736\n\n\n\n\n\n140028155952000\n\nstep0.fc.bias\n(1)\n\n\n\n140028155952000->140025091526416\n\n\n\n\n\n140025091524976\n\nMulBackward0\n\n\n\n140025091524976->140025091526128\n\n\n\n\n\n140025091526560\n\nViewBackward0\n\n\n\n140025091526560->140025091524976\n\n\n\n\n\n140025091525456\n\nSumBackward1\n\n\n\n140025091525456->140025091526560\n\n\n\n\n\n140025091524112\n\nMseLossBackwardBackward0\n\n\n\n140025091524112->140025091525456\n\n\n\n\n\n140024973742672\n\nTBackward0\n\n\n\n140025091524112->140024973742672\n\n\n\n\n\n140024973742288\n\nMulBackward0\n\n\n\n140024973742288->140025091524112\n\n\n\n\n\n140024973742384\n\nAccumulateGrad\n\n\n\n140024973742384->140024973742288\n\n\n\n\n\n140025091549440\n\nmeta_parameter\n()\n\n\n\n140025091549440->140024973742384\n\n\n\n\n\n140028156436736->140025091524112\n\n\n\n\n\n140025091525408\n\nTBackward0\n\n\n\n140025091525408->140028156436736\n\n\n\n\n\n140025091526224\n\nAccumulateGrad\n\n\n\n140025091526224->140025091525408\n\n\n\n\n\n140025091524928\n\nAddBackward0\nstep1.fc.weight\n(1, " + "16)\n\n\n\n140025091526224->140025091524928\n\n\n\n\n\n140028155952880\n\nstep0.fc.weight\n(1, 16)\n\n\n\n140028155952880->140025091526224\n\n\n\n\n\n140025091524448\n\nTBackward0\n\n\n\n140025091524448->140025091525216\n\n\n\n\n\n140025091524928->140025091524448\n\n\n\n\n\n140025091525600\n\nMulBackward0\n\n\n\n140025091525600->140025091524928\n\n\n\n\n\n140024973742144\n\nTBackward0\n\n\n\n140024973742144->140025091525600\n\n\n\n\n\n140024973742576\n\nTBackward0\n\n\n\n140024973742576->140024973742144\n\n\n\n\n\n140024973742480\n\nMmBackward0\n\n\n\n140024973742480->140024973742576\n\n\n\n\n\n140024973742672->140024973742480\n\n\n\n\n\n" + }, + "metadata" : {}, + "output_type" : "display_data" + } + ], + "source" : [ + "# Extract `state_dict`` for updated network\n", + "one_step_net_state = torchopt.extract_state_dict(net, enable_visual=True, " + "visual_prefix='step1.')\n", + "one_step_optim_state = torchopt.extract_state_dict(optim)\n", + "\n", + "# Calculate outer loss\n", + "outer_loss = loss_fn(net(x), y)\n", + "print(f'outer loss: {outer_loss:.4f}')\n", + "display(\n", + " torchopt.visual.make_dot(\n", + " outer_loss,\n", + " params=(\n", + " init_net_state,\n", + " one_step_net_state,\n", + " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", + " ),\n", + " )\n", + ")" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Then we backward the loss to conduct outer-loop meta-optimization."] + }, + { + "cell_type" : "code", + "execution_count" : 7, + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : [ + "meta_parameter.grad = tensor(-0.1205)\n", + "meta_parameter = Parameter containing:\n", + "tensor(1.1000, requires_grad=True)\n" + ] + } ], + "source" : [ + "meta_optim.zero_grad()\n", + "outer_loss.backward()\n", + "print(f'meta_parameter.grad = {meta_parameter.grad!r}')\n", + "meta_optim.step()\n", + "print(f'meta_parameter = {meta_parameter!r}')" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : + ["We have already conducted one bi-level optimization and optimize our meta-parameters. " + "When you want to conduct the second bi-level optimization, you need to be careful " + "whether you need to use the `stop_gradient` function. For example, if your new " + "inner-loop parameters directly inherits previous inner-loop parameters (which is a " + "common strategy in many meta-learning algorithms like Meta-Gradient Reinforcement " + "Learning (MGRL) ([arXiv:1805.09801](https://arxiv.org/abs/1805.09801))), you might " + "need `stop_gradient` function."] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "In general, the backpropagation only frees saved tensors (often used as auxiliary data " + "for computing the gradient) but the computation graph remains. Once the outer iteration " + "is finished, if you want to use any intermediate network parameters produced by the inner " + "loop for the next bi-level iteration, you should detach them from the computation " + "graph.\n", + "\n", + "There are two main reasons:\n", + "\n", + "- The network parameters are still connected to the previous computation graph " + "(`.grad_fn` is not `None`). If later the gradient backpropagate to these parameters, the " + "PyTorch backward engine will try to backpropagate through the previous computation graph. " + "This will raise a `RuntimeError`: Trying to backward through the graph a second time...\n", + "- If we do not detach the computation graph, the computation graph connected to these " + "parameters can not be freed by GC (Garbage Collector) until these parameters are " + "collected by GC." + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["Now let us see what will happen if we do not use the `stop_gradient` function " + "before we conduct the second bi-level process."] }, { - "data": { - "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140024973754912\n\nouter_loss\n()\n\n\n\n140024956770528\n\nMseLossBackward0\n\n\n\n140024956770528->140024973754912\n\n\n\n\n\n140024956772112\n\nAddmmBackward0\n\n\n\n140024956772112->140024956770528\n\n\n\n\n\n140024956770720\n\nAddBackward0\n\n\n\n140024956770720->140024956772112\n\n\n\n\n\n140024962101312\n\nAccumulateGrad\n\n\n\n140024962101312->140024956770720\n\n\n\n\n\n140024973745552\n\nAddmmBackward0\n\n\n\n140024962101312->140024973745552\n\n\n\n\n\n140025091547520\n\nstep1.detached.fc.bias\n(1)\n\n\n\n140025091547520->140024962101312\n\n\n\n\n\n140024971586864\n\nMulBackward0\n\n\n\n140024971586864->140024956770720\n\n\n\n\n\n140024973742528\n\nViewBackward0\n\n\n\n140024973742528->140024971586864\n\n\n\n\n\n140024973743968\n\nSumBackward1\n\n\n\n140024973743968->140024973742528\n\n\n\n\n\n140024973742768\n\nMseLossBackwardBackward0\n\n\n\n140024973742768->140024973743968\n\n\n\n\n\n140024973744400\n\nTBackward0\n\n\n\n140024973742768->140024973744400\n\n\n\n\n\n140024973744688\n\nMulBackward0\n\n\n\n140024973744688->140024973742768\n\n\n\n\n\n140024973745264\n\nAccumulateGrad\n\n\n\n140024973745264->140024973744688\n\n\n\n\n\n140025091549440\n\nmeta_parameter\n()\n\n\n\n140025091549440->140024973745264\n\n\n\n\n\n140024973745552->140024973742768\n\n\n\n\n\n140024973745168\n\nTBackward0\n\n\n\n140024973745168->140024973745552\n\n\n\n\n\n140024973744256\n\nAccumulateGrad\n\n\n\n140024973744256->140024973745168\n\n\n\n\n\n140024973745984\n\nAddBackward0\n\n\n\n140024973744256->140024973745984\n\n\n\n\n\n140027828983424\n\nstep1.detached.fc.weight\n(1, 16)\n\n\n\n140027828983424->140024973744256\n\n\n\n\n\n140024956771632\n\nTBackward0\n\n\n\n140024956771632->140024956772112\n\n\n\n\n\n140024973745984->140024956771632\n\n\n\n\n\n140024973743728\n\nMulBackward0\n\n\n\n140024973743728->140024973745984\n\n\n\n\n\n140024973743344\n\nTBackward0\n\n\n\n140024973743344->140024973743728\n\n\n\n\n\n140024973745312\n\nTBackward0\n\n\n\n140024973745312->140024973743344\n\n\n\n\n\n140024973743200\n\nMmBackward0\n\n\n\n140024973743200->140024973745312\n\n\n\n\n\n140024973744400->140024973743200\n\n\n\n\n\n" - }, - "metadata": {}, - "output_type": "display_data" + "cell_type" : "code", + "execution_count" : 8, + "metadata" : {}, + "outputs" : [ + { + "name" : "stdout", + "output_type" : "stream", + "text" : ["\n"] + }, + { + "data" : { + "image/svg+xml" : + "\n\n\n\n\n\n%3\n\n\n\n140024973755152\n\nouter_loss\n()\n\n\n\n140027829363232\n\nMseLossBackward0\n\n\n\n140027829363232->140024973755152\n\n\n\n\n\n140027829363616\n\nAddmmBackward0\n\n\n\n140027829363616->140027829363232\n\n\n\n\n\n140027829366544\n\nAddBackward0\n\n\n\n140027829366544->140027829363616\n\n\n\n\n\n140025091526128\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140025091526128->140027829366544\n\n\n\n\n\n140025091725152\n\nAddmmBackward0\n\n\n\n140025091526128->140025091725152\n\n\n\n\n\n140025091526416\n\nAccumulateGrad\n\n\n\n140025091526416->140025091526128\n\n\n\n\n\n140028156436736\n\nAddmmBackward0\n\n\n\n140025091526416->140028156436736\n\n\n\n\n\n140028155952000\n\nstep0.fc.bias\n(1)\n\n\n\n140028155952000->140025091526416\n\n\n\n\n\n140025091524976\n\nMulBackward0\n\n\n\n140025091524976->140025091526128\n\n\n\n\n\n140025091526560\n\nViewBackward0\n\n\n\n140025091526560->140025091524976\n\n\n\n\n\n140025091525456\n\nSumBackward1\n\n\n\n140025091525456->140025091526560\n\n\n\n\n\n140025091524112\n\nMseLossBackwardBackward0\n\n\n\n140025091524112->140025091525456\n\n\n\n\n\n140024973742672\n\nTBackward0\n\n\n\n140025091524112->140024973742672\n\n\n\n\n\n140024973742288\n\nMulBackward0\n\n\n\n140024973742288->140025091524112\n\n\n\n\n\n140024973742384\n\nAccumulateGrad\n\n\n\n140024973742384->140024973742288\n\n\n\n\n\n140025091726064\n\nMulBackward0\n\n\n\n140024973742384->140025091726064\n\n\n\n\n\n140025091549440\n\nmeta_parameter\n()\n\n\n\n140025091549440->140024973742384\n\n\n\n\n\n140028156436736->140025091524112\n\n\n\n\n\n140025091525408\n\nTBackward0\n\n\n\n140025091525408->140028156436736\n\n\n\n\n\n140025091526224\n\nAccumulateGrad\n\n\n\n140025091526224->140025091525408\n\n\n\n\n\n140025091524928\n\nAddBackward0\nstep1.fc.weight\n(1, " + "16)\n\n\n\n140025091526224->140025091524928\n\n\n\n\n\n140028155952880\n\nstep0.fc.weight\n(1, 16)\n\n\n\n140028155952880->140025091526224\n\n\n\n\n\n140025091726784\n\nMulBackward0\n\n\n\n140025091726784->140027829366544\n\n\n\n\n\n140025091726688\n\nViewBackward0\n\n\n\n140025091726688->140025091726784\n\n\n\n\n\n140025091725680\n\nSumBackward1\n\n\n\n140025091725680->140025091726688\n\n\n\n\n\n140025091726112\n\nMseLossBackwardBackward0\n\n\n\n140025091726112->140025091725680\n\n\n\n\n\n140025091726880\n\nTBackward0\n\n\n\n140025091726112->140025091726880\n\n\n\n\n\n140025091726064->140025091726112\n\n\n\n\n\n140025091725152->140025091726112\n\n\n\n\n\n140025091725824\n\nTBackward0\n\n\n\n140025091725824->140025091725152\n\n\n\n\n\n140025091524928->140025091725824\n\n\n\n\n\n140025091726016\n\nAddBackward0\n\n\n\n140025091524928->140025091726016\n\n\n\n\n\n140025091525600\n\nMulBackward0\n\n\n\n140025091525600->140025091524928\n\n\n\n\n\n140024973742144\n\nTBackward0\n\n\n\n140024973742144->140025091525600\n\n\n\n\n\n140024973742576\n\nTBackward0\n\n\n\n140024973742576->140024973742144\n\n\n\n\n\n140024973742480\n\nMmBackward0\n\n\n\n140024973742480->140024973742576\n\n\n\n\n\n140024973742672->140024973742480\n\n\n\n\n\n140027829365632\n\nTBackward0\n\n\n\n140027829365632->140027829363616\n\n\n\n\n\n140025091726016->140027829365632\n\n\n\n\n\n140025091726544\n\nMulBackward0\n\n\n\n140025091726544->140025091726016\n\n\n\n\n\n140025091726448\n\nTBackward0\n\n\n\n140025091726448->140025091726544\n\n\n\n\n\n140025091725584\n\nTBackward0\n\n\n\n140025091725584->140025091726448\n\n\n\n\n\n140025091727024\n\nMmBackward0\n\n\n\n140025091727024->140025091725584\n\n\n\n\n\n140025091726880->140025091727024\n\n\n\n\n\n" + }, + "metadata" : {}, + "output_type" : "display_data" + }, + { + "data" : { + "text/html" : [ + "
╭─────────────────────────────────────── "
+              "Traceback (most recent call last) "
+              "───────────────────────────────────────╮\n",
+              " /tmp/ipykernel_3962266/4178930003.py:21 in <module>      "
+              "                                                       \n",
+              "             "
+              "                                                                                    "
+              "                \n",
+              " [Errno "
+              "2] No such file or directory: '/tmp/ipykernel_3962266/4178930003.py'         "
+              "                            \n",
+              "             "
+              "                                                                                    "
+              "                \n",
+              " /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/"
+              "torch/_tensor.py:487 in backward           \n",
+              "             "
+              "                                                                                    "
+              "                \n",
+              "    484 │   │   │   │   "
+              "create_graph=create_graph,                                                   "
+              "            \n",
+              "    485 │   │   │   │   "
+              "inputs=inputs,                                                               "
+              "            \n",
+              "    486 │   │   │   )  "
+              "                                                                                    "
+              "      \n",
+              "  487 │   │   "
+              "torch.autograd.backward(                                                     "
+              "                    \n",
+              "    488 │   │   │   "
+              "self, "
+              "gradient, retain_graph, create_graph, inputs=inputs                                 "
+              "   \n",
+              "    489 │   │   )      "
+              "                                                                                    "
+              "      \n",
+              "    490                "
+              "                                                                                    "
+              "      \n",
+              "             "
+              "                                                                                    "
+              "                \n",
+              " ╭───────────────────────── "
+              "locals ──────────────────────────╮                                           "
+              "        \n",
+              "  create_graph = "
+              "False         "
+              "                                                                                \n",
+              "      gradient = "
+              "None          "
+              "                                                                                \n",
+              "        inputs = "
+              "None          "
+              "                                                                                \n",
+              "  retain_graph = "
+              "None          "
+              "                                                                                \n",
+              "          self = "
+              "tensor(0.1203, grad_fn=<MseLossBackward0>)                                                    \n",
+              " ╰───────────────────────────────────────────────────────────╯      "
+              "                                             \n",
+              "             "
+              "                                                                                    "
+              "                \n",
+              " /home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/"
+              "torch/autograd/__init__.py:197 in backward \n",
+              "             "
+              "                                                                                    "
+              "                \n",
+              "   194 │   # The reason we "
+              "repeat same the comment below is that                                        "
+              "         \n",
+              "   195 │   # some Python "
+              "versions print out the first line of a multi-line function                   "
+              "           \n",
+              "   196 │   # calls in the "
+              "traceback and some print out the last line                                   "
+              "          \n",
+              " 197 │   "
+              "Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the ba   "
+              "                \n",
+              "   198 │   │   "
+              "tensors, grad_tensors_, retain_graph, create_graph, inputs,                  "
+              "                     \n",
+              "   199 │   │   "
+              "allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to r            "
+              "       \n",
+              "   200                 "
+              "                                                                                    "
+              "      \n",
+              "             "
+              "                                                                                    "
+              "                \n",
+              " ╭──────────────────────────── locals ────────────────────────────╮ "
+              "                                             \n",
+              "    create_graph = "
+              "False         "
+              "                                                                           "
+              "   \n",
+              "    grad_tensors = "
+              "None          "
+              "                                                                           "
+              "   \n",
+              "   grad_tensors_ = "
+              "(tensor(1.),)                          "
+              "             "
+              "                                        \n",
+              "  grad_variables = "
+              "None          "
+              "                                                                           "
+              "   \n",
+              "          inputs = "
+              "()                                         "
+              "             "
+              "                                    \n",
+              "    retain_graph = "
+              "False         "
+              "                                                                           "
+              "   \n",
+              "         tensors = "
+              "(tensor(0.1203, grad_fn=<MseLossBackward0>),)                                   "
+              "            \n",
+              " ╰────────────────────────────────────────────────────────────────╯ "
+              "                                             \n",
+              ""
+              "╰───────────────────────────────────────────────────────────────────────────────────"
+              "──────────────────────────────╯\n",
+              "RuntimeError: Trying to backward through the graph a second time "
+              "(or directly access saved tensors after "
+              "they have \n",
+              "already been freed). Saved intermediate "
+              "values of the graph are freed when you call .backward() or autograd.grad().\n",
+              "Specify retain_graph=True if you need to backward through the graph "
+              "a second time or if you need to access saved \n",
+              "tensors after calling backward.\n",
+              "
\n" + ], + "text/plain" : [ + "\u001b[31m╭─\u001b[0m\u001b[31m────────────────────────────────────── " + "\u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call " + "last)\u001b[0m\u001b[31m " + "──────────────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", + "\u001b[31m│\u001b[0m " + "\u001b[2;33m/tmp/ipykernel_3962266/" + "\u001b[0m\u001b[1;33m4178930003.py\u001b[0m:\u001b[94m21\u001b[0m in " + "\u001b[92m\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[3;31m[Errno 2] No such file or directory: " + "'/tmp/ipykernel_3962266/4178930003.py'\u001b[0m " + "\u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m " + "\u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/" + "torch/\u001b[0m\u001b[1;33m_tensor.py\u001b[0m:\u001b[94m487\u001b[0m in " + "\u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 484 \u001b[0m\u001b[2m│ │ │ │ " + "\u001b[0mcreate_graph=create_graph, " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 485 \u001b[0m\u001b[2m│ │ │ │ " + "\u001b[0minputs=inputs, " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 486 \u001b[0m\u001b[2m│ │ │ \u001b[0m) " + " " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 487 \u001b[2m│ │ " + "\u001b[0mtorch.autograd.backward( " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 488 \u001b[0m\u001b[2m│ │ │ " + "\u001b[0m\u001b[96mself\u001b[0m, gradient, retain_graph, create_graph, " + "inputs=inputs \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 489 \u001b[0m\u001b[2m│ │ \u001b[0m) " + " " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 490 \u001b[0m " + " " + "\u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m╭─\u001b[0m\u001b[33m──────────────────────── locals " + "─────────────────────────\u001b[0m\u001b[33m─╮\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m create_graph = \u001b[94mFalse\u001b[0m " + " \u001b[33m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m gradient = \u001b[94mNone\u001b[0m " + " \u001b[33m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m inputs = \u001b[94mNone\u001b[0m " + " \u001b[33m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m retain_graph = \u001b[94mNone\u001b[0m " + " \u001b[33m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m self = " + "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m0.1203\u001b[0m, " + "\u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mMseLossBackward0\u001b[" + "0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m " + "\u001b[33m╰───────────────────────────────────────────────────────────╯\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m " + "\u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/" + "torch/autograd/\u001b[0m\u001b[1;33m__init__.py\u001b[0m:\u001b[94m197\u001b[0m in " + "\u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m194 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# The " + "reason we repeat same the comment below is that\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m195 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# some " + "Python versions print out the first line of a multi-line function\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m196 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# calls " + "in the traceback and some print out the last line\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m197 \u001b[2m│ " + "\u001b[0mVariable._execution_engine.run_backward( \u001b[2m# Calls into the C++ " + "engine to run the ba\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m198 \u001b[0m\u001b[2m│ │ \u001b[0mtensors, " + "grad_tensors_, retain_graph, create_graph, inputs, " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m199 \u001b[0m\u001b[2m│ │ " + "\u001b[0mallow_unreachable=\u001b[94mTrue\u001b[0m, " + "accumulate_grad=\u001b[94mTrue\u001b[0m) \u001b[2m# Calls into the C++ engine to " + "r\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m200 \u001b[0m " + " " + "\u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m╭─\u001b[0m\u001b[33m─────────────────────────── " + "locals ───────────────────────────\u001b[0m\u001b[33m─╮\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m create_graph = \u001b[94mFalse\u001b[0m " + " \u001b[33m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_tensors = \u001b[94mNone\u001b[0m " + " \u001b[33m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_tensors_ = " + "\u001b[1m(\u001b[0m\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m1\u001b[" + "0m.\u001b[1m)\u001b[0m,\u001b[1m)\u001b[0m " + "\u001b[33m│\u001b[0m " + "\u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_variables = \u001b[94mNone\u001b[0m " + " \u001b[33m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m inputs = " + "\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m " + "\u001b[33m│\u001b[0m " + "\u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m retain_graph = \u001b[94mFalse\u001b[0m " + " \u001b[33m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m tensors = " + "\u001b[1m(\u001b[0m\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m0." + "1203\u001b[0m, " + "\u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mMseLossBackward0\u001b[" + "0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m,\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m " + " \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m " + "\u001b[33m╰────────────────────────────────────────────────────────────────╯\u001b[" + "0m \u001b[31m│\u001b[0m\n", + "\u001b[" + "31m╰────────────────────────────────────────────────────────────────────────────────" + "─────────────────────────────────╯\u001b[0m\n", + "\u001b[1;91mRuntimeError: \u001b[0mTrying to backward through the graph a second " + "time \u001b[1m(\u001b[0mor directly access saved tensors after they have \n", + "already been freed\u001b[1m)\u001b[0m. Saved intermediate values of the graph are " + "freed when you call " + "\u001b[1;35m.backward\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m or " + "\u001b[1;35mautograd.grad\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m.\n", + "Specify \u001b[33mretain_graph\u001b[0m=\u001b[3;92mTrue\u001b[0m if you need to " + "backward through the graph a second time or if you need to access saved \n", + "tensors after calling backward.\n" + ] + }, + "metadata" : {}, + "output_type" : "display_data" + } + ], + "source" : [ + "# Inner update with attached computation graph\n", + "inner_loss = loss_fn(net(x), y)\n", + "loss = inner_loss * meta_parameter\n", + "optim.step(loss)\n", + "\n", + "# Outer forward process\n", + "outer_loss = loss_fn(net(x), y)\n", + "display(\n", + " torchopt.visual.make_dot(\n", + " outer_loss,\n", + " params=(\n", + " init_net_state,\n", + " one_step_net_state,\n", + " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", + " ),\n", + " )\n", + ")\n", + "\n", + "# Outer update\n", + "meta_optim.zero_grad()\n", + "outer_loss.backward()\n", + "meta_optim.step()" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : [ + "From the graph we can see, directly conducting the second bi-level process links the " + "graph of first and second bi-level process together. We should manually stop gradient " + "with `torchopt.stop_gradient`. `torchopt.stop_gradient` will detach the node of gradient " + "graph and make it become a leaf node. It allows the input of network, optimizer, or state " + "dictionary and the gradient operation happens in an in-place manner.\n", + "\n", + "Let's use `recover_state_dict` to come back to one-step updated states." + ] + }, + { + "cell_type" : "code", + "execution_count" : 9, + "metadata" : {}, + "outputs" : [], + "source" : [ + "# Reset to previous one-step updated states\n", + "torchopt.recover_state_dict(net, one_step_net_state)\n", + "torchopt.recover_state_dict(optim, one_step_optim_state)" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["And finally, Let's conduct the stop-gradient operation before the second " + "meta-optimization step. "] + }, + { + "cell_type" : "code", + "execution_count" : 10, + "metadata" : {}, + "outputs" : [ + { + "name" : "stdout", + "output_type" : "stream", + "text" : [ + "meta_parameter.grad = tensor(-0.0635)\n", + "meta_parameter = Parameter containing:\n", + "tensor(1.1940, requires_grad=True)\n", + "\n" + ] + }, + { + "data" : { + "image/svg+xml" : + "\n\n\n\n\n\n%3\n\n\n\n140024973754912\n\nouter_loss\n()\n\n\n\n140024956770528\n\nMseLossBackward0\n\n\n\n140024956770528->140024973754912\n\n\n\n\n\n140024956772112\n\nAddmmBackward0\n\n\n\n140024956772112->140024956770528\n\n\n\n\n\n140024956770720\n\nAddBackward0\n\n\n\n140024956770720->140024956772112\n\n\n\n\n\n140024962101312\n\nAccumulateGrad\n\n\n\n140024962101312->140024956770720\n\n\n\n\n\n140024973745552\n\nAddmmBackward0\n\n\n\n140024962101312->140024973745552\n\n\n\n\n\n140025091547520\n\nstep1.detached.fc.bias\n(1)\n\n\n\n140025091547520->140024962101312\n\n\n\n\n\n140024971586864\n\nMulBackward0\n\n\n\n140024971586864->140024956770720\n\n\n\n\n\n140024973742528\n\nViewBackward0\n\n\n\n140024973742528->140024971586864\n\n\n\n\n\n140024973743968\n\nSumBackward1\n\n\n\n140024973743968->140024973742528\n\n\n\n\n\n140024973742768\n\nMseLossBackwardBackward0\n\n\n\n140024973742768->140024973743968\n\n\n\n\n\n140024973744400\n\nTBackward0\n\n\n\n140024973742768->140024973744400\n\n\n\n\n\n140024973744688\n\nMulBackward0\n\n\n\n140024973744688->140024973742768\n\n\n\n\n\n140024973745264\n\nAccumulateGrad\n\n\n\n140024973745264->140024973744688\n\n\n\n\n\n140025091549440\n\nmeta_parameter\n()\n\n\n\n140025091549440->140024973745264\n\n\n\n\n\n140024973745552->140024973742768\n\n\n\n\n\n140024973745168\n\nTBackward0\n\n\n\n140024973745168->140024973745552\n\n\n\n\n\n140024973744256\n\nAccumulateGrad\n\n\n\n140024973744256->140024973745168\n\n\n\n\n\n140024973745984\n\nAddBackward0\n\n\n\n140024973744256->140024973745984\n\n\n\n\n\n140027828983424\n\nstep1.detached.fc.weight\n(1, " + "16)\n\n\n\n140027828983424->140024973744256\n\n\n\n\n\n140024956771632\n\nTBackward0\n\n\n\n140024956771632->140024956772112\n\n\n\n\n\n140024973745984->140024956771632\n\n\n\n\n\n140024973743728\n\nMulBackward0\n\n\n\n140024973743728->140024973745984\n\n\n\n\n\n140024973743344\n\nTBackward0\n\n\n\n140024973743344->140024973743728\n\n\n\n\n\n140024973745312\n\nTBackward0\n\n\n\n140024973745312->140024973743344\n\n\n\n\n\n140024973743200\n\nMmBackward0\n\n\n\n140024973743200->140024973745312\n\n\n\n\n\n140024973744400->140024973743200\n\n\n\n\n\n" + }, + "metadata" : {}, + "output_type" : "display_data" + } + ], + "source" : [ + "# Stop gradient and make them become the leaf node\n", + "torchopt.stop_gradient(net)\n", + "torchopt.stop_gradient(optim)\n", + "one_step_net_state_detached = torchopt.extract_state_dict(\n", + " net, enable_visual=True, visual_prefix='step1.detached.'\n", + ")\n", + "\n", + "# Inner update\n", + "inner_loss = loss_fn(net(x), y)\n", + "loss = inner_loss * meta_parameter\n", + "optim.step(loss)\n", + "\n", + "# Outer update\n", + "outer_loss = loss_fn(net(x), y)\n", + "meta_optim.zero_grad()\n", + "outer_loss.backward()\n", + "print(f'meta_parameter.grad = {meta_parameter.grad!r}')\n", + "meta_optim.step()\n", + "print(f'meta_parameter = {meta_parameter!r}')\n", + "\n", + "display(\n", + " torchopt.visual.make_dot(\n", + " outer_loss,\n", + " params=(\n", + " one_step_net_state_detached,\n", + " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", + " ),\n", + " )\n", + ")" + ] + }, + { + "cell_type" : "markdown", + "metadata" : {}, + "source" : ["The gradient graph is the same with the first meta-optimization's gradient " + "graph and we successfully conduct the second bi-level process."] + } + ], + "metadata" : { + "kernelspec" : + {"display_name" : "Python 3 (ipykernel)", "language" : "python", "name" : "python3"}, + "language_info" : { + "codemirror_mode" : {"name" : "ipython", "version" : 3}, + "file_extension" : ".py", + "mimetype" : "text/x-python", + "name" : "python", + "nbconvert_exporter" : "python", + "pygments_lexer" : "ipython3", + "version" : "3.9.15" + }, + "vscode" : { + "interpreter" : + {"hash" : "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da"} } - ], - "source": [ - "# Stop gradient and make them become the leaf node\n", - "torchopt.stop_gradient(net)\n", - "torchopt.stop_gradient(optim)\n", - "one_step_net_state_detached = torchopt.extract_state_dict(\n", - " net, enable_visual=True, visual_prefix='step1.detached.'\n", - ")\n", - "\n", - "# Inner update\n", - "inner_loss = loss_fn(net(x), y)\n", - "loss = inner_loss * meta_parameter\n", - "optim.step(loss)\n", - "\n", - "# Outer update\n", - "outer_loss = loss_fn(net(x), y)\n", - "meta_optim.zero_grad()\n", - "outer_loss.backward()\n", - "print(f'meta_parameter.grad = {meta_parameter.grad!r}')\n", - "meta_optim.step()\n", - "print(f'meta_parameter = {meta_parameter!r}')\n", - "\n", - "display(\n", - " torchopt.visual.make_dot(\n", - " outer_loss,\n", - " params=(\n", - " one_step_net_state_detached,\n", - " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", - " ),\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The gradient graph is the same with the first meta-optimization's gradient graph and we successfully conduct the second bi-level process." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" }, - "vscode": { - "interpreter": { - "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat" : 4, + "nbformat_minor" : 4 } diff --git a/tutorials/5_Implicit_Differentiation.ipynb b/tutorials/5_Implicit_Differentiation.ipynb index 23407801..a7dc8ed5 100644 --- a/tutorials/5_Implicit_Differentiation.ipynb +++ b/tutorials/5_Implicit_Differentiation.ipynb @@ -1,576 +1,590 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", - "metadata": {}, - "source": [ - "# TorchOpt for Implicit Differentiation" - ] - }, - { - "cell_type": "markdown", - "id": "2b547376", - "metadata": {}, - "source": [ - "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)" - ] - }, - { - "cell_type": "markdown", - "id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd", - "metadata": {}, - "source": [ - "By treating the solution $\\phi^{\\star}$ as an implicit function of $\\theta$, the idea of implicit differentiation is to directly get analytical best-response derivatives $\\partial \\phi^{\\star}(\\theta)/ \\partial \\theta$ by implicit function theorem. This is suitable for algorithms when the inner-level optimal solution is achieved ${\\left. \\frac{\\partial F (\\phi, \\theta)}{\\partial \\phi} \\right\\rvert}_{\\phi = \\phi^{\\star}} = 0$ or reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377)." - ] - }, - { - "cell_type": "markdown", - "id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", - "metadata": {}, - "source": [ - "In this tutorial, we will introduce how TorchOpt can be used to conduct implicit differentiation." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "8f13ae67-e328-409f-84a8-1fc425c03a66", - "metadata": {}, - "outputs": [], - "source": [ - "import functorch\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "import torchopt" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", - "metadata": {}, - "source": [ - "## 1. Functional API\n", - "\n", - "The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures. Users are required to implement the stationary conditions for the inner-loop process, which will be used as the input of custom_root decorator. We show the pseudo code in the following part.\n", - "\n", - "```python\n", - "# Functional API for implicit gradient\n", - "def stationary(params, meta_params, data):\n", - " # stationary condition construction\n", - " return stationary condition\n", - "\n", - "# Decorator that wraps the function\n", - "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", - "@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver)\n", - "def solve(params, meta_params, data):\n", - " # Forward optimization process for params\n", - " return optimal_params\n", - "\n", - "# Define params, meta_params and get data\n", - "params, meta_prams, data = ..., ..., ...\n", - "optimal_params = solve(params, meta_params, data)\n", - "loss = outer_loss(optimal_params)\n", - "\n", - "meta_grads = torch.autograd.grad(loss, meta_params)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "dbef87df-2164-4f1d-8919-37a6fbdc5011", - "metadata": {}, - "source": [ - "Here we use the example of [iMAML](https://arxiv.org/abs/1909.04630) as a real example. For iMAML, the inner-loop objective is described by the following equation.\n", - "\n", - "$$\n", - "{\\mathcal{Alg}}^{\\star} \\left( \\boldsymbol{\\theta}, \\mathcal{D}_{i}^{\\text{tr}} \\right) = \\underset{\\phi'}{\\operatorname{\\arg \\min}} ~ G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\triangleq \\mathcal{L} \\left( \\boldsymbol{\\phi}', \\mathcal{D}_{i}^{\\text{tr}} \\right) + \\frac{\\lambda}{2} {\\left\\| \\boldsymbol{\\phi}' - \\boldsymbol{\\theta} \\right\\|}^{2}\n", - "$$\n", - "\n", - "According to this function, we can define the forward function `inner_solver`, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.\n", - "\n", - "$$\n", - "{\\left. \\nabla_{\\boldsymbol{\\phi}'} G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\right\\rvert}_{\\boldsymbol{\\phi}' = \\boldsymbol{\\phi}^{\\star}} = 0\n", - "$$\n", - "\n", - "Thus we can define the optimality function by defining `imaml_objective` and make it first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out `functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated by the `@torchopt.diff.implicit.custom_root` decorator and the optimality condition we define." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", - "metadata": {}, - "outputs": [], - "source": [ - "# Inner-loop objective function\n", - "# The optimality function: grad(imaml_objective)\n", - "def imaml_objective(params, meta_params, data):\n", - " x, y, fmodel = data\n", - " y_pred = fmodel(params, x)\n", - " regularization_loss = 0.0\n", - " for p1, p2 in zip(params, meta_params):\n", - " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", - " loss = F.mse_loss(y_pred, y) + regularization_loss\n", - " return loss\n", - "\n", - "\n", - "# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by\n", - "# specifying argnums=0 in functorch.grad) the argnums=1 specify which meta-parameter we want to\n", - "# backpropogate, in this case we want to backpropogate to the initial parameters so we set it as 1.\n", - "# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta-parameters\n", - "\n", - "\n", - "# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of\n", - "# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.\n", - "# torchopt.linear_solve.solve_normal_cg specify that we use the conjugate gradient based linear solver\n", - "@torchopt.diff.implicit.custom_root(\n", - " functorch.grad(imaml_objective, argnums=0), # optimality function\n", - " argnums=1,\n", - " solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", - ")\n", - "def inner_solver(params, meta_params, data):\n", - " # Initial functional optimizer based on TorchOpt\n", - " x, y, fmodel = data\n", - " optimizer = torchopt.sgd(lr=2e-2)\n", - " opt_state = optimizer.init(params)\n", - " with torch.enable_grad():\n", - " # Temporarily enable gradient computation for conducting the optimization\n", - " for i in range(100):\n", - " pred = fmodel(params, x)\n", - " loss = F.mse_loss(pred, y) # compute loss\n", - "\n", - " # Compute regularization loss\n", - " regularization_loss = 0.0\n", - " for p1, p2 in zip(params, meta_params):\n", - " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", - " final_loss = loss + regularization_loss\n", - "\n", - " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", - " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", - " params = torchopt.apply_updates(params, updates, inplace=True)\n", - "\n", - " optimal_params = params\n", - " return optimal_params\n", - "\n", - "\n", - "# torchopt.linear_solve.solve_inv specify that we use the Neumann Series inversion linear solver\n", - "@torchopt.diff.implicit.custom_root(\n", - " functorch.grad(imaml_objective, argnums=0), # optimality function\n", - " argnums=1,\n", - " solve=torchopt.linear_solve.solve_inv(ns=True, maxiter=100, alpha=0.1),\n", - ")\n", - "def inner_solver_inv_ns(params, meta_params, data):\n", - " # Initial functional optimizer based on TorchOpt\n", - " x, y, fmodel = data\n", - " optimizer = torchopt.sgd(lr=2e-2)\n", - " opt_state = optimizer.init(params)\n", - " with torch.enable_grad():\n", - " # Temporarily enable gradient computation for conducting the optimization\n", - " for i in range(100):\n", - " pred = fmodel(params, x)\n", - " loss = F.mse_loss(pred, y) # compute loss\n", - "\n", - " # Compute regularization loss\n", - " regularization_loss = 0.0\n", - " for p1, p2 in zip(params, meta_params):\n", - " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", - " final_loss = loss + regularization_loss\n", - "\n", - " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", - " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", - " params = torchopt.apply_updates(params, updates, inplace=True)\n", - "\n", - " optimal_params = params\n", - " return optimal_params" - ] - }, - { - "cell_type": "markdown", - "id": "32a75c81-d479-4120-a73d-5b2b488358d0", - "metadata": {}, - "source": [ - "In the next step, we consider a specific case for one layer neural network to fit the linear data." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "fb95538b-1fd9-4ec8-9f57-6360bedc05b7", - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(0)\n", - "x = torch.randn(20, 4)\n", - "w = torch.randn(4, 1)\n", - "b = torch.randn(1)\n", - "y = x @ w + b + 0.5 * torch.randn(20, 1)" - ] - }, - { - "cell_type": "markdown", - "id": "eeb1823a-2231-4471-bb68-cce7724f2578", - "metadata": {}, - "source": [ - "We instantiate an one layer neural network, where the weights and bias are initialized with constant." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "d50a7bfe-ac69-4089-8cf8-3cbd69d6d4e7", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "class Net(nn.Module):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, 1, bias=True)\n", - " nn.init.ones_(self.fc.weight)\n", - " nn.init.zeros_(self.fc.bias)\n", - "\n", - " def forward(self, x):\n", - " return self.fc(x)\n", - "\n", - "\n", - "model = Net(4)\n", - "fmodel, meta_params = functorch.make_functional(model)\n", - "data = (x, y, fmodel)\n", - "\n", - "\n", - "# Clone function for parameters\n", - "def clone(params):\n", - " cloned = []\n", - " for item in params:\n", - " if isinstance(item, torch.Tensor):\n", - " cloned.append(item.clone().detach_().requires_grad_(True))\n", - " else:\n", - " cloned.append(item)\n", - " return tuple(cloned)" - ] - }, - { - "cell_type": "markdown", - "id": "065c36c4-89e2-4a63-8213-63db6ee3b08e", - "metadata": {}, - "source": [ - "We take the forward process by calling out the forward function, then we pass the optimal params into the outer-loop loss function." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "115e79c6-911f-4743-a2ed-e50a71c3a813", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "optimal_params = inner_solver(clone(meta_params), meta_params, data)\n", - "\n", - "outer_loss = fmodel(optimal_params, x).mean()" - ] - }, - { - "cell_type": "markdown", - "id": "e2812351-f635-496e-9732-c80831ac04a6", - "metadata": {}, - "source": [ - "Finally, we can get the meta-gradient as shown below." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "6bdcbe8d-2336-4f80-b124-eb43c5a2fc0a", - "metadata": {}, - "outputs": [ + "cells" : [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" - ] - } - ], - "source": [ - "torch.autograd.grad(outer_loss, meta_params)" - ] - }, - { - "cell_type": "markdown", - "id": "926ae8bb", - "metadata": {}, - "source": [ - "Also we can switch to the Neumann Series inversion linear solver." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "43df0374", - "metadata": {}, - "outputs": [ + "cell_type" : "markdown", + "id" : "8850c832-3b54-4971-8ee0-2cd64b585ea8", + "metadata" : {}, + "source" : ["# TorchOpt for Implicit Differentiation"] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" - ] - } - ], - "source": [ - "optimal_params = inner_solver_inv_ns(clone(meta_params), meta_params, data)\n", - "outer_loss = fmodel(optimal_params, x).mean()\n", - "torch.autograd.grad(outer_loss, meta_params)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "c92e67ea-b220-4a14-a1ea-4eb3c5f52b6b", - "metadata": {}, - "source": [ - "## 2. OOP API\n", - "\n", - "The basic OOP class is the class `ImplicitMetaGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation. We show the pseudo code in the following part.\n", - "\n", - "```python\n", - "from torchopt.nn import ImplicitMetaGradientModule\n", - "\n", - "# Inherited from the class ImplicitMetaGradientModule\n", - "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", - "class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):\n", - " def __init__(self, meta_module):\n", - " ...\n", - "\n", - " def forward(self, batch):\n", - " # Forward process\n", - " ...\n", - "\n", - " def optimality(self, batch, labels):\n", - " # Stationary condition construction for calculating implicit gradient\n", - " # NOTE: If this method is not implemented, it will be automatically derived from the\n", - " # gradient of the `objective` function.\n", - " ...\n", - "\n", - " def objective(self, batch, labels):\n", - " # Define the inner-loop optimization objective\n", - " # NOTE: This method is optional if method `optimality` is implemented.\n", - " ...\n", - "\n", - " def solve(self, batch, labels):\n", - " # Conduct the inner-loop optimization\n", - " ...\n", - " return self # optimized module\n", - "\n", - "# Get meta_params and data\n", - "meta_params, data = ..., ...\n", - "inner_net = InnerNet()\n", - "\n", - "# Solve for inner-loop process related to the meta-parameters\n", - "optimal_inner_net = inner_net.solve(meta_params, *data)\n", - "\n", - "# Get outer-loss and solve for meta-gradient\n", - "loss = outer_loss(optimal_inner_net)\n", - "meta_grad = torch.autograd.grad(loss, meta_params)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "62fbe520-11d0-41ff-9b0a-c6508b1d01cf", - "metadata": {}, - "source": [ - "The class `ImplicitMetaGradientModule` is to enable the gradient flow from `self.parameters()` to `self.meta_parameters()`. In `__init__` function, users need to define the inner parameters and meta-parameters. By default, `ImplicitMetaGradientModule` treats all tensors and modules from input as `self.meta_parameters()`, and all tensors and modules defined in the `__init__` are regarded as `self.parameters()`. Users can also register `self.parameters()` and `self.meta_parameters()` by calling `self.register_parameter()` and `self.register_meta_parameter()` respectively." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "c3999684-f4d3-4bc0-86ab-a7e803b2fe80", - "metadata": {}, - "outputs": [ + "cell_type" : "markdown", + "id" : "2b547376", + "metadata" : {}, + "source" : ["[](https://" + "colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/" + "5_Implicit_Differentiation.ipynb)"] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" - ] - } - ], - "source": [ - "class InnerNet(\n", - " torchopt.nn.ImplicitMetaGradientModule,\n", - " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", - "):\n", - " def __init__(self, meta_net, n_inner_iter, reg_param):\n", - " super().__init__()\n", - " # Declaration of the meta-parameter\n", - " self.meta_net = meta_net\n", - " # Get a deepcopy, register inner-parameter\n", - " self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)\n", - " self.n_inner_iter = n_inner_iter\n", - " self.reg_param = reg_param\n", - "\n", - " def forward(self, x):\n", - " return self.net(x)\n", - "\n", - " def objective(self, x, y):\n", - " # We do not implement the optimality conditions, so it will be automatically derived from\n", - " # the gradient of the `objective` function.\n", - " y_pred = self(x)\n", - " loss = F.mse_loss(y_pred, y)\n", - " regularization_loss = 0\n", - " for p1, p2 in zip(\n", - " self.parameters(), # parameters of `self.net`\n", - " self.meta_parameters(), # parameters of `self.meta_net`\n", - " ):\n", - " regularization_loss += (\n", - " 0.5 * self.reg_param * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", - " )\n", - " return loss + regularization_loss\n", - "\n", - " def solve(self, x, y):\n", - " params = tuple(self.parameters())\n", - " inner_optim = torchopt.SGD(params, lr=2e-2)\n", - " with torch.enable_grad():\n", - " # Temporarily enable gradient computation for conducting the optimization\n", - " for _ in range(self.n_inner_iter):\n", - " loss = self.objective(x, y)\n", - " inner_optim.zero_grad()\n", - " # NOTE: The parameter inputs should be explicitly specified in `backward` function\n", - " # as argument `inputs`. Otherwise, if not provided, the gradient is accumulated into\n", - " # all the leaf Tensors (including the meta-parameters) that were used to compute the\n", - " # objective output. Alternatively, please use `torch.autograd.grad` instead.\n", - " loss.backward(inputs=params) # backward pass in inner-loop\n", - " inner_optim.step() # update inner parameters\n", - " return self\n", - "\n", - "\n", - "# Initialize the meta-network\n", - "meta_net = Net(4)\n", - "inner_net = InnerNet(meta_net, 100, reg_param=1)\n", - "\n", - "# Solve for inner-loop\n", - "optimal_inner_net = inner_net.solve(x, y)\n", - "outer_loss = optimal_inner_net(x).mean()\n", - "\n", - "# Derive the meta-gradient\n", - "torch.autograd.grad(outer_loss, meta_net.parameters())" - ] - }, - { - "cell_type": "markdown", - "id": "2b69a5d6-b5e4-4f08-af0a-40afc2382b45", - "metadata": {}, - "source": [ - "We also show an example on how to implement implicit gradient calculation when the inner-level optimal solution reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [DEQ](https://arxiv.org/abs/1909.01377), based on the OOP API. " - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "de87c308-d847-4491-9aa1-bc393e6dd1d8", - "metadata": {}, - "outputs": [ + "cell_type" : "markdown", + "id" : "8d7f9865-dc02-43d4-be90-da1160c4e4dd", + "metadata" : {}, + "source" : + ["By treating the solution $\\phi^{\\star}$ as an implicit function of $\\theta$, the " + "idea of implicit differentiation is to directly get analytical best-response " + "derivatives $\\partial \\phi^{\\star}(\\theta)/ \\partial \\theta$ by implicit " + "function theorem. This is suitable for algorithms when the inner-level optimal " + "solution is achieved ${\\left. \\frac{\\partial F (\\phi, \\theta)}{\\partial \\phi} " + "\\right\\rvert}_{\\phi = \\phi^{\\star}} = 0$ or reaches some stationary conditions $F " + "(\\phi^{\\star}, \\theta) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and " + "[DEQ](https://arxiv.org/abs/1909.01377)."] + }, + { + "cell_type" : "markdown", + "id" : "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", + "metadata" : {}, + "source" : ["In this tutorial, we will introduce how TorchOpt can be used to conduct " + "implicit differentiation."] + }, + { + "cell_type" : "code", + "execution_count" : 1, + "id" : "8f13ae67-e328-409f-84a8-1fc425c03a66", + "metadata" : {}, + "outputs" : [], + "source" : [ + "import functorch\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torchopt" + ] + }, + { + "attachments" : {}, + "cell_type" : "markdown", + "id" : "0cdaac49-4b94-4900-9bb5-a39057ac8b21", + "metadata" : {}, + "source" : [ + "## 1. Functional API\n", + "\n", + "The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the " + "decorator for the forward process implicit gradient procedures. Users are required to " + "implement the stationary conditions for the inner-loop process, which will be used as the " + "input of custom_root decorator. We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "# Functional API for implicit gradient\n", + "def stationary(params, meta_params, data):\n", + " # stationary condition construction\n", + " return stationary condition\n", + "\n", + "# Decorator that wraps the function\n", + "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", + "@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver)\n", + "def solve(params, meta_params, data):\n", + " # Forward optimization process for params\n", + " return optimal_params\n", + "\n", + "# Define params, meta_params and get data\n", + "params, meta_prams, data = ..., ..., ...\n", + "optimal_params = solve(params, meta_params, data)\n", + "loss = outer_loss(optimal_params)\n", + "\n", + "meta_grads = torch.autograd.grad(loss, meta_params)\n", + "```" + ] + }, + { + "cell_type" : "markdown", + "id" : "dbef87df-2164-4f1d-8919-37a6fbdc5011", + "metadata" : {}, + "source" : [ + "Here we use the example of [iMAML](https://arxiv.org/abs/1909.04630) as a real example. " + "For iMAML, the inner-loop objective is described by the following equation.\n", + "\n", + "$$\n", + "{\\mathcal{Alg}}^{\\star} \\left( \\boldsymbol{\\theta}, \\mathcal{D}_{i}^{\\text{tr}} " + "\\right) = \\underset{\\phi'}{\\operatorname{\\arg \\min}} ~ G \\left( " + "\\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\triangleq \\mathcal{L} \\left( " + "\\boldsymbol{\\phi}', \\mathcal{D}_{i}^{\\text{tr}} \\right) + \\frac{\\lambda}{2} " + "{\\left\\| \\boldsymbol{\\phi}' - \\boldsymbol{\\theta} \\right\\|}^{2}\n", + "$$\n", + "\n", + "According to this function, we can define the forward function `inner_solver`, where we " + "solve this equation based on sufficient gradient descents. For such inner-loop process, " + "the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.\n", + "\n", + "$$\n", + "{\\left. \\nabla_{\\boldsymbol{\\phi}'} G \\left( \\boldsymbol{\\phi}', " + "\\boldsymbol{\\theta} \\right) \\right\\rvert}_{\\boldsymbol{\\phi}' = " + "\\boldsymbol{\\phi}^{\\star}} = 0\n", + "$$\n", + "\n", + "Thus we can define the optimality function by defining `imaml_objective` and make it " + "first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out " + "`functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated " + "by the `@torchopt.diff.implicit.custom_root` decorator and the optimality condition we " + "define." + ] + }, + { + "cell_type" : "code", + "execution_count" : 2, + "id" : "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", + "metadata" : {}, + "outputs" : [], + "source" : [ + "# Inner-loop objective function\n", + "# The optimality function: grad(imaml_objective)\n", + "def imaml_objective(params, meta_params, data):\n", + " x, y, fmodel = data\n", + " y_pred = fmodel(params, x)\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " loss = F.mse_loss(y_pred, y) + regularization_loss\n", + " return loss\n", + "\n", + "\n", + "# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve " + "so by\n", + "# specifying argnums=0 in functorch.grad) the argnums=1 specify which meta-parameter we " + "want to\n", + "# backpropogate, in this case we want to backpropogate to the initial parameters so we " + "set it as 1.\n", + "# You can also set argnums as (1, 2) if you want to backpropogate through multiple " + "meta-parameters\n", + "\n", + "\n", + "# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient " + "of\n", + "# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.\n", + "# torchopt.linear_solve.solve_normal_cg specify that we use the conjugate gradient based " + "linear solver\n", + "@torchopt.diff.implicit.custom_root(\n", + " functorch.grad(imaml_objective, argnums=0), # optimality function\n", + " argnums=1,\n", + " solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + ")\n", + "def inner_solver(params, meta_params, data):\n", + " # Initial functional optimizer based on TorchOpt\n", + " x, y, fmodel = data\n", + " optimizer = torchopt.sgd(lr=2e-2)\n", + " opt_state = optimizer.init(params)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for i in range(100):\n", + " pred = fmodel(params, x)\n", + " loss = F.mse_loss(pred, y) # compute loss\n", + "\n", + " # Compute regularization loss\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - " + "p2.view(-1)))\n", + " final_loss = loss + regularization_loss\n", + "\n", + " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get " + "updates\n", + " params = torchopt.apply_updates(params, updates, inplace=True)\n", + "\n", + " optimal_params = params\n", + " return optimal_params\n", + "\n", + "\n", + "# torchopt.linear_solve.solve_inv specify that we use the Neumann Series inversion linear " + "solver\n", + "@torchopt.diff.implicit.custom_root(\n", + " functorch.grad(imaml_objective, argnums=0), # optimality function\n", + " argnums=1,\n", + " solve=torchopt.linear_solve.solve_inv(ns=True, maxiter=100, alpha=0.1),\n", + ")\n", + "def inner_solver_inv_ns(params, meta_params, data):\n", + " # Initial functional optimizer based on TorchOpt\n", + " x, y, fmodel = data\n", + " optimizer = torchopt.sgd(lr=2e-2)\n", + " opt_state = optimizer.init(params)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for i in range(100):\n", + " pred = fmodel(params, x)\n", + " loss = F.mse_loss(pred, y) # compute loss\n", + "\n", + " # Compute regularization loss\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - " + "p2.view(-1)))\n", + " final_loss = loss + regularization_loss\n", + "\n", + " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get " + "updates\n", + " params = torchopt.apply_updates(params, updates, inplace=True)\n", + "\n", + " optimal_params = params\n", + " return optimal_params" + ] + }, + { + "cell_type" : "markdown", + "id" : "32a75c81-d479-4120-a73d-5b2b488358d0", + "metadata" : {}, + "source" : ["In the next step, we consider a specific case for one layer neural network to " + "fit the linear data."] + }, + { + "cell_type" : "code", + "execution_count" : 3, + "id" : "fb95538b-1fd9-4ec8-9f57-6360bedc05b7", + "metadata" : {}, + "outputs" : [], + "source" : [ + "torch.manual_seed(0)\n", + "x = torch.randn(20, 4)\n", + "w = torch.randn(4, 1)\n", + "b = torch.randn(1)\n", + "y = x @ w + b + 0.5 * torch.randn(20, 1)" + ] + }, + { + "cell_type" : "markdown", + "id" : "eeb1823a-2231-4471-bb68-cce7724f2578", + "metadata" : {}, + "source" : ["We instantiate an one layer neural network, where the weights and bias are " + "initialized with constant."] + }, + { + "cell_type" : "code", + "execution_count" : 4, + "id" : "d50a7bfe-ac69-4089-8cf8-3cbd69d6d4e7", + "metadata" : {"tags" : []}, + "outputs" : [], + "source" : [ + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1, bias=True)\n", + " nn.init.ones_(self.fc.weight)\n", + " nn.init.zeros_(self.fc.bias)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "model = Net(4)\n", + "fmodel, meta_params = functorch.make_functional(model)\n", + "data = (x, y, fmodel)\n", + "\n", + "\n", + "# Clone function for parameters\n", + "def clone(params):\n", + " cloned = []\n", + " for item in params:\n", + " if isinstance(item, torch.Tensor):\n", + " cloned.append(item.clone().detach_().requires_grad_(True))\n", + " else:\n", + " cloned.append(item)\n", + " return tuple(cloned)" + ] + }, + { + "cell_type" : "markdown", + "id" : "065c36c4-89e2-4a63-8213-63db6ee3b08e", + "metadata" : {}, + "source" : ["We take the forward process by calling out the forward function, then we pass " + "the optimal params into the outer-loop loss function."] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "(\n", - "│ tensor([[ 0.0272, 0.0031, -0.0156, -0.0238],\n", - "│ │ [ 0.1004, 0.0113, -0.0573, -0.0878],\n", - "│ │ [ 0.0666, 0.0075, -0.0380, -0.0583],\n", - "│ │ [ 0.1446, 0.0163, -0.0826, -0.1265]]),\n", - "│ tensor([0.0574, 0.2114, 0.1403, 0.3046])\n", - ")\n" - ] + "cell_type" : "code", + "execution_count" : 5, + "id" : "115e79c6-911f-4743-a2ed-e50a71c3a813", + "metadata" : {"tags" : []}, + "outputs" : [], + "source" : [ + "optimal_params = inner_solver(clone(meta_params), meta_params, data)\n", + "\n", + "outer_loss = fmodel(optimal_params, x).mean()" + ] + }, + { + "cell_type" : "markdown", + "id" : "e2812351-f635-496e-9732-c80831ac04a6", + "metadata" : {}, + "source" : ["Finally, we can get the meta-gradient as shown below."] + }, + { + "cell_type" : "code", + "execution_count" : 6, + "id" : "6bdcbe8d-2336-4f80-b124-eb43c5a2fc0a", + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : ["(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n"] + } ], + "source" : ["torch.autograd.grad(outer_loss, meta_params)"] + }, + { + "cell_type" : "markdown", + "id" : "926ae8bb", + "metadata" : {}, + "source" : ["Also we can switch to the Neumann Series inversion linear solver."] + }, + { + "cell_type" : "code", + "execution_count" : 7, + "id" : "43df0374", + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : ["(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n"] + } ], + "source" : [ + "optimal_params = inner_solver_inv_ns(clone(meta_params), meta_params, data)\n", + "outer_loss = fmodel(optimal_params, x).mean()\n", + "torch.autograd.grad(outer_loss, meta_params)" + ] + }, + { + "attachments" : {}, + "cell_type" : "markdown", + "id" : "c92e67ea-b220-4a14-a1ea-4eb3c5f52b6b", + "metadata" : {}, + "source" : [ + "## 2. OOP API\n", + "\n", + "The basic OOP class is the class `ImplicitMetaGradientModule`. We make the network as an " + "`nn.Module` following a classical PyTorch style. Users need to define the stationary " + "condition/objective function and the inner-loop solve function to enable implicit " + "gradient computation. We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "from torchopt.nn import ImplicitMetaGradientModule\n", + "\n", + "# Inherited from the class ImplicitMetaGradientModule\n", + "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", + "class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):\n", + " def __init__(self, meta_module):\n", + " ...\n", + "\n", + " def forward(self, batch):\n", + " # Forward process\n", + " ...\n", + "\n", + " def optimality(self, batch, labels):\n", + " # Stationary condition construction for calculating implicit gradient\n", + " # NOTE: If this method is not implemented, it will be automatically derived from " + "the\n", + " # gradient of the `objective` function.\n", + " ...\n", + "\n", + " def objective(self, batch, labels):\n", + " # Define the inner-loop optimization objective\n", + " # NOTE: This method is optional if method `optimality` is implemented.\n", + " ...\n", + "\n", + " def solve(self, batch, labels):\n", + " # Conduct the inner-loop optimization\n", + " ...\n", + " return self # optimized module\n", + "\n", + "# Get meta_params and data\n", + "meta_params, data = ..., ...\n", + "inner_net = InnerNet()\n", + "\n", + "# Solve for inner-loop process related to the meta-parameters\n", + "optimal_inner_net = inner_net.solve(meta_params, *data)\n", + "\n", + "# Get outer-loss and solve for meta-gradient\n", + "loss = outer_loss(optimal_inner_net)\n", + "meta_grad = torch.autograd.grad(loss, meta_params)\n", + "```" + ] + }, + { + "cell_type" : "markdown", + "id" : "62fbe520-11d0-41ff-9b0a-c6508b1d01cf", + "metadata" : {}, + "source" : + ["The class `ImplicitMetaGradientModule` is to enable the gradient flow from " + "`self.parameters()` to `self.meta_parameters()`. In `__init__` function, users need to " + "define the inner parameters and meta-parameters. By default, " + "`ImplicitMetaGradientModule` treats all tensors and modules from input as " + "`self.meta_parameters()`, and all tensors and modules defined in the `__init__` are " + "regarded as `self.parameters()`. Users can also register `self.parameters()` and " + "`self.meta_parameters()` by calling `self.register_parameter()` and " + "`self.register_meta_parameter()` respectively."] + }, + { + "cell_type" : "code", + "execution_count" : 8, + "id" : "c3999684-f4d3-4bc0-86ab-a7e803b2fe80", + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : ["(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n"] + } ], + "source" : [ + "class InnerNet(\n", + " torchopt.nn.ImplicitMetaGradientModule,\n", + " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + "):\n", + " def __init__(self, meta_net, n_inner_iter, reg_param):\n", + " super().__init__()\n", + " # Declaration of the meta-parameter\n", + " self.meta_net = meta_net\n", + " # Get a deepcopy, register inner-parameter\n", + " self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)\n", + " self.n_inner_iter = n_inner_iter\n", + " self.reg_param = reg_param\n", + "\n", + " def forward(self, x):\n", + " return self.net(x)\n", + "\n", + " def objective(self, x, y):\n", + " # We do not implement the optimality conditions, so it will be automatically " + "derived from\n", + " # the gradient of the `objective` function.\n", + " y_pred = self(x)\n", + " loss = F.mse_loss(y_pred, y)\n", + " regularization_loss = 0\n", + " for p1, p2 in zip(\n", + " self.parameters(), # parameters of `self.net`\n", + " self.meta_parameters(), # parameters of `self.meta_net`\n", + " ):\n", + " regularization_loss += (\n", + " 0.5 * self.reg_param * torch.sum(torch.square(p1.view(-1) - " + "p2.view(-1)))\n", + " )\n", + " return loss + regularization_loss\n", + "\n", + " def solve(self, x, y):\n", + " params = tuple(self.parameters())\n", + " inner_optim = torchopt.SGD(params, lr=2e-2)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for _ in range(self.n_inner_iter):\n", + " loss = self.objective(x, y)\n", + " inner_optim.zero_grad()\n", + " # NOTE: The parameter inputs should be explicitly specified in `backward` " + "function\n", + " # as argument `inputs`. Otherwise, if not provided, the gradient is " + "accumulated into\n", + " # all the leaf Tensors (including the meta-parameters) that were used to " + "compute the\n", + " # objective output. Alternatively, please use `torch.autograd.grad` " + "instead.\n", + " loss.backward(inputs=params) # backward pass in inner-loop\n", + " inner_optim.step() # update inner parameters\n", + " return self\n", + "\n", + "\n", + "# Initialize the meta-network\n", + "meta_net = Net(4)\n", + "inner_net = InnerNet(meta_net, 100, reg_param=1)\n", + "\n", + "# Solve for inner-loop\n", + "optimal_inner_net = inner_net.solve(x, y)\n", + "outer_loss = optimal_inner_net(x).mean()\n", + "\n", + "# Derive the meta-gradient\n", + "torch.autograd.grad(outer_loss, meta_net.parameters())" + ] + }, + { + "cell_type" : "markdown", + "id" : "2b69a5d6-b5e4-4f08-af0a-40afc2382b45", + "metadata" : {}, + "source" : ["We also show an example on how to implement implicit gradient calculation when " + "the inner-level optimal solution reaches some stationary conditions $F " + "(\\phi^{\\star}, \\theta) = 0$, such as " + "[DEQ](https://arxiv.org/abs/1909.01377), based on the OOP API. "] + }, + { + "cell_type" : "code", + "execution_count" : 9, + "id" : "de87c308-d847-4491-9aa1-bc393e6dd1d8", + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : [ + "\n", + "(\n", + "│ tensor([[ 0.0272, 0.0031, -0.0156, -0.0238],\n", + "│ │ [ 0.1004, 0.0113, -0.0573, -0.0878],\n", + "│ │ [ 0.0666, 0.0075, -0.0380, -0.0583],\n", + "│ │ [ 0.1446, 0.0163, -0.0826, -0.1265]]),\n", + "│ tensor([0.0574, 0.2114, 0.1403, 0.3046])\n", + ")\n" + ] + } ], + "source" : [ + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, dim)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "class InnerNet(\n", + " torchopt.nn.ImplicitMetaGradientModule,\n", + " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + "):\n", + " def __init__(self, meta_net, x0):\n", + " super().__init__()\n", + " # Register meta-parameter\n", + " self.meta_net = meta_net\n", + " # Declaration of the inner-parameter, register inner-parameter\n", + " self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)\n", + "\n", + " def forward(self, x):\n", + " return self.meta_net(x)\n", + "\n", + " def optimality(self):\n", + " # Fixed-point condition\n", + " return (self.x - self(self.x),)\n", + "\n", + " def solve(self):\n", + " # Solving inner-loop fixed-point iteration\n", + " # This is just an illustrating example for solving fixed-point iteration\n", + " # one can use more advanced method to solve fixed-point iteration\n", + " # such as anderson acceleration.\n", + " for _ in range(10):\n", + " self.x.copy_(self(self.x))\n", + " return self\n", + "\n", + "\n", + "# Initialize meta-network\n", + "torch.manual_seed(0)\n", + "meta_net = Net(4)\n", + "x0 = torch.randn(1, 4)\n", + "inner_net = InnerNet(meta_net, x0)\n", + "\n", + "# Solve for inner-loop\n", + "optimal_inner_net = inner_net.solve()\n", + "outer_loss = optimal_inner_net.x.mean()\n", + "\n", + "# Derive the meta-gradient\n", + "torch.autograd.grad(outer_loss, meta_net.parameters())" + ] + } + ], + "metadata" : { + "kernelspec" : + {"display_name" : "Python 3 (ipykernel)", "language" : "python", "name" : "python3"}, + "language_info" : { + "codemirror_mode" : {"name" : "ipython", "version" : 3}, + "file_extension" : ".py", + "mimetype" : "text/x-python", + "name" : "python", + "nbconvert_exporter" : "python", + "pygments_lexer" : "ipython3", + "version" : "3.9.15" + }, + "vscode" : { + "interpreter" : + {"hash" : "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da"} } - ], - "source": [ - "class Net(nn.Module):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, dim)\n", - "\n", - " def forward(self, x):\n", - " return self.fc(x)\n", - "\n", - "\n", - "class InnerNet(\n", - " torchopt.nn.ImplicitMetaGradientModule,\n", - " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", - "):\n", - " def __init__(self, meta_net, x0):\n", - " super().__init__()\n", - " # Register meta-parameter\n", - " self.meta_net = meta_net\n", - " # Declaration of the inner-parameter, register inner-parameter\n", - " self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)\n", - "\n", - " def forward(self, x):\n", - " return self.meta_net(x)\n", - "\n", - " def optimality(self):\n", - " # Fixed-point condition\n", - " return (self.x - self(self.x),)\n", - "\n", - " def solve(self):\n", - " # Solving inner-loop fixed-point iteration\n", - " # This is just an illustrating example for solving fixed-point iteration\n", - " # one can use more advanced method to solve fixed-point iteration\n", - " # such as anderson acceleration.\n", - " for _ in range(10):\n", - " self.x.copy_(self(self.x))\n", - " return self\n", - "\n", - "\n", - "# Initialize meta-network\n", - "torch.manual_seed(0)\n", - "meta_net = Net(4)\n", - "x0 = torch.randn(1, 4)\n", - "inner_net = InnerNet(meta_net, x0)\n", - "\n", - "# Solve for inner-loop\n", - "optimal_inner_net = inner_net.solve()\n", - "outer_loss = optimal_inner_net.x.mean()\n", - "\n", - "# Derive the meta-gradient\n", - "torch.autograd.grad(outer_loss, meta_net.parameters())" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" }, - "vscode": { - "interpreter": { - "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "nbformat" : 4, + "nbformat_minor" : 5 } diff --git a/tutorials/6_Zero_Order_Differentiation.ipynb b/tutorials/6_Zero_Order_Differentiation.ipynb index d6cb028c..d6e2a4e0 100644 --- a/tutorials/6_Zero_Order_Differentiation.ipynb +++ b/tutorials/6_Zero_Order_Differentiation.ipynb @@ -1,356 +1,390 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", - "metadata": {}, - "source": [ - "# TorchOpt for Zero-Order Differentiation" - ] - }, - { - "cell_type": "markdown", - "id": "2b547376", - "metadata": {}, - "source": [ - "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/6_Zero_Order_Differentiation.ipynb)" - ] - }, - { - "cell_type": "markdown", - "id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd", - "metadata": {}, - "source": [ - "When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose ZD. ZD typically gets gradients based on zero-order estimation, such as finite-difference, or Evolutionary Strategy.\n", - "\n", - "TorchOpt offers API for ES-based differentiation. Instead of optimizing the objective $f (\\boldsymbol{\\theta}): \\mathbb{R}^n \\to \\mathbb{R}$, ES optimizes a Gaussion smoothing objective defined as $\\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\mathbb{E}_{\\boldsymbol{z} \\sim \\mathcal{N}( 0, {I}_d )} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) ]$, where $\\sigma$ denotes precision. The gradient of such objective is $\\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\mathcal{N}( 0, {I}_d )} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) \\cdot \\boldsymbol{z} ]$. Refer to [ES-MAML](https://arxiv.org/pdf/1910.01215.pdf) for more details." - ] - }, - { - "cell_type": "markdown", - "id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", - "metadata": {}, - "source": [ - "In this tutorial, we will introduce how TorchOpt can be used to ES-based differentiation." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "8f13ae67-e328-409f-84a8-1fc425c03a66", - "metadata": {}, - "outputs": [], - "source": [ - "import functorch\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "import torchopt" - ] - }, - { - "cell_type": "markdown", - "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", - "metadata": {}, - "source": [ - "## 1. Functional API\n", - "\n", - "The basic functional API is `torchopt.diff.zero_order.zero_order`, which is used as the decorator for the forward process zero-order gradient procedures. Users are required to implement the noise sampling function, which will be used as the input of zero_order decorator. Here we show the specific meaning for each parameter used in the decorator.\n", - "\n", - "- `distribution` for noise sampling distribution. The distribution $\\lambda$ should be spherical symmetric and with a constant variance of $1$ for each element. I.e.:\n", - "\n", - " - Spherical symmetric: $\\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ \\boldsymbol{z} ] = \\boldsymbol{0}$.\n", - " - Constant variance of $1$ for each element: $\\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ {\\lvert z_i \\rvert}^2 ] = 1$.\n", - " - For example, the standard multi-dimensional normal distribution $\\mathcal{N} (\\boldsymbol{0}, \\boldsymbol{1})$.\n", - "\n", - "- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://arxiv.org/abs/1803.07055)).\n", - "\n", - " $$\n", - " \\begin{align*}\n", - " \\text{naive} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) \\cdot \\boldsymbol{z} ] \\\\\n", - " \\text{forward} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ ( f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) - f (\\boldsymbol{\\theta}) ) \\cdot \\boldsymbol{z} ] \\\\\n", - " \\text{antithetic} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{2 \\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ (f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) - f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) ) \\cdot \\boldsymbol{z} ]\n", - " \\end{align*}\n", - " $$\n", - "\n", - "- `argnums` specifies which parameter we want to trace the meta-gradient.\n", - "- `num_samples` specifies how many times we want to conduct the sampling.\n", - "- `sigma` is for precision. This is the scaling factor for the sampling distribution.\n", - "\n", - "We show the pseudo code in the following part.\n", - "\n", - "```python\n", - "# Functional API for zero-order differentiation\n", - "# 1. Customize the noise distribution via a distribution class\n", - "class Distribution:\n", - " def sample(self, sample_shape=torch.Size()):\n", - " # Sampling function for noise\n", - " # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n", - " ...\n", - " return noise_batch\n", - "\n", - "distribution = Distribution()\n", - "\n", - "# 2. Customize the noise distribution via a sampling function\n", - "def distribution(sample_shape=torch.Size()):\n", - " # Sampling function for noise\n", - " # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n", - " ...\n", - " return noise_batch\n", - "\n", - "# 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`\n", - "distribution = torch.distributions.Normal(loc=0, scale=1)\n", - "\n", - "# Decorator that wraps the function\n", - "@torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, num_samples=100, sigma=0.01)\n", - "def forward(params, data):\n", - " # Forward optimization process for params\n", - " ...\n", - " return objective # the returned tensor should be a scalar tensor\n", - "\n", - "# Define params and get data\n", - "params, data = ..., ...\n", - "\n", - "# Forward pass\n", - "loss = forward(params, data)\n", - "# Backward pass using zero-order differentiation\n", - "grads = torch.autograd.grad(loss, params)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "dbef87df-2164-4f1d-8919-37a6fbdc5011", - "metadata": {}, - "source": [ - "Here we use the example of a linear layer as an example, note that this is just an example to show linear layer can work with ES." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", - "metadata": {}, - "outputs": [ + "cells" : [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "001: tensor(0.0265, grad_fn=)\n", - "002: tensor(0.0243, grad_fn=)\n", - "003: tensor(0.0222, grad_fn=)\n", - "004: tensor(0.0202, grad_fn=)\n", - "005: tensor(0.0184, grad_fn=)\n", - "006: tensor(0.0170, grad_fn=)\n", - "007: tensor(0.0157, grad_fn=)\n", - "008: tensor(0.0146, grad_fn=)\n", - "009: tensor(0.0137, grad_fn=)\n", - "010: tensor(0.0130, grad_fn=)\n", - "011: tensor(0.0123, grad_fn=)\n", - "012: tensor(0.0118, grad_fn=)\n", - "013: tensor(0.0114, grad_fn=)\n", - "014: tensor(0.0111, grad_fn=)\n", - "015: tensor(0.0111, grad_fn=)\n", - "016: tensor(0.0111, grad_fn=)\n", - "017: tensor(0.0113, grad_fn=)\n", - "018: tensor(0.0115, grad_fn=)\n", - "019: tensor(0.0118, grad_fn=)\n", - "020: tensor(0.0120, grad_fn=)\n", - "021: tensor(0.0121, grad_fn=)\n", - "022: tensor(0.0121, grad_fn=)\n", - "023: tensor(0.0122, grad_fn=)\n", - "024: tensor(0.0122, grad_fn=)\n", - "025: tensor(0.0122, grad_fn=)\n" - ] - } - ], - "source": [ - "torch.random.manual_seed(0)\n", - "\n", - "fmodel, params = functorch.make_functional(nn.Linear(32, 1))\n", - "x = torch.randn(64, 32) * 0.1\n", - "y = torch.randn(64, 1) * 0.1\n", - "distribution = torch.distributions.Normal(loc=0, scale=1)\n", - "\n", - "\n", - "@torchopt.diff.zero_order(\n", - " distribution=distribution, method='forward', argnums=0, num_samples=100, sigma=0.01\n", - ")\n", - "def forward_process(params, fn, x, y):\n", - " y_pred = fn(params, x)\n", - " loss = F.mse_loss(y_pred, y)\n", - " return loss\n", - "\n", - "\n", - "optimizer = torchopt.adam(lr=0.01)\n", - "opt_state = optimizer.init(params) # init optimizer\n", - "\n", - "for i in range(25):\n", - " loss = forward_process(params, fmodel, x, y) # compute loss\n", - "\n", - " grads = torch.autograd.grad(loss, params) # compute gradients\n", - " updates, opt_state = optimizer.update(grads, opt_state) # get updates\n", - " params = torchopt.apply_updates(params, updates) # update network parameters\n", - "\n", - " print(f'{i + 1:03d}: {loss!r}')" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "db723f6b", - "metadata": {}, - "source": [ - "## 2. OOP API\n", - "\n", - "The basic OOP API is the class `ZeroOrderGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the forward process zero-order gradient procedures `forward()` and a noise sampling function `sample()`. Here we show the specific meaning for each parameter used in the class.\n", - "\n", - "- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://d1wqtxts1xzle7.cloudfront.net/75609515/coredp2011_1web-with-cover-page-v2.pdf?Expires=1670215467&Signature=RfP~mQhhhI7aGknwXbRBgSggFrKuNTPYdyUSdMmfTxOa62QoOJAm-Xhr3F1PLyjUQc2JVxmKIKGGuyYvyfCTpB31dfmMtuVQxZMWVF-SfErTN05SliC93yjA1x1g2kjhn8bkBFdQqGl~1RQSKnhj88BakgSeDNzyCxwbD5VgR89BXRs4YIK5RBIKYtgLhoyz5jar7wHS3TJhRzs3WNeTIAjAmLqJ068oGFZ0Jr7maGquTe3w~8LEEIprJ6cyCMc6b1UUJkmwjNq0RLTVbxgFjfi4Z9kyxyJB9IOS1J25OOON4jfwh5JlXS7MVskuONUyHJim1TQ8OwCraKlBsQLPQw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA)).\n", - "- `num_samples` specifies how many times we want to conduct the sampling.\n", - "- `sigma` is for precision. This is the scaling factor for the sampling distribution.\n", - "\n", - "We show the pseudo code in the following part.\n", - "\n", - "```python\n", - "from torchopt.nn import ZeroOrderGradientModule\n", - "\n", - "# Inherited from the class ZeroOrderGradientModule\n", - "# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling\n", - "class Net(ZeroOrderGradientModule, method='naive', num_samples=100, sigma=0.01):\n", - " def __init__(self, ...):\n", - " ...\n", - "\n", - " def forward(self, batch):\n", - " # Forward process\n", - " ...\n", - " return objective # the returned tensor should be a scalar tensor\n", - "\n", - " def sample(self, sample_shape=torch.Size()):\n", - " # Generate a batch of noise samples\n", - " # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n", - " ...\n", - " return noise_batch\n", - "\n", - "# Get model and data\n", - "net = Net(...)\n", - "data = ...\n", - "\n", - "# Forward pass\n", - "loss = Net(data)\n", - "# Backward pass using zero-order differentiation\n", - "grads = torch.autograd.grad(loss, net.parameters())\n", - "```" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "b53524f5", - "metadata": {}, - "source": [ - "Here we reimplement the functional API example above with the OOP API." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "ecc5730c", - "metadata": {}, - "outputs": [ + "cell_type" : "markdown", + "id" : "8850c832-3b54-4971-8ee0-2cd64b585ea8", + "metadata" : {}, + "source" : ["# TorchOpt for Zero-Order Differentiation"] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "001: tensor(0.0201, grad_fn=)\n", - "002: tensor(0.0181, grad_fn=)\n", - "003: tensor(0.0167, grad_fn=)\n", - "004: tensor(0.0153, grad_fn=)\n", - "005: tensor(0.0142, grad_fn=)\n", - "006: tensor(0.0133, grad_fn=)\n", - "007: tensor(0.0125, grad_fn=)\n", - "008: tensor(0.0119, grad_fn=)\n", - "009: tensor(0.0116, grad_fn=)\n", - "010: tensor(0.0114, grad_fn=)\n", - "011: tensor(0.0112, grad_fn=)\n", - "012: tensor(0.0112, grad_fn=)\n", - "013: tensor(0.0113, grad_fn=)\n", - "014: tensor(0.0116, grad_fn=)\n", - "015: tensor(0.0118, grad_fn=)\n", - "016: tensor(0.0121, grad_fn=)\n", - "017: tensor(0.0123, grad_fn=)\n", - "018: tensor(0.0125, grad_fn=)\n", - "019: tensor(0.0127, grad_fn=)\n", - "020: tensor(0.0127, grad_fn=)\n", - "021: tensor(0.0125, grad_fn=)\n", - "022: tensor(0.0123, grad_fn=)\n", - "023: tensor(0.0120, grad_fn=)\n", - "024: tensor(0.0118, grad_fn=)\n", - "025: tensor(0.0117, grad_fn=)\n" - ] + "cell_type" : "markdown", + "id" : "2b547376", + "metadata" : {}, + "source" : ["[](https://" + "colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/" + "6_Zero_Order_Differentiation.ipynb)"] + }, + { + "cell_type" : "markdown", + "id" : "8d7f9865-dc02-43d4-be90-da1160c4e4dd", + "metadata" : {}, + "source" : [ + "When the inner-loop process is non-differentiable or one wants to eliminate the heavy " + "computation burdens in the previous two modes (brought by Hessian), one can choose ZD. ZD " + "typically gets gradients based on zero-order estimation, such as finite-difference, or " + "Evolutionary Strategy.\n", + "\n", + "TorchOpt offers API for ES-based differentiation. Instead of optimizing the objective $f " + "(\\boldsymbol{\\theta}): \\mathbb{R}^n \\to \\mathbb{R}$, ES optimizes a Gaussion " + "smoothing objective defined as $\\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = " + "\\mathbb{E}_{\\boldsymbol{z} \\sim \\mathcal{N}( 0, {I}_d )} [ f (\\boldsymbol{\\theta} + " + "\\sigma \\, \\boldsymbol{z}) ]$, where $\\sigma$ denotes precision. The gradient of such " + "objective is $\\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} " + "(\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim " + "\\mathcal{N}( 0, {I}_d )} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) " + "\\cdot \\boldsymbol{z} ]$. Refer to [ES-MAML](https://arxiv.org/pdf/1910.01215.pdf) for " + "more details." + ] + }, + { + "cell_type" : "markdown", + "id" : "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", + "metadata" : {}, + "source" : ["In this tutorial, we will introduce how TorchOpt can be used to ES-based " + "differentiation."] + }, + { + "cell_type" : "code", + "execution_count" : 1, + "id" : "8f13ae67-e328-409f-84a8-1fc425c03a66", + "metadata" : {}, + "outputs" : [], + "source" : [ + "import functorch\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torchopt" + ] + }, + { + "cell_type" : "markdown", + "id" : "0cdaac49-4b94-4900-9bb5-a39057ac8b21", + "metadata" : {}, + "source" : [ + "## 1. Functional API\n", + "\n", + "The basic functional API is `torchopt.diff.zero_order.zero_order`, which is used as the " + "decorator for the forward process zero-order gradient procedures. Users are required to " + "implement the noise sampling function, which will be used as the input of zero_order " + "decorator. Here we show the specific meaning for each parameter used in the decorator.\n", + "\n", + "- `distribution` for noise sampling distribution. The distribution $\\lambda$ should be " + "spherical symmetric and with a constant variance of $1$ for each element. I.e.:\n", + "\n", + " - Spherical symmetric: $\\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ " + "\\boldsymbol{z} ] = \\boldsymbol{0}$.\n", + " - Constant variance of $1$ for each element: $\\mathbb{E}_{\\boldsymbol{z} \\sim " + "\\lambda} [ {\\lvert z_i \\rvert}^2 ] = 1$.\n", + " - For example, the standard multi-dimensional normal distribution $\\mathcal{N} " + "(\\boldsymbol{0}, \\boldsymbol{1})$.\n", + "\n", + "- `method` for different kind of algorithms, we support `'naive'` " + "([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` " + "([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and " + "`'antithetic'` ([antithetic](https://arxiv.org/abs/1803.07055)).\n", + "\n", + " $$\n", + " \\begin{align*}\n", + " \\text{naive} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} " + "(\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} " + "[ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) \\cdot \\boldsymbol{z} ] \\\\\n", + " \\text{forward} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} " + "(\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} " + "[ ( f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) - f (\\boldsymbol{\\theta}) ) " + "\\cdot \\boldsymbol{z} ] \\\\\n", + " \\text{antithetic} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} " + "(\\boldsymbol{\\theta}) = \\frac{1}{2 \\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim " + "\\lambda} [ (f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) - f " + "(\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) ) \\cdot \\boldsymbol{z} ]\n", + " \\end{align*}\n", + " $$\n", + "\n", + "- `argnums` specifies which parameter we want to trace the meta-gradient.\n", + "- `num_samples` specifies how many times we want to conduct the sampling.\n", + "- `sigma` is for precision. This is the scaling factor for the sampling distribution.\n", + "\n", + "We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "# Functional API for zero-order differentiation\n", + "# 1. Customize the noise distribution via a distribution class\n", + "class Distribution:\n", + " def sample(self, sample_shape=torch.Size()):\n", + " # Sampling function for noise\n", + " # NOTE: The distribution should be spherical symmetric and with a constant " + "variance of 1.\n", + " ...\n", + " return noise_batch\n", + "\n", + "distribution = Distribution()\n", + "\n", + "# 2. Customize the noise distribution via a sampling function\n", + "def distribution(sample_shape=torch.Size()):\n", + " # Sampling function for noise\n", + " # NOTE: The distribution should be spherical symmetric and with a constant variance " + "of 1.\n", + " ...\n", + " return noise_batch\n", + "\n", + "# 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., " + "`torch.distributions.Normal(...)`\n", + "distribution = torch.distributions.Normal(loc=0, scale=1)\n", + "\n", + "# Decorator that wraps the function\n", + "@torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, " + "num_samples=100, sigma=0.01)\n", + "def forward(params, data):\n", + " # Forward optimization process for params\n", + " ...\n", + " return objective # the returned tensor should be a scalar tensor\n", + "\n", + "# Define params and get data\n", + "params, data = ..., ...\n", + "\n", + "# Forward pass\n", + "loss = forward(params, data)\n", + "# Backward pass using zero-order differentiation\n", + "grads = torch.autograd.grad(loss, params)\n", + "```" + ] + }, + { + "cell_type" : "markdown", + "id" : "dbef87df-2164-4f1d-8919-37a6fbdc5011", + "metadata" : {}, + "source" : ["Here we use the example of a linear layer as an example, note that this is just " + "an example to show linear layer can work with ES."] + }, + { + "cell_type" : "code", + "execution_count" : 2, + "id" : "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : [ + "001: tensor(0.0265, grad_fn=)\n", + "002: tensor(0.0243, grad_fn=)\n", + "003: tensor(0.0222, grad_fn=)\n", + "004: tensor(0.0202, grad_fn=)\n", + "005: tensor(0.0184, grad_fn=)\n", + "006: tensor(0.0170, grad_fn=)\n", + "007: tensor(0.0157, grad_fn=)\n", + "008: tensor(0.0146, grad_fn=)\n", + "009: tensor(0.0137, grad_fn=)\n", + "010: tensor(0.0130, grad_fn=)\n", + "011: tensor(0.0123, grad_fn=)\n", + "012: tensor(0.0118, grad_fn=)\n", + "013: tensor(0.0114, grad_fn=)\n", + "014: tensor(0.0111, grad_fn=)\n", + "015: tensor(0.0111, grad_fn=)\n", + "016: tensor(0.0111, grad_fn=)\n", + "017: tensor(0.0113, grad_fn=)\n", + "018: tensor(0.0115, grad_fn=)\n", + "019: tensor(0.0118, grad_fn=)\n", + "020: tensor(0.0120, grad_fn=)\n", + "021: tensor(0.0121, grad_fn=)\n", + "022: tensor(0.0121, grad_fn=)\n", + "023: tensor(0.0122, grad_fn=)\n", + "024: tensor(0.0122, grad_fn=)\n", + "025: tensor(0.0122, grad_fn=)\n" + ] + } ], + "source" : [ + "torch.random.manual_seed(0)\n", + "\n", + "fmodel, params = functorch.make_functional(nn.Linear(32, 1))\n", + "x = torch.randn(64, 32) * 0.1\n", + "y = torch.randn(64, 1) * 0.1\n", + "distribution = torch.distributions.Normal(loc=0, scale=1)\n", + "\n", + "\n", + "@torchopt.diff.zero_order(\n", + " distribution=distribution, method='forward', argnums=0, num_samples=100, sigma=0.01\n", + ")\n", + "def forward_process(params, fn, x, y):\n", + " y_pred = fn(params, x)\n", + " loss = F.mse_loss(y_pred, y)\n", + " return loss\n", + "\n", + "\n", + "optimizer = torchopt.adam(lr=0.01)\n", + "opt_state = optimizer.init(params) # init optimizer\n", + "\n", + "for i in range(25):\n", + " loss = forward_process(params, fmodel, x, y) # compute loss\n", + "\n", + " grads = torch.autograd.grad(loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state) # get updates\n", + " params = torchopt.apply_updates(params, updates) # update network parameters\n", + "\n", + " print(f'{i + 1:03d}: {loss!r}')" + ] + }, + { + "attachments" : {}, + "cell_type" : "markdown", + "id" : "db723f6b", + "metadata" : {}, + "source" : [ + "## 2. OOP API\n", + "\n", + "The basic OOP API is the class `ZeroOrderGradientModule`. We make the network as an " + "`nn.Module` following a classical PyTorch style. Users need to define the forward process " + "zero-order gradient procedures `forward()` and a noise sampling function `sample()`. Here " + "we show the specific meaning for each parameter used in the class.\n", + "\n", + "- `method` for different kind of algorithms, we support `'naive'` " + "([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` " + "([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and " + "`'antithetic'` " + "([antithetic](https://d1wqtxts1xzle7.cloudfront.net/75609515/" + "coredp2011_1web-with-cover-page-v2.pdf?Expires=1670215467&Signature=RfP~" + "mQhhhI7aGknwXbRBgSggFrKuNTPYdyUSdMmfTxOa62QoOJAm-" + "Xhr3F1PLyjUQc2JVxmKIKGGuyYvyfCTpB31dfmMtuVQxZMWVF-SfErTN05SliC93yjA1x1g2kjhn8bkBFdQqGl~" + "1RQSKnhj88BakgSeDNzyCxwbD5VgR89BXRs4YIK5RBIKYtgLhoyz5jar7wHS3TJhRzs3WNeTIAjAmLqJ068oGFZ0Jr" + "7maGquTe3w~" + "8LEEIprJ6cyCMc6b1UUJkmwjNq0RLTVbxgFjfi4Z9kyxyJB9IOS1J25OOON4jfwh5JlXS7MVskuONUyHJim1TQ8OwC" + "raKlBsQLPQw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA)).\n", + "- `num_samples` specifies how many times we want to conduct the sampling.\n", + "- `sigma` is for precision. This is the scaling factor for the sampling distribution.\n", + "\n", + "We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "from torchopt.nn import ZeroOrderGradientModule\n", + "\n", + "# Inherited from the class ZeroOrderGradientModule\n", + "# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling\n", + "class Net(ZeroOrderGradientModule, method='naive', num_samples=100, sigma=0.01):\n", + " def __init__(self, ...):\n", + " ...\n", + "\n", + " def forward(self, batch):\n", + " # Forward process\n", + " ...\n", + " return objective # the returned tensor should be a scalar tensor\n", + "\n", + " def sample(self, sample_shape=torch.Size()):\n", + " # Generate a batch of noise samples\n", + " # NOTE: The distribution should be spherical symmetric and with a constant " + "variance of 1.\n", + " ...\n", + " return noise_batch\n", + "\n", + "# Get model and data\n", + "net = Net(...)\n", + "data = ...\n", + "\n", + "# Forward pass\n", + "loss = Net(data)\n", + "# Backward pass using zero-order differentiation\n", + "grads = torch.autograd.grad(loss, net.parameters())\n", + "```" + ] + }, + { + "attachments" : {}, + "cell_type" : "markdown", + "id" : "b53524f5", + "metadata" : {}, + "source" : ["Here we reimplement the functional API example above with the OOP API."] + }, + { + "cell_type" : "code", + "execution_count" : 3, + "id" : "ecc5730c", + "metadata" : {}, + "outputs" : [ { + "name" : "stdout", + "output_type" : "stream", + "text" : [ + "001: tensor(0.0201, grad_fn=)\n", + "002: tensor(0.0181, grad_fn=)\n", + "003: tensor(0.0167, grad_fn=)\n", + "004: tensor(0.0153, grad_fn=)\n", + "005: tensor(0.0142, grad_fn=)\n", + "006: tensor(0.0133, grad_fn=)\n", + "007: tensor(0.0125, grad_fn=)\n", + "008: tensor(0.0119, grad_fn=)\n", + "009: tensor(0.0116, grad_fn=)\n", + "010: tensor(0.0114, grad_fn=)\n", + "011: tensor(0.0112, grad_fn=)\n", + "012: tensor(0.0112, grad_fn=)\n", + "013: tensor(0.0113, grad_fn=)\n", + "014: tensor(0.0116, grad_fn=)\n", + "015: tensor(0.0118, grad_fn=)\n", + "016: tensor(0.0121, grad_fn=)\n", + "017: tensor(0.0123, grad_fn=)\n", + "018: tensor(0.0125, grad_fn=)\n", + "019: tensor(0.0127, grad_fn=)\n", + "020: tensor(0.0127, grad_fn=)\n", + "021: tensor(0.0125, grad_fn=)\n", + "022: tensor(0.0123, grad_fn=)\n", + "023: tensor(0.0120, grad_fn=)\n", + "024: tensor(0.0118, grad_fn=)\n", + "025: tensor(0.0117, grad_fn=)\n" + ] + } ], + "source" : [ + "torch.random.manual_seed(0)\n", + "\n", + "\n", + "class Net(torchopt.nn.ZeroOrderGradientModule, method='forward', num_samples=100, " + "sigma=0.01):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1)\n", + " self.distribution = torch.distributions.Normal(loc=0, scale=1)\n", + "\n", + " def forward(self, x, y):\n", + " y_pred = self.fc(x)\n", + " loss = F.mse_loss(y_pred, y)\n", + " return loss\n", + "\n", + " def sample(self, sample_shape=torch.Size()):\n", + " return self.distribution.sample(sample_shape)\n", + "\n", + "\n", + "x = torch.randn(64, 32) * 0.1\n", + "y = torch.randn(64, 1) * 0.1\n", + "net = Net(dim=32)\n", + "\n", + "\n", + "optimizer = torchopt.Adam(net.parameters(), lr=0.01)\n", + "\n", + "for i in range(25):\n", + " loss = net(x, y) # compute loss\n", + "\n", + " optimizer.zero_grad()\n", + " loss.backward() # backward pass\n", + " optimizer.step() # update network parameters\n", + "\n", + " print(f'{i + 1:03d}: {loss!r}')" + ] + } + ], + "metadata" : { + "kernelspec" : + {"display_name" : "Python 3.9.15 ('torchopt')", "language" : "python", "name" : "python3"}, + "language_info" : { + "codemirror_mode" : {"name" : "ipython", "version" : 3}, + "file_extension" : ".py", + "mimetype" : "text/x-python", + "name" : "python", + "nbconvert_exporter" : "python", + "pygments_lexer" : "ipython3", + "version" : "3.9.15" + }, + "vscode" : { + "interpreter" : + {"hash" : "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da"} } - ], - "source": [ - "torch.random.manual_seed(0)\n", - "\n", - "\n", - "class Net(torchopt.nn.ZeroOrderGradientModule, method='forward', num_samples=100, sigma=0.01):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, 1)\n", - " self.distribution = torch.distributions.Normal(loc=0, scale=1)\n", - "\n", - " def forward(self, x, y):\n", - " y_pred = self.fc(x)\n", - " loss = F.mse_loss(y_pred, y)\n", - " return loss\n", - "\n", - " def sample(self, sample_shape=torch.Size()):\n", - " return self.distribution.sample(sample_shape)\n", - "\n", - "\n", - "x = torch.randn(64, 32) * 0.1\n", - "y = torch.randn(64, 1) * 0.1\n", - "net = Net(dim=32)\n", - "\n", - "\n", - "optimizer = torchopt.Adam(net.parameters(), lr=0.01)\n", - "\n", - "for i in range(25):\n", - " loss = net(x, y) # compute loss\n", - "\n", - " optimizer.zero_grad()\n", - " loss.backward() # backward pass\n", - " optimizer.step() # update network parameters\n", - "\n", - " print(f'{i + 1:03d}: {loss!r}')" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.9.15 ('torchopt')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" }, - "vscode": { - "interpreter": { - "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "nbformat" : 4, + "nbformat_minor" : 5 }