Skip to content

Commit

Permalink
Merge branch 'release/v0.4.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
speedcell4 committed Jul 30, 2021
2 parents b011197 + 89d40d3 commit 69f458e
Show file tree
Hide file tree
Showing 23 changed files with 558 additions and 1,703 deletions.
91 changes: 64 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,92 @@

![Unit Tests](https://github.com/speedcell4/torchlatent/workflows/Unit%20Tests/badge.svg)
![Upload Python Package](https://github.com/speedcell4/torchlatent/workflows/Upload%20Python%20Package/badge.svg)
[![Downloads](https://pepy.tech/badge/torchrua)](https://pepy.tech/project/torchrua)

## Requirements

- Python 3.7
- PyTorch 1.6.0
- PyTorch 1.6.0

## Installation

`python3 -m pip torchlatent`

## Quickstart
## Usage

```python
import torch
from torch.nn.utils.rnn import pack_sequence
from torchrua import pack_sequence

from torchlatent.crf import CrfDecoder

num_tags = 7
num_tags = 3
num_conjugates = 1

decoder = CrfDecoder(num_tags=num_tags)
decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates)

emissions = pack_sequence([
torch.randn((5, num_tags)),
torch.randn((2, num_tags)),
torch.randn((3, num_tags)),
], enforce_sorted=False)
emissions.data.requires_grad_(True)
torch.randn((5, num_conjugates, num_tags), requires_grad=True),
torch.randn((2, num_conjugates, num_tags), requires_grad=True),
torch.randn((3, num_conjugates, num_tags), requires_grad=True),
])

tags = pack_sequence([
torch.randint(0, num_tags, (5,)),
torch.randint(0, num_tags, (2,)),
torch.randint(0, num_tags, (3,)),
], enforce_sorted=False)

print(decoder.fit(emissions, tags, reduction='sum'))
print(decoder.decode(emissions))

# tensor(-24.1321, grad_fn=<SumBackward0>)
# PackedSequence(data=tensor([1, 3, 5, 6, 0, 2, 5, 2, 1, 1]), batch_sizes=tensor([3, 3, 2, 1, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1]))
torch.randint(0, num_tags, (5, num_conjugates)),
torch.randint(0, num_tags, (2, num_conjugates)),
torch.randint(0, num_tags, (3, num_conjugates)),
])

print(decoder.fit(emissions=emissions, tags=tags))
# tensor([[-6.7424],
# [-5.1288],
# [-2.7283]], grad_fn=<SubBackward0>)

print(decoder.decode(emissions=emissions))
# PackedSequence(data=tensor([[2],
# [0],
# [1],
# [0],
# [2],
# [0],
# [2],
# [0],
# [1],
# [2]]),
# batch_sizes=tensor([3, 3, 2, 1, 1]),
# sorted_indices=tensor([0, 2, 1]),
# unsorted_indices=tensor([0, 2, 1]))

print(decoder.marginals(emissions=emissions))
# tensor([[[0.1040, 0.1001, 0.7958]],
#
# [[0.5736, 0.0784, 0.3479]],
#
# [[0.0932, 0.8797, 0.0271]],
#
# [[0.6558, 0.0472, 0.2971]],
#
# [[0.2740, 0.1109, 0.6152]],
#
# [[0.4811, 0.2163, 0.3026]],
#
# [[0.2321, 0.3478, 0.4201]],
#
# [[0.4987, 0.1986, 0.3027]],
#
# [[0.2029, 0.5888, 0.2083]],
#
# [[0.2802, 0.2358, 0.4840]]], grad_fn=<AddBackward0>)
```

## Latent Structures and Utilities
## Latent Structures

- [x] Conditional Random Fields (CRF)
- [ ] Conditional Random Fields (CRF)
- [x] Conjugated
- [ ] Dynamic Transition Matrix
- [ ] Second-order
- [ ] Variant-order
- [ ] Tree CRF
- [ ] Non-Projective Dependency Tree (Matrix-tree Theorem)
- [ ] Probabilistic Context-free Grammars (PCFG)
- [ ] Dependency Model with Valence (DMV)

## Thanks

This library is greatly inspired by [torch-struct](https://github.com/harvardnlp/pytorch-struct).
- [ ] Dependency Model with Valence (DMV)
128 changes: 0 additions & 128 deletions benchmark.py

This file was deleted.

9 changes: 2 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name=name,
version='0.3.1',
version='0.4.0',
packages=[package for package in find_packages() if package.startswith(name)],
url='https://github.com/speedcell4/torchlatent',
license='MIT',
Expand All @@ -14,7 +14,7 @@
python_requires='>=3.7',
install_requires=[
'numpy',
'torchrua>=0.2.0',
'torchrua>=0.3.0',
],
extras_require={
'dev': [
Expand All @@ -23,10 +23,5 @@
'hypothesis',
'pytorch-crf',
],
'benchmark': [
'aku',
'tqdm',
'pytorch-crf',
]
}
)
71 changes: 42 additions & 29 deletions tests/strategies.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,69 @@
from typing import List

import torch

from hypothesis import strategies as st
from torch.nn.utils.rnn import pack_sequence

MAX_BATCH_SIZE = 7
TOTAL_LENGTH = 11
MAX_NUM_TAGS = 13
if torch.cuda.is_available():
MAX_BATCH_SIZE = 120
TINY_BATCH_SIZE = 24

MAX_TOKEN_SIZE = 512
TINY_TOKEN_SIZE = 12

MAX_NUM_TAGS = 100
MAX_NUM_CONJUGATES = 16

else:
MAX_BATCH_SIZE = 12
TINY_BATCH_SIZE = 6

MAX_TOKEN_SIZE = 24
TINY_TOKEN_SIZE = 12

MAX_NUM_TAGS = 12
MAX_NUM_CONJUGATES = 6


@st.composite
def devices(draw):
if not torch.cuda.is_available():
return torch.device('cpu')
device = torch.device('cpu')
else:
return torch.device('cuda')
device = torch.device('cuda:0')
_ = torch.empty((1,), device=device)
return device


@st.composite
def batch_size_integers(draw, max_batch_size: int = MAX_BATCH_SIZE):
return draw(st.integers(min_value=1, max_value=max_batch_size))
def batch_sizes(draw, max_value: int = MAX_BATCH_SIZE):
return draw(st.integers(min_value=1, max_value=max_value))


@st.composite
def length_integers(draw, total_length: int = TOTAL_LENGTH):
return draw(st.integers(min_value=1, max_value=total_length))
def batch_size_lists(draw, max_batch_size: int = MAX_BATCH_SIZE):
return [
draw(batch_sizes(max_value=max_batch_size))
for _ in range(draw(batch_sizes(max_value=max_batch_size)))
]


@st.composite
def length_lists(draw, total_length: int = TOTAL_LENGTH, batch_sizes: int = MAX_BATCH_SIZE):
return draw(st.lists(length_integers(total_length=total_length), min_size=1, max_size=batch_sizes))
def token_sizes(draw, max_value: int = MAX_TOKEN_SIZE):
return draw(st.integers(min_value=1, max_value=max_value))


@st.composite
def num_tags_integers(draw, max_num_tags: int = MAX_NUM_TAGS):
return draw(st.integers(min_value=1, max_value=max_num_tags))



def token_size_lists(draw, max_token_size: int = MAX_TOKEN_SIZE, max_batch_size: int = MAX_BATCH_SIZE):
return [
draw(token_sizes(max_value=max_token_size))
for _ in range(draw(batch_sizes(max_value=max_batch_size)))
]


@st.composite
def tags_packs(draw, lengths: List[int], num_tags: int):
return pack_sequence([
torch.randint(0, num_tags, (length,), device=draw(devices()))
for length in lengths
], enforce_sorted=False)
def tag_sizes(draw, max_value: int = MAX_NUM_TAGS):
return draw(st.integers(min_value=1, max_value=max_value))


@st.composite
def conjugated_tags_packs(draw, lengths: List[int], num_tags: int, num_conjugates: int):
return pack_sequence([
torch.randint(0, num_tags, (length, num_conjugates), device=draw(devices()))
for length in lengths
], enforce_sorted=False)
def conjugate_sizes(draw, max_value: int = MAX_NUM_CONJUGATES):
return draw(st.integers(min_value=1, max_value=max_value))
Loading

0 comments on commit 69f458e

Please sign in to comment.