LRU全称:(Least Recently Used)即最近最少使用的,这个算法经常用于缓存场景,可以做到O(1)
的读写复杂度,自动淘汰最近最少使用的数据
。
我们用Rust实现的时候,如果不想碰繁琐但安全的Rc<Refcell>
组合,可以试试Unsafe
来实现。
数据结构 由于我们需要哈希表才能实现O(1)的读写时间复杂度,我们用哈希表
索引所有KV。此外我们还需要链表
来确保我们能在O(1)时间复杂度内完成最近更新
的状态更新,这意味着任何读写操作都会把数据从链表中挪到队列尾
,如果写操作导致LRU大于最大值则淘汰队列头
的数据。
由于我们淘汰数据时不仅要移除链表尾,还必须移除对应哈希表的数据,所以链表节点中必须包含数据的K,所以链表的节点数据结构就是这样
1 2 3 4 5 6 struct Entry <K, V> { pre: *mut Entry<K, V>, next: *mut Entry<K, V>, k: K, v: V, }
LRU的数据结构类似这样
1 2 3 4 5 6 7 8 9 10 struct LruCache <K, V>where K: Hash + Eq , { data: HashMap<K, *mut Entry<K, V>>, head: *mut Entry<K, V>, tail: *mut Entry<K, V>, cap: usize , }
这里我们哈希表直接使用std的HashMap
我们为其实现简单的new size接口
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 impl <K, V> LruCache<K, V>where K: Hash + Eq , { pub fn new (cap: usize ) -> Self { LruCache { data: HashMap::new (), head: null_mut (), tail: null_mut (), cap, } } pub fn size (&self ) -> usize { self .data.len () } }
假如泛型实现了Debug
,我们为其实现print,打印链表
1 2 3 4 5 6 7 8 9 10 11 12 13 pub fn print (&self )where K: Debug , V: Debug , { let mut cur = self .head; while !cur.is_null () { unsafe { println! ("k = {:?} v = {:?}" , (*cur).k, (*cur).v); cur = (*cur).next; } } }
然后我们需要一个内部函数用于将节点从链中移动到链尾
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 unsafe fn move_to_last (&mut self , entry: *mut Entry<K, V>)where K: Eq + Hash, { if self .tail.is_null () || entry.is_null () { return ; } let e = entry.read (); let pre = e.pre; let next = e.next; if !pre.is_null () { (*pre).next = next; } if !next.is_null () { (*next).pre = pre; } if self .head == entry { self .head = next } (*self .tail).next = entry; (*entry).pre = self .tail; (*entry).next = std::ptr::null_mut (); self .tail = entry; }
然后我们就能实现get函数,这个函数的意图是,从哈希表中查找K,如果K存在则执行更新队尾操作,如果不存在则返回None
1 2 3 4 5 6 7 8 9 10 fn get (&mut self , k: &K) -> Option <&V> { match self .data.get (k) { Some (x) => unsafe { self .move_to_last (*x); let r = Some (&(**x).v); return r; }, None => return None , } }
如果我们这样写,实现会发现行不通,我们在get中对self有不可变引用,move_to_last又要求使用可变引用,换句话说,我们始终无法直接使用哈希表中的结果来修改链表。这里我们就需要做些修改,提前结束不可变引用的生命周期。
1 2 3 4 5 6 7 8 9 10 fn get (&mut self , k: &K) -> Option <&V> { if !self .data.contains_key (k) { return None ; }; let r = *self .data.get (k).unwrap (); unsafe { self .move_to_last (r); return Some (&(*r).v); } }
还有个将节点添加到链表中的函数
1 2 3 4 5 6 7 8 9 10 11 12 13 unsafe fn add_new_entry (&mut self , new_entry: *mut Entry<K, V>) { if self .data.len () == 0 { self .head = new_entry; self .tail = new_entry; return ; }; (*self .tail).next = new_entry; (*new_entry).pre = self .tail; (*new_entry).next = null_mut (); self .tail = new_entry; }
淘汰函数,负责将首节点弹出,这里就是由于最少使用
弹出的
1 2 3 4 5 6 7 8 9 10 unsafe fn evict (&mut self ) -> *mut Entry<K, V> { if self .head.is_null () { panic! ("empty" ); }; let h = self .head; self .head = (*h).next; (*self .head).pre = null_mut (); (*h).next = null_mut (); return h; }
然后我们就能实现put了
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 fn put (&mut self , k: K, v: V) { if self .data.contains_key (&k) { let entry = *self .data.get (&k).unwrap (); unsafe { (*entry).v = v; self .move_to_last (entry); return ; } }; unsafe { if self .data.len () == self .cap { let new_entry = self .evict (); self .data.remove (&(*new_entry).k); drop (Box ::from_raw (new_entry)); } let new_entry = Box ::into_raw (Box ::new (Entry { pre: null_mut (), next: null_mut (), k, v, })); self .add_new_entry (new_entry); self .data.insert (k, new_entry); } }
但是,上面的代码运行也会报错,这是因为,我们在哈希表和链表的节点中都添加了K,一个值只能move一次。一个work around的方法是将K的约束加个Copy
。
即:
1 2 3 4 5 6 7 8 9 10 11 struct LruCache <K, V>where K: Hash + Eq + Copy , { ... #[allow(dead_code)] impl <K, V> LruCache<K, V>where K: Hash + Eq + Copy , ...
还有一个办法是封装K的裸指针,可以不用Copy
1 2 3 4 5 6 7 8 9 10 11 struct Key <T: Hash>(*const T);impl <K: Eq + Hash> Eq for Key <K> {}impl <T: Hash> Hash for Key <T> { fn hash <H: std::hash::Hasher>(&self , state: &mut H) { unsafe { (*self .0 ).hash (state) } } }
封装裸指针后,需要计算hash就解引用。
在查询时构建Key,只需要将K的引用传入,会自动转为裸指针。
1 2 3 4 5 6 7 8 9 10 11 pub fn get (&mut self , k: &K) -> Option <&V> { if !self .data.contains_key (&Key (k)) { return None ; }; let r = *self .data.get (&Key (k)).unwrap (); unsafe { self .move_to_last (r); return Some (&(*r).v); } }
然后我们加上一些测试用例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 #[cfg(test)] mod test { use crate::common::cache::lru::LruCache; #[test] fn lru_f1 () { let mut cache = LruCache::new (3 ); cache.put ("2" , 1 ); cache.put ("3" , 2 ); cache.put ("4" , 2 ); cache.put ("5" , 2 ); cache.put ("7" , 2 ); assert_eq! (cache.size (), 3 ); } #[test] fn lru_f2 () { let mut cache = LruCache::new (3 ); cache.put ("2" , 1 ); cache.put ("3" , 2 ); cache.put ("4" , 4 ); assert_eq! (cache.get (&"2" ).unwrap (), &1 ); } #[test] fn lru_f3 () { let mut cache = LruCache::new (3 ); cache.put ("2" , 1 ); cache.put ("3" , 2 ); cache.put ("4" , 4 ); cache.put ("5" , 6 ); assert_eq! (cache.get (&"2" ), None ); } #[test] fn lru_f4 () { let mut cache = LruCache::new (3 ); cache.put ("2" , 1 ); cache.print (); cache.put ("3" , 2 ); cache.print (); assert_eq! (cache.get (&"2" ), Some (&1 )); cache.print (); cache.put ("4" , 4 ); cache.print (); cache.put ("5" , 6 ); cache.print (); assert_eq! (cache.get (&"3" ), None ); } }
执行cargo test lru
即可执行上面的单元测试。
如果一切顺利,我们就能看到:
1 2 3 4 5 6 7 running 4 tests test common::cache::lru::test::lru_f1 ... ok test common::cache::lru::test::lru_f2 ... ok test common::cache::lru::test::lru_f3 ... ok test common::cache::lru::test::lru_f4 ... ok test result: ok. 4 passed; 0 failed; 0 ignored; 0 measured; 11 filtered out; finished in 0.00s
但是我们的程序真的没问题吗?我们的数据结构中的一堆unsafe没问题吗?
这里我们引入miri来测试下,miri是一款Rust的MIR实验性解释器,能用于检测未定义行为 。miri依赖nightly Rust:
1 rustup +nightly component add miri
然后执行:
发现报错:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 error: memory leaked: alloc36198 (Rust heap, size: 40, align: 8), allocated here: --> /home/fenix/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/alloc/src/alloc.rs:98:9 | 98 | __rust_alloc(layout.size(), layout.align()) | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ... note: inside `common::cache::lru::LruCache::<&str, i32>::put` --> src/common/cache/lru.rs:151:43 | 151 | let new_entry = Box::into_raw(Box::new(Entry { | ___________________________________________^ 152 | | pre: null_mut(), 153 | | next: null_mut(), 154 | | k, 155 | | v, 156 | | })); | |______________^
原来是我们的Box内存泄漏了,我们要给LruCache加个Drop
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 impl <K, V> Drop for LruCache <K, V>where K: Hash + Eq + Copy , { fn drop (&mut self ) { let mut cur = self .head; unsafe { while !cur.is_null () { self .data.remove (&(*cur).k); let d = cur; cur = (*cur).next; drop (Box ::from_raw (d)); } } } }
再次执行
发现已经不会报错了