LFU Cache

Design and implement a data structure for Least Frequently Used (LFU) cache. It should support the following operations: get and set.

get(key) - Get the value (will always be positive) of the key if the key exists in the cache, otherwise return -1.

set(key, value) - Set or insert the value if the key is not already present. When the cache reaches its capacity, it should invalidate the least frequently used item before inserting a new item. For the purpose of this problem, when there is a tie (i.e., two or more keys that have the same frequency), the least recently used key would be evicted.

Solution

Just as the LRU cache, the get operation will also modify the internal state. More concretely, we need to track the number of accesses(read and write) of each key, and keep them in order to guarntee the O(1)O(1) complexity of the key invalidation.

The dict is chosen as the internal data store to achieve the O(1)O(1) complexity for get and set operations. The access frequency is tracked in a doubly linked list. Without losing the genercity, we keep the doubly linked list in the descending order, the node will be moved towards the head after the specific key is accessed(get and set).

Here is the doubly linked list in python:

def dll_init():
    head = []
    head[:] = [head, head, None]
    return head

def dll_append(head, value):
    last = head[0]
    node = [last, head, value]
    last[1] = head[0] = node
    return node
    
def dll_remove(node):
    prev_link, next_link, _ = node
    prev_link[1] = next_link
    next_link[0] = prev_link
    return node

def dll_iter(root):
    curr = root[1]              # start at the first node
    while curr is not root:
        yield curr[2]           # yield the curr[KEY]
        curr = curr[1]          # move to next node

Add some basic unit test cases:

def unit_test(cls, update_func=None):
    print('Sanity check.')
    cache =  cls(2, update_func);
    cache.set(1, 1)
    cache.set(2, 2)
    assert cache.get(1) == 1
    cache.set(3, 3);    # evicts key 2
    assert cache.get(2) == -1 
    assert cache.get(3) == 3
    cache.set(4, 4);    # evicts key 1.
    assert cache.get(1) == -1
    assert cache.get(3) == 3 
    assert cache.get(4) == 4
    
    print('The set operation SHOULD increment the counter.')
    cache = cls(2, update_func)
    cache.set(3, 1)
    cache.set(2, 1)
    cache.set(2, 2)
    cache.set(4, 4)
    assert cache.get(2) == 2
    
    print('Append the node, and then sync the linked list.')
    cache = cls(3, update_func)
    for x in [1, 2, 3, 4]:
        cache.set(x, x)

    assert cache.get(4) == 4
    assert cache.get(3) == 3
    assert cache.get(2) == 2
    assert cache.get(1) == -1

    cache.set(5, 5)

    assert cache.get(1) == -1
    assert cache.get(2) == 2
    assert cache.get(3) == 3
    assert cache.get(4) == -1
    assert cache.get(5) == 5

The first attempt

The first attemp add the recently-access ndoe to the tail of the doubly linked list, then call bubble_update to bubble it up by swapping the node with its precedence(the LFUCache has been rewritten in the plugable fashion to test different update policy).

def bubble_update(node, head):
    '''Update the access counter of the node, then bubble it up
    by swapping the node with its precedence.'''
    # update access count for node
    prev, next_, (key, value, counter) = node
    counter += 1
    node[2] = (key, value, counter)
    while(prev is not head and prev[2][2] <= counter):
        # swap prev and node
        prev2 = prev[0]
        prev2[1] = node

        node[0] = prev2
        node[1] = prev

        prev[0] = node
        prev[1] = next_

        next_[0] = prev

        # reset all variables for the next loop
        next_ = prev
        prev = prev2
        prev2 = prev[0]

class LFUCache(object):
    def __init__(self, capacity, update_func):
        self.capacity = capacity
        self.update_freq = update_func
        self.cache = dict()
        self.head = dll_init()
    
    def set(self, key, value):
        # special case for capacity <= 0
        if self.capacity <= 0:
            return
        
        # Does the key exist in the cache?
        node = self.cache.get(key)
        if node:
            key, _, counter = node[2]
            node[2] = (key, value, counter)
            self.update_freq(node, self.head)
            return

        # Remove the LRU key if exceeding the capacity
        if len(self.cache) >= self.capacity:
            # remove the last element
            node = dll_remove(self.head[0])
            self.cache.pop(node[2][0])

        # append to the end of double-linked list
        node = dll_append(self.head, (key, value, 0))
        self.cache[key] = node
        self.update_freq(node, self.head)
            
    def get(self, key):
        node = self.cache.get(key)
        if node is None:
            return -1
        self.update_freq(node, self.head)
        return node[2][1]
    
    def dump(self):
        print(list(dll_iter(self.head)))
        

unit_test(cls=LFUCache, update_func=bubble_update)
Sanity check.
The set operation SHOULD increment the counter.
Append the node, and then sync the linked list.

The second approach

It is quite expensive to swap a node with its neighbour for doubly linked list, insertion is a much more efficient operation for the linked list:

def insert_update(node, head):
    '''Update the access counter, then find the pivot from the
    tail of the linked list, insert the node AFTER the pivot.'''
    # update access count for node
    pivot, _, (key, value, counter) = node
    counter += 1
    node[2] = (key, value, counter)
    while(pivot is not head and pivot[2][2] <= counter):
        pivot = pivot[0]
        
    # Insert the node AFTER the pivot
    if pivot is not node[0]:
        # remove the node from the linked list
        dll_remove(node)
        # insert after the pivot
        node[1] = pivot[1]
        node[0] = pivot
        pivot[1][0] = node
        pivot[1] = node

unit_test(cls=LFUCache, update_func=insert_update)
Sanity check.
The set operation SHOULD increment the counter.
Append the node, and then sync the linked list.

Here is a json blob from LeetCode’s test case, and we can benchmark our implementation:

%load_ext line_profiler

import json

def benchmark(cache):
    lut = {
        'set': cache.set,
        'get': cache.get
    }
    with open('lfu-cache-test-fixture.json') as f:
        fixture = json.load(f)
        for method, args in zip(*fixture):
            lut[method](*args)
%lprun -f benchmark -f LFUCache.get -f LFUCache.set -f insert_update \
    benchmark(LFUCache(2048, insert_update))

It takes 8.45s to run the benchmark, and 99% of the CPU cycles are wasted in the linear search for the pivot. We must find a more efficient way to update the frequency.

The third attempt

The lru paper presents a new approach:

  1. A double-linked list is used as the frequency list.
  2. Each node of the frequency list points to a double-linked list, value list for all the keys with the same access frequency.
  3. A hashmap will maps the key to the node, just as descibed above.

Thus, when a node is accessed, we first remove it from the value list, then find the corresponding node in the frequency list, or insert a new node if necessary; the node is inserted to the top of the value list at the end.

When we need to evict a new node, we just remove the last node in the least node of the frequence list.

from collections import defaultdict

    
class LFUCache2(object):
    def __init__(self, capacity, *args):
        self.capacity = capacity
        self.bucket_lut = defaultdict(dll_init)
        self.cache = dict()
        
    def remove_node(self, node):
        # Remove the node from the cache and also the bucket
        _, _, (key, value, counter) = node
        dll_remove(node)
        self.cache.pop(key)
        # clean up the freq_head if it is empty
        bucket = self.bucket_lut[counter]
        if bucket[1] is bucket:
            self.bucket_lut.pop(counter)
            
    def add_node(self, data):
        # Create a node to host the data, add it to cache and 
        # append to the bucket.
        bucket = self.bucket_lut[data[2]]
        node = dll_append(bucket, data)
        self.cache[data[0]] = node
        
    def set(self, key, value):
        # special case for capacity <= 0
        if self.capacity <= 0:
            return
        
        # Does the key exist in the cache?
        node = self.cache.get(key)
        if node:
            # Update the value and counter
            counter = node[2][2]
            self.remove_node(node)
            self.add_node((key, value, counter + 1))
            return
        
        if len(self.cache) >= self.capacity:
            # Remove the least used, least recently accessed
            min_counter = min(self.bucket_lut.keys())
            bucket = self.bucket_lut[min_counter]
            self.remove_node(bucket[1])
            
        self.add_node((key, value, 1))
            
    def get(self, key):
        node = self.cache.get(key)
        if node is None:
            return -1
        
        key, value, counter = node[2]
        self.remove_node(node)
        self.add_node((key, value, counter + 1))
        return value


unit_test(cls=LFUCache2)
Sanity check.
The set operation SHOULD increment the counter.
Append the node, and then sync the linked list.
%lprun -f benchmark -f LFUCache2.get -f LFUCache2.set \
    -f LFUCache2.add_node -f LFUCache2.remove_node \
    benchmark(LFUCache2(2048))

The new approach takes only 0.35s to run!

from collections import namedtuple


def dll_insert_before(succedent, data):
    # Create a node to host data, and insert BEFORE the succedent
    node = [succedent[0], succedent, data]
    succedent[0][1] = node
    succedent[0] = node
    return node


def dll_insert_after(precedent, data):
    # Create a node to host data, and insert AFTER the precedent
    node = [precedent, precedent[1], data]
    precedent[1][0] = node
    precedent[1] = node
    return node


# head points to the doubly-linked list of Node
Bucket = namedtuple('Bucket', ['counter', 'head'])
Item = namedtuple('Item',['key', 'value', 'bucket_node'])

class LFUCache3(object):
    def __init__(self, capacity, *args):
        self.capacity = capacity
        self.bucket_head = dll_init()
        # data storage
        self.cache = dict()
        
    def remove_node(self, node):
        item = node[2]
        self.cache.pop(item.key)   # remove from cache
        dll_remove(node)           # remove node from the bucket
        
        bucket = item.bucket_node[2]
        # print('Remove %s from bucket(%d)' % (item.key, bucket.counter))
        
        
        if bucket.head[1] is bucket.head:
            # remove the bucket if empty
            # print('Remove bucket(%s)' % bucket.counter)
            dll_remove(item.bucket_node)
            
    def add_node(self, key, value, original_bucket_node):
        '''Add the (key, value) content pulled from orginal_bucket_node
        to a new bucket'''
        counter = 0 if original_bucket_node is self.bucket_head \
            else (original_bucket_node[2].counter)
        next_bucket_node = original_bucket_node[1]
        
        if next_bucket_node is self.bucket_head:
            # No bucket(counter + k) exists, append a new bucket(counter + 1)
            bucket = Bucket(counter + 1, dll_init())
            bucket_node = dll_append(self.bucket_head, bucket)
            # print('Append bucket(%s)' % (counter + 1))
        elif next_bucket_node[2].counter != counter + 1:
            # bucket(counter + k) exist, insert bucket(counter + 1) BEFORE
            # next_bucket_node
            bucket = Bucket(counter + 1, dll_init())
            bucket_node = dll_insert_before(next_bucket_node, bucket)
            # print('Insert bucket(%s)' % (counter + 1))
        else:
            # bucket(counter + 1) exists, use it
            bucket = next_bucket_node[2]
            bucket_node = next_bucket_node
            
        # Create the item, append it to the bucket and add to the cache.
        # print('Add %s to bucket(%s)' % (key, bucket.counter))
        item = Item(key, value, bucket_node)
        self.cache[key] = dll_append(bucket.head, item)


    def set(self, key, value):
        # special case for capacity <= 0
        if self.capacity <= 0:
            return
        
        # Does the key exist in the cache?
        node = self.cache.get(key)
        if node:
            item = node[2]
            self.remove_node(node)
            self.add_node(item.key, value, item.bucket_node)
            return
        
        if len(self.cache) >= self.capacity:
            # Apply LRFU alogrithm here!
            bucket = self.bucket_head[1][2]
            self.remove_node(bucket.head[1])
            
        self.add_node(key, value, self.bucket_head)
            
    def get(self, key):
        node = self.cache.get(key)
        if node is None:
            return -1
        
        item = node[2]
        self.remove_node(node)
        self.add_node(item.key, item.value, item.bucket_node)
        return item.value


unit_test(cls=LFUCache3)
Sanity check.
The set operation SHOULD increment the counter.
Append the node, and then sync the linked list.
%lprun -f benchmark -f LFUCache3.get -f LFUCache3.set \
    -f LFUCache3.add_node -f LFUCache3.remove_node \
    benchmark(LFUCache3(2048))

The LFUCache3 has slightly better performance in the synethetic benchmark, but performs worse in leetcode’s tests.