use near_sdk::borsh::{self, BorshDeserialize, BorshSerialize}; use near_sdk::env; pub trait Heap { fn new() -> Self; fn get_min(&self) -> Option; fn remove_min(&mut self, storage_key: u128) -> bool; fn get(&self, storage_key: u128, key: u32) -> bool; fn set(&mut self, storage_key: u128, key: u32) -> bool; fn remove(&mut self, storage_key: u128, key: u32) -> bool; fn num_bits() -> u64; } #[derive(BorshDeserialize, BorshSerialize, Default, Copy, Clone)] pub struct HeapNBit { min_value: Option, bucket_mask: Heap8Bit, min_bucket: HeapChild, } #[derive(BorshDeserialize, BorshSerialize, Default, Copy, Clone)] pub struct Heap8Bit { data: [u64; 4], } pub type Heap16Bit = HeapNBit; pub type Heap24Bit = HeapNBit; impl Heap for HeapNBit { fn new() -> Self { Self { min_value: None, bucket_mask: Heap8Bit::new(), min_bucket: HeapChild::new(), } } /// Returns the smallest element /// Doesn't incur any additional IO cost fn get_min(&self) -> Option { self.min_value } /// Removes the smallest element from the heap. /// Returns whether the heap became empty /// Panics if the heap is empty. /// Can incur at most one write, and only if the heap remained non-empty. fn remove_min(&mut self, storage_key: u128) -> bool { match self.bucket_mask.get_min() { None => { self.min_value = None; true } Some(min_bucket_ord) => { self.min_value = Some( self.min_bucket.get_min().unwrap() + min_bucket_ord * (1 << HeapChild::num_bits()), ); // If this `remove_min` incurs a write, it also returns false, and is is the only write incurred let bucket_is_empty = self.min_bucket.remove_min(compute_child_storage_key( storage_key, min_bucket_ord, HeapChild::num_bits(), )); if bucket_is_empty { // Heap8Bit::remove_min doesn't incur a write if !self.bucket_mask.remove_min(0) { env::storage_remove( &compute_bucket_storage_key( storage_key, self.bucket_mask.get_min().unwrap(), HeapChild::num_bits(), ) .to_be_bytes(), ); self.min_bucket = HeapChild::try_from_slice(&env::storage_get_evicted().unwrap()) .unwrap(); } } false } } } fn get(&self, storage_key: u128, key: u32) -> bool { if self.min_value == Some(key) { return true; } let bucket_ord = key >> HeapChild::num_bits(); let within_bucket = key & ((1 << HeapChild::num_bits()) - 1); let bucket_storage_key = compute_bucket_storage_key(storage_key, bucket_ord, HeapChild::num_bits()); let child_storage_key = compute_child_storage_key(storage_key, bucket_ord, HeapChild::num_bits()); match self.bucket_mask.get_min() { None => false, Some(min_bucket_ord) => { if min_bucket_ord == bucket_ord { self.min_bucket.get(child_storage_key, within_bucket) } else if self.bucket_mask.get(0, bucket_ord) { let mask = HeapChild::try_from_slice( &env::storage_read(&bucket_storage_key.to_be_bytes()).unwrap(), ) .unwrap(); mask.get(child_storage_key, within_bucket) } else { false } } } } /// Sets the element /// Returns whether the element previously was not set /// Heap16Bit incurs at most one read and one write /// Heap24Bit incurs at most two reads and two writes fn set(&mut self, storage_key: u128, mut key: u32) -> bool { assert!(key < (1 << Self::num_bits())); match self.min_value { None => { self.min_value = Some(key); return true; } Some(ref mut min_value) => { if key < *min_value { std::mem::swap(min_value, &mut key); } else if key == *min_value { return false; } } }; let bucket_ord = key >> HeapChild::num_bits(); let within_bucket = key & ((1 << HeapChild::num_bits()) - 1); let bucket_storage_key = compute_bucket_storage_key(storage_key, bucket_ord, HeapChild::num_bits()); let child_storage_key = compute_child_storage_key(storage_key, bucket_ord, HeapChild::num_bits()); match self.bucket_mask.get_min() { None => { self.bucket_mask.set(0, bucket_ord); self.min_bucket = HeapChild::new(); self.min_bucket.set(child_storage_key, within_bucket) } Some(min_bucket_ord) => { if bucket_ord < min_bucket_ord { env::storage_write( &compute_bucket_storage_key( storage_key, min_bucket_ord, HeapChild::num_bits(), ) .to_be_bytes(), &self.min_bucket.try_to_vec().unwrap(), ); self.min_bucket = HeapChild::new(); self.bucket_mask.set(0, bucket_ord); self.min_bucket.set(child_storage_key, within_bucket) } else if bucket_ord == min_bucket_ord { self.min_bucket.set(child_storage_key, within_bucket) } else if !self.bucket_mask.get(0, bucket_ord) { self.bucket_mask.set(0, bucket_ord); let mut bucket = HeapChild::new(); bucket.set(child_storage_key, within_bucket); env::storage_write( &bucket_storage_key.to_be_bytes(), &bucket.try_to_vec().unwrap(), ); true } else { let mut mask = HeapChild::try_from_slice( &env::storage_read(&bucket_storage_key.to_be_bytes()).unwrap(), ) .unwrap(); let ret = mask.set(child_storage_key, within_bucket); env::storage_write( &bucket_storage_key.to_be_bytes(), &mask.try_to_vec().unwrap(), ); ret } } } } /// Removes the element /// Panics if the element being removed is not in the heap /// Heap16Bit incurs at most one read and two writes /// Heap24Bit incurs at most two reads and four writes fn remove(&mut self, storage_key: u128, key: u32) -> bool { if key == self.min_value.unwrap() { return self.remove_min(storage_key); } let bucket_ord = key >> HeapChild::num_bits(); let within_bucket = key & ((1 << HeapChild::num_bits()) - 1); let min_bucket_ord = self.bucket_mask.get_min().unwrap(); let bucket_storage_key = compute_bucket_storage_key(storage_key, bucket_ord, HeapChild::num_bits()); let child_storage_key = compute_child_storage_key(storage_key, bucket_ord, HeapChild::num_bits()); if bucket_ord == min_bucket_ord { if self.min_bucket.remove(child_storage_key, within_bucket) { if !self.bucket_mask.remove(0, bucket_ord) { env::storage_remove( &compute_bucket_storage_key( storage_key, self.bucket_mask.get_min().unwrap(), HeapChild::num_bits(), ) .to_be_bytes(), ); self.min_bucket = HeapChild::try_from_slice(&env::storage_get_evicted().unwrap()).unwrap(); } } } else { let mut mask = HeapChild::try_from_slice( &env::storage_read(&bucket_storage_key.to_be_bytes()).unwrap(), ) .unwrap(); if !mask.remove(child_storage_key, within_bucket) { env::storage_write( &bucket_storage_key.to_be_bytes(), &mask.try_to_vec().unwrap(), ); } else { env::storage_remove(&bucket_storage_key.to_be_bytes()); self.bucket_mask.remove(0, bucket_ord); } } false } fn num_bits() -> u64 { HeapChild::num_bits() + 8 } } impl Heap for Heap8Bit { fn new() -> Self { Self { data: [0; 4] } } fn get_min(&self) -> Option { for i in 0..4 { if self.data[i] != 0 { return Some(self.data[i].trailing_zeros() + (i as u32) * 64); } } None } fn remove_min(&mut self, storage_key: u128) -> bool { self.remove(storage_key, self.get_min().unwrap()) } fn get(&self, _storage_key: u128, key: u32) -> bool { assert!(key < 1 << 8); (self.data[(key >> 6) as usize] & (1 << (key & 63))) != 0 } fn set(&mut self, _storage_key: u128, key: u32) -> bool { assert!(key < (1 << 8)); let ret = self.data[(key >> 6) as usize] & (1 << (key & 63)) == 0; self.data[(key >> 6) as usize] |= 1 << (key & 63); ret } fn remove(&mut self, _storage_key: u128, key: u32) -> bool { assert!(key < (1 << 8)); self.data[(key >> 6) as usize] &= !(1 << (key & 63)); self.data[0] == 0 && self.data[1] == 0 && self.data[2] == 0 && self.data[3] == 0 } fn num_bits() -> u64 { 8 } } fn compute_child_storage_key(storage_key: u128, bucket_ord: u32, num_bits: u64) -> u128 { (((storage_key << 1) | 1) << num_bits) | (bucket_ord as u128) } fn compute_bucket_storage_key(storage_key: u128, bucket_ord: u32, num_bits: u64) -> u128 { (((storage_key << 1) | 0) << num_bits) | (bucket_ord as u128) } #[cfg(not(target_arch = "wasm32"))] #[cfg(test)] mod tests { use super::*; use near_sdk::test_utils::test_env; use rand::{thread_rng, Rng}; use std::collections::BTreeSet; #[test] fn test_heap_8bit() { test_heap_internal::(1 << 8); } #[test] fn test_heap_16bit() { test_heap_internal::(1 << 16); } #[test] fn test_heap_24bit() { test_heap_internal::(1 << 24); } fn test_heap_internal(max_value: u32) { test_env::setup_free(); // Run many small tests, and one large for run in 0..100 { let storage_key = run as u128 | 256; #[cfg(debug_assertions)] let num_iters = if run == 0 { 50000 } else { 1000 }; #[cfg(not(debug_assertions))] let num_iters = if run == 0 { 500000 } else { 10000 }; let mut heap = H::new(); let mut map = BTreeSet::new(); for _ in 0..num_iters { let key = thread_rng().gen_range(0, max_value); if thread_rng().gen() || map.len() == 1 { heap.set(storage_key, key); map.insert(key); } else { if map.remove(&key) { heap.remove(storage_key, key); } } assert_eq!(heap.get_min().as_ref(), map.range(..).next()); } } } }