Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a minimal fix for Penalty calculation in UndelegateClaimCircuit #3478

Merged
merged 7 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/bench/benches/undelegate_claim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fn undelegate_claim_proving_time(c: &mut Criterion) {
let start_epoch_index = 1;
let unbonding_token = UnbondingToken::new(validator_identity, start_epoch_index);
let unbonding_id = unbonding_token.id();
let penalty = Penalty(1u64);
let penalty = Penalty::from_bps_squared(1u64);
let balance = penalty.balance_for_claim(unbonding_id, unbonding_amount);
let balance_commitment = balance.commit(balance_blinding);

Expand Down
2 changes: 1 addition & 1 deletion crates/bin/pcli/tests/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ fn undelegate_claim_parameters_vs_current_undelegate_claim_circuit() {
let start_epoch_index = 1;
let unbonding_token = UnbondingToken::new(validator_identity, start_epoch_index);
let unbonding_id = unbonding_token.id();
let penalty = Penalty(penalty_amount);
let penalty = Penalty::from_bps_squared(penalty_amount);
let balance = penalty.balance_for_claim(unbonding_id, unbonding_amount);
let balance_commitment = balance.commit(balance_blinding);

Expand Down
10 changes: 5 additions & 5 deletions crates/core/component/stake/src/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ pub(crate) trait StakingImpl: StateWriteExt {
let penalty = self.get_stake_params().await?.slashing_penalty_downtime;

// Record the slashing penalty on this validator.
self.record_slashing_penalty(identity_key, Penalty(penalty))
self.record_slashing_penalty(identity_key, Penalty::from_bps_squared(penalty))
.await?;

// The validator's delegation pool begins unbonding. Jailed
Expand All @@ -273,7 +273,7 @@ pub(crate) trait StakingImpl: StateWriteExt {
let penalty = self.get_stake_params().await?.slashing_penalty_misbehavior;

// Record the slashing penalty on this validator.
self.record_slashing_penalty(identity_key, Penalty(penalty))
self.record_slashing_penalty(identity_key, Penalty::from_bps_squared(penalty))
.await?;

// Regardless of its current bonding state, the validator's
Expand Down Expand Up @@ -411,7 +411,7 @@ pub(crate) trait StakingImpl: StateWriteExt {
let penalty = self
.penalty_in_epoch(&validator.identity_key, epoch_to_end.index)
.await?
.unwrap_or_default();
.unwrap_or(Penalty::from_percent(0));
let prev_validator_rate_with_penalty = prev_validator_rate.slash(penalty);

// Then compute the next validator rate, accounting for funding streams and validator state.
Expand Down Expand Up @@ -1099,7 +1099,7 @@ pub trait StateReadExt: StateRead {
let start_key = state_key::penalty_in_epoch(id, start);
let end_key = state_key::penalty_in_epoch(id, end);

let mut compounded = Penalty::default();
let mut compounded = Penalty::from_percent(0);
for (_key, penalty) in all_penalties.range(start_key..end_key) {
compounded = compounded.compound(*penalty);
}
Expand Down Expand Up @@ -1413,7 +1413,7 @@ pub trait StateWriteExt: StateWrite {
let current_penalty = self
.penalty_in_epoch(identity_key, current_epoch_index)
.await?
.unwrap_or_default();
.unwrap_or(Penalty::from_percent(0));

let new_penalty = current_penalty.compound(slashing_penalty);

Expand Down
216 changes: 82 additions & 134 deletions crates/core/component/stake/src/penalty.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::str::FromStr;

use ark_ff::ToConstraintField;
use ark_r1cs_std::prelude::*;
use ark_relations::r1cs::SynthesisError;
use decaf377::{r1cs::FqVar, FieldExt, Fq};
use decaf377::Fq;
use penumbra_proto::{penumbra::core::component::stake::v1alpha1 as pbs, DomainType};
use serde::{Deserialize, Serialize};

Expand All @@ -12,52 +10,70 @@ use penumbra_asset::{
balance::BalanceVar,
Balance, Value, ValueVar, STAKING_TOKEN_ASSET_ID,
};
use penumbra_num::{fixpoint::bit_constrain, Amount, AmountVar};
use penumbra_num::{
fixpoint::{U128x128, U128x128Var},
Amount, AmountVar,
};

/// Tracks slashing penalties applied to a validator in some epoch.
///
/// The penalty is represented as a fixed-point integer in bps^2 (denominator 10^8).
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
/// You do not need to know how the penalty is represented.
///
/// If you insist on knowing, it's represented as a U128x128 between 0 and 1,
/// which denotes the amount *kept* after applying a penalty. e.g. a 1% penalty
/// would be 0.99.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(try_from = "pbs::Penalty", into = "pbs::Penalty")]
pub struct Penalty(pub u64);
pub struct Penalty(U128x128);

impl Penalty {
/// Create a `Penalty` from a percentage e.g.
/// `Penalty::from_percent(1)` is a 1% penalty.
/// `Penalty::from_percent(100)` is a 100% penalty.
pub fn from_percent(percent: u64) -> Self {
Penalty::from_bps(percent * 100)
Penalty::from_bps(percent.saturating_mul(100))
}

/// Create a `Penalty` from a basis point e.g.
/// `Penalty::from_bps(1)` is a 1 bps penalty.
/// `Penalty::from_bps(100)` is a 100 bps penalty.
pub fn from_bps(bps: u64) -> Self {
Penalty(bps * 10_000)
Penalty::from_bps_squared(bps.saturating_mul(10000))
}

/// Create a `Penalty` from a basis point squared e.g.
/// `Penalty::from_bps(1_0000_0000)` is a 100% penalty.
pub fn from_bps_squared(bps_squared: u64) -> Self {
assert!(bps_squared <= 1_0000_0000);
Self(U128x128::ratio(bps_squared, 1_0000_0000).expect(&format!(
"{bps_squared} bps^2 should be convertible to a U128x128"
)))
.one_minus_this()
}

fn one_minus_this(&self) -> Penalty {
Self(
(U128x128::from(1u64) - self.0)
.expect("1 - penalty should never underflow, because penalty is at most 1"),
)
}

/// Compound this `Penalty` with another `Penalty`.
pub fn compound(&self, other: Penalty) -> Penalty {
// We want to compute q sth (1 - q) = (1-p1)(1-p2)
// q = 1 - (1-p1)(1-p2)
// but since each p_i implicitly carries a factor of 10^8, we need to divide by 10^8 after multiplying.
let one = 1_0000_0000u128;
let p1 = self.0 as u128;
let p2 = other.0 as u128;
let q = u64::try_from(one - (((one - p1) * (one - p2)) / 1_0000_0000))
.expect("value should fit in 64 bits");
Penalty(q)
Self((self.0 * other.0).expect("compounding penalities will not overflow"))
}

/// Apply this `Penalty` to an `Amount` of unbonding tokens.
pub fn apply_to(&self, amount: Amount) -> Amount {
let penalized_amount = (u128::try_from(amount).expect("amount should be a valid u128"))
* (1_0000_0000 - self.0 as u128)
/ 1_0000_0000;
Amount::try_from(
u64::try_from(penalized_amount).expect("penalized amount should fit in u64"),
)
.expect("all u64 values should be valid Amounts")
pub fn apply_to_amount(&self, amount: Amount) -> Amount {
self.apply_to(amount)
.round_down()
.try_into()
.expect("converting integral U128xU128 into Amount will succeed")
}

/// Apply this `Penalty` to some fraction.
pub fn apply_to(&self, amount: impl Into<U128x128>) -> U128x128 {
(amount.into() * self.0).expect("should not overflow, because penalty is <= 1")
}

/// Helper method to compute the effect of an UndelegateClaim on the
Expand All @@ -76,21 +92,42 @@ impl Penalty {
asset_id: unbonding_id,
}
+ Value {
amount: self.apply_to(unbonding_amount),
amount: self.apply_to_amount(unbonding_amount),
asset_id: *STAKING_TOKEN_ASSET_ID,
}
}
}

impl ToConstraintField<Fq> for Penalty {
fn to_field_elements(&self) -> Option<Vec<Fq>> {
let field_elements = vec![Fq::from(self.0)];
Some(field_elements)
self.0.to_field_elements()
}
}

impl From<Penalty> for [u8; 32] {
fn from(value: Penalty) -> Self {
value.0.into()
}
}

impl<'a> TryFrom<&'a [u8]> for Penalty {
type Error = <U128x128 as TryFrom<&'a [u8]>>::Error;

fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
U128x128::try_from(value).map(Self)
}
}

/// One explicit choice in this in circuit representation is that we
/// DO NOT CHECK THAT THE PENALTY IS <= 1 IN CIRCUIT. This is in practice
/// the UndelegateClaim circuit is the ONLY consumer of this type, and
/// in the context of that circuit, the penalty is checked out of circuit
/// to conform to a real value which will be <= 1.
///
/// I repeat myself: IF YOU EVER USE THIS IN A DIFFERENT CIRCUIT, MAKE ABSOLUTELY
/// CERTAIN THAT A PENALTY BEING > 1 IN CIRCUIT IS NOT AN ISSUE.
pub struct PenaltyVar {
inner: FqVar,
inner: U128x128Var,
}

impl AllocVar<Penalty, Fq> for PenaltyVar {
Expand All @@ -99,80 +136,21 @@ impl AllocVar<Penalty, Fq> for PenaltyVar {
f: impl FnOnce() -> Result<T, SynthesisError>,
mode: ark_r1cs_std::prelude::AllocationMode,
) -> Result<Self, SynthesisError> {
let ns = cs.into();
let cs = ns.cs();
let inner: Penalty = *f()?.borrow();
let penalty = FqVar::new_variable(cs, || Ok(Fq::from(inner.0)), mode)?;
// Check the Penalty is 64 bits
let _ = bit_constrain(penalty.clone(), 64);
Ok(Self { inner: penalty })
}
}

impl From<&PenaltyVar> for AmountVar {
fn from(penalty_var: &PenaltyVar) -> Self {
// `AmountVar`s must fit in 128 bits, but `PenaltyVar`s have already
// been constrained to fit in 64 bits, so we can safely
// construct an `AmountVar` from a `PenaltyVar`.
AmountVar {
amount: penalty_var.inner.clone(),
}
assert!(
matches!(mode, ark_r1cs_std::prelude::AllocationMode::Input),
"Penalty must be an input variable"
);
Ok(Self {
inner: U128x128Var::new_variable(cs, || Ok(f()?.borrow().0), mode)?,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here in AllocVar as an additional guard we could return an error if mode is not AllocationMode::Input

})
}
}

impl PenaltyVar {
pub fn apply_to(&self, amount: AmountVar) -> Result<AmountVar, SynthesisError> {
let penalty = self.value().unwrap_or(Penalty(0));
/* Bound analysis
* `penalty_amount = amount * (1_0000_0000 - penalty) / 1_0000_0000`
* Order of operations:
* 1. cst: `penalty` cast to u128 (infallible)
* 2. sub: `1_0000_0000 - penalty`
* 3. mul: `amount * (1_0000_0000 - penalty)`
* 4. div: `amount * (1_0000_0000 - penalty) / 1_0000_0000`
*
* Units:
* `amount` : staking tokens to undelegate (128 bits)
* `penalty` : a bps^2 penalty factor between 0 and 10^8 (64 bits)
* `staking_token_unit_amount` : 10^6 ~ 2^20
* `bps_squared_constant` : 10^8 ~ 2^27
*
* Overflow condition: `amount * (1_0000_0000 - penalty) > 2^128 - 1`
* Undeflow condition: `penalty` > 10^8 (penalty is greater than 100%)
*
* Boundary: If penalty is 0, then `amount * 1_0000_0000 = amount * 2^27`
* With `amount` as 2^(x+20) - 1, where x is log2(staking tokens):
* What quantity of staking tokens would cause an overflow? (for 128 bits)
* Find x: 2^(x+20) * 2^27 < 2^128
* True for x < 81 (~10^24 staking tokens), so an overflow for 128 bits is implausible.
*
* What quantity of staking tokens would cause an overflow? (for 64 bits)
* Find x: 2^(x+20) * 2^27 < 2^64
* True for x < 17 (~10^5 staking tokens), so an overflow for 64 bits is possible and plausible.
*
*/

// Out of circuit penalized amount computation:
let amount_bytes = &amount.value().unwrap_or(Amount::from(0u64)).to_le_bytes()[0..16];
let amount_128 =
u128::from_le_bytes(amount_bytes.try_into().expect("should fit in 16 bytes"));
let penalized_amount = amount_128 * (1_0000_0000 - penalty.0 as u128) / 1_0000_0000;

// Witness the result in the circuit.
let penalized_amount_var = AmountVar::new_witness(self.cs(), || {
Ok(Amount::from(
u64::try_from(penalized_amount).expect("can fit in u64"),
))
})?;

// Now we certify the witnessed penalized amount was calculated correctly.
// Constrain: penalized_amount = amount * (1_0000_0000 - penalty (public)) / 1_0000_0000
let hundred_mil = AmountVar::new_constant(self.cs(), Amount::from(1_0000_0000u128))?; // 1_0000_0000
let numerator = amount * (hundred_mil.clone() - self.into());
let (penalized_amount_quo, _) = numerator.quo_rem(&hundred_mil)?;
penalized_amount_quo.enforce_equal(&penalized_amount_var)?;

Ok(penalized_amount_var)
U128x128Var::from_amount_var(amount)?
.checked_mul(&self.inner)?
.round_down_to_amount()
}

pub fn balance_for_claim(
Expand Down Expand Up @@ -203,26 +181,7 @@ impl R1CSVar<Fq> for PenaltyVar {
}

fn value(&self) -> Result<Self::Value, SynthesisError> {
let inner_fq = self.inner.value()?;
let inner_bytes = &inner_fq.to_bytes()[0..8];
let penalty_bytes: [u8; 8] = inner_bytes
.try_into()
.expect("should be able to fit in 16 bytes");
Ok(Penalty(u64::from_le_bytes(penalty_bytes)))
}
}

impl FromStr for Penalty {
type Err = <u64 as FromStr>::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let v = u64::from_str(s)?;
Ok(Penalty(v))
}
}

impl std::fmt::Display for Penalty {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
Ok(Penalty(self.inner.value()?))
}
}

Expand All @@ -232,26 +191,15 @@ impl DomainType for Penalty {

impl From<Penalty> for pbs::Penalty {
fn from(v: Penalty) -> Self {
pbs::Penalty { inner: v.0 }
pbs::Penalty {
inner: <[u8; 32]>::from(v).to_vec(),
}
}
}

impl TryFrom<pbs::Penalty> for Penalty {
type Error = anyhow::Error;
fn try_from(v: pbs::Penalty) -> Result<Self, Self::Error> {
Ok(Penalty(v.inner))
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn penalty_display_fromstr_roundtrip() {
let p = Penalty(123456789);
let s = p.to_string();
let p2 = Penalty::from_str(&s).unwrap();
assert_eq!(p, p2);
Ok(Penalty::try_from(v.inner.as_slice())?)
}
}
21 changes: 9 additions & 12 deletions crates/core/component/stake/src/rate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,14 @@ impl RateData {

pub fn slash(&self, penalty: Penalty) -> Self {
let mut slashed = self.clone();
// (1 - penalty) * exchange_rate
slashed.validator_exchange_rate = self
.validator_exchange_rate
// Slashing penalty is in bps^2, so we divide by 1e8
.saturating_sub(
u64::try_from(
(self.validator_exchange_rate as u128 * penalty.0 as u128) / 1_0000_0000,
)
.expect("penalty should fit in u64"),
);

// This will automatically produce a ratio which is multiplied by 1_0000_0000, and so
// rounding down does what we want.
let penalized_exchange_rate: u64 = penalty
.apply_to(self.validator_exchange_rate)
.round_down()
.try_into()
.expect("multiplying will not overflow");
slashed.validator_exchange_rate = penalized_exchange_rate;
slashed
}

Expand Down Expand Up @@ -280,7 +277,7 @@ mod tests {
validator_exchange_rate: 2_0000_0000,
};
// 10%
let penalty = Penalty(1000_0000);
let penalty = Penalty::from_percent(10);
let slashed = rate_data.slash(penalty);
assert_eq!(slashed.validator_exchange_rate, 1_8000_0000);
}
Expand Down
Loading
Loading