Skip to content

Commit

Permalink
Safety on case insensitive non-string sort
Browse files Browse the repository at this point in the history
  • Loading branch information
UnknownJoe796 committed Oct 24, 2023
1 parent f3e255f commit 66e36fa
Show file tree
Hide file tree
Showing 16 changed files with 101 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,29 @@ inline fun <Model : Any> FieldCollection<Model>.interceptCreate(crossinline inte
): Boolean = wraps.upsertOneIgnoringResult(condition, modification, interceptor(model))
}

/**
* Intercept all kinds of creates, including [FieldCollection.insert], [FieldCollection.upsertOne], and [FieldCollection.upsertOneIgnoringResult].
* Allows you to modify the object before it is actually created.
*/
inline fun <Model : Any> FieldCollection<Model>.interceptCreates(crossinline interceptor: suspend (Iterable<Model>) -> List<Model>): FieldCollection<Model> =
object : FieldCollection<Model> by this {
override val wraps = this@interceptCreates
override suspend fun insert(models: Iterable<Model>): List<Model> =
wraps.insertMany(models.let { interceptor(it) })

override suspend fun upsertOne(
condition: Condition<Model>,
modification: Modification<Model>,
model: Model
): EntryChange<Model> = wraps.upsertOne(condition, modification, interceptor(listOf(model)).first())

override suspend fun upsertOneIgnoringResult(
condition: Condition<Model>,
modification: Modification<Model>,
model: Model
): Boolean = wraps.upsertOneIgnoringResult(condition, modification, interceptor(listOf(model)).first())
}

/**
* Intercepts all kinds of replace operations.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class OauthClientEndpoints(
idSerializer = Serialization.module.serializer()
)
override val authOptions: AuthOptions<HasId<*>> get() = maintainPermissions as AuthOptions<HasId<*>>
override fun baseCollection(): FieldCollection<OauthClient> = database().collection<OauthClient>()
override fun collection(): FieldCollection<OauthClient> = database().collection<OauthClient>()

override suspend fun collection(auth: AuthAccessor<HasId<*>>): FieldCollection<OauthClient> = collection()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ class OneTimePasswordProofEndpoints(
it.value.idSerializer as KSerializer<Comparable<Any>>
),
authOptions = Authentication.isAdmin as AuthOptions<HasId<*>>,
getCollection = {
table(it.value).withPermissions(
getBaseCollection = { table(it.value) },
getCollection = { collection ->
collection.withPermissions(
ModelPermissions(
create = Condition.Always(),
read = Condition.Always(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class PasswordProofEndpoints(
ModelRestEndpoints<HasId<*>, PasswordSecret<Comparable<Any>>, Comparable<Any>>(path("secrets/${it.value.name.lowercase()}"), modelInfo< HasId<*>, PasswordSecret<Comparable<Any>>, Comparable<Any>>(
serialization = ModelSerializationInfo(PasswordSecret.serializer(it.value.idSerializer as KSerializer<Comparable<Any>>), it.value.idSerializer as KSerializer<Comparable<Any>>),
authOptions = Authentication.isAdmin as AuthOptions<HasId<*>>,
getCollection = { table(it.value).withPermissions(ModelPermissions(
getBaseCollection = { table(it.value) },
getCollection = { c -> c.withPermissions(ModelPermissions(
create = Condition.Always(),
read = Condition.Always(),
readMask = Mask(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ class AuthEndpointsForSubject<SUBJECT : HasId<ID>, ID : Comparable<ID>>(
)
)
) + Authentication.isAdmin + Authentication.isSuperUser,
getCollection = {
getBaseCollection = {
database().collection(
sessionSerializer,
"${handler.name}Session"
)
},
getCollection = { it },
forUser = { collection: FieldCollection<Session<SUBJECT, ID>> ->
val requestAuth = this.authOrNull
val canUse: Condition<Session<SUBJECT, ID>> = when {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ import kotlinx.serialization.serializer
@Deprecated("User newer version with auth accessor instead, as it enables more potential optimizations.")
inline fun <reified USER : HasId<*>, reified T : HasId<ID>, reified ID : Comparable<ID>> ModelInfo(
noinline getCollection: () -> FieldCollection<T>,
noinline getBaseCollection: () -> FieldCollection<T> = { getCollection() },
noinline forUser: suspend FieldCollection<T>.(principal: USER) -> FieldCollection<T>,
modelName: String = Serialization.module.serializer<T>().descriptor.serialName.substringBefore('<')
.substringAfterLast('.'),
) = ModelInfo(
serialization = ModelSerializationInfo<T, ID>(),
authOptions = com.lightningkite.lightningserver.auth.authOptions<USER>(),
getCollection = getCollection,
getBaseCollection = getBaseCollection,
forUser = forUser,
modelName = modelName,
)
Expand All @@ -28,11 +30,13 @@ fun <USER : HasId<*>, T : HasId<ID>, ID : Comparable<ID>> ModelInfo(
serialization: ModelSerializationInfo<T, ID>,
authOptions: AuthOptions<USER>,
getCollection: () -> FieldCollection<T>,
getBaseCollection: () -> FieldCollection<T> = { getCollection() },
forUser: suspend FieldCollection<T>.(principal: USER) -> FieldCollection<T>,
modelName: String = serialization.serializer.descriptor.serialName.substringBefore('<').substringAfterLast('.')
) = object : ModelInfo<USER, T, ID> {
override val authOptions: AuthOptions<USER> = authOptions
override val serialization: ModelSerializationInfo<T, ID> = serialization
override fun baseCollection(): FieldCollection<T> = getBaseCollection()
override fun collection(): FieldCollection<T> = getCollection()
override suspend fun collection(auth: AuthAccessor<USER>): FieldCollection<T> = forUser(collection(), auth.user())

Expand All @@ -42,13 +46,15 @@ fun <USER : HasId<*>, T : HasId<ID>, ID : Comparable<ID>> ModelInfo(
fun <USER : HasId<*>?, T : HasId<ID>, ID : Comparable<ID>> modelInfo(
serialization: ModelSerializationInfo<T, ID>,
authOptions: AuthOptions<USER>,
getCollection: () -> FieldCollection<T>,
getBaseCollection: () -> FieldCollection<T>,
getCollection: (collection: FieldCollection<T>) -> FieldCollection<T> = { it },
forUser: suspend AuthAccessor<USER>.(collection: FieldCollection<T>) -> FieldCollection<T> = { it },
modelName: String = serialization.serializer.descriptor.serialName.substringBefore('<').substringAfterLast('.')
) = object : ModelInfo<USER, T, ID> {
override val authOptions: AuthOptions<USER> = authOptions
override val serialization: ModelSerializationInfo<T, ID> = serialization
override fun collection(): FieldCollection<T> = getCollection()
override fun baseCollection(): FieldCollection<T> = getBaseCollection()
override fun collection(): FieldCollection<T> = getCollection(this.baseCollection())
override suspend fun collection(auth: AuthAccessor<USER>): FieldCollection<T> =
auth.forUser(this.collection())

Expand All @@ -72,7 +78,8 @@ interface ModelInfo<USER : HasId<*>?, T : HasId<ID>, ID : Comparable<ID>> {
val collectionName: String
get() = serialization.serializer.descriptor.serialName.substringBefore('<').substringAfterLast('.')

fun collection(): FieldCollection<T>
fun baseCollection(): FieldCollection<T> = collection()
fun collection(): FieldCollection<T> = baseCollection()
suspend fun collection(auth: AuthAccessor<USER>): FieldCollection<T>
suspend fun collection(user: USER): FieldCollection<T> = collection(AuthAccessor.test(user))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ import kotlinx.serialization.serializer
@Deprecated("User newer version with auth accessor instead, as it enables more potential optimizations.")
inline fun <reified USER: HasId<*>?, reified T : HasId<ID>, reified ID : Comparable<ID>> ModelInfoWithDefault(
noinline getCollection: () -> FieldCollection<T>,
noinline getBaseCollection: () -> FieldCollection<T> = { getCollection() },
noinline forUser: suspend FieldCollection<T>.(principal: USER) -> FieldCollection<T>,
modelName: String = Serialization.module.serializer<T>().descriptor.serialName.substringBefore('<').substringAfterLast('.'),
noinline defaultItem: suspend (auth: USER) -> T,
noinline exampleItem: ()->T? = { null },
) = ModelInfoWithDefault(
): ModelInfoWithDefault<USER, T, ID> = ModelInfoWithDefault(
serialization = ModelSerializationInfo<T, ID>(),
authOptions = com.lightningkite.lightningserver.auth.authOptions<USER>(),
getCollection = getCollection,
getBaseCollection = getBaseCollection,
forUser = forUser,
modelName = modelName,
defaultItem = defaultItem,
Expand All @@ -30,13 +32,15 @@ fun <USER: HasId<*>?, T : HasId<ID>, ID : Comparable<ID>> ModelInfoWithDefault(
serialization: ModelSerializationInfo<T, ID>,
authOptions: AuthOptions<USER>,
getCollection: () -> FieldCollection<T>,
getBaseCollection: () -> FieldCollection<T> = { getCollection() },
forUser: suspend FieldCollection<T>.(principal: USER) -> FieldCollection<T>,
modelName: String = serialization.serializer.descriptor.serialName.substringBefore('<').substringAfterLast('.'),
defaultItem: suspend (auth: USER) -> T,
exampleItem: ()->T? = { null },
) = object : ModelInfoWithDefault<USER, T, ID> {
override val authOptions: AuthOptions<USER> = authOptions
override val serialization: ModelSerializationInfo<T, ID> = serialization
override fun baseCollection(): FieldCollection<T> = getBaseCollection()
override fun collection(): FieldCollection<T> = getCollection()
override suspend fun collection(auth: AuthAccessor<USER>): FieldCollection<T> = forUser(collection(), auth.user())

Expand All @@ -48,15 +52,17 @@ fun <USER: HasId<*>?, T : HasId<ID>, ID : Comparable<ID>> ModelInfoWithDefault(
fun <USER: HasId<*>?, T : HasId<ID>, ID : Comparable<ID>> modelInfoWithDefault(
serialization: ModelSerializationInfo<T, ID>,
authOptions: AuthOptions<USER>,
getCollection: () -> FieldCollection<T>,
getBaseCollection: () -> FieldCollection<T>,
getCollection: (collection: FieldCollection<T>) -> FieldCollection<T> = { it },
forUser: suspend AuthAccessor<USER>.(collection: FieldCollection<T>) -> FieldCollection<T> = { it },
modelName: String = serialization.serializer.descriptor.serialName.substringBefore('<').substringAfterLast('.'),
defaultItem: suspend AuthAccessor<USER>.() -> T,
exampleItem: ()->T? = { null },
) = object : ModelInfoWithDefault<USER, T, ID> {
override val authOptions: AuthOptions<USER> = authOptions
override val serialization: ModelSerializationInfo<T, ID> = serialization
override fun collection(): FieldCollection<T> = getCollection()
override fun baseCollection(): FieldCollection<T> = getBaseCollection()
override fun collection(): FieldCollection<T> = getCollection(this.baseCollection())
override suspend fun collection(auth: AuthAccessor<USER>): FieldCollection<T> =
auth.forUser(this.collection())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ class GroupedDatabaseExceptionReporter(val packageName: String, val database: Da
idSerializer = Serialization.module.serializer()
)
override val authOptions: AuthOptions<HasId<*>> get() = Authentication.isDeveloper as AuthOptions<HasId<*>>
override fun collection(): FieldCollection<ReportedExceptionGroup> = database.collection<ReportedExceptionGroup>()
override fun baseCollection(): FieldCollection<ReportedExceptionGroup> = database.collection<ReportedExceptionGroup>()
override fun collection(): FieldCollection<ReportedExceptionGroup> = baseCollection()

override suspend fun collection(auth: AuthAccessor<HasId<*>>): FieldCollection<ReportedExceptionGroup> = collection()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ExternalAsyncTaskIntegration<REQUEST, RESPONSE : HasId<String>, RESULT>(
serializer = ExternalAsyncTaskRequest.serializer(),
idSerializer = String.serializer()
),
getCollection = {
getBaseCollection = {
database().collection<ExternalAsyncTaskRequest>(name = "$path/ExternalTaskRequest")
},
defaultItem = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ object TestSettings: ServerPathGroup(ServerPath.root) {
info = modelInfo(
authOptions = noAuth,
serialization = ModelSerializationInfo(),
getCollection = { database().collection() },
getBaseCollection = { database().collection() },
forUser = { it },
)
)
Expand All @@ -74,7 +74,7 @@ object TestSettings: ServerPathGroup(ServerPath.root) {
info = modelInfo(
authOptions = noAuth,
serialization = ModelSerializationInfo(),
getCollection = { database().collection() },
getBaseCollection = { database().collection() },
forUser = { it },
),
key = TestThing__id
Expand All @@ -87,7 +87,7 @@ object TestSettings: ServerPathGroup(ServerPath.root) {
}

val userInfo = modelInfoWithDefault<TestUser, TestUser, UUID>(
getCollection = {
getBaseCollection = {
database().collection<TestUser>()
},
defaultItem = { TestUser(email = "") },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.descriptors.PrimitiveKind
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.SchemaUtils.statementsRequiredToActualizeScheme
Expand All @@ -25,7 +26,8 @@ class PostgresCollection<T : Any>(

val table = SerialDescriptorTable(name, serializer.descriptor)

suspend inline fun <T> t(noinline action: suspend Transaction.()->T): T = newSuspendedTransaction(Dispatchers.IO, db = db, transactionIsolation = TRANSACTION_READ_COMMITTED, action)
suspend inline fun <T> t(noinline action: suspend Transaction.() -> T): T =
newSuspendedTransaction(Dispatchers.IO, db = db, transactionIsolation = TRANSACTION_READ_COMMITTED, action)

@OptIn(DelicateCoroutinesApi::class, ExperimentalSerializationApi::class)
val prepare = GlobalScope.async(Dispatchers.Unconfined, start = CoroutineStart.LAZY) {
Expand All @@ -47,7 +49,7 @@ class PostgresCollection<T : Any>(
val items = t {
table
.select { condition(condition, serializer, table).asOp() }
.orderBy(*orderBy.map { (if(it.ignoreCase) (table.col[it.field.colName]!! as Column<String>).lowerCase() else table.col[it.field.colName]!!) to if (it.ascending) SortOrder.ASC else SortOrder.DESC }
.orderBy(*orderBy.map { (if (it.ignoreCase && it.field.serializerAny.descriptor.kind == PrimitiveKind.STRING) (table.col[it.field.colName]!! as Column<String>).lowerCase() else table.col[it.field.colName]!!) to if (it.ascending) SortOrder.ASC else SortOrder.DESC }
.toTypedArray())
.limit(limit, skip.toLong())
// .prep
Expand Down Expand Up @@ -84,7 +86,7 @@ class PostgresCollection<T : Any>(
prepare.await()
return t {
val valueCol = table.col[property.colName] as Column<Number>
val agg = when(aggregate) {
val agg = when (aggregate) {
Aggregate.Sum -> Sum(valueCol, DecimalColumnType(Int.MAX_VALUE, 8))
Aggregate.Average -> Avg<Double, Double>(valueCol, 8)
Aggregate.StandardDeviationSample -> StdDevSamp(valueCol, 8)
Expand All @@ -106,7 +108,7 @@ class PostgresCollection<T : Any>(
return t {
val groupCol = table.col[groupBy.colName] as Column<Key>
val valueCol = table.col[property.colName] as Column<Number>
val agg = when(aggregate) {
val agg = when (aggregate) {
Aggregate.Sum -> Sum(valueCol, DoubleColumnType())
Aggregate.Average -> Avg<Double, Double>(valueCol, 8)
Aggregate.StandardDeviationSample -> StdDevSamp(valueCol, 8)
Expand All @@ -132,14 +134,22 @@ class PostgresCollection<T : Any>(
return updateOneImpl(condition, Modification.Assign(model), orderBy)
}

override suspend fun replaceOneIgnoringResultImpl(condition: Condition<T>, model: T, orderBy: List<SortPart<T>>): Boolean {
override suspend fun replaceOneIgnoringResultImpl(
condition: Condition<T>,
model: T,
orderBy: List<SortPart<T>>
): Boolean {
return updateOneIgnoringResultImpl(condition, Modification.Assign(model), orderBy)
}

override suspend fun upsertOneImpl(condition: Condition<T>, modification: Modification<T>, model: T): EntryChange<T> {
override suspend fun upsertOneImpl(
condition: Condition<T>,
modification: Modification<T>,
model: T
): EntryChange<T> {
return newSuspendedTransaction(db = db, transactionIsolation = TRANSACTION_SERIALIZABLE) {
val existing = findOne(condition)
if(existing == null) {
if (existing == null) {
EntryChange(null, insertImpl(listOf(model)).first())
} else
updateOneImpl(condition, modification)
Expand All @@ -153,7 +163,7 @@ class PostgresCollection<T : Any>(
): Boolean {
return newSuspendedTransaction(db = db, transactionIsolation = TRANSACTION_SERIALIZABLE) {
val existing = findOne(condition)
if(existing == null) {
if (existing == null) {
insertImpl(listOf(model))
false
} else
Expand All @@ -166,7 +176,7 @@ class PostgresCollection<T : Any>(
modification: Modification<T>,
orderBy: List<SortPart<T>>
): EntryChange<T> {
if(orderBy.isNotEmpty()) throw UnsupportedOperationException()
if (orderBy.isNotEmpty()) throw UnsupportedOperationException()
return t {
val old = table.updateReturningOld(
where = { condition(condition, serializer, table).asOp() },
Expand All @@ -186,7 +196,7 @@ class PostgresCollection<T : Any>(
modification: Modification<T>,
orderBy: List<SortPart<T>>
): Boolean {
if(orderBy.isNotEmpty()) throw UnsupportedOperationException()
if (orderBy.isNotEmpty()) throw UnsupportedOperationException()
return t {
table.update(
where = { condition(condition, serializer, table).asOp() },
Expand Down Expand Up @@ -226,7 +236,7 @@ class PostgresCollection<T : Any>(
}

override suspend fun deleteOneImpl(condition: Condition<T>, orderBy: List<SortPart<T>>): T? {
if(orderBy.isNotEmpty()) throw UnsupportedOperationException()
if (orderBy.isNotEmpty()) throw UnsupportedOperationException()
return t {
table.deleteReturningWhere(
limit = 1,
Expand All @@ -236,7 +246,7 @@ class PostgresCollection<T : Any>(
}

override suspend fun deleteOneIgnoringOldImpl(condition: Condition<T>, orderBy: List<SortPart<T>>): Boolean {
if(orderBy.isNotEmpty()) throw UnsupportedOperationException()
if (orderBy.isNotEmpty()) throw UnsupportedOperationException()
return t {
table.deleteWhere(
limit = 1,
Expand Down
Loading

0 comments on commit 66e36fa

Please sign in to comment.