Skip to content

Commit

Permalink
Fix generator bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti committed Sep 8, 2024
1 parent 0e0c1c8 commit 9458e42
Showing 1 changed file with 28 additions and 26 deletions.
54 changes: 28 additions & 26 deletions stingray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,18 +1436,20 @@ def stream_from_gti_lists(
fmt : str
The format of the output files. Default is 'hdf5'.
Returns
-------
Yields
------
output_files : list of str
A list of the output file names.
"""

if only_attrs is not None and root_file_name is not None:
raise ValueError("You can only use only_attrs with a generator.")
new_gti_lists = np.asanyarray(new_gti_lists)
if len(new_gti_lists[0]) == len(self.gti) and np.all(
np.abs(np.asanyarray(new_gti_lists[0]).flatten() - self.gti.flatten()) < 1e-3
):
logger.info("No change of GTI")
if only_attrs is not None:
yield [copy.deepcopy(getattr(self, attr)) for attr in only_attrs]
else:
Expand All @@ -1458,29 +1460,29 @@ def stream_from_gti_lists(
output_file = root_file_name + f"_00." + fmt.lstrip(".")
ev.write(output_file, fmt=fmt)
yield output_file
else:
for i, gti in enumerate(new_gti_lists):
if len(gti) == 0:
continue

for i, gti in enumerate(new_gti_lists):
if len(gti) == 0:
continue

lower_edge, upper_edge = self.get_idx_from_time_range(gti[0, 0], gti[-1, 1])
lower_edge, upper_edge = self.get_idx_from_time_range(gti[0, 0], gti[-1, 1])

if only_attrs is not None:
yield [
copy.deepcopy(getattr(self, attr)[lower_edge : upper_edge + 1])
for attr in only_attrs
]
else:
ev = self[lower_edge : upper_edge + 1]
ev.gti = gti

if root_file_name is not None:
new_file = root_file_name + f"_{i:002d}." + fmt.lstrip(".")
logger.info(f"Writing {new_file}")
ev.write(new_file, fmt=fmt)
yield new_file
if only_attrs is not None:
yield [
copy.deepcopy(getattr(self, attr)[lower_edge : upper_edge + 1])
for attr in only_attrs
]
else:
yield ev
ev = self[lower_edge : upper_edge + 1]
ev.gti = gti

if root_file_name is not None:
new_file = root_file_name + f"_{i:002d}." + fmt.lstrip(".")
logger.info(f"Writing {new_file}")
ev.write(new_file, fmt=fmt)
yield new_file
else:
yield ev

def stream_by_number_of_samples(
self, nsamples, root_file_name=None, fmt=DEFAULT_FORMAT, only_attrs=None
Expand All @@ -1501,8 +1503,8 @@ def stream_by_number_of_samples(
fmt : str
The format of the output files. Default is 'hdf5'.
Returns
-------
Yields
------
output_files : list of str
A list of the output file names.
"""
Expand Down Expand Up @@ -1533,8 +1535,8 @@ def stream_from_time_intervals(
fmt : str
The format of the output files. Default is 'hdf5'.
Returns
-------
Yields
------
output_files : list of str
A list of the output file names.
"""
Expand Down

0 comments on commit 9458e42

Please sign in to comment.