Skip to content

Commit

Permalink
feat: get a float encoding from the full value
Browse files Browse the repository at this point in the history
  • Loading branch information
jannisborn committed Nov 16, 2023
1 parent 7596e9b commit 0ed6b80
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions terminator/numerical_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,34 @@ def get_float_encoding(
return vals / (vmax / 10)


def get_full_float_encoding(
value: float, embedding_size: int, vmax: float = 1.0
) -> Tensor:
"""
Convert a float value into a _fixed_ embedding vector.
Args:
value: The float value to be encoded.
embedding_size: The size of the embedding.
vmax: Maximal value the `value` variable can take. This normalizes values
to be in the range ~ [-10, 10]. NOTE: If remaining nn.embeddings in
model use `max_norm`, this might result in large range discrepancies.
Returns:
torch.Tensor of shape (embedding_size, ) containing the embedding.
"""
if embedding_size % 2 != 0:
raise ValueError(f"Embedding size {embedding_size} cant be odd.")
integer = int(value)
decimal = value - integer
scalar = integer * 10**decimal
embedding = torch.zeros((embedding_size,))
for i in range(0, embedding_size, 2):
embedding[i] = scalar / (i + 1)
embedding[i + 1] = -scalar / (i + 1)
return embedding


def get_int_encoding(token: str, embedding_size: int) -> torch.Tensor:
"""Convert a token representing an integer into a _fixed_ embedding vector.
NOTE: This can be used only for positive integers - the generation of the
Expand Down

0 comments on commit 0ed6b80

Please sign in to comment.