-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split #16
base: main
Are you sure you want to change the base?
Split #16
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
// limitations under the License. | ||
|
||
extension ShapedArray { | ||
|
||
@inlinable | ||
public func reshaped(to newShape: [Int]) -> ShapedArray { | ||
let shape = { | ||
if newShape.contains(-1) { | ||
|
@@ -29,16 +31,18 @@ extension ShapedArray { | |
return newShape | ||
} | ||
}() | ||
|
||
precondition(shape.reduce(1, *) == self.scalarCount, "Cannot reshape to shape \(shape) because it has a different number of elements than the original shape \(self.shape).") | ||
|
||
return .init(shape: shape, scalars: self.scalars) | ||
} | ||
|
||
@inlinable | ||
public func reshaped(to newShape: Int...) -> ShapedArray { | ||
self.reshaped(to: newShape) | ||
} | ||
|
||
@inlinable | ||
public func reshaped<T>(like other: ShapedArray<T>) -> ShapedArray { | ||
self.reshaped(to: other.shape) | ||
} | ||
|
@@ -49,6 +53,115 @@ extension ShapedArray { | |
self.reshaped(to: -1) | ||
} | ||
|
||
// helper data structure for subarray generation | ||
@usableFromInline | ||
internal struct SubarrayIndices { | ||
@usableFromInline | ||
let start: [Int] | ||
@usableFromInline | ||
let end: [Int] | ||
} | ||
|
||
// helper function for subarray indices generation | ||
/// takes a shape for the original array, a shape for the subarray, | ||
/// and an axis on which the original array would be divided | ||
/// and returns an array of start and end indices for each subarray | ||
|
||
@usableFromInline | ||
internal func _calculateSubarrayIndices(shape: [Int], subarrayShape: [Int], axis: Int) -> [SubarrayIndices] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only use of the However, i think we can improve on the name more too. Possibly slicedSubArrayIndices(forShape shape: [Int], subarrayShape: [Int], alongAxis axis: Int) -> [SubarrayIndices] ? Lastly, and maybe this isn't possible with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or maybe |
||
let numberOfSubarrays = shape[axis] / subarrayShape[axis] | ||
|
||
var subarrays = [SubarrayIndices]() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Trying to think of a name that better captures this type, maybe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
for i in 0..<numberOfSubarrays { | ||
var startIndices = [Int](repeating: 0, count: shape.count) | ||
var endIndices = shape | ||
|
||
startIndices[axis] = i * subarrayShape[axis] | ||
endIndices[axis] = startIndices[axis] + subarrayShape[axis] | ||
|
||
subarrays.append(SubarrayIndices(start: startIndices, end: endIndices)) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this may look a bit nicer with a map let subarrays = (0..<numberSubarrays).map { i in
var startIndices = [Int](repeating: 0, count: shape.count)
var endIndices = shape
startIndices[axis] = i * subarrayShape[axis]
endIndices[axis] = startIndices[axis] + subarrayShape[axis]
return SubarrayIndices(start: startIndices, end: endIndices)
} |
||
|
||
return subarrays | ||
} | ||
|
||
// helper function for subarray indices generation | ||
/// takes a start and end index for each dimension and returns all the n-dim indices in between | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as before about doc comments. |
||
@usableFromInline | ||
internal func _generateIndices(start: [Int], end: [Int], currentIndex: [Int] = [], depth: Int = 0) -> [[Int]] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe allIndices(fromStart start: [Int], toEnd end: [Int], currentIndices: [Int] = [], depth: Int = 0) -> [[Int]] |
||
if depth == start.count { | ||
return [currentIndex] | ||
} | ||
|
||
var indices = [[Int]]() | ||
|
||
for i in start[depth]..<end[depth] { | ||
let newIndex = currentIndex + [i] | ||
let subIndices = _generateIndices(start: start, end: end, currentIndex: newIndex, depth: depth + 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't tail recursive so there's no tail-call optimization. Granted probably plenty to optimize in the future, but maybe add a comment here calling that out explicitly? |
||
indices += subIndices | ||
} | ||
return indices | ||
} | ||
|
||
// helper function for subarray indices generation | ||
/// takes a shape, strides, and n-dim indices and returns the linear index | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
@usableFromInline | ||
internal func _calculateLinearIndex(shape: [Int], strides: [Int], indices: [Int]) -> Int { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe linearIndexFrom(shape: [Int], strides: [Int], indices: [Int]) -> Int ? |
||
var linearIndex = 0 | ||
for i in 0..<shape.count { | ||
let dimSize = shape[i] | ||
let stride = strides[i] | ||
let index = indices[i] | ||
if index >= dimSize { | ||
fatalError("Index out of bounds") | ||
} | ||
linearIndex += stride * index | ||
} | ||
return linearIndex | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indentation is very off here... |
||
} | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only 1 newline please |
||
/// Splits the ShapedArray into multiple subarrays along the given axis. | ||
/// - Parameters: | ||
/// - count: The number of subarrays to return. | ||
/// - axis: The axis along which to split the ShapedArray. Negative values wrap around. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
/// - Returns: An array of ShapedArrays. | ||
/// - Precondition: `count` must evenly divide the size of the ShapedArray along the given axis. | ||
/// - Precondition: `axis` must be in the range `[-rank, rank)`. | ||
@inlinable | ||
public func split(count: Int, alongAxis axis: Int = 0) -> [ShapedArray] { | ||
ensureValid(axis: axis) | ||
let axis = axis < 0 ? axis + self.rank : axis | ||
|
||
let newShape = self.shape.enumerated().map { $0.0 == axis ? $0.1 / count : $0.1 } | ||
let scalarsPerArray = newShape.reduce(1, *) | ||
|
||
// Generate the n-dim start and end indices for each subarray | ||
let indices = _calculateSubarrayIndices(shape: self.shape, subarrayShape: newShape, axis: axis) | ||
|
||
let newArrays = (0..<count).map { i -> ShapedArray in | ||
|
||
// Generate all the n-dim indices for each subarray using the start and end indices | ||
let allIndices = _generateIndices(start: indices[i].start, end: indices[i].end) | ||
|
||
let scalars = Array<Scalar>(unsafeUninitializedCapacity: scalarsPerArray) { buffer, initializedCount in | ||
for j in 0..<scalarsPerArray { | ||
// Calculate the linear index for each n-dim index | ||
let index = _calculateLinearIndex(shape: self.shape, strides: self.stride, indices: allIndices[j]) | ||
|
||
let value = self.scalars[index] | ||
buffer[j] = value | ||
} | ||
initializedCount = scalarsPerArray | ||
} | ||
return ShapedArray(shape: newShape, scalars: scalars) | ||
} | ||
|
||
return newArrays | ||
} | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only 1 newline please |
||
/// Unpacks the given dimension of a rank-`R` ShapedArray into multiple rank-`(R-1)` ShapedArrays. | ||
/// Unpacks `N` ShapedArrays from this ShapedArray by chipping it along the `axis` dimension, where `N` | ||
/// is inferred from this ShapedArray's shape. For example, given a ShapedArray with shape | ||
|
@@ -70,15 +183,17 @@ extension ShapedArray { | |
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of | ||
/// the provided ShapedArrays. | ||
/// | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. both newlines here should be deleted please, this breaks up the comment |
||
/// - Returns: Array containing the unstacked ShapedArrays. | ||
@inlinable | ||
public func unstacked(alongAxis axis: Int = 0) -> [ShapedArray] { | ||
ensureValid(axis: axis) | ||
let axis = axis < 0 ? axis + self.rank : axis | ||
let numberOfArrays = self.shape[axis] | ||
|
||
let lengthAfterAxis = self.shape[(axis + 1)..<self.shape.count].reduce(1, *) | ||
let lengthAtAxis = self.shape[axis..<self.shape.count].reduce(1, *) | ||
let lengthAfterAxis = self.stride[axis] | ||
let lengthAtAxis = axis == 0 ? self.stride.reduce(1, *) : self.stride[axis - 1] | ||
let newShape = self.shape.enumerated().filter { $0.0 != axis }.map { $0.1 } | ||
let scalarsPerArray = newShape.reduce(1, *) | ||
|
||
|
@@ -99,6 +214,47 @@ extension ShapedArray { | |
|
||
return newArrays.map { ShapedArray(shape: newShape, scalars: $0) } | ||
} | ||
|
||
@inlinable | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some docs we can include:
|
||
public func expandingShape(at axes: [Int]) -> ShapedArray { | ||
var resultShape = self.shape | ||
for i in axes { | ||
var dim = i | ||
if dim < 0 { dim += resultShape.count + 1 } | ||
resultShape.insert(1, at: dim) | ||
} | ||
return self.reshaped(to: resultShape) | ||
} | ||
|
||
@inlinable | ||
public func expandingShape(at axes: Int...) -> ShapedArray { | ||
return self.expandingShape(at: axes) | ||
} | ||
|
||
@inlinable | ||
public func rankLifted() -> ShapedArray { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possible docs for this:
|
||
return self.expandingShape(at: 0) | ||
} | ||
|
||
/// Removes the specified dimensions of size 1 from the shape of a tensor. If no dimensions are | ||
/// specified, then all dimensions of size 1 will be removed. | ||
@inlinable | ||
public func squeezingShape(at axes: [Int]) -> ShapedArray { | ||
var resultShape = self.shape | ||
for i in 0..<shape.count { | ||
if axes.contains(i) || (axes.isEmpty && shape[i] == 1) { | ||
precondition(shape[i] == 1, "Can't squeeze axis \(i) since its size is not 1") | ||
resultShape.remove(at: i) | ||
} | ||
} | ||
return self.reshaped(to: resultShape) | ||
} | ||
|
||
@inlinable | ||
public func squeezingShape(at axes: Int...) -> ShapedArray { | ||
return self.squeezingShape(at: axes) | ||
} | ||
|
||
} | ||
|
||
//===------------------------------------------------------------------------------------------===// | ||
|
@@ -139,7 +295,7 @@ extension ShapedArray { | |
line: line) | ||
return self.areValid(axes: axes.scalars) | ||
} | ||
|
||
/// Checks that each element of `axes` denotes an axis of `self`, and stops the program with a | ||
/// diagnostic otherwise. | ||
@usableFromInline | ||
|
@@ -155,7 +311,7 @@ extension ShapedArray { | |
file: file, | ||
line: line) | ||
} | ||
|
||
/// Checks that each element of `axes` denotes an axis of `self`, and stops the program with a | ||
/// diagnostic otherwise. | ||
@usableFromInline | ||
|
@@ -171,7 +327,7 @@ extension ShapedArray { | |
file: file, | ||
line: line) | ||
} | ||
|
||
/// Checks that `k` denotes an axis of `self`, and stops the program with a diagnostic otherwise. | ||
@usableFromInline | ||
func ensureValid( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -318,3 +318,14 @@ extension _ShapedArrayProtocol where Scalar: Equatable { | |
} | ||
} | ||
} | ||
|
||
extension _ShapedArrayProtocol where Scalar: Comparable { | ||
internal func _isLess(than other: Self) -> Bool { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see this function used anywhere, is this something we need to keep? |
||
return shape == other.shape | ||
&& withUnsafeBufferPointer { selfBuf in | ||
other.withUnsafeBufferPointer { otherBuf in | ||
selfBuf.lexicographicallyPrecedes(otherBuf) | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For doc comments, let's use both starting capitalization and ending punctuation. also
// helper function
is missing an/