diff --git a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py index 706746590b..b61d98271c 100644 --- a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +++ b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py @@ -66,7 +66,18 @@ def upgrade() -> None: sa.TIMESTAMP(timezone=True), nullable=True, ), - sa.CheckConstraint("password_hash is null or password_salt is not null", name="salt"), + sa.CheckConstraint( + "(password_hash IS NULL) = (password_salt IS NULL)", + name="password_hash_and_salt", + ), + sa.CheckConstraint( + "(oauth2_client_id IS NULL) = (oauth2_user_id IS NULL)", + name="oauth2_client_id_and_user_id", + ), + sa.CheckConstraint( + "password_hash IS NULL or oauth2_client_id IS NULL", + name="at_most_one_auth_method", + ), sa.UniqueConstraint( "oauth2_client_id", "oauth2_user_id", diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 27cd868447..bdf12cd4b3 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -690,7 +690,18 @@ def _auth_method_expression(cls) -> ColumnElement[Optional[str]]: ) __table_args__ = ( - CheckConstraint("password_hash is null or password_salt is not null", name="salt"), + CheckConstraint( + "(password_hash IS NULL) = (password_salt IS NULL)", + name="password_hash_and_salt", + ), + CheckConstraint( + "(oauth2_client_id IS NULL) = (oauth2_user_id IS NULL)", + name="oauth2_client_id_and_user_id", + ), + CheckConstraint( + "password_hash IS NULL or oauth2_client_id IS NULL", + name="at_most_one_auth_method", + ), UniqueConstraint( "oauth2_client_id", "oauth2_user_id",