Skip to content

Commit

Permalink
Merge pull request #20 from TomMonks/dev
Browse files Browse the repository at this point in the history
PATCH: 0.3.3
  • Loading branch information
TomMonks authored Feb 7, 2024
2 parents 6f15f67 + ae1e787 commit f9357d3
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 7 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import setuptools

# patch ed - build error due to imporing numpy into __init__
# patched - build error due to imporing numpy into __init__
# from sim_tools import __version__
VERSION = "0.3.2"
VERSION = "0.3.3"


# Read in the requirements.txt file
Expand Down
2 changes: 1 addition & 1 deletion sim_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.3.2'
__version__ = '0.3.3'
__author__ = 'Thomas Monks'

from . import datasets, distributions, time_dependent, ovs
8 changes: 7 additions & 1 deletion sim_tools/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,13 @@ def sample(self, size: Optional[int] = None) -> float | np.ndarray:
Number of samples to return. If integer then
numpy array returned.
"""
return self.rng.choice(self.values, p=self.probabilities, size=size).item()
sample = self.rng.choice(self.values, p=self.probabilities, size=size)

if size is None:
return sample.item()
else:
return sample



class TruncatedDistribution(Distribution):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dists.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ def test_discrete():
assert type(d.sample()) == int


def test_discrete():
d = dists.Discrete(values=[1, 2, 3], freq=[95, 3, 2], random_seed=SEED_1)
assert type(d.sample()) == int

def test_discrete_multiple():
d = dists.Discrete(values=[1, 2, 3], freq=[95, 3, 2], random_seed=SEED_1)
assert len(d.sample(size=100)) == 100

def test_truncated_type():
d1 = dists.Normal(10, 1, random_seed=SEED_1)
Expand Down

0 comments on commit f9357d3

Please sign in to comment.