Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sub-byte custom element types #63

Open
hawkinsp opened this issue May 3, 2023 · 2 comments
Open

Sub-byte custom element types #63

hawkinsp opened this issue May 3, 2023 · 2 comments

Comments

@hawkinsp
Copy link

hawkinsp commented May 3, 2023

I'm wondering if anyone has considered sub-byte user types as part of the API design. The assumption that types have element sizes that are integer numbers of bytes seems to be baked into current user dtypes design.

For example, consider a packed int4 type, whose elements are nibbles (half-bytes). It doesn't seem possible to express such a type in the user dtype API, e.g., the type descriptor requires a byte size for each element, and ufuncs require byte strides. I was toying with the idea of trying to write such a type, but it doesn't seem to be possible without API changes.

I see this was mentioned briefly in NEP 41 (int2) but there is no further discussion I can see. Has there been any more thinking along those lines?

@seberg
Copy link
Member

seberg commented May 3, 2023

Well, there was always the issue of ABI. Now, we can change ABI for the DType at least in 2.0 and I am planning to do it to make additions easier. So representing a bit-sized dtype should be relatively straight forward with that. Maybe as:

  • A flag or even a negative size or so.
  • Keep byte-size, but additionally indicate a bit size.

The bigger problem is how you actually work with it. We pass in dtypes and some state to ufuncs/cast, etc. would a dtype indicating bit-sized always get passed bit pointers and bit-strides?

Right now, we pass around pointers such as char *. For 32bit systems, I think you would have to change this to be a 64bit number (strides/array sizes seem acceptable at 32bit, but a 32bit pointer won't do).
We could do that, would it be annoying for normal byte-loops?

So a lot of questions, and I don't have all the answer:

  • Lets say we support for the numpy array to store such a dtype directly, how would you do that? Still use a normal pointer that remains valid, but indicate the bit-offset somewhere (i.e. as a last stride?)
  • The main question: How do we make it that casts/ufunc signatures can work with it decently. Let's not worry about NumPy itself there!
    • I would like if your int4 would very transparently work as an Int4(bitsize=8) so that you could store it into a byte-strided array easily. But if that is to give convenience, it may be that your dtype would need bit and byte loops?
    • Or would you actually have bit- and byte- loops with distinct signatures? We would just have a separate mechanism to fetch the bit sized versions (or a flag asking for it?). (We could indicate it additionally, so you could have a single loop in principle.)

I also don't think you could update all of NumPy functionality on a reasonable timeline without hiring 1-2 people dedicated to it. So we could do additions to the array object to allow it in principle and even get it to work with most things. But I doubt you will get it to work with everything quickly (which is fine by me, you could just add a NotImplementedError).

Anyway, questions... The interesting part might be if we can formulate anything that affects the array object or the inner-loop signatures for ufuncs/casts.

@hawkinsp
Copy link
Author

jax-ml/ml_dtypes#71 is a prototype of adding a padded int4 NumPy type to ml_dtypes. Since we cannot represent sub-byte types, the best we can do is use an int4 type where each element is padded up to a byte.

This is good enough for my use case, although the padding makes me a little bit sad.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants