Skip to content

Commit

Permalink
changed default behavior for jax hypernets + functorch vmap example
Browse files Browse the repository at this point in the history
  • Loading branch information
shyamsn97 committed May 31, 2022
1 parent 95a59c8 commit 61e26e5
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 23 deletions.
80 changes: 68 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

**Note: This library is experimental and currently under development - the flax implementations in particular are far from perfect and can be improved. If you have any suggestions on how to improve this library, please open a github issue or feel free to reach out directly!**

`hyper-nn` gives users with the ability to create easily customizable [Hypernetworks](https://arxiv.org/abs/1609.09106) for almost any generic `torch.nn.Module` from [Pytorch](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) and `flax.linen.Module` from [Flax](https://flax.readthedocs.io/en/latest/flax.linen.html). Our Hypernetwork objects are also `torch.nn.Modules` and `flax.linen.Modules`, allowing for easy integration with existing systems
`hyper-nn` gives users with the ability to create easily customizable [Hypernetworks](https://arxiv.org/abs/1609.09106) for almost any generic `torch.nn.Module` from [Pytorch](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) and `flax.linen.Module` from [Flax](https://flax.readthedocs.io/en/latest/flax.linen.html). Our Hypernetwork objects are also `torch.nn.Modules` and `flax.linen.Modules`, allowing for easy integration with existing systems. For Pytorch, we make use of the amazing library [`functorch`](https://github.com/pytorch/functorch)

<p align="center">Generating Policy Weights for Lunar Lander</p>

Expand Down Expand Up @@ -68,16 +68,16 @@ The main classes to use are `TorchHyperNetwork` and `JaxHyperNetwork` and those
```python
import torch.nn as nn

# static hypernetwork
from hypernn.torch.hypernet import TorchHyperNetwork

# any module
target_network = nn.Sequential(
nn.Linear(32, 64),
nn.ReLU(),
nn.Linear(64, 32)
)

# static hypernetwork
from hypernn.torch.hypernet import TorchHyperNetwork

EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

Expand Down Expand Up @@ -107,6 +107,9 @@ import flax.linen as nn
import jax.numpy as jnp
from jax import random

# static hypernetwork
from hypernn.jax.dynamic_hypernet import JaxHyperNetwork

# any module
target_network = nn.Sequential(
[
Expand All @@ -116,9 +119,6 @@ target_network = nn.Sequential(
]
)

# static hypernetwork
from hypernn.jax.hypernet import JaxHyperNetwork

EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

Expand All @@ -130,9 +130,9 @@ hypernetwork = JaxHyperNetwork.from_target(
)

# now we can use the hypernetwork like any other nn.Module
inp = jnp.zeros((1, 32)
inp = jnp.zeros((1, 32))
key = random.PRNGKey(0)
hypernetwork_params = hypernetwork.init(key, inp=[inp)]) # flax needs to initialize hypernetwork parameters first
hypernetwork_params = hypernetwork.init(key, inp=[inp]) # flax needs to initialize hypernetwork parameters first

# by default we only output what we'd expect from the target network
output = hypernetwork.apply(hypernetwork_params, inp=[inp])
Expand All @@ -143,10 +143,55 @@ output, generated_params, aux_output = hypernetwork.apply(hypernetwork_params, i
# generate params separately
generated_params, aux_output = hypernetwork.apply(hypernetwork_params, inp=[inp], method=hypernetwork.generate_params)

output = hypernetwork.apply(inp=[inp], generated_params=generated_params)
output = hypernetwork.apply(hypernetwork_params, inp=[inp], generated_params=generated_params)
```
---

## Advanced: Using vmap for batching operations
This is useful when dealing with dynamic hypernetworks that generate different params depending on inputs.

### Pytorch
```python
import torch.nn as nn
from functorch import vmap

# dynamic hypernetwork
from hypernn.torch.dynamic_hypernet import TorchDynamicHyperNetwork

# any module
target_network = nn.Sequential(
nn.Linear(8, 256),
nn.ReLU(),
nn.Linear(256, 32)
)

EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

# conditioned on input to generate param vector
hypernetwork = TorchDynamicHyperNetwork.from_target(
target_network = target_network,
embedding_dim = EMBEDDING_DIM,
num_embeddings = NUM_EMBEDDINGS,
input_dim = 8
)

# batch of 10 inputs
inp = torch.randn((10, 1, 8))

# use with a for loop
outputs = []
for i in range(10):
outputs.append(hypernetwork(inp=[inp[i]]))
outputs = torch.stack(outputs)
assert outputs.size() == (10, 1, 32)

# using vmap
outputs = vmap(hypernetwork)([inp])
assert outputs.size() == (10, 1, 32)
```


## Detailed Explanation

### EmbeddingModule
Expand Down Expand Up @@ -253,11 +298,11 @@ class HyperNetwork(metaclass=abc.ABCMeta):
```

---
### Citation
## Citing hyper-nn

If you use this software in your academic work please cite

```
```bibtex
@misc{sudhakaran2022,
author = {Sudhakaran, Shyam Sudhakaran},
title = {hyper-nn},
Expand All @@ -266,4 +311,15 @@ If you use this software in your academic work please cite
journal = {GitHub repository},
howpublished = {\url{https://github.com/shyamsn97/hyper-nn}}
}
```
---

### Projects used in hyper-nn
```bibtex
@Misc{functorch2021,
author = {Horace He, Richard Zou},
title = {functorch: JAX-like composable function transforms for PyTorch},
howpublished = {\url{https://github.com/pytorch/functorch}},
year = {2021}
}
```
4 changes: 2 additions & 2 deletions hypernn/jax/hypernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def forward(
self,
inp: Iterable[Any] = [],
generated_params: Optional[jnp.array] = None,
has_aux: bool = True,
has_aux: bool = False,
*args,
**kwargs,
) -> Tuple[jnp.array, List[jnp.array]]:
Expand All @@ -103,7 +103,7 @@ def __call__(
self,
inp: Iterable[Any] = [],
generated_params: Optional[jnp.array] = None,
has_aux: bool = True,
has_aux: bool = False,
*args,
**kwargs,
) -> Tuple[jnp.array, List[jnp.array]]:
Expand Down
16 changes: 8 additions & 8 deletions notebooks/dynamic_hypernetworks/JaxDynamicHyperRNN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -159,7 +159,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -198,7 +198,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -214,7 +214,7 @@
"193211"
]
},
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -233,7 +233,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -257,7 +257,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -266,7 +266,7 @@
"24439"
]
},
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -578,7 +578,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.8.0"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
setup(
name="hyper-nn",
packages=find_packages(exclude=('tests',)),
version="0.1.1",
version="0.1.2",
url="https://github.com/shyamsn97/hyper-nn",
license='MIT',

Expand Down

0 comments on commit 61e26e5

Please sign in to comment.