Releases: shyamsn97/hyper-nn
Releases · shyamsn97/hyper-nn
Refactored Models
Made it easier to create customized hypernets by removing embedding module + weight generator concepts from base hypernetwork.
Before:
class TorchHyperNetwork(nn.Module, HyperNetwork):
def __init__(
self,
target_network: nn.Module,
num_target_parameters: Optional[int] = None,
embedding_dim: int = 100,
num_embeddings: int = 3,
weight_chunk_dim: Optional[int] = None,
custom_embedding_module: Optional[nn.Module] = None,
custom_weight_generator: Optional[nn.Module] = None,
):
After:
class TorchHyperNetwork(nn.Module, HyperNetwork):
def __init__(
self,
target_network: nn.Module,
num_target_parameters: Optional[int] = None,
):
Updating Functional Wrapper
- fixes some functional wrapper bugs
Refactoring forward apis
- Refactored forward apis so they require minimal modifications to existing pipelines. This makes it easier to replace any
nn.Module
with a hypernetwork and use it almost exactly how the target is originally used - Specifically, removed
inp
keyword and instead just takes in*args, **kwargs
. In addition, to allow for specificgenerate_params
keywords, an optional dict of arguments can be provided throughgenerate_params_kwargs
.
For standard usage:
output = hypernetwork(inp=[inp]) # old
-> output = hypernetwork(inp) # new
For dynamic hypernetworks:
output = dynamic_hypernetwork(inp, hidden_state=torch.zeros((1,32))) # old
-> output = dynamic_hypernetwork(inp, generate_params_kwargs=dict(hidden_state=torch.zeros((1,32)))) # new
Added additional documentation
- Added detailed documentation
- rearranged functions for better readability
Changed Default behavior of jax hypernet
- Jax hypernets had
has_aux=True
by default, when it should be false
Adding sources to Pypi
0.1.1 bump version with new sources
First Release
- Basic Hypernetworks that work for generic Pytorch & Flax Modules
- Dynamic Hypernetworks that adapt their weights with input