diff --git a/README.md b/README.md index 6fa1da4..3034207 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ **Note: This library is experimental and currently under development - the flax implementations in particular are far from perfect and can be improved. If you have any suggestions on how to improve this library, please open a github issue or feel free to reach out directly!** -`hyper-nn` gives users with the ability to create easily customizable [Hypernetworks](https://arxiv.org/abs/1609.09106) for almost any generic `torch.nn.Module` from [Pytorch](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) and `flax.linen.Module` from [Flax](https://flax.readthedocs.io/en/latest/flax.linen.html). Our Hypernetwork objects are also `torch.nn.Modules` and `flax.linen.Modules`, allowing for easy integration with existing systems +`hyper-nn` gives users with the ability to create easily customizable [Hypernetworks](https://arxiv.org/abs/1609.09106) for almost any generic `torch.nn.Module` from [Pytorch](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) and `flax.linen.Module` from [Flax](https://flax.readthedocs.io/en/latest/flax.linen.html). Our Hypernetwork objects are also `torch.nn.Modules` and `flax.linen.Modules`, allowing for easy integration with existing systems. For Pytorch, we make use of the amazing library [`functorch`](https://github.com/pytorch/functorch)

Generating Policy Weights for Lunar Lander

@@ -68,6 +68,9 @@ The main classes to use are `TorchHyperNetwork` and `JaxHyperNetwork` and those ```python import torch.nn as nn +# static hypernetwork +from hypernn.torch.hypernet import TorchHyperNetwork + # any module target_network = nn.Sequential( nn.Linear(32, 64), @@ -75,9 +78,6 @@ target_network = nn.Sequential( nn.Linear(64, 32) ) -# static hypernetwork -from hypernn.torch.hypernet import TorchHyperNetwork - EMBEDDING_DIM = 4 NUM_EMBEDDINGS = 32 @@ -107,6 +107,9 @@ import flax.linen as nn import jax.numpy as jnp from jax import random +# static hypernetwork +from hypernn.jax.dynamic_hypernet import JaxHyperNetwork + # any module target_network = nn.Sequential( [ @@ -116,9 +119,6 @@ target_network = nn.Sequential( ] ) -# static hypernetwork -from hypernn.jax.hypernet import JaxHyperNetwork - EMBEDDING_DIM = 4 NUM_EMBEDDINGS = 32 @@ -130,9 +130,9 @@ hypernetwork = JaxHyperNetwork.from_target( ) # now we can use the hypernetwork like any other nn.Module -inp = jnp.zeros((1, 32) +inp = jnp.zeros((1, 32)) key = random.PRNGKey(0) -hypernetwork_params = hypernetwork.init(key, inp=[inp)]) # flax needs to initialize hypernetwork parameters first +hypernetwork_params = hypernetwork.init(key, inp=[inp]) # flax needs to initialize hypernetwork parameters first # by default we only output what we'd expect from the target network output = hypernetwork.apply(hypernetwork_params, inp=[inp]) @@ -143,10 +143,55 @@ output, generated_params, aux_output = hypernetwork.apply(hypernetwork_params, i # generate params separately generated_params, aux_output = hypernetwork.apply(hypernetwork_params, inp=[inp], method=hypernetwork.generate_params) -output = hypernetwork.apply(inp=[inp], generated_params=generated_params) +output = hypernetwork.apply(hypernetwork_params, inp=[inp], generated_params=generated_params) ``` --- +## Advanced: Using vmap for batching operations +This is useful when dealing with dynamic hypernetworks that generate different params depending on inputs. + +### Pytorch +```python +import torch.nn as nn +from functorch import vmap + +# dynamic hypernetwork +from hypernn.torch.dynamic_hypernet import TorchDynamicHyperNetwork + +# any module +target_network = nn.Sequential( + nn.Linear(8, 256), + nn.ReLU(), + nn.Linear(256, 32) +) + +EMBEDDING_DIM = 4 +NUM_EMBEDDINGS = 32 + +# conditioned on input to generate param vector +hypernetwork = TorchDynamicHyperNetwork.from_target( + target_network = target_network, + embedding_dim = EMBEDDING_DIM, + num_embeddings = NUM_EMBEDDINGS, + input_dim = 8 +) + +# batch of 10 inputs +inp = torch.randn((10, 1, 8)) + +# use with a for loop +outputs = [] +for i in range(10): + outputs.append(hypernetwork(inp=[inp[i]])) +outputs = torch.stack(outputs) +assert outputs.size() == (10, 1, 32) + +# using vmap +outputs = vmap(hypernetwork)([inp]) +assert outputs.size() == (10, 1, 32) +``` + + ## Detailed Explanation ### EmbeddingModule @@ -253,11 +298,11 @@ class HyperNetwork(metaclass=abc.ABCMeta): ``` --- -### Citation +## Citing hyper-nn If you use this software in your academic work please cite -``` +```bibtex @misc{sudhakaran2022, author = {Sudhakaran, Shyam Sudhakaran}, title = {hyper-nn}, @@ -266,4 +311,15 @@ If you use this software in your academic work please cite journal = {GitHub repository}, howpublished = {\url{https://github.com/shyamsn97/hyper-nn}} } +``` +--- + +### Projects used in hyper-nn +```bibtex +@Misc{functorch2021, + author = {Horace He, Richard Zou}, + title = {functorch: JAX-like composable function transforms for PyTorch}, + howpublished = {\url{https://github.com/pytorch/functorch}}, + year = {2021} +} ``` \ No newline at end of file diff --git a/hypernn/jax/hypernet.py b/hypernn/jax/hypernet.py index aa56124..4e4d7d0 100644 --- a/hypernn/jax/hypernet.py +++ b/hypernn/jax/hypernet.py @@ -77,7 +77,7 @@ def forward( self, inp: Iterable[Any] = [], generated_params: Optional[jnp.array] = None, - has_aux: bool = True, + has_aux: bool = False, *args, **kwargs, ) -> Tuple[jnp.array, List[jnp.array]]: @@ -103,7 +103,7 @@ def __call__( self, inp: Iterable[Any] = [], generated_params: Optional[jnp.array] = None, - has_aux: bool = True, + has_aux: bool = False, *args, **kwargs, ) -> Tuple[jnp.array, List[jnp.array]]: diff --git a/notebooks/dynamic_hypernetworks/JaxDynamicHyperRNN.ipynb b/notebooks/dynamic_hypernetworks/JaxDynamicHyperRNN.ipynb index ea5333e..ef5a2e7 100644 --- a/notebooks/dynamic_hypernetworks/JaxDynamicHyperRNN.ipynb +++ b/notebooks/dynamic_hypernetworks/JaxDynamicHyperRNN.ipynb @@ -98,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -198,7 +198,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -214,7 +214,7 @@ "193211" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -233,7 +233,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -257,7 +257,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -266,7 +266,7 @@ "24439" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -578,7 +578,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.8.0" } }, "nbformat": 4, diff --git a/setup.py b/setup.py index 76460ee..fc60c01 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setup( name="hyper-nn", packages=find_packages(exclude=('tests',)), - version="0.1.1", + version="0.1.2", url="https://github.com/shyamsn97/hyper-nn", license='MIT',