-
Notifications
You must be signed in to change notification settings - Fork 524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: simplify dataset construction #4437
Conversation
for more information, see https://pre-commit.ci
@anyangml This PR adds a more detailed warning output in the case of dataset reading is throttled. |
📝 Walkthrough📝 WalkthroughWalkthroughThe changes in this pull request involve modifications to two primary files: Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant DpLoaderSet
participant BackgroundConsumer
participant BufferedIterator
User->>DpLoaderSet: Initialize with parameters
DpLoaderSet->>DpLoaderSet: Call construct_dataset
DpLoaderSet->>BackgroundConsumer: Initialize consumer
BackgroundConsumer->>BufferedIterator: Initialize iterator
BufferedIterator->>BackgroundConsumer: Start consuming data
BackgroundConsumer->>DpLoaderSet: Signal end of data loading
DpLoaderSet->>User: Print summary (if rank 0)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (3)
deepmd/pt/utils/dataloader.py (2)
5-7
: LGTM! Consider adding type hints for better code clarity.The introduction of
construct_dataset
function and the use ofpartial
improves code reusability and maintainability. Consider adding type hints to make the function signature more explicit:-def construct_dataset(system, type_map): +def construct_dataset(system: str, type_map: list[str]) -> DeepmdDataSetForLoader:Also applies to: 11-13, 59-63
225-246
: Consider making queue size and warning threshold configurable.The current implementation has hardcoded values that might not be optimal for all scenarios:
QUEUESIZE = 32
might need adjustment based on memory constraints or dataset characteristics- The 1-second warning threshold might be too aggressive for larger batches or slower storage systems
Consider making these values configurable:
-QUEUESIZE = 32 +DEFAULT_QUEUE_SIZE = 32 +DEFAULT_WARNING_THRESHOLD = 1.0 class BufferedIterator: - def __init__(self, iterable) -> None: + def __init__( + self, + iterable, + queue_size: int = DEFAULT_QUEUE_SIZE, + warning_threshold: float = DEFAULT_WARNING_THRESHOLD + ) -> None: - self._queue = Queue(QUEUESIZE) + self._queue = Queue(queue_size) self._iterable = iterable self._consumer = BackgroundConsumer(self._queue, self._iterable) self._consumer.start() self.len = len(iterable) + self._warning_threshold = warning_threshold def __next__(self): start_wait = time.time() item = self._queue.get() wait_time = time.time() - start_wait - if wait_time > 1.0: + if wait_time > self._warning_threshold: log.warning(f"Data loading is slow, waited {wait_time:.2f} seconds.")deepmd/utils/path.py (1)
47-49
: LGTM! Consider adding docstring for clarity.The
__getnewargs__
implementation correctly returns the essential arguments needed for pickling. However, consider adding a docstring to explain its purpose:def __getnewargs__(self): + """Return a tuple of arguments needed to create a new instance during unpickling. + + Returns + ------- + tuple + A tuple of (path, mode) used by __new__ to reconstruct the object. + """ return (self.path, self.mode)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
deepmd/pt/utils/dataloader.py
(5 hunks)deepmd/utils/path.py
(1 hunks)
🔇 Additional comments (3)
deepmd/pt/utils/dataloader.py (3)
194-208
: LGTM! Rank check prevents duplicate output in distributed training.
The addition of rank check ensures that summary information is printed only once in distributed training environments, improving log clarity.
212-223
: LGTM! More Pythonic implementation with StopIteration.
The simplified interface and use of StopIteration as a sentinel value aligns better with Python's iterator protocol.
Line range hint 100-110
: LGTM! Verify pool size calculation in distributed environments.
The parallel dataset construction using Pool is efficient and properly considers distributed training environments. However, let's verify the pool size calculation behavior across different distributed configurations.
✅ Verification successful
Let me gather more information about the distributed training setup to verify the pool size calculation.
LGTM! Pool size calculation is correctly implemented for distributed environments
The pool size calculation is properly implemented and consistent with the distributed training setup:
- The code correctly checks for distributed environment using
dist.is_available() and dist.is_initialized()
LOCAL_WORLD_SIZE
is only used in this specific pool size calculation, which is appropriate as it represents the number of processes on the local machine- The distributed training initialization and rank handling throughout the codebase follows consistent patterns
- The pool size division ensures resources are properly allocated across distributed processes without oversubscription
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for potential issues with pool size calculation in distributed environments
# Search for other uses of LOCAL_WORLD_SIZE to ensure consistent handling
rg "LOCAL_WORLD_SIZE" -A 2 -B 2
# Search for similar pool size calculations to verify consistency
ast-grep --pattern 'Pool(
$$$
os.cpu_count()
$$$
)'
Length of output: 425
Script:
#!/bin/bash
# Check for distributed training initialization and configuration
rg "dist\.init_process_group|dist\.is_initialized\(\)" -A 2 -B 2
# Look for other process/thread pool implementations
rg "multiprocessing\.(Pool|Process|ThreadPool)" -A 2 -B 2
# Check if there are any other environment variables used for process control
rg "os\.environ\[.*WORLD_SIZE" -A 2 -B 2
# Search for any comments or documentation about distributed training
rg -g "*.md" -g "*.rst" "distributed training|LOCAL_WORLD_SIZE"
Length of output: 6341
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
deepmd/pt/utils/dataloader.py (1)
246-252
: Consider making the warning threshold configurable.The 1.0-second threshold for slow data loading warnings is hardcoded. Consider making this configurable to accommodate different system capabilities and requirements.
- if ( - wait_time > 1.0 - ): # Even for Multi-Task training, each step usually takes < 1s + # Default to 1.0s, but allow configuration through environment variable + warning_threshold = float(os.getenv('DEEPMD_LOADER_WARNING_THRESHOLD', '1.0')) + if wait_time > warning_threshold:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/utils/dataloader.py
(5 hunks)
🔇 Additional comments (5)
deepmd/pt/utils/dataloader.py (5)
60-64
: LGTM! Clean factory function implementation.
The new construct_dataset
function provides a clean abstraction for dataset construction, improving code modularity and reusability.
Line range hint 102-112
: LGTM! Improved parallel dataset construction.
Good improvements:
- Using
partial
for cleaner parameter binding - Proper CPU pool sizing for distributed training
- More maintainable parallel dataset construction
196-210
: LGTM! Proper handling of distributed output.
Good addition of rank-aware printing to prevent duplicate output in distributed training environments.
233-237
: LGTM! Proper length protocol implementation.
Good addition of the len
attribute to support Python's length protocol.
224-225
: Verify StopIteration handling in DataLoader context.
The comment suggests this case shouldn't occur for DataLoader. Let's verify the DataLoader's iteration behavior.
✅ Verification successful
StopIteration handling is correctly implemented
The code review reveals that the StopIteration concern is properly handled:
- The BackgroundConsumer sends StopIteration through the queue when iteration is complete
- The BufferedIterator's
__next__
method checks if the received item is an Exception (which includes StopIteration) and raises it appropriately - The comment "this should not happen for DataLoader" is misleading and can be ignored, as the StopIteration handling is a necessary part of the Python iterator protocol implementation
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for DataLoader iteration handling to verify StopIteration behavior
ast-grep --pattern 'class DataLoader {
$$$
def __iter__() {
$$$
}
$$$
}'
Length of output: 86
Script:
#!/bin/bash
# Let's try a broader search for DataLoader class and its iteration-related methods
rg -A 10 "class DataLoader" deepmd/pt/utils/dataloader.py
# Also search for any StopIteration handling
rg -B 2 -A 2 "StopIteration" deepmd/pt/utils/dataloader.py
Length of output: 429
Script:
#!/bin/bash
# Let's search for the BackgroundConsumer class and its context
rg -B 10 -A 20 "class BackgroundConsumer" deepmd/pt/utils/dataloader.py
# Also search for where this consumer is used
rg -B 5 -A 5 "BackgroundConsumer" deepmd/pt/utils/dataloader.py
Length of output: 1668
Script:
#!/bin/bash
# Let's see the complete BufferedIterator implementation to understand how it handles iteration
rg -B 2 -A 15 "class BufferedIterator" deepmd/pt/utils/dataloader.py
# And search for where BufferedIterator is used
rg -B 5 -A 5 "BufferedIterator" deepmd/pt/utils/dataloader.py
Length of output: 830
Script:
#!/bin/bash
# Let's see the complete __next__ implementation of BufferedIterator
rg -B 2 -A 10 "def __next__" deepmd/pt/utils/dataloader.py
# And let's see if there's a DataLoader class implementation in the file
rg -B 5 -A 15 "class.*DataLoader" deepmd/pt/utils/dataloader.py
Length of output: 565
Duplicated in #4434 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (4)
deepmd/pt/utils/dataloader.py (3)
102-102
: LGTM! Good distributed training awareness.The changes improve the code by:
- Preventing redundant output in distributed training with rank check
- Properly scaling pool size based on LOCAL_WORLD_SIZE
- Using clean dataset construction abstraction with Pool
Consider adding error handling for the Pool operations to gracefully handle potential failures during parallel dataset construction.
with Pool( os.cpu_count() // ( int(os.environ["LOCAL_WORLD_SIZE"]) if dist.is_available() and dist.is_initialized() else 1 ) ) as pool: - self.systems = pool.map(construct_dataset_systems, systems) + try: + self.systems = pool.map(construct_dataset_systems, systems) + except Exception as e: + log.error(f"Failed to construct datasets in parallel: {e}") + # Fallback to sequential construction + self.systems = [construct_dataset_systems(system) for system in systems]Also applies to: 196-210
214-226
: Improve the comment about DataLoader.The changes to use StopIteration for signaling are good, but the comment "this should not happen for DataLoader" is unclear. Consider clarifying when and why StopIteration might occur.
- # Signal the consumer we are done; this should not happen for DataLoader + # Signal the end of iteration. Note: For DataLoader, this typically only occurs + # when the DataLoader is explicitly closed or the dataset is exhausted self._queue.put(StopIteration)
233-253
: Enhance warning system and error handling.The changes improve the warning output, but consider these enhancements:
- Make the warning threshold configurable
- Add more actionable information to the warning
- Make the error handling more explicit
+ # Class variable for warning threshold + SLOW_LOADING_THRESHOLD = 1.0 # seconds + def __init__(self, iterable) -> None: self._queue = Queue(QUEUESIZE) self._iterable = iterable self._consumer = BackgroundConsumer(self._queue, self._iterable) self._consumer.start() self.len = len(iterable) + self._warned = False # Track if warning was already issued def __next__(self): start_wait = time.time() item = self._queue.get() wait_time = time.time() - start_wait - if ( - wait_time > 1.0 - ): # Even for Multi-Task training, each step usually takes < 1s - log.warning(f"Data loading is slow, waited {wait_time:.2f} seconds.") + if wait_time > self.SLOW_LOADING_THRESHOLD and not self._warned: + self._warned = True # Warn only once to avoid log spam + log.warning( + f"Data loading is slow (waited {wait_time:.2f}s). Consider:\n" + "1. Increasing the number of worker processes\n" + "2. Reducing batch size\n" + "3. Using memory-mapped files or faster storage" + ) - if issubclass(item, Exception): + if item is StopIteration: raise item + elif isinstance(item, Exception): + raise RuntimeError(f"Background worker failed: {item}") return itemdeepmd/pt/train/training.py (1)
1061-1071
: Consider enhancing error handling and logging for data loading issues.While the code correctly handles StopIteration for dataloader refresh, it would be beneficial to:
- Add logging when the dataloader is refreshed to help with debugging
- Consider adding a max retry count to prevent infinite loops in case of persistent data loading issues
Here's a suggested improvement:
if data is None and not is_train: return {}, {}, {} if self.multi_task: data = data[task_key] dataloader = dataloader[task_key] +max_retries = 3 +retry_count = 0 try: batch_data = next(iter(data)) except StopIteration: + log.debug(f"Refreshing dataloader for {'training' if is_train else 'validation'} data") # Refresh the status of the dataloader to start from a new epoch data = BufferedIterator(iter(dataloader)) + while retry_count < max_retries: + try: + batch_data = next(iter(data)) + break + except StopIteration: + retry_count += 1 + log.warning(f"Failed to get batch after refresh, attempt {retry_count}/{max_retries}") + data = BufferedIterator(iter(dataloader)) + if retry_count == max_retries: + log.error("Failed to get batch after maximum retries") + raise RuntimeError("Failed to get batch after maximum retries") - batch_data = next(iter(data))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
deepmd/pt/train/training.py
(1 hunks)deepmd/pt/utils/dataloader.py
(5 hunks)
🔇 Additional comments (3)
deepmd/pt/utils/dataloader.py (2)
5-7
: LGTM! Clean abstraction for dataset construction.
The new construct_dataset
function provides a clean abstraction for creating dataset instances, and using partial
to bind type_map
is a good functional programming practice.
Also applies to: 60-64
Line range hint 1-253
: Verify PR objectives implementation.
The changes successfully achieve the PR objectives:
- Dataset construction is simplified through the new
construct_dataset
function and proper Pool usage - Warning output for throttled reading is improved
However, consider enhancing the warning system further to provide even more detailed diagnostics about the cause of throttling.
✅ Verification successful
Warning system implementation is appropriate for the use case
The warning system in the dataloader is well-implemented and sufficient for its purpose:
- The warning at line 252 provides clear timing information:
log.warning(f"Data loading is slow, waited {wait_time:.2f} seconds.")
- The 1-second threshold is reasonable as it's calibrated based on typical step times ("Even for Multi-Task training, each step usually takes < 1s")
- The warning includes both the fact that data loading is slow and the exact wait time for diagnostics
Additional diagnostic details would add unnecessary complexity since:
- The root cause (I/O bottleneck) is clear from the wait time
- The warning already triggers at an appropriate threshold
- The message is actionable - users can adjust batch sizes or I/O parameters
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify that the warning system is properly triggered in different scenarios
# Test: Search for all warning-related code to ensure comprehensive coverage
rg -n "warning|warn" --type py
# Test: Check for any existing performance-related issues or discussions
gh issue list --label "performance"
gh pr list --label "performance"
Length of output: 10444
deepmd/pt/train/training.py (1)
1056-1060
: LGTM! Clean data selection logic.
The code cleanly handles the selection between training and validation data using tuple unpacking, which improves readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (3)
deepmd/pt/utils/dataloader.py (3)
5-7
: LGTM! Consider adding docstring to the new function.The addition of
construct_dataset
function and the use ofpartial
improves code modularity by extracting the dataset construction logic. Consider adding a docstring to document the parameters and return type.def construct_dataset(system, type_map): + """Construct a DeepmdDataSetForLoader instance. + + Parameters + ---------- + system : str + Path to the system data + type_map : list + Mapping of atom types + + Returns + ------- + DeepmdDataSetForLoader + Dataset instance for the given system + """ return DeepmdDataSetForLoader( system=system, type_map=type_map, )Also applies to: 60-64
Line range hint
102-112
: LGTM! Consider adding progress feedback for large datasets.The parallel dataset construction using Pool is a good improvement. For better user experience when loading large datasets, consider adding progress feedback.
construct_dataset_systems = partial(construct_dataset, type_map=type_map) + if len(systems) >= 100: + from tqdm import tqdm + with Pool(...) as pool: + self.systems = list(tqdm( + pool.imap(construct_dataset_systems, systems), + total=len(systems), + desc="Constructing datasets" + )) - with Pool(...) as pool: - self.systems = pool.map(construct_dataset_systems, systems)
237-237
: Consider making the warning threshold configurable.While the warning for slow data loading is helpful, the 1-second threshold might not be suitable for all scenarios. Consider making this threshold configurable.
class BufferedIterator: def __init__(self, iterable) -> None: self._queue = Queue(QUEUESIZE) self._iterable = iterable self._consumer = BackgroundConsumer(self._queue, self._iterable) self._consumer.start() self.len = len(iterable) + self.warning_threshold = 1.0 # seconds def __next__(self): start_wait = time.time() item = self._queue.get() wait_time = time.time() - start_wait - if wait_time > 1.0: + if wait_time > self.warning_threshold: log.warning(f"Data loading is slow, waited {wait_time:.2f} seconds.")Also applies to: 246-252
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/utils/dataloader.py
(5 hunks)
🔇 Additional comments (3)
deepmd/pt/utils/dataloader.py (3)
196-210
: LGTM! Good handling of distributed output.
The rank check in print_summary prevents duplicate output in distributed training, which is the correct approach.
214-226
: LGTM! Good simplification of the consumer implementation.
The changes improve the code by:
- Removing the unused
max_len
parameter - Using
StopIteration
for end-of-data signaling, which better aligns with Python's iterator protocol
233-237
: Verify thread safety of the iterator implementation.
Let's verify that the iterator implementation is thread-safe, particularly around the len attribute access.
✅ Verification successful
Thread safety implementation is correct and well-designed
The implementation is thread-safe due to several key design choices:
- The Queue is properly used as a thread-safe producer-consumer pattern
- Length is captured once during initialization and stored as an immutable attribute
- BackgroundConsumer runs as a daemon thread with clean separation of concerns
- The iterator implementation uses blocking queue operations which handle synchronization
- No shared mutable state exists between the consumer thread and main thread
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for potential race conditions in iterator usage
ast-grep --pattern 'class $_ {
$$$
def __iter__($$$) {
$$$
}
$$$
def __len__($$$) {
$$$
}
$$$
}'
Length of output: 120
Script:
#!/bin/bash
# Let's check the implementation of BackgroundConsumer and related iterator code
rg -A 10 "class BackgroundConsumer"
# Also check the implementation of the class containing this Queue setup
rg -B 5 -A 10 "self._queue = Queue\(QUEUESIZE\)"
# Check if there are any other references to len or iterator access
rg "def __len__" deepmd/pt/utils/dataloader.py
rg "def __iter__" deepmd/pt/utils/dataloader.py
Length of output: 1817
Signed-off-by: Chun Cai <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
deepmd/pt/train/training.py (1)
1056-1071
: LGTM! Clean and efficient data loading implementation.The refactored code is well-structured and handles both single-task and multi-task scenarios elegantly. The improvements include:
- Cleaner data/dataloader selection using tuple unpacking
- Clear handling of validation data absence
- Proper error handling for dataset iteration
Consider adding type hints to improve code maintainability:
- def get_data(self, is_train=True, task_key="Default"): + def get_data( + self, + is_train: bool = True, + task_key: str = "Default" + ) -> tuple[dict, dict, dict]:
Co-authored-by: Han Wang <[email protected]> Signed-off-by: Chun Cai <[email protected]>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4437 +/- ##
==========================================
+ Coverage 83.27% 83.75% +0.47%
==========================================
Files 667 667
Lines 61446 61525 +79
Branches 3486 3487 +1
==========================================
+ Hits 51167 51528 +361
+ Misses 9151 8872 -279
+ Partials 1128 1125 -3 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Chun Cai <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
deepmd/pt/train/training.py (1)
1066-1072
: Add a clarifying comment for StopIteration handling.The StopIteration handling logic ensures continuous training across epochs, but this important detail isn't immediately obvious.
Add a comment to explain the purpose:
try: batch_data = next(iter(data)) except StopIteration: - # Refresh the status of the dataloader to start from a new epoch + # When all data in the current epoch is exhausted, refresh the dataloader + # to start a new epoch. This ensures continuous training across epochs. with torch.device("cpu"): data = BufferedIterator(iter(dataloader)) batch_data = next(iter(data))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/train/training.py
(1 hunks)
🔇 Additional comments (1)
deepmd/pt/train/training.py (1)
1056-1065
: LGTM! But verify test coverage for multi-task handling.
The refactoring improves code readability by using tuple unpacking and clear conditional logic. However, based on past review comments, ensure that the multi-task handling at line 1062 is properly covered by tests.
I'm reverting changes to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (2)
deepmd/pt/utils/dataloader.py (2)
Line range hint
102-112
: Add error handling for parallel dataset constructionWhile the parallel construction using Pool is a good optimization, consider adding error handling to gracefully handle failures during dataset construction. This is especially important when dealing with multiple systems in parallel.
with Pool( os.cpu_count() // ( int(os.environ["LOCAL_WORLD_SIZE"]) if dist.is_available() and dist.is_initialized() else 1 ) ) as pool: - self.systems = pool.map(construct_dataset_systems, systems) + try: + self.systems = pool.map(construct_dataset_systems, systems) + except Exception as e: + log.error(f"Failed to construct datasets in parallel: {str(e)}") + # Fallback to sequential construction + self.systems = [construct_dataset_systems(system) for system in systems]
247-256
: Good improvement on warning messages and throttling!The enhanced warning system with throttling addresses the concern about too many warnings. However, consider making the warning threshold configurable.
class BufferedIterator: def __init__(self, iterable) -> None: self._queue = Queue(QUEUESIZE) self._iterable = iterable self._consumer = BackgroundConsumer(self._queue, self._iterable) self._consumer.start() self.last_warning_time = time.time() self.len = len(iterable) + self.warning_threshold = 1.0 # seconds + + @property + def warning_threshold(self) -> float: + return self._warning_threshold + + @warning_threshold.setter + def warning_threshold(self, value: float) -> None: + if value <= 0: + raise ValueError("Warning threshold must be positive") + self._warning_threshold = value def __next__(self): start_wait = time.time() item = self._queue.get() wait_time = time.time() - start_wait if ( - wait_time > 1.0 and start_wait - self.last_warning_time > 15 * 60 + wait_time > self.warning_threshold and start_wait - self.last_warning_time > 15 * 60 ):
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/utils/dataloader.py
(5 hunks)
🔇 Additional comments (1)
deepmd/pt/utils/dataloader.py (1)
Line range hint 5-64
: Good abstraction of dataset construction logic!
The introduction of construct_dataset
function and use of partial
from functools shows good separation of concerns and functional programming practices. This makes the code more maintainable and testable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
deepmd/pt/utils/dataloader.py (2)
227-231
: Consider making queue size configurableThe
QUEUESIZE
is hardcoded to 32. Consider making it configurable through constructor parameters to allow tuning based on memory constraints and performance requirements.class BufferedIterator: - def __init__(self, iterable) -> None: + def __init__(self, iterable, queue_size: int = 32) -> None: self._queue = Queue(QUEUESIZE) self._iterable = iterable self._consumer = BackgroundConsumer(self._queue, self._iterable) self._consumer.start() self.last_warning_time = time.time()
240-249
: Consider making warning thresholds configurableThe warning thresholds (1.0 second wait time and 15 minutes cooldown) are hardcoded. These values might need adjustment based on different workloads and environments.
class BufferedIterator: - def __init__(self, iterable) -> None: + def __init__( + self, + iterable, + wait_threshold: float = 1.0, + warning_cooldown: float = 900.0 # 15 minutes + ) -> None: self._queue = Queue(QUEUESIZE) self._iterable = iterable self._consumer = BackgroundConsumer(self._queue, self._iterable) self._consumer.start() self.last_warning_time = time.time() + self._wait_threshold = wait_threshold + self._warning_cooldown = warning_cooldown def __next__(self): start_wait = time.time() item = self._queue.get() wait_time = time.time() - start_wait if ( - wait_time > 1.0 and start_wait - self.last_warning_time > 15 * 60 + wait_time > self._wait_threshold and + start_wait - self.last_warning_time > self._warning_cooldown ):
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/utils/dataloader.py
(2 hunks)
🔇 Additional comments (2)
deepmd/pt/utils/dataloader.py (2)
8-10
: LGTM: Import change improves clarity
The explicit import of Queue
is clearer than the generic queue import.
208-220
: Consider improving edge case handling
The comment "should not happen for DataLoader" suggests an edge case that should be properly documented or handled. Consider:
- Adding documentation explaining why this case shouldn't occur
- Adding proper error handling if it does occur unexpectedly
Let's verify if this edge case is handled elsewhere:
Consider this improvement:
def run(self) -> None:
+ """Background consumer thread that feeds items from source to queue.
+
+ Note: StopIteration is used as a sentinel to signal completion, though
+ this should not occur during normal DataLoader operation as it typically
+ cycles indefinitely.
+ """
for item in self._source:
self._queue.put(item) # Blocking if the queue is full
# Signal the consumer we are done; this should not happen for DataLoader
self._queue.put(StopIteration())
✅ Verification successful
Let me gather more context about the BackgroundConsumer usage to understand how StopIteration is handled by the consumer.
Let me check how the consumer side handles the StopIteration.
Let me check the complete implementation of the consumer class and its usage in the BufferedIterator.
The StopIteration handling is correct and well-designed
The comment "should not happen for DataLoader" is actually accurate and doesn't indicate a problem. Here's why:
- The BackgroundConsumer is used within BufferedIterator, which wraps an iterable (DataLoader)
- DataLoaders in PyTorch are designed to be infinite iterators during training - they automatically restart when reaching the end of an epoch
- The StopIteration signal is properly handled by the consumer-producer pattern:
- Producer (BackgroundConsumer) signals completion with StopIteration
- Consumer (BufferedIterator.next) checks if the item is an Exception and raises it
The current implementation follows a standard producer-consumer pattern and handles the edge case appropriately. The comment serves as documentation for maintainers to understand the expected behavior.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for StopIteration handling in the codebase
rg -A 3 "StopIteration" --type py
Length of output: 5685
Script:
#!/bin/bash
# Search for BackgroundConsumer usage and its consumer side
rg -B 3 -A 10 "BackgroundConsumer" --type py
Length of output: 3197
Script:
#!/bin/bash
# Search for the __next__ implementation in the same file
rg -B 2 -A 10 "__next__" deepmd/pt/utils/dataloader.py
Length of output: 587
Script:
#!/bin/bash
# Get the complete implementation of BackgroundConsumer and BufferedIterator
cat deepmd/pt/utils/dataloader.py
Length of output: 10011
I've trained 100k steps from scratch. Since this PR only changes the warning mechanism of DataLoader, the loss curve does not change significantly. The deviation is not introduced in this PR - I've attached the curve of two runs on devel branch, and they are not exactly the same. |
Summary by CodeRabbit
New Features
Bug Fixes
Refactor