LFU Cache
Design and implement a data structure for Least Frequently Used (LFU) cache. It should support the following operations: get and set.
get(key) - Get the value (will always be positive) of the key if the key exists in the cache, otherwise return -1.
set(key, value) - Set or insert the value if the key is not already present. When the cache reaches its capacity, it should invalidate the least frequently used item before inserting a new item. For the purpose of this problem, when there is a tie (i.e., two or more keys that have the same frequency), the least recently used key would be evicted.
Solution
Just as the LRU cache, the get operation will also modify the internal state. More concretely, we need to track the number of accesses(read and write) of each key, and keep them in order to guarntee the complexity of the key invalidation.
The dict is chosen as the internal data store to achieve the complexity for get and set operations. The access frequency is tracked in a doubly linked list. Without losing the genercity, we keep the doubly linked list in the descending order, the node will be moved towards the head after the specific key is accessed(get and set).
Here is the doubly linked list in python:
def dll_init():
head = []
head[:] = [head, head, None]
return head
def dll_append(head, value):
last = head[0]
node = [last, head, value]
last[1] = head[0] = node
return node
def dll_remove(node):
prev_link, next_link, _ = node
prev_link[1] = next_link
next_link[0] = prev_link
return node
def dll_iter(root):
curr = root[1] # start at the first node
while curr is not root:
yield curr[2] # yield the curr[KEY]
curr = curr[1] # move to next node
Add some basic unit test cases:
def unit_test(cls, update_func=None):
print('Sanity check.')
cache = cls(2, update_func);
cache.set(1, 1)
cache.set(2, 2)
assert cache.get(1) == 1
cache.set(3, 3); # evicts key 2
assert cache.get(2) == -1
assert cache.get(3) == 3
cache.set(4, 4); # evicts key 1.
assert cache.get(1) == -1
assert cache.get(3) == 3
assert cache.get(4) == 4
print('The set operation SHOULD increment the counter.')
cache = cls(2, update_func)
cache.set(3, 1)
cache.set(2, 1)
cache.set(2, 2)
cache.set(4, 4)
assert cache.get(2) == 2
print('Append the node, and then sync the linked list.')
cache = cls(3, update_func)
for x in [1, 2, 3, 4]:
cache.set(x, x)
assert cache.get(4) == 4
assert cache.get(3) == 3
assert cache.get(2) == 2
assert cache.get(1) == -1
cache.set(5, 5)
assert cache.get(1) == -1
assert cache.get(2) == 2
assert cache.get(3) == 3
assert cache.get(4) == -1
assert cache.get(5) == 5
The first attempt
The first attemp add the recently-access ndoe to the tail of the doubly linked list, then call bubble_update to bubble it up by swapping the node with its precedence(the LFUCache has been rewritten in the plugable fashion to test different update policy).
def bubble_update(node, head):
'''Update the access counter of the node, then bubble it up
by swapping the node with its precedence.'''
# update access count for node
prev, next_, (key, value, counter) = node
counter += 1
node[2] = (key, value, counter)
while(prev is not head and prev[2][2] <= counter):
# swap prev and node
prev2 = prev[0]
prev2[1] = node
node[0] = prev2
node[1] = prev
prev[0] = node
prev[1] = next_
next_[0] = prev
# reset all variables for the next loop
next_ = prev
prev = prev2
prev2 = prev[0]
class LFUCache(object):
def __init__(self, capacity, update_func):
self.capacity = capacity
self.update_freq = update_func
self.cache = dict()
self.head = dll_init()
def set(self, key, value):
# special case for capacity <= 0
if self.capacity <= 0:
return
# Does the key exist in the cache?
node = self.cache.get(key)
if node:
key, _, counter = node[2]
node[2] = (key, value, counter)
self.update_freq(node, self.head)
return
# Remove the LRU key if exceeding the capacity
if len(self.cache) >= self.capacity:
# remove the last element
node = dll_remove(self.head[0])
self.cache.pop(node[2][0])
# append to the end of double-linked list
node = dll_append(self.head, (key, value, 0))
self.cache[key] = node
self.update_freq(node, self.head)
def get(self, key):
node = self.cache.get(key)
if node is None:
return -1
self.update_freq(node, self.head)
return node[2][1]
def dump(self):
print(list(dll_iter(self.head)))
unit_test(cls=LFUCache, update_func=bubble_update)
Sanity check. The set operation SHOULD increment the counter. Append the node, and then sync the linked list.
The second approach
It is quite expensive to swap a node with its neighbour for doubly linked list, insertion is a much more efficient operation for the linked list:
def insert_update(node, head):
'''Update the access counter, then find the pivot from the
tail of the linked list, insert the node AFTER the pivot.'''
# update access count for node
pivot, _, (key, value, counter) = node
counter += 1
node[2] = (key, value, counter)
while(pivot is not head and pivot[2][2] <= counter):
pivot = pivot[0]
# Insert the node AFTER the pivot
if pivot is not node[0]:
# remove the node from the linked list
dll_remove(node)
# insert after the pivot
node[1] = pivot[1]
node[0] = pivot
pivot[1][0] = node
pivot[1] = node
unit_test(cls=LFUCache, update_func=insert_update)
Sanity check. The set operation SHOULD increment the counter. Append the node, and then sync the linked list.
Here is a json blob from LeetCode’s test case, and we can benchmark our implementation:
%load_ext line_profiler
import json
def benchmark(cache):
lut = {
'set': cache.set,
'get': cache.get
}
with open('lfu-cache-test-fixture.json') as f:
fixture = json.load(f)
for method, args in zip(*fixture):
lut[method](*args)
%lprun -f benchmark -f LFUCache.get -f LFUCache.set -f insert_update \
benchmark(LFUCache(2048, insert_update))
It takes 8.45s to run the benchmark, and 99% of the CPU cycles are wasted in the linear search for the pivot. We must find a more efficient way to update the frequency.
The third attempt
The lru paper presents a new approach:
- A double-linked list is used as the frequency list.
- Each node of the frequency list points to a double-linked list, value list for all the keys with the same access frequency.
- A hashmap will maps the key to the node, just as descibed above.
Thus, when a node is accessed, we first remove it from the value list, then find the corresponding node in the frequency list, or insert a new node if necessary; the node is inserted to the top of the value list at the end.
When we need to evict a new node, we just remove the last node in the least node of the frequence list.
from collections import defaultdict
class LFUCache2(object):
def __init__(self, capacity, *args):
self.capacity = capacity
self.bucket_lut = defaultdict(dll_init)
self.cache = dict()
def remove_node(self, node):
# Remove the node from the cache and also the bucket
_, _, (key, value, counter) = node
dll_remove(node)
self.cache.pop(key)
# clean up the freq_head if it is empty
bucket = self.bucket_lut[counter]
if bucket[1] is bucket:
self.bucket_lut.pop(counter)
def add_node(self, data):
# Create a node to host the data, add it to cache and
# append to the bucket.
bucket = self.bucket_lut[data[2]]
node = dll_append(bucket, data)
self.cache[data[0]] = node
def set(self, key, value):
# special case for capacity <= 0
if self.capacity <= 0:
return
# Does the key exist in the cache?
node = self.cache.get(key)
if node:
# Update the value and counter
counter = node[2][2]
self.remove_node(node)
self.add_node((key, value, counter + 1))
return
if len(self.cache) >= self.capacity:
# Remove the least used, least recently accessed
min_counter = min(self.bucket_lut.keys())
bucket = self.bucket_lut[min_counter]
self.remove_node(bucket[1])
self.add_node((key, value, 1))
def get(self, key):
node = self.cache.get(key)
if node is None:
return -1
key, value, counter = node[2]
self.remove_node(node)
self.add_node((key, value, counter + 1))
return value
unit_test(cls=LFUCache2)
Sanity check. The set operation SHOULD increment the counter. Append the node, and then sync the linked list.
%lprun -f benchmark -f LFUCache2.get -f LFUCache2.set \
-f LFUCache2.add_node -f LFUCache2.remove_node \
benchmark(LFUCache2(2048))
The new approach takes only 0.35s to run!
from collections import namedtuple
def dll_insert_before(succedent, data):
# Create a node to host data, and insert BEFORE the succedent
node = [succedent[0], succedent, data]
succedent[0][1] = node
succedent[0] = node
return node
def dll_insert_after(precedent, data):
# Create a node to host data, and insert AFTER the precedent
node = [precedent, precedent[1], data]
precedent[1][0] = node
precedent[1] = node
return node
# head points to the doubly-linked list of Node
Bucket = namedtuple('Bucket', ['counter', 'head'])
Item = namedtuple('Item',['key', 'value', 'bucket_node'])
class LFUCache3(object):
def __init__(self, capacity, *args):
self.capacity = capacity
self.bucket_head = dll_init()
# data storage
self.cache = dict()
def remove_node(self, node):
item = node[2]
self.cache.pop(item.key) # remove from cache
dll_remove(node) # remove node from the bucket
bucket = item.bucket_node[2]
# print('Remove %s from bucket(%d)' % (item.key, bucket.counter))
if bucket.head[1] is bucket.head:
# remove the bucket if empty
# print('Remove bucket(%s)' % bucket.counter)
dll_remove(item.bucket_node)
def add_node(self, key, value, original_bucket_node):
'''Add the (key, value) content pulled from orginal_bucket_node
to a new bucket'''
counter = 0 if original_bucket_node is self.bucket_head \
else (original_bucket_node[2].counter)
next_bucket_node = original_bucket_node[1]
if next_bucket_node is self.bucket_head:
# No bucket(counter + k) exists, append a new bucket(counter + 1)
bucket = Bucket(counter + 1, dll_init())
bucket_node = dll_append(self.bucket_head, bucket)
# print('Append bucket(%s)' % (counter + 1))
elif next_bucket_node[2].counter != counter + 1:
# bucket(counter + k) exist, insert bucket(counter + 1) BEFORE
# next_bucket_node
bucket = Bucket(counter + 1, dll_init())
bucket_node = dll_insert_before(next_bucket_node, bucket)
# print('Insert bucket(%s)' % (counter + 1))
else:
# bucket(counter + 1) exists, use it
bucket = next_bucket_node[2]
bucket_node = next_bucket_node
# Create the item, append it to the bucket and add to the cache.
# print('Add %s to bucket(%s)' % (key, bucket.counter))
item = Item(key, value, bucket_node)
self.cache[key] = dll_append(bucket.head, item)
def set(self, key, value):
# special case for capacity <= 0
if self.capacity <= 0:
return
# Does the key exist in the cache?
node = self.cache.get(key)
if node:
item = node[2]
self.remove_node(node)
self.add_node(item.key, value, item.bucket_node)
return
if len(self.cache) >= self.capacity:
# Apply LRFU alogrithm here!
bucket = self.bucket_head[1][2]
self.remove_node(bucket.head[1])
self.add_node(key, value, self.bucket_head)
def get(self, key):
node = self.cache.get(key)
if node is None:
return -1
item = node[2]
self.remove_node(node)
self.add_node(item.key, item.value, item.bucket_node)
return item.value
unit_test(cls=LFUCache3)
Sanity check. The set operation SHOULD increment the counter. Append the node, and then sync the linked list.
%lprun -f benchmark -f LFUCache3.get -f LFUCache3.set \
-f LFUCache3.add_node -f LFUCache3.remove_node \
benchmark(LFUCache3(2048))
The LFUCache3 has slightly better performance in the synethetic benchmark, but performs worse
in leetcode’s tests.