• 极客专栏正式上线!欢迎访问 https://www.jikewenku.com/topic.html
  • 极客专栏正式上线!欢迎访问 https://www.jikewenku.com/topic.html

动手实现一个 LRU cache

技术杂谈 勤劳的小蚂蚁 3个月前 (01-29) 44次浏览 已收录 0个评论 扫描二维码

前言

LRU 是 LeastRecentlyUsed 的简写,字面意思则是 最近最少使用
通常用于缓存的淘汰策略实现,由于缓存的内存非常宝贵,所以需要根据某种规则来剔除数据保证内存不被撑满。
如常用的 Redis 就有以下几种策略:
策略描述
volatile-lru从已设置过期时间的数据集中挑选最近最少使用的数据淘汰
volatile-ttl从已设置过期时间的数据集中挑选将要过期的数据淘汰
volatile-random从已设置过期时间的数据集中任意选择数据淘汰
allkeys-lru从所有数据集中挑选最近最少使用的数据淘汰
allkeys-random从所有数据集中任意选择数据进行淘汰
no-envicition禁止驱逐数据

实现一

之前也有接触过一道面试题,大概需求是:
  • 实现一个 LRU 缓存,当缓存数据达到 N 之后需要淘汰掉最近最少使用的数据。
  • N 小时之内没有被访问的数据也需要淘汰掉。
以下是我的实现:
  1. publicclassLRUAbstractMapextends java.util.AbstractMap{
  2.    privatefinalstaticLogger LOGGER =LoggerFactory.getLogger(LRUAbstractMap.class);
  3.    /**
  4.     * 检查是否超期线程
  5.     */
  6.    privateExecutorService checkTimePool ;
  7.    /**
  8.     * map 最大size
  9.     */
  10.    privatefinalstaticint MAX_SIZE =1024;
  11.    privatefinalstaticArrayBlockingQueue<Node> QUEUE =newArrayBlockingQueue<>(MAX_SIZE);
  12.    /**
  13.     * 默认大小
  14.     */
  15.    privatefinalstaticint DEFAULT_ARRAY_SIZE =1024;
  16.    /**
  17.     * 数组长度
  18.     */
  19.    privateint arraySize ;
  20.    /**
  21.     * 数组
  22.     */
  23.    privateObject[] arrays ;
  24.    /**
  25.     * 判断是否停止 flag
  26.     */
  27.    privatevolatileboolean flag =true;
  28.    /**
  29.     * 超时时间
  30.     */
  31.    privatefinalstaticLong EXPIRE_TIME =60*60*1000L;
  32.    /**
  33.     * 整个 Map 的大小
  34.     */
  35.    privatevolatileAtomicInteger size  ;
  36.    publicLRUAbstractMap(){
  37.        arraySize = DEFAULT_ARRAY_SIZE;
  38.        arrays =newObject[arraySize];
  39.        //开启一个线程检查最先放入队列的值是否超期
  40.        executeCheckTime();
  41.    }
  42.    /**
  43.     * 开启一个线程检查最先放入队列的值是否超期 设置为守护线程
  44.     */
  45.    privatevoid executeCheckTime(){
  46.        ThreadFactory namedThreadFactory =newThreadFactoryBuilder()
  47.                .setNameFormat("check-thread-%d")
  48.                .setDaemon(true)
  49.                .build();
  50.        checkTimePool =newThreadPoolExecutor(1,1,0L,TimeUnit.MILLISECONDS,
  51.                newArrayBlockingQueue<>(1),namedThreadFactory,newThreadPoolExecutor.AbortPolicy());
  52.        checkTimePool.execute(newCheckTimeThread());
  53.    }
  54.    @Override
  55.    publicSet<Entry> entrySet(){
  56.        returnsuper.keySet();
  57.    }
  58.    @Override
  59.    publicObject put(Object key,Object value){
  60.        int hash = hash(key);
  61.        int index = hash % arraySize ;
  62.        Node currentNode =(Node) arrays[index];
  63.        if(currentNode ==null){
  64.            arrays[index]=newNode(null,null, key, value);
  65.            //写入队列
  66.            QUEUE.offer((Node) arrays[index]);
  67.            sizeUp();
  68.        }else{
  69.            Node cNode = currentNode ;
  70.            Node nNode = cNode ;
  71.            //存在就覆盖
  72.            if(nNode.key == key){
  73.                cNode.val = value ;
  74.            }
  75.            while(nNode.next !=null){
  76.                //key 存在 就覆盖 简单判断
  77.                if(nNode.key == key){
  78.                    nNode.val = value ;
  79.                    break;
  80.                }else{
  81.                    //不存在就新增链表
  82.                    sizeUp();
  83.                    Node node =newNode(nNode,null,key,value);
  84.                    //写入队列
  85.                    QUEUE.offer(currentNode);
  86.                    cNode.next = node ;
  87.                }
  88.                nNode = nNode.next ;
  89.            }
  90.        }
  91.        returnnull;
  92.    }
  93.    @Override
  94.    publicObject get(Object key){
  95.        int hash = hash(key);
  96.        int index = hash % arraySize ;
  97.        Node currentNode =(Node) arrays[index];
  98.        if(currentNode ==null){
  99.            returnnull;
  100.        }
  101.        if(currentNode.next ==null){
  102.            //更新时间
  103.            currentNode.setUpdateTime(System.currentTimeMillis());
  104.            //没有冲突
  105.            return currentNode ;
  106.        }
  107.        Node nNode = currentNode ;
  108.        while(nNode.next !=null){
  109.            if(nNode.key == key){
  110.                //更新时间
  111.                currentNode.setUpdateTime(System.currentTimeMillis());
  112.                return nNode ;
  113.            }
  114.            nNode = nNode.next ;
  115.        }
  116.        returnsuper.get(key);
  117.    }
  118.    @Override
  119.    publicObject remove(Object key){
  120.        int hash = hash(key);
  121.        int index = hash % arraySize ;
  122.        Node currentNode =(Node) arrays[index];
  123.        if(currentNode ==null){
  124.            returnnull;
  125.        }
  126.        if(currentNode.key == key){
  127.            sizeDown();
  128.            arrays[index]=null;
  129.            //移除队列
  130.            QUEUE.poll();
  131.            return currentNode ;
  132.        }
  133.        Node nNode = currentNode ;
  134.        while(nNode.next !=null){
  135.            if(nNode.key == key){
  136.                sizeDown();
  137.                //在链表中找到了 把上一个节点的 next 指向当前节点的下一个节点
  138.                nNode.pre.next = nNode.next ;
  139.                nNode =null;
  140.                //移除队列
  141.                QUEUE.poll();
  142.                return nNode;
  143.            }
  144.            nNode = nNode.next ;
  145.        }
  146.        returnsuper.remove(key);
  147.    }
  148.    /**
  149.     * 增加size
  150.     */
  151.    privatevoid sizeUp(){
  152.        //在put值时候认为里边已经有数据了
  153.        flag =true;
  154.        if(size ==null){
  155.            size =newAtomicInteger();
  156.        }
  157.        int size =this.size.incrementAndGet();
  158.        if(size >= MAX_SIZE){
  159.            //找到队列头的数据
  160.            Node node = QUEUE.poll();
  161.            if(node ==null){
  162.                thrownewRuntimeException("data error");
  163.            }
  164.            //移除该 key
  165.            Object key = node.key ;
  166.            remove(key);
  167.            lruCallback();
  168.        }
  169.    }
  170.    /**
  171.     * 数量减小
  172.     */
  173.    privatevoid sizeDown(){
  174.        if(QUEUE.size()==0){
  175.            flag =false;
  176.        }
  177.        this.size.decrementAndGet();
  178.    }
  179.    @Override
  180.    publicint size(){
  181.        return size.get();
  182.    }
  183.    /**
  184.     * 链表
  185.     */
  186.    privateclassNode{
  187.        privateNode next ;
  188.        privateNode pre ;
  189.        privateObject key ;
  190.        privateObject val ;
  191.        privateLong updateTime ;
  192.        publicNode(Node pre,Node next,Object key,Object val){
  193.            this.pre = pre ;
  194.            this.next = next;
  195.            this.key = key;
  196.            this.val = val;
  197.            this.updateTime =System.currentTimeMillis();
  198.        }
  199.        publicvoid setUpdateTime(Long updateTime){
  200.            this.updateTime = updateTime;
  201.        }
  202.        publicLong getUpdateTime(){
  203.            return updateTime;
  204.        }
  205.        @Override
  206.        publicString toString(){
  207.            return"Node{"+
  208.                    "key="+ key +
  209.                    ", val="+ val +
  210.                    '}';
  211.        }
  212.    }
  213.    /**
  214.     * copy HashMap 的 hash 实现
  215.     * @param key
  216.     * @return
  217.     */
  218.    publicint hash(Object key){
  219.        int h;
  220.        return(key ==null)?0:(h = key.hashCode())^(h >>>16);
  221.    }
  222.    privatevoid lruCallback(){
  223.        LOGGER.debug("lruCallback");
  224.    }
  225.    privateclassCheckTimeThreadimplementsRunnable{
  226.        @Override
  227.        publicvoid run(){
  228.            while(flag){
  229.                try{
  230.                    Node node = QUEUE.poll();
  231.                    if(node ==null){
  232.                        continue;
  233.                    }
  234.                    Long updateTime = node.getUpdateTime();
  235.                    if((updateTime -System.currentTimeMillis())>= EXPIRE_TIME){
  236.                        remove(node.key);
  237.                    }
  238.                }catch(Exception e){
  239.                    LOGGER.error("InterruptedException");
  240.                }
  241.            }
  242.        }
  243.    }
  244. }
感兴趣的朋友可以直接从:
下载代码本地运行。
代码看着比较多,其实实现的思路还是比较简单:
  • 采用了与 HashMap 一样的保存数据方式,只是自己手动实现了一个简易版。
  • 内部采用了一个队列来保存每次写入的数据。
  • 写入的时候判断缓存是否大于了阈值 N,如果满足则根据队列的 FIFO 特性将队列头的数据删除。因为队列头的数据肯定是最先放进去的。
  • 再开启了一个守护线程用于判断最先放进去的数据是否超期(因为就算超期也是最先放进去的数据最有可能满足超期条件。)
  • 设置为守护线程可以更好的表明其目的(最坏的情况下,如果是一个用户线程最终有可能导致程序不能正常退出,因为该线程一直在运行,守护线程则不会有这个情况。)
以上代码大体功能满足了,但是有一个致命问题。
就是最近最少使用没有满足,删除的数据都是最先放入的数据。
不过其中的 putget 流程算是一个简易的 HashMap 实现,可以对 HashMap 加深一些理解。

实现二

因此如何来实现一个完整的 LRU 缓存呢,这次不考虑过期时间的问题。
其实从上一个实现也能想到一些思路:
  • 要记录最近最少使用,那至少需要一个有序的集合来保证写入的顺序。
  • 在使用了数据之后能够更新它的顺序。
基于以上两点很容易想到一个常用的数据结构:链表
  1. 每次写入数据时将数据放入链表头结点。
  2. 使用数据时候将数据移动到头结点
  3. 缓存数量超过阈值时移除链表尾部数据。
因此有了以下实现:
  1. publicclassLRUMap<K, V>{
  2.    privatefinalMap<K, V> cacheMap =newHashMap<>();
  3.    /**
  4.     * 最大缓存大小
  5.     */
  6.    privateint cacheSize;
  7.    /**
  8.     * 节点大小
  9.     */
  10.    privateint nodeCount;
  11.    /**
  12.     * 头结点
  13.     */
  14.    privateNode<K, V> header;
  15.    /**
  16.     * 尾结点
  17.     */
  18.    privateNode<K, V> tailer;
  19.    publicLRUMap(int cacheSize){
  20.        this.cacheSize = cacheSize;
  21.        //头结点的下一个结点为空
  22.        header =newNode<>();
  23.        header.next =null;
  24.        //尾结点的上一个结点为空
  25.        tailer =newNode<>();
  26.        tailer.tail =null;
  27.        //双向链表 头结点的上结点指向尾结点
  28.        header.tail = tailer;
  29.        //尾结点的下结点指向头结点
  30.        tailer.next = header;
  31.    }
  32.    publicvoid put(K key, V value){
  33.        cacheMap.put(key, value);
  34.        //双向链表中添加结点
  35.        addNode(key, value);
  36.    }
  37.    public V get(K key){
  38.        Node<K, V> node = getNode(key);
  39.        //移动到头结点
  40.        moveToHead(node);
  41.        return cacheMap.get(key);
  42.    }
  43.    privatevoid moveToHead(Node<K,V> node){
  44.        //如果是最后的一个节点
  45.        if(node.tail ==null){
  46.            node.next.tail =null;
  47.            tailer = node.next ;
  48.            nodeCount --;
  49.        }
  50.        //如果是本来就是头节点 不作处理
  51.        if(node.next ==null){
  52.            return;
  53.        }
  54.        //如果处于中间节点
  55.        if(node.tail !=null&& node.next !=null){
  56.            //它的上一节点指向它的下一节点 也就删除当前节点
  57.            node.tail.next = node.next ;
  58.            nodeCount --;
  59.        }
  60.        //最后在头部增加当前节点
  61.        //注意这里需要重新 new 一个对象,不然原本的node 还有着下面的引用,会造成内存溢出。
  62.        node =newNode<>(node.getKey(),node.getValue());
  63.        addHead(node);
  64.    }
  65.    /**
  66.     * 链表查询 效率较低
  67.     * @param key
  68.     * @return
  69.     */
  70.    privateNode<K,V> getNode(K key){
  71.        Node<K,V> node = tailer ;
  72.        while(node !=null){
  73.            if(node.getKey().equals(key)){
  74.                return node ;
  75.            }
  76.            node = node.next ;
  77.        }
  78.        returnnull;
  79.    }
  80.    /**
  81.     * 写入头结点
  82.     * @param key
  83.     * @param value
  84.     */
  85.    privatevoid addNode(K key, V value){
  86.        Node<K, V> node =newNode<>(key, value);
  87.        //容量满了删除最后一个
  88.        if(cacheSize == nodeCount){
  89.            //删除尾结点
  90.            delTail();
  91.        }
  92.        //写入头结点
  93.        addHead(node);
  94.    }
  95.    /**
  96.     * 添加头结点
  97.     *
  98.     * @param node
  99.     */
  100.    privatevoid addHead(Node<K, V> node){
  101.        //写入头结点
  102.        header.next = node;
  103.        node.tail = header;
  104.        header = node;
  105.        nodeCount++;
  106.        //如果写入的数据大于2个 就将初始化的头尾结点删除
  107.        if(nodeCount ==2){
  108.            tailer.next.next.tail =null;
  109.            tailer = tailer.next.next;
  110.        }
  111.    }    
  112.    privatevoid delTail(){
  113.        //把尾结点从缓存中删除
  114.        cacheMap.remove(tailer.getKey());
  115.        //删除尾结点
  116.        tailer.next.tail =null;
  117.        tailer = tailer.next;
  118.        nodeCount--;
  119.    }
  120.    privateclassNode<K, V>{
  121.        private K key;
  122.        private V value;
  123.        Node<K, V> tail;
  124.        Node<K, V> next;
  125.        publicNode(K key, V value){
  126.            this.key = key;
  127.            this.value = value;
  128.        }
  129.        publicNode(){
  130.        }
  131.        public K getKey(){
  132.            return key;
  133.        }
  134.        publicvoid setKey(K key){
  135.            this.key = key;
  136.        }
  137.        public V getValue(){
  138.            return value;
  139.        }
  140.        publicvoid setValue(V value){
  141.            this.value = value;
  142.        }
  143.    }
  144.    @Override
  145.    publicString toString(){
  146.        StringBuilder sb =newStringBuilder();
  147.        Node<K,V> node = tailer ;
  148.        while(node !=null){
  149.            sb.append(node.getKey()).append(":")
  150.                    .append(node.getValue())
  151.                    .append("-->");
  152.            node = node.next ;
  153.        }
  154.        return sb.toString();
  155.    }
  156. }
源码: https://github.com/crossoverJie/Java-Interview/blob/master/src/main/java/com/crossoverjie/actual/LRUMap.java
实际效果,写入时:
  1.    @Test
  2.    publicvoid put()throwsException{
  3.        LRUMap<String,Integer> lruMap =newLRUMap(3);
  4.        lruMap.put("1",1);
  5.        lruMap.put("2",2);
  6.        lruMap.put("3",3);
  7.        System.out.println(lruMap.toString());
  8.        lruMap.put("4",4);
  9.        System.out.println(lruMap.toString());
  10.        lruMap.put("5",5);
  11.        System.out.println(lruMap.toString());
  12.    }
  13. //输出:
  14. 1:1-->2:2-->3:3-->
  15. 2:2-->3:3-->4:4-->
  16. 3:3-->4:4-->5:5-->
使用时:
  1.    @Test
  2.    publicvoid get()throwsException{
  3.        LRUMap<String,Integer> lruMap =newLRUMap(3);
  4.        lruMap.put("1",1);