use near_sdk::borsh::{self, BorshDeserialize, BorshSerialize}; const TWO_TO_16: u128 = 1u128 << 16; const TWO_TO_18: u128 = 1u128 << 18; const TWO_TO_32: u128 = 1u128 << 32; const TWO_TO_64: u128 = 1u128 << 64; const TWO_TO_96: u128 = 1u128 << 96; #[derive( BorshDeserialize, BorshSerialize, PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Default, )] pub struct Rate { x1: u8, x2: u16, } impl Rate { pub fn new(numerator: u128, denominator: u128) -> Self { assert!( numerator != 0 && denominator != 0 && numerator / denominator < TWO_TO_32 && denominator / numerator < TWO_TO_32 ); // rate will contain num / den * 2^96 with at least 32 most significant bits correct. let rate: u128 = if numerator < TWO_TO_64 { // den is at most 2^32 * nom, so nom * 2^64 / den gives at least 32 bits correct numerator * TWO_TO_64 / denominator * TWO_TO_32 } else if numerator < TWO_TO_96 { if numerator >= denominator { numerator * TWO_TO_32 / denominator * TWO_TO_64 } else { // den here is at least 2^64, so dividing it by 2^32 keeps enough precision numerator * TWO_TO_32 / (denominator / TWO_TO_32) * TWO_TO_32 } } else { if numerator >= denominator { // den here is at least 2^64, so dividing it by 2^32 keeps enough precision numerator / (denominator / TWO_TO_32) * TWO_TO_64 } else { // den here is at least 2^96, so dividing it by 2^64 keeps enough precision numerator / (denominator / TWO_TO_64) * TWO_TO_32 } }; assert!(rate >= TWO_TO_64); // Since rate is >= 2^64, the exponent is between 0 and 63, and fits into 6 bits let exponent_inv = rate.leading_zeros() as u128; assert!(exponent_inv < 64); let exponent = 63 - exponent_inv; let mantissa = (rate >> (46 + exponent)) & (TWO_TO_18 - 1); Self { x1: ((exponent << 2) + (mantissa >> 16)) as u8, x2: (mantissa & (TWO_TO_16 - 1)) as u16, } } pub fn mul(&self, mut other: u128, round_up: bool) -> u128 { let exponent = self.x1 as u128 >> 2; let mut mantissa = self.x2 as u128 + ((self.x1 as u128 & 3) << 16); if round_up { mantissa += 1; } let mut rate = (mantissa + (1 << 18)) << (46 + exponent); let mut shift = 96; if rate >= TWO_TO_64 { rate >>= 32; shift -= 32; } if rate >= TWO_TO_64 { rate >>= 32; shift -= 32; } if other >= TWO_TO_64 { other >>= 32; shift -= 32; } if other >= TWO_TO_64 && shift > 0 { other >>= 32; shift -= 32; } (other.saturating_mul(rate)) >> shift } pub fn as_raw_u32(&self) -> u32 { self.x1 as u32 * (TWO_TO_16 as u32) + (self.x2 as u32) } pub fn from_raw_u32(val: u32) -> Self { Rate { x1: (val >> 16) as u8, x2: (val & ((1 << 16) - 1)) as u16, } } pub fn is_none(&self) -> bool { self.x1 == 0 && self.x2 == 0 } } #[cfg(not(target_arch = "wasm32"))] #[cfg(test)] mod tests { use super::*; use rand::{thread_rng, Rng}; /// `Rate` represents a floating point number as an exponent and mantissa. If one iterates over /// ratios from 1 / 2^32 to 2^32, increasing the numerator by the 1 >> exponent at each step, /// we expect that the integer representation of the rate goes from 1 all the way to 2^24 - 1, with /// increments of one. This test ensures that. #[test] fn test_rate_n_over_two_to_64() { for shift in &[0, 13, 14, 15, 45, 46, 47, 76, 77] { let mut add = 1; let mut nom = (1 << 18) + add; let mut expected = 1; while nom / (1 << 50) < TWO_TO_32 { // if shift <= 46, there should be no overflows, and the test should run through // the entire range. Otherwise check for overflow if *shift > 46 && u128::MAX >> *shift < nom { break; } let rate = Rate::new(nom << shift, (1 << 50) << shift); let rate_raw = rate.as_raw_u32(); assert_eq!(rate_raw, expected); if (rate_raw as u128) & (TWO_TO_18 - 1) == 0 { add *= 2; } nom += add; expected += 1; } // For shifts <= 32, we expect the test to go through the entire range if *shift <= 46 { assert_eq!(expected, 1 << 24); } } } #[test] fn test_rate_mul() { for _ in 0..500000 { let a: u128 = thread_rng().gen(); let b: u128 = thread_rng().gen_range(a / TWO_TO_32 + 1, a.saturating_mul(TWO_TO_32)); let r = Rate::new(a, b); for round_up in &[false, true] { let c = r.mul(b, *round_up); if a >= c { assert!(!round_up); assert!(a - c <= a / TWO_TO_18, "{} {}", a, c); } else { assert!(round_up); assert!(c - a <= a / TWO_TO_18, "{} {}", a, c); } } } } }