Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: simplify dataset construction #4437

Merged
merged 14 commits into from
Dec 9, 2024

Conversation

caic99
Copy link
Member

@caic99 caic99 commented Nov 27, 2024

Summary by CodeRabbit

  • New Features

    • Introduced a new function for dataset construction, enhancing data loading processes.
    • Added a method to improve pickling and unpickling capabilities for path handling classes.
  • Bug Fixes

    • Updated summary printing to prevent redundant output during distributed training.
  • Refactor

    • Simplified initialization of the BackgroundConsumer class.
    • Streamlined consumer thread and queue handling in the BufferedIterator class.

@caic99
Copy link
Member Author

caic99 commented Nov 27, 2024

@anyangml This PR adds a more detailed warning output in the case of dataset reading is throttled.

Copy link
Contributor

coderabbitai bot commented Nov 27, 2024

📝 Walkthrough
📝 Walkthrough

Walkthrough

The changes in this pull request involve modifications to two primary files: dataloader.py and path.py. In dataloader.py, enhancements are made to the DpLoaderSet, BackgroundConsumer, and BufferedIterator classes, including the addition of a new dataset construction function and updates to existing methods for improved modularity and clarity. In path.py, a new abstract method for pickling support is introduced in the DPPath class and its subclasses, enhancing the class hierarchy's serialization capabilities.

Changes

File Change Summary
deepmd/pt/utils/dataloader.py - Added construct_dataset(system) method.
- Updated print_summary to include rank check.
- Modified BackgroundConsumer constructor to remove max_len parameter.
- Adjusted BufferedIterator to remove max_len from consumer initialization.
deepmd/utils/path.py - Added __getnewargs__ method in DPPath, DPOSPath, and DPH5Path classes for enhanced pickling support.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant DpLoaderSet
    participant BackgroundConsumer
    participant BufferedIterator

    User->>DpLoaderSet: Initialize with parameters
    DpLoaderSet->>DpLoaderSet: Call construct_dataset
    DpLoaderSet->>BackgroundConsumer: Initialize consumer
    BackgroundConsumer->>BufferedIterator: Initialize iterator
    BufferedIterator->>BackgroundConsumer: Start consuming data
    BackgroundConsumer->>DpLoaderSet: Signal end of data loading
    DpLoaderSet->>User: Print summary (if rank 0)
Loading

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (3)
deepmd/pt/utils/dataloader.py (2)

5-7: LGTM! Consider adding type hints for better code clarity.

The introduction of construct_dataset function and the use of partial improves code reusability and maintainability. Consider adding type hints to make the function signature more explicit:

-def construct_dataset(system, type_map):
+def construct_dataset(system: str, type_map: list[str]) -> DeepmdDataSetForLoader:

Also applies to: 11-13, 59-63


225-246: Consider making queue size and warning threshold configurable.

The current implementation has hardcoded values that might not be optimal for all scenarios:

  1. QUEUESIZE = 32 might need adjustment based on memory constraints or dataset characteristics
  2. The 1-second warning threshold might be too aggressive for larger batches or slower storage systems

Consider making these values configurable:

-QUEUESIZE = 32
+DEFAULT_QUEUE_SIZE = 32
+DEFAULT_WARNING_THRESHOLD = 1.0

 class BufferedIterator:
-    def __init__(self, iterable) -> None:
+    def __init__(
+        self,
+        iterable,
+        queue_size: int = DEFAULT_QUEUE_SIZE,
+        warning_threshold: float = DEFAULT_WARNING_THRESHOLD
+    ) -> None:
-        self._queue = Queue(QUEUESIZE)
+        self._queue = Queue(queue_size)
         self._iterable = iterable
         self._consumer = BackgroundConsumer(self._queue, self._iterable)
         self._consumer.start()
         self.len = len(iterable)
+        self._warning_threshold = warning_threshold

     def __next__(self):
         start_wait = time.time()
         item = self._queue.get()
         wait_time = time.time() - start_wait
-        if wait_time > 1.0:
+        if wait_time > self._warning_threshold:
             log.warning(f"Data loading is slow, waited {wait_time:.2f} seconds.")
deepmd/utils/path.py (1)

47-49: LGTM! Consider adding docstring for clarity.

The __getnewargs__ implementation correctly returns the essential arguments needed for pickling. However, consider adding a docstring to explain its purpose:

 def __getnewargs__(self):
+    """Return a tuple of arguments needed to create a new instance during unpickling.
+
+    Returns
+    -------
+    tuple
+        A tuple of (path, mode) used by __new__ to reconstruct the object.
+    """
     return (self.path, self.mode)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 4a45fe5 and 417da94.

📒 Files selected for processing (2)
  • deepmd/pt/utils/dataloader.py (5 hunks)
  • deepmd/utils/path.py (1 hunks)
🔇 Additional comments (3)
deepmd/pt/utils/dataloader.py (3)

194-208: LGTM! Rank check prevents duplicate output in distributed training.

The addition of rank check ensures that summary information is printed only once in distributed training environments, improving log clarity.


212-223: LGTM! More Pythonic implementation with StopIteration.

The simplified interface and use of StopIteration as a sentinel value aligns better with Python's iterator protocol.


Line range hint 100-110: LGTM! Verify pool size calculation in distributed environments.

The parallel dataset construction using Pool is efficient and properly considers distributed training environments. However, let's verify the pool size calculation behavior across different distributed configurations.

✅ Verification successful

Let me gather more information about the distributed training setup to verify the pool size calculation.


LGTM! Pool size calculation is correctly implemented for distributed environments

The pool size calculation is properly implemented and consistent with the distributed training setup:

  • The code correctly checks for distributed environment using dist.is_available() and dist.is_initialized()
  • LOCAL_WORLD_SIZE is only used in this specific pool size calculation, which is appropriate as it represents the number of processes on the local machine
  • The distributed training initialization and rank handling throughout the codebase follows consistent patterns
  • The pool size division ensures resources are properly allocated across distributed processes without oversubscription
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for potential issues with pool size calculation in distributed environments

# Search for other uses of LOCAL_WORLD_SIZE to ensure consistent handling
rg "LOCAL_WORLD_SIZE" -A 2 -B 2

# Search for similar pool size calculations to verify consistency
ast-grep --pattern 'Pool(
  $$$
  os.cpu_count()
  $$$
)'

Length of output: 425


Script:

#!/bin/bash
# Check for distributed training initialization and configuration
rg "dist\.init_process_group|dist\.is_initialized\(\)" -A 2 -B 2

# Look for other process/thread pool implementations
rg "multiprocessing\.(Pool|Process|ThreadPool)" -A 2 -B 2

# Check if there are any other environment variables used for process control
rg "os\.environ\[.*WORLD_SIZE" -A 2 -B 2

# Search for any comments or documentation about distributed training
rg -g "*.md" -g "*.rst" "distributed training|LOCAL_WORLD_SIZE"

Length of output: 6341

deepmd/utils/path.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
deepmd/pt/utils/dataloader.py (1)

246-252: Consider making the warning threshold configurable.

The 1.0-second threshold for slow data loading warnings is hardcoded. Consider making this configurable to accommodate different system capabilities and requirements.

-        if (
-            wait_time > 1.0
-        ):  # Even for Multi-Task training, each step usually takes < 1s
+        # Default to 1.0s, but allow configuration through environment variable
+        warning_threshold = float(os.getenv('DEEPMD_LOADER_WARNING_THRESHOLD', '1.0'))
+        if wait_time > warning_threshold:
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 417da94 and e9672d0.

📒 Files selected for processing (1)
  • deepmd/pt/utils/dataloader.py (5 hunks)
🔇 Additional comments (5)
deepmd/pt/utils/dataloader.py (5)

60-64: LGTM! Clean factory function implementation.

The new construct_dataset function provides a clean abstraction for dataset construction, improving code modularity and reusability.


Line range hint 102-112: LGTM! Improved parallel dataset construction.

Good improvements:

  • Using partial for cleaner parameter binding
  • Proper CPU pool sizing for distributed training
  • More maintainable parallel dataset construction

196-210: LGTM! Proper handling of distributed output.

Good addition of rank-aware printing to prevent duplicate output in distributed training environments.


233-237: LGTM! Proper length protocol implementation.

Good addition of the len attribute to support Python's length protocol.


224-225: Verify StopIteration handling in DataLoader context.

The comment suggests this case shouldn't occur for DataLoader. Let's verify the DataLoader's iteration behavior.

✅ Verification successful

StopIteration handling is correctly implemented

The code review reveals that the StopIteration concern is properly handled:

  • The BackgroundConsumer sends StopIteration through the queue when iteration is complete
  • The BufferedIterator's __next__ method checks if the received item is an Exception (which includes StopIteration) and raises it appropriately
  • The comment "this should not happen for DataLoader" is misleading and can be ignored, as the StopIteration handling is a necessary part of the Python iterator protocol implementation
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for DataLoader iteration handling to verify StopIteration behavior
ast-grep --pattern 'class DataLoader {
  $$$
  def __iter__() {
    $$$
  }
  $$$
}'

Length of output: 86


Script:

#!/bin/bash
# Let's try a broader search for DataLoader class and its iteration-related methods
rg -A 10 "class DataLoader" deepmd/pt/utils/dataloader.py

# Also search for any StopIteration handling
rg -B 2 -A 2 "StopIteration" deepmd/pt/utils/dataloader.py

Length of output: 429


Script:

#!/bin/bash
# Let's search for the BackgroundConsumer class and its context
rg -B 10 -A 20 "class BackgroundConsumer" deepmd/pt/utils/dataloader.py

# Also search for where this consumer is used
rg -B 5 -A 5 "BackgroundConsumer" deepmd/pt/utils/dataloader.py

Length of output: 1668


Script:

#!/bin/bash
# Let's see the complete BufferedIterator implementation to understand how it handles iteration
rg -B 2 -A 15 "class BufferedIterator" deepmd/pt/utils/dataloader.py

# And search for where BufferedIterator is used
rg -B 5 -A 5 "BufferedIterator" deepmd/pt/utils/dataloader.py

Length of output: 830


Script:

#!/bin/bash
# Let's see the complete __next__ implementation of BufferedIterator
rg -B 2 -A 10 "def __next__" deepmd/pt/utils/dataloader.py

# And let's see if there's a DataLoader class implementation in the file
rg -B 5 -A 15 "class.*DataLoader" deepmd/pt/utils/dataloader.py

Length of output: 565

@iProzd
Copy link
Collaborator

iProzd commented Nov 27, 2024

Duplicated in #4434

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (4)
deepmd/pt/utils/dataloader.py (3)

102-102: LGTM! Good distributed training awareness.

The changes improve the code by:

  1. Preventing redundant output in distributed training with rank check
  2. Properly scaling pool size based on LOCAL_WORLD_SIZE
  3. Using clean dataset construction abstraction with Pool

Consider adding error handling for the Pool operations to gracefully handle potential failures during parallel dataset construction.

 with Pool(
     os.cpu_count()
     // (
         int(os.environ["LOCAL_WORLD_SIZE"])
         if dist.is_available() and dist.is_initialized()
         else 1
     )
 ) as pool:
-    self.systems = pool.map(construct_dataset_systems, systems)
+    try:
+        self.systems = pool.map(construct_dataset_systems, systems)
+    except Exception as e:
+        log.error(f"Failed to construct datasets in parallel: {e}")
+        # Fallback to sequential construction
+        self.systems = [construct_dataset_systems(system) for system in systems]

Also applies to: 196-210


214-226: Improve the comment about DataLoader.

The changes to use StopIteration for signaling are good, but the comment "this should not happen for DataLoader" is unclear. Consider clarifying when and why StopIteration might occur.

-        # Signal the consumer we are done; this should not happen for DataLoader
+        # Signal the end of iteration. Note: For DataLoader, this typically only occurs
+        # when the DataLoader is explicitly closed or the dataset is exhausted
         self._queue.put(StopIteration)

233-253: Enhance warning system and error handling.

The changes improve the warning output, but consider these enhancements:

  1. Make the warning threshold configurable
  2. Add more actionable information to the warning
  3. Make the error handling more explicit
+    # Class variable for warning threshold
+    SLOW_LOADING_THRESHOLD = 1.0  # seconds
+
     def __init__(self, iterable) -> None:
         self._queue = Queue(QUEUESIZE)
         self._iterable = iterable
         self._consumer = BackgroundConsumer(self._queue, self._iterable)
         self._consumer.start()
         self.len = len(iterable)
+        self._warned = False  # Track if warning was already issued

     def __next__(self):
         start_wait = time.time()
         item = self._queue.get()
         wait_time = time.time() - start_wait
-        if (
-            wait_time > 1.0
-        ):  # Even for Multi-Task training, each step usually takes < 1s
-            log.warning(f"Data loading is slow, waited {wait_time:.2f} seconds.")
+        if wait_time > self.SLOW_LOADING_THRESHOLD and not self._warned:
+            self._warned = True  # Warn only once to avoid log spam
+            log.warning(
+                f"Data loading is slow (waited {wait_time:.2f}s). Consider:\n"
+                "1. Increasing the number of worker processes\n"
+                "2. Reducing batch size\n"
+                "3. Using memory-mapped files or faster storage"
+            )
-        if issubclass(item, Exception):
+        if item is StopIteration:
             raise item
+        elif isinstance(item, Exception):
+            raise RuntimeError(f"Background worker failed: {item}")
         return item
deepmd/pt/train/training.py (1)

1061-1071: Consider enhancing error handling and logging for data loading issues.

While the code correctly handles StopIteration for dataloader refresh, it would be beneficial to:

  1. Add logging when the dataloader is refreshed to help with debugging
  2. Consider adding a max retry count to prevent infinite loops in case of persistent data loading issues

Here's a suggested improvement:

 if data is None and not is_train:
     return {}, {}, {}
 if self.multi_task:
     data = data[task_key]
     dataloader = dataloader[task_key]
+max_retries = 3
+retry_count = 0
 try:
     batch_data = next(iter(data))
 except StopIteration:
+    log.debug(f"Refreshing dataloader for {'training' if is_train else 'validation'} data")
     # Refresh the status of the dataloader to start from a new epoch
     data = BufferedIterator(iter(dataloader))
+    while retry_count < max_retries:
+        try:
+            batch_data = next(iter(data))
+            break
+        except StopIteration:
+            retry_count += 1
+            log.warning(f"Failed to get batch after refresh, attempt {retry_count}/{max_retries}")
+            data = BufferedIterator(iter(dataloader))
+    if retry_count == max_retries:
+        log.error("Failed to get batch after maximum retries")
+        raise RuntimeError("Failed to get batch after maximum retries")
-    batch_data = next(iter(data))
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between e9672d0 and 3895c94.

📒 Files selected for processing (2)
  • deepmd/pt/train/training.py (1 hunks)
  • deepmd/pt/utils/dataloader.py (5 hunks)
🔇 Additional comments (3)
deepmd/pt/utils/dataloader.py (2)

5-7: LGTM! Clean abstraction for dataset construction.

The new construct_dataset function provides a clean abstraction for creating dataset instances, and using partial to bind type_map is a good functional programming practice.

Also applies to: 60-64


Line range hint 1-253: Verify PR objectives implementation.

The changes successfully achieve the PR objectives:

  1. Dataset construction is simplified through the new construct_dataset function and proper Pool usage
  2. Warning output for throttled reading is improved

However, consider enhancing the warning system further to provide even more detailed diagnostics about the cause of throttling.

✅ Verification successful

Warning system implementation is appropriate for the use case

The warning system in the dataloader is well-implemented and sufficient for its purpose:

  • The warning at line 252 provides clear timing information: log.warning(f"Data loading is slow, waited {wait_time:.2f} seconds.")
  • The 1-second threshold is reasonable as it's calibrated based on typical step times ("Even for Multi-Task training, each step usually takes < 1s")
  • The warning includes both the fact that data loading is slow and the exact wait time for diagnostics

Additional diagnostic details would add unnecessary complexity since:

  1. The root cause (I/O bottleneck) is clear from the wait time
  2. The warning already triggers at an appropriate threshold
  3. The message is actionable - users can adjust batch sizes or I/O parameters
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that the warning system is properly triggered in different scenarios

# Test: Search for all warning-related code to ensure comprehensive coverage
rg -n "warning|warn" --type py

# Test: Check for any existing performance-related issues or discussions
gh issue list --label "performance"
gh pr list --label "performance"

Length of output: 10444

deepmd/pt/train/training.py (1)

1056-1060: LGTM! Clean data selection logic.

The code cleanly handles the selection between training and validation data using tuple unpacking, which improves readability.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (3)
deepmd/pt/utils/dataloader.py (3)

5-7: LGTM! Consider adding docstring to the new function.

The addition of construct_dataset function and the use of partial improves code modularity by extracting the dataset construction logic. Consider adding a docstring to document the parameters and return type.

 def construct_dataset(system, type_map):
+    """Construct a DeepmdDataSetForLoader instance.
+
+    Parameters
+    ----------
+    system : str
+        Path to the system data
+    type_map : list
+        Mapping of atom types
+
+    Returns
+    -------
+    DeepmdDataSetForLoader
+        Dataset instance for the given system
+    """
     return DeepmdDataSetForLoader(
         system=system,
         type_map=type_map,
     )

Also applies to: 60-64


Line range hint 102-112: LGTM! Consider adding progress feedback for large datasets.

The parallel dataset construction using Pool is a good improvement. For better user experience when loading large datasets, consider adding progress feedback.

         construct_dataset_systems = partial(construct_dataset, type_map=type_map)
+        if len(systems) >= 100:
+            from tqdm import tqdm
+            with Pool(...) as pool:
+                self.systems = list(tqdm(
+                    pool.imap(construct_dataset_systems, systems),
+                    total=len(systems),
+                    desc="Constructing datasets"
+                ))
-        with Pool(...) as pool:
-            self.systems = pool.map(construct_dataset_systems, systems)

237-237: Consider making the warning threshold configurable.

While the warning for slow data loading is helpful, the 1-second threshold might not be suitable for all scenarios. Consider making this threshold configurable.

 class BufferedIterator:
     def __init__(self, iterable) -> None:
         self._queue = Queue(QUEUESIZE)
         self._iterable = iterable
         self._consumer = BackgroundConsumer(self._queue, self._iterable)
         self._consumer.start()
         self.len = len(iterable)
+        self.warning_threshold = 1.0  # seconds
 
     def __next__(self):
         start_wait = time.time()
         item = self._queue.get()
         wait_time = time.time() - start_wait
-        if wait_time > 1.0:
+        if wait_time > self.warning_threshold:
             log.warning(f"Data loading is slow, waited {wait_time:.2f} seconds.")

Also applies to: 246-252

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 3895c94 and ddc39de.

📒 Files selected for processing (1)
  • deepmd/pt/utils/dataloader.py (5 hunks)
🔇 Additional comments (3)
deepmd/pt/utils/dataloader.py (3)

196-210: LGTM! Good handling of distributed output.

The rank check in print_summary prevents duplicate output in distributed training, which is the correct approach.


214-226: LGTM! Good simplification of the consumer implementation.

The changes improve the code by:

  • Removing the unused max_len parameter
  • Using StopIteration for end-of-data signaling, which better aligns with Python's iterator protocol

233-237: Verify thread safety of the iterator implementation.

Let's verify that the iterator implementation is thread-safe, particularly around the len attribute access.

✅ Verification successful

Thread safety implementation is correct and well-designed

The implementation is thread-safe due to several key design choices:

  • The Queue is properly used as a thread-safe producer-consumer pattern
  • Length is captured once during initialization and stored as an immutable attribute
  • BackgroundConsumer runs as a daemon thread with clean separation of concerns
  • The iterator implementation uses blocking queue operations which handle synchronization
  • No shared mutable state exists between the consumer thread and main thread
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for potential race conditions in iterator usage
ast-grep --pattern 'class $_ {
  $$$
  def __iter__($$$) {
    $$$
  }
  $$$
  def __len__($$$) {
    $$$
  }
  $$$
}'

Length of output: 120


Script:

#!/bin/bash
# Let's check the implementation of BackgroundConsumer and related iterator code
rg -A 10 "class BackgroundConsumer" 

# Also check the implementation of the class containing this Queue setup
rg -B 5 -A 10 "self._queue = Queue\(QUEUESIZE\)"

# Check if there are any other references to len or iterator access
rg "def __len__" deepmd/pt/utils/dataloader.py
rg "def __iter__" deepmd/pt/utils/dataloader.py

Length of output: 1817

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
deepmd/pt/train/training.py (1)

1056-1071: LGTM! Clean and efficient data loading implementation.

The refactored code is well-structured and handles both single-task and multi-task scenarios elegantly. The improvements include:

  • Cleaner data/dataloader selection using tuple unpacking
  • Clear handling of validation data absence
  • Proper error handling for dataset iteration

Consider adding type hints to improve code maintainability:

-    def get_data(self, is_train=True, task_key="Default"):
+    def get_data(
+        self,
+        is_train: bool = True,
+        task_key: str = "Default"
+    ) -> tuple[dict, dict, dict]:
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between ddc39de and 1310691.

📒 Files selected for processing (1)
  • deepmd/pt/train/training.py (1 hunks)

deepmd/pt/train/training.py Outdated Show resolved Hide resolved
Co-authored-by: Han Wang <[email protected]>
Signed-off-by: Chun Cai <[email protected]>
Copy link

codecov bot commented Nov 28, 2024

Codecov Report

Attention: Patch coverage is 84.21053% with 3 lines in your changes missing coverage. Please review.

Project coverage is 83.75%. Comparing base (037cf3f) to head (a199c44).
Report is 10 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/utils/path.py 66.66% 2 Missing ⚠️
deepmd/pt/utils/dataloader.py 92.30% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4437      +/-   ##
==========================================
+ Coverage   83.27%   83.75%   +0.47%     
==========================================
  Files         667      667              
  Lines       61446    61525      +79     
  Branches     3486     3487       +1     
==========================================
+ Hits        51167    51528     +361     
+ Misses       9151     8872     -279     
+ Partials     1128     1125       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

deepmd/pt/train/training.py Outdated Show resolved Hide resolved
deepmd/utils/path.py Outdated Show resolved Hide resolved
@njzjz njzjz added the Test CUDA Trigger test CUDA workflow label Nov 28, 2024
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Nov 28, 2024
@njzjz njzjz added the Test CUDA Trigger test CUDA workflow label Nov 28, 2024
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Nov 28, 2024
deepmd/pt/train/training.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
deepmd/pt/train/training.py (1)

1066-1072: Add a clarifying comment for StopIteration handling.

The StopIteration handling logic ensures continuous training across epochs, but this important detail isn't immediately obvious.

Add a comment to explain the purpose:

        try:
            batch_data = next(iter(data))
        except StopIteration:
-            # Refresh the status of the dataloader to start from a new epoch
+            # When all data in the current epoch is exhausted, refresh the dataloader
+            # to start a new epoch. This ensures continuous training across epochs.
            with torch.device("cpu"):
                data = BufferedIterator(iter(dataloader))
            batch_data = next(iter(data))
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between a1506af and 91d2b9c.

📒 Files selected for processing (1)
  • deepmd/pt/train/training.py (1 hunks)
🔇 Additional comments (1)
deepmd/pt/train/training.py (1)

1056-1065: LGTM! But verify test coverage for multi-task handling.

The refactoring improves code readability by using tuple unpacking and clear conditional logic. However, based on past review comments, ensure that the multi-task handling at line 1062 is properly covered by tests.

@caic99
Copy link
Member Author

caic99 commented Nov 29, 2024

I'm reverting changes to training.py for there are too many changes in one PR, making it hard to test.

@caic99 caic99 requested a review from njzjz November 29, 2024 02:27
@caic99 caic99 requested a review from wanghan-iapcm November 29, 2024 05:22
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (2)
deepmd/pt/utils/dataloader.py (2)

Line range hint 102-112: Add error handling for parallel dataset construction

While the parallel construction using Pool is a good optimization, consider adding error handling to gracefully handle failures during dataset construction. This is especially important when dealing with multiple systems in parallel.

 with Pool(
     os.cpu_count()
     // (
         int(os.environ["LOCAL_WORLD_SIZE"])
         if dist.is_available() and dist.is_initialized()
         else 1
     )
 ) as pool:
-    self.systems = pool.map(construct_dataset_systems, systems)
+    try:
+        self.systems = pool.map(construct_dataset_systems, systems)
+    except Exception as e:
+        log.error(f"Failed to construct datasets in parallel: {str(e)}")
+        # Fallback to sequential construction
+        self.systems = [construct_dataset_systems(system) for system in systems]

247-256: Good improvement on warning messages and throttling!

The enhanced warning system with throttling addresses the concern about too many warnings. However, consider making the warning threshold configurable.

 class BufferedIterator:
     def __init__(self, iterable) -> None:
         self._queue = Queue(QUEUESIZE)
         self._iterable = iterable
         self._consumer = BackgroundConsumer(self._queue, self._iterable)
         self._consumer.start()
         self.last_warning_time = time.time()
         self.len = len(iterable)
+        self.warning_threshold = 1.0  # seconds
+
+    @property
+    def warning_threshold(self) -> float:
+        return self._warning_threshold
+
+    @warning_threshold.setter
+    def warning_threshold(self, value: float) -> None:
+        if value <= 0:
+            raise ValueError("Warning threshold must be positive")
+        self._warning_threshold = value

     def __next__(self):
         start_wait = time.time()
         item = self._queue.get()
         wait_time = time.time() - start_wait
         if (
-            wait_time > 1.0 and start_wait - self.last_warning_time > 15 * 60
+            wait_time > self.warning_threshold and start_wait - self.last_warning_time > 15 * 60
         ):
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 91d2b9c and 68dd750.

📒 Files selected for processing (1)
  • deepmd/pt/utils/dataloader.py (5 hunks)
🔇 Additional comments (1)
deepmd/pt/utils/dataloader.py (1)

Line range hint 5-64: Good abstraction of dataset construction logic!

The introduction of construct_dataset function and use of partial from functools shows good separation of concerns and functional programming practices. This makes the code more maintainable and testable.

deepmd/pt/utils/dataloader.py Show resolved Hide resolved
@caic99 caic99 requested a review from wanghan-iapcm December 2, 2024 05:25
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (2)
deepmd/pt/utils/dataloader.py (2)

227-231: Consider making queue size configurable

The QUEUESIZE is hardcoded to 32. Consider making it configurable through constructor parameters to allow tuning based on memory constraints and performance requirements.

 class BufferedIterator:
-    def __init__(self, iterable) -> None:
+    def __init__(self, iterable, queue_size: int = 32) -> None:
         self._queue = Queue(QUEUESIZE)
         self._iterable = iterable
         self._consumer = BackgroundConsumer(self._queue, self._iterable)
         self._consumer.start()
         self.last_warning_time = time.time()

240-249: Consider making warning thresholds configurable

The warning thresholds (1.0 second wait time and 15 minutes cooldown) are hardcoded. These values might need adjustment based on different workloads and environments.

 class BufferedIterator:
-    def __init__(self, iterable) -> None:
+    def __init__(
+        self,
+        iterable,
+        wait_threshold: float = 1.0,
+        warning_cooldown: float = 900.0  # 15 minutes
+    ) -> None:
         self._queue = Queue(QUEUESIZE)
         self._iterable = iterable
         self._consumer = BackgroundConsumer(self._queue, self._iterable)
         self._consumer.start()
         self.last_warning_time = time.time()
+        self._wait_threshold = wait_threshold
+        self._warning_cooldown = warning_cooldown

     def __next__(self):
         start_wait = time.time()
         item = self._queue.get()
         wait_time = time.time() - start_wait
         if (
-            wait_time > 1.0 and start_wait - self.last_warning_time > 15 * 60
+            wait_time > self._wait_threshold and
+            start_wait - self.last_warning_time > self._warning_cooldown
         ):
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 68dd750 and a199c44.

📒 Files selected for processing (1)
  • deepmd/pt/utils/dataloader.py (2 hunks)
🔇 Additional comments (2)
deepmd/pt/utils/dataloader.py (2)

8-10: LGTM: Import change improves clarity

The explicit import of Queue is clearer than the generic queue import.


208-220: Consider improving edge case handling

The comment "should not happen for DataLoader" suggests an edge case that should be properly documented or handled. Consider:

  1. Adding documentation explaining why this case shouldn't occur
  2. Adding proper error handling if it does occur unexpectedly

Let's verify if this edge case is handled elsewhere:

Consider this improvement:

 def run(self) -> None:
+    """Background consumer thread that feeds items from source to queue.
+    
+    Note: StopIteration is used as a sentinel to signal completion, though
+    this should not occur during normal DataLoader operation as it typically
+    cycles indefinitely.
+    """
     for item in self._source:
         self._queue.put(item)  # Blocking if the queue is full
     # Signal the consumer we are done; this should not happen for DataLoader
     self._queue.put(StopIteration())
✅ Verification successful

Let me gather more context about the BackgroundConsumer usage to understand how StopIteration is handled by the consumer.


Let me check how the consumer side handles the StopIteration.


Let me check the complete implementation of the consumer class and its usage in the BufferedIterator.


The StopIteration handling is correct and well-designed

The comment "should not happen for DataLoader" is actually accurate and doesn't indicate a problem. Here's why:

  1. The BackgroundConsumer is used within BufferedIterator, which wraps an iterable (DataLoader)
  2. DataLoaders in PyTorch are designed to be infinite iterators during training - they automatically restart when reaching the end of an epoch
  3. The StopIteration signal is properly handled by the consumer-producer pattern:
    • Producer (BackgroundConsumer) signals completion with StopIteration
    • Consumer (BufferedIterator.next) checks if the item is an Exception and raises it

The current implementation follows a standard producer-consumer pattern and handles the edge case appropriately. The comment serves as documentation for maintainers to understand the expected behavior.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for StopIteration handling in the codebase
rg -A 3 "StopIteration" --type py

Length of output: 5685


Script:

#!/bin/bash
# Search for BackgroundConsumer usage and its consumer side
rg -B 3 -A 10 "BackgroundConsumer" --type py

Length of output: 3197


Script:

#!/bin/bash
# Search for the __next__ implementation in the same file
rg -B 2 -A 10 "__next__" deepmd/pt/utils/dataloader.py

Length of output: 587


Script:

#!/bin/bash
# Get the complete implementation of BackgroundConsumer and BufferedIterator
cat deepmd/pt/utils/dataloader.py

Length of output: 10011

@caic99
Copy link
Member Author

caic99 commented Dec 6, 2024

I've trained 100k steps from scratch. Since this PR only changes the warning mechanism of DataLoader, the loss curve does not change significantly. The deviation is not introduced in this PR - I've attached the curve of two runs on devel branch, and they are not exactly the same.

Run 1 on devel branch

image

Run 2 on devel branch

image

This PR

image

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Dec 9, 2024
Merged via the queue into deepmodeling:devel with commit b4ade5c Dec 9, 2024
60 checks passed
@caic99 caic99 deleted the refactor-dl branch December 9, 2024 04:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants