Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various fixes for IdentityLinearOperator with different input/output structures #50

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,21 +681,37 @@ def __init__(
output_structure = _inexact_structure(output_structure)
self.input_structure = jtu.tree_flatten(input_structure)
self.output_structure = jtu.tree_flatten(output_structure)
if self.in_size() != self.out_size():
raise ValueError(
"input and output structures must have the same number of elements."
)

def mv(self, vector):
if jax.eval_shape(lambda: vector) != self.in_structure():
raise ValueError("Vector and operator structures do not match")
return vector
elif self.input_structure == self.output_structure:
return vector # fast-path for common special case
else:
# TODO(kidger): this could be done slightly more efficiently, by iterating
# leaf-by-leaf.
leaves = jtu.tree_leaves(vector)
dtype = jnp.result_type(*leaves)
vector = jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves])
out_size = self.out_size()
if vector.size < out_size:
vector = jnp.concatenate(
[vector, jnp.zeros(out_size - vector.size, vector.dtype)]
)
else:
vector = vector[:out_size]
leaves, treedef = jtu.tree_flatten(self.out_structure())
sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]])
split = jnp.split(vector, sizes)
assert len(split) == len(leaves)
shaped = [x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)]
return jtu.tree_unflatten(treedef, shaped)

def as_matrix(self):
return jnp.eye(self.in_size())
return jnp.eye(self.out_size(), self.in_size())

def transpose(self):
return self
return IdentityLinearOperator(self.out_structure(), self.in_structure())

def in_structure(self):
leaves, treedef = self.input_structure
Expand Down Expand Up @@ -1697,7 +1713,6 @@ def _(operator):

@diagonal.register(TaggedLinearOperator)
def _(operator):
# Untagged; we might not have any of the properties our tags represent any more.
return diagonal(operator.operator)


Expand Down
29 changes: 29 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,32 @@ def run(diag):
return out.value

jax.jvp(run, (diag,), (t_diag,))


def test_identity_with_different_structures():
structure1 = (
jax.ShapeDtypeStruct((), jnp.float32),
jax.ShapeDtypeStruct((2, 3), jnp.float16),
)
structure2 = {"a": jax.ShapeDtypeStruct((5,), jnp.float32)}
# structure3 = (None, jax.ShapeDtypeStruct((2, 3), jnp.float16))
op1 = lx.IdentityLinearOperator(structure1, structure2)
op2 = lx.IdentityLinearOperator(structure2, structure1)
# op3 = lx.IdentityLinearOperator(structure3, structure2)

assert op1.T == op2
# assert op2.transpose((True, False)) == op3
assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7))
assert op1.in_size() == 7
assert op1.out_size() == 5
vec1 = (
jnp.array(1.0, dtype=jnp.float32),
jnp.array([[2, 3, 4], [5, 6, 7]], dtype=jnp.float16),
)
vec2 = {"a": jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=jnp.float32)}
vec1b = (
jnp.array(1.0, dtype=jnp.float32),
jnp.array([[2, 3, 4], [5, 0, 0]], dtype=jnp.float16),
)
assert shaped_allclose(op1.mv(vec1), vec2)
assert shaped_allclose(op2.mv(vec2), vec1b)
Loading