Rust实现的LRU安全吗?

LRU全称:(Least Recently Used)即最近最少使用的,这个算法经常用于缓存场景,可以做到O(1)的读写复杂度,自动淘汰最近最少使用的数据

我们用Rust实现的时候,如果不想碰繁琐但安全的Rc<Refcell> 组合,可以试试Unsafe来实现。

数据结构

由于我们需要哈希表才能实现O(1)的读写时间复杂度,我们用哈希表索引所有KV。此外我们还需要链表来确保我们能在O(1)时间复杂度内完成最近更新的状态更新,这意味着任何读写操作都会把数据从链表中挪到队列尾,如果写操作导致LRU大于最大值则淘汰队列头的数据。

由于我们淘汰数据时不仅要移除链表尾,还必须移除对应哈希表的数据,所以链表节点中必须包含数据的K,所以链表的节点数据结构就是这样

1
2
3
4
5
6
struct Entry<K, V> {
pre: *mut Entry<K, V>,
next: *mut Entry<K, V>,
k: K,
v: V,
}

LRU的数据结构类似这样
1

1
2
3
4
5
6
7
8
9
10

struct LruCache<K, V>
where
K: Hash + Eq,
{
data: HashMap<K, *mut Entry<K, V>>,
head: *mut Entry<K, V>,
tail: *mut Entry<K, V>,
cap: usize,
}

这里我们哈希表直接使用std的HashMap

我们为其实现简单的new size接口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
impl<K, V> LruCache<K, V>
where
K: Hash + Eq,
{
pub fn new(cap: usize) -> Self {
LruCache {
data: HashMap::new(),
head: null_mut(),
tail: null_mut(),
cap,
}
}

pub fn size(&self) -> usize {
self.data.len()
}
}

假如泛型实现了Debug,我们为其实现print,打印链表

1
2
3
4
5
6
7
8
9
10
11
12
13
pub fn print(&self)
where
K: Debug,
V: Debug,
{
let mut cur = self.head;
while !cur.is_null() {
unsafe {
println!("k = {:?} v = {:?}", (*cur).k, (*cur).v);
cur = (*cur).next;
}
}
}

然后我们需要一个内部函数用于将节点从链中移动到链尾

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
unsafe fn move_to_last(&mut self, entry: *mut Entry<K, V>)
where
K: Eq + Hash,
{
// 说明当前没数据
if self.tail.is_null() || entry.is_null() {
return;
}
let e = entry.read();
let pre = e.pre;
let next = e.next;
// 把前后节点接上
if !pre.is_null() {
(*pre).next = next;
}
if !next.is_null() {
(*next).pre = pre;
}
// 如果当前节点是头节点,头节点指向下一个
if self.head == entry {
self.head = next
}
// 将当前节点追加到尾节点后面,如果当前节点就是尾节点,执行完下面四步什么都不会发生
(*self.tail).next = entry;
(*entry).pre = self.tail;
(*entry).next = std::ptr::null_mut();
self.tail = entry;
}

然后我们就能实现get函数,这个函数的意图是,从哈希表中查找K,如果K存在则执行更新队尾操作,如果不存在则返回None

1
2
3
4
5
6
7
8
9
10
fn get(&mut self, k: &K) -> Option<&V> {
match self.data.get(k) {
Some(x) => unsafe {
self.move_to_last(*x);
let r = Some(&(**x).v);
return r;
},
None => return None,
}
}

如果我们这样写,实现会发现行不通,我们在get中对self有不可变引用,move_to_last又要求使用可变引用,换句话说,我们始终无法直接使用哈希表中的结果来修改链表。这里我们就需要做些修改,提前结束不可变引用的生命周期。

1
2
3
4
5
6
7
8
9
10
fn get(&mut self, k: &K) -> Option<&V> {
if !self.data.contains_key(k) {
return None;
};
let r = *self.data.get(k).unwrap();
unsafe {
self.move_to_last(r);
return Some(&(*r).v);
}
}

还有个将节点添加到链表中的函数

1
2
3
4
5
6
7
8
9
10
11
12
13
unsafe fn add_new_entry(&mut self, new_entry: *mut Entry<K, V>) {
// 如果链表为空,首尾指向 当前节点
if self.data.len() == 0 {
self.head = new_entry;
self.tail = new_entry;
return;
};
// 将当前节点接到尾节点后面
(*self.tail).next = new_entry;
(*new_entry).pre = self.tail;
(*new_entry).next = null_mut();
self.tail = new_entry;
}

淘汰函数,负责将首节点弹出,这里就是由于最少使用弹出的

1
2
3
4
5
6
7
8
9
10
unsafe fn evict(&mut self) -> *mut Entry<K, V> {
if self.head.is_null() {
panic!("empty");
};
let h = self.head;
self.head = (*h).next;
(*self.head).pre = null_mut();
(*h).next = null_mut();
return h;
}

然后我们就能实现put了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
fn put(&mut self, k: K, v: V) {
// 如果数据存在于哈希表中,则更新
if self.data.contains_key(&k) {
let entry = *self.data.get(&k).unwrap();
unsafe {
(*entry).v = v;
self.move_to_last(entry);
return;
}
};
unsafe {
// 如果lru数据量达到阈值,执行太太
if self.data.len() == self.cap {
let new_entry = self.evict();
self.data.remove(&(*new_entry).k);
drop(Box::from_raw(new_entry));
}
// 这里把kv封装到Entry,在塞进Box中,并转裸指针
let new_entry = Box::into_raw(Box::new(Entry {
pre: null_mut(),
next: null_mut(),
k,
v,
}));
self.add_new_entry(new_entry);
self.data.insert(k, new_entry);
}
}

但是,上面的代码运行也会报错,这是因为,我们在哈希表和链表的节点中都添加了K,一个值只能move一次。一个work around的方法是将K的约束加个Copy

即:

1
2
3
4
5
6
7
8
9
10
11
struct LruCache<K, V>
where
K: Hash + Eq + Copy,
{
...

#[allow(dead_code)]
impl<K, V> LruCache<K, V>
where
K: Hash + Eq + Copy,
...

还有一个办法是封装K的裸指针,可以不用Copy

1
2
3
4
5
6
7
8
9
10
11

struct Key<T: Hash>(*const T);

impl<K: Eq + Hash> Eq for Key<K> {}

impl<T: Hash> Hash for Key<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
unsafe { (*self.0).hash(state) }
}
}

封装裸指针后,需要计算hash就解引用。

在查询时构建Key,只需要将K的引用传入,会自动转为裸指针。

1
2
3
4
5
6
7
8
9
10
11
pub fn get(&mut self, k: &K) -> Option<&V> {
if !self.data.contains_key(&Key(k)) {
return None;
};
let r = *self.data.get(&Key(k)).unwrap();
unsafe {
self.move_to_last(r);
return Some(&(*r).v);
}
}

然后我们加上一些测试用例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

#[cfg(test)]
mod test {
use crate::common::cache::lru::LruCache;

#[test]
fn lru_f1() {
let mut cache = LruCache::new(3);
cache.put("2", 1);
cache.put("3", 2);
cache.put("4", 2);
cache.put("5", 2);
cache.put("7", 2);
assert_eq!(cache.size(), 3);
}

#[test]
fn lru_f2() {
let mut cache = LruCache::new(3);
cache.put("2", 1);
cache.put("3", 2);
cache.put("4", 4);
assert_eq!(cache.get(&"2").unwrap(), &1);
}

#[test]
fn lru_f3() {
let mut cache = LruCache::new(3);
cache.put("2", 1);
cache.put("3", 2);
cache.put("4", 4);
cache.put("5", 6);
assert_eq!(cache.get(&"2"), None);
}

#[test]
fn lru_f4() {
let mut cache = LruCache::new(3);
cache.put("2", 1);
cache.print();
cache.put("3", 2);
cache.print();
assert_eq!(cache.get(&"2"), Some(&1));
cache.print();
cache.put("4", 4);
cache.print();
cache.put("5", 6);
cache.print();
assert_eq!(cache.get(&"3"), None);
}
}

执行cargo test lru 即可执行上面的单元测试。

如果一切顺利,我们就能看到:

1
2
3
4
5
6
7
running 4 tests
test common::cache::lru::test::lru_f1 ... ok
test common::cache::lru::test::lru_f2 ... ok
test common::cache::lru::test::lru_f3 ... ok
test common::cache::lru::test::lru_f4 ... ok

test result: ok. 4 passed; 0 failed; 0 ignored; 0 measured; 11 filtered out; finished in 0.00s

但是我们的程序真的没问题吗?我们的数据结构中的一堆unsafe没问题吗?

这里我们引入miri来测试下,miri是一款Rust的MIR实验性解释器,能用于检测未定义行为 。miri依赖nightly Rust:

1
rustup +nightly component add miri

然后执行:

1
cargo miri test lru

发现报错:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
error: memory leaked: alloc36198 (Rust heap, size: 40, align: 8), allocated here:
--> /home/fenix/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/alloc/src/alloc.rs:98:9
|
98 | __rust_alloc(layout.size(), layout.align())
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

...

note: inside `common::cache::lru::LruCache::<&str, i32>::put`
--> src/common/cache/lru.rs:151:43
|
151 | let new_entry = Box::into_raw(Box::new(Entry {
| ___________________________________________^
152 | | pre: null_mut(),
153 | | next: null_mut(),
154 | | k,
155 | | v,
156 | | }));
| |______________^

原来是我们的Box内存泄漏了,我们要给LruCache加个Drop

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
impl<K, V> Drop for LruCache<K, V>
where
K: Hash + Eq + Copy,
{
fn drop(&mut self) {
let mut cur = self.head;
unsafe {
while !cur.is_null() {
self.data.remove(&(*cur).k);
let d = cur;
cur = (*cur).next;
drop(Box::from_raw(d));
}
}
}
}

再次执行

1
cargo miri test lru

发现已经不会报错了