Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Stats compute perf #449

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

Conversation

Kh4L
Copy link

@Kh4L Kh4L commented Sep 9, 2024

This PR improves performance of StatType.compute by calling np.hstack only when needed.

Before change:

$ time python  gnn_node.py --dataset=rel-hm --task=user-churn --epochs 1 --max_steps_per_epoch 250
[...]
real    3m17.517s
user    3m19.252s
sys     0m6.757s

After change:

$ time python  gnn_node.py --dataset=rel-hm --task=user-churn --epochs 1 --max_steps_per_epoch 250
[...]
real    1m43.020s
user    1m45.716s
sys     0m6.022s

@weihua916
Copy link
Contributor

Thanks.
@Kh4L could you pass the test? @akihironitta could you help review/update this PR?

Copy link
Member

@akihironitta akihironitta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Kh4L, this PR looks great so far! 🔥 Would you still be interested in finishing this PR?

Comment on lines +88 to +89
val = np.hstack(ser.values) if ser.values.ndim > 1 else ser.values
flattened = np.hstack(val) if val.ndim > 1 else val
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the chagne doesn't take account of the sequence_numerical case:

self = <StatType.MEAN: 'MEAN'>
ser = 0                                                 [nan]
1         [0.3309801272693331, nan, 0.3018088119575514]
2     ...42, nan, 0.9997390177887806, ...
19                            [nan, 0.36[1182](https://github.com/pyg-team/pytorch-frame/pull/449/checks#step:6:1183)4486676457]
Name: seq_num_1, dtype: object
sep = None

    def compute(
        self,
        ser: Series,
        sep: str | None = None,
    ) -> Any:
        if self == StatType.MEAN:
            val = np.hstack(ser.values) if ser.values.ndim > 1 else ser.values
            flattened = np.hstack(val) if val.ndim > 1 else val
>           finite_mask = np.isfinite(flattened)
E           TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''

torch_frame/data/stats.py:90: TypeError

https://github.com/pyg-team/pytorch-frame/actions/runs/12376953024/job/34545299882

@akihironitta akihironitta self-assigned this Dec 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants