• 极客专栏正式上线!欢迎访问 https://www.jikewenku.com/topic.html
  • 极客专栏正式上线!欢迎访问 https://www.jikewenku.com/topic.html

ThreadLocal的使用及原理分析

技术杂谈 勤劳的小蚂蚁 3个月前 (02-01) 62次浏览 已收录 0个评论 扫描二维码

 

文章简介

 

 

ThreadLocal应该都比较熟悉,这篇文章会基于ThreadLocal的应用以及实现原理做一个全面的分析

 

 

内容导航

 

 

  • 什么是ThreadLocal
  • ThreadLocal的使用
  • 分析ThreadLocal的实现原理
  • ThreadLocal的应用场景及问题

 

01

 

什么是ThreadLocal

 

ThreadLocal,简单翻译过来就是本地线程,但是直接这么翻译很难理解ThreadLocal的作用,如果换一种说法,可以称为线程本地存储。简单来说,就是ThreadLocal为共享变量在每个线程中都创建一个副本,每个线程可以访问自己内部的副本变量。这样做的好处是可以保证共享变量在多线程环境下访问的线程安全性

 

02

 

ThreadLocal的使用演示

 

ThreadLocal的使用

没有使用ThreadLocal时

通过一个简单的例子来演示一下ThreadLocal的作用,这段代码是定义了一个静态的成员变量 num,然后通过构造5个线程对这个 num做递增
  1. publicclassThreadLocalDemo{
  2.    privatestaticInteger num=0;
  3.    publicstaticvoid main(String[] args){
  4.        Thread[] threads=newThread[5];
  5.        for(int i=0;i<5;i++){
  6.            threads[i]=newThread(()->{
  7.               num+=5;
  8.               System.out.println(Thread.currentThread().getName()+" : "+num);
  9.            },"Thread-"+i);
  10.        }
  11.        for(Thread thread:threads){
  12.            thread.start();
  13.        }
  14.    }
  15. }
运行结果
  1. Thread-0:5
  2. Thread-1:10
  3. Thread-2:15
  4. Thread-3:20
  5. Thread-4:25
每个线程都会对这个成员变量做递增,如果线程的执行顺序不确定,那么意味着每个线程获得的结果也是不一样的。

使用了ThreadLocal以后

通过ThreadLocal对上面的代码做一个改动
  1. publicclassThreadLocalDemo{
  2.    privatestaticfinalThreadLocal<Integer> local=newThreadLocal<Integer>(){
  3.        protectedInteger initialValue(){
  4.            return0;//通过initialValue方法设置默认值
  5.        }
  6.    };
  7.    publicstaticvoid main(String[] args){
  8.        Thread[] threads=newThread[5];
  9.        for(int i=0;i<5;i++){
  10.            threads[i]=newThread(()->{
  11.                int num=local.get().intValue();
  12.                num+=5;
  13.               System.out.println(Thread.currentThread().getName()+" : "+num);
  14.            },"Thread-"+i);
  15.        }
  16.        for(Thread thread:threads){
  17.            thread.start();
  18.        }
  19.    }
  20. }
运行结果
  1. Thread-0:5
  2. Thread-4:5
  3. Thread-2:5
  4. Thread-1:5
  5. Thread-3:5
从结果可以看到,每个线程的值都是5,意味着各个线程都是从ThreadLocal的 initialValue方法中拿到默认值0并且做了 num+=5的操作,同时也意味着每个线程从ThreadLocal中拿到的值都是0,这样使得各个线程对于共享变量num来说,是完全隔离彼此不相互影响.
ThreadLocal会给定一个初始值,也就是 initialValue()方法,而每个线程都会从ThreadLocal中获得这个初始化的值的副本,这样可以使得每个线程都拥有一个副本拷贝

 

03

 

从源码分析ThreadLocal的实现

 

看到这里,估计有很多人都会和我一样有一些疑问
  1. 每个线程的变量副本是怎么存储的?
  2. ThreadLocal是如何实现多线程场景下的共享变量副本隔离?
带着疑问,来看一下ThreadLocal这个类的定义(默认情况下,JDK的源码都是基于1.8版本)
从ThreadLocal的方法定义来看,还是挺简单的。就几个方法
  • get: 获取ThreadLocal中当前线程对应的线程局部变量
  • set:设置当前线程的线程局部变量的值
  • remove:将当前线程局部变量的值删除
另外,还有一个initialValue()方法,在前面的代码中有演示,作用是返回当前线程局部变量的初始值,这个方法是一个 protected方法,主要是在构造ThreadLocal时用于设置默认的初始值

set方法的实现

set方法是设置一个线程的局部变量的值,相当于当前线程通过set设置的局部变量的值,只对当前线程可见。
  1.    publicvoid set(T value){
  2.        Thread t =Thread.currentThread();//获取当前执行的线程
  3.        ThreadLocalMap map = getMap(t);//获得当前线程的ThreadLocalMap实例
  4.        if(map !=null)//如果map不为空,说明当前线程已经有了一个ThreadLocalMap实例
  5.            map.set(this, value);//直接将当前value设置到ThreadLocalMap中
  6.        else
  7.            createMap(t, value);//说明当前线程是第一次使用线程本地变量,构造map
  8.    }
  • Thread.currentThread 获取当前执行的线程
  • getMap(t) ,根据当前线程得到当前线程的ThreadLocalMap对象,这个对象具体是做什么的?稍后分析
  • 如果map不为空,说明当前线程已经构造过ThreadLocalMap,直接将值存储到map中
  • 如果map为空,说明是第一次使用,调用 createMap构造

ThreadLocalMap是什么?

我们来分析一下这句话, ThreadLocalMapmap=getMap(t)获得一个ThreadLocalMap对象,那这个对象是干嘛的呢? 其实不用分析,基本上也能猜测出来,Map是一个集合,集合用来存储数据,那么在ThreadLocal中,应该就是用来存储线程的局部变量的。 ThreadLocalMap这个类很关键。
  1.    ThreadLocalMap getMap(Thread t){
  2.        return t.threadLocals;
  3.    }
t.threadLocals实际上就是访问Thread类中的ThreadLocalMap这个成员变量
  1. public
  2. classThreadimplementsRunnable{
  3. /* ThreadLocal values pertaining to this thread. This map is maintained
  4.     * by the ThreadLocal class. */
  5.    ThreadLocal.ThreadLocalMap threadLocals =null;
  6. ...
  7. }
从上面的代码发现每一个线程都有自己单独的ThreadLocalMap实例,而对应这个线程的所有本地变量都会保存到这个map内

ThreadLocalMap是在哪里构造?

set方法中,有一行代码 createmap(t,value);,这个方法就是用来构造ThreadLocalMap,从传入的参数来看,它的实现逻辑基本也能猜出出几分吧
  1.    void createMap(Thread t, T firstValue){
  2.        t.threadLocals =newThreadLocalMap(this, firstValue);
  3.    }
Threadt 是通过 Thread.currentThread()来获取的表示当前线程,然后直接通过 newThreadLocalMap将当前线程中的 threadLocals做了初始化 ThreadLocalMap是一个静态内部类,内部定义了一个Entry对象用来真正存储数据
  1. staticclassThreadLocalMap{
  2.        staticclassEntryextendsWeakReference<ThreadLocal<?>>{
  3.            /** The value associated with this ThreadLocal. */
  4.            Object value;
  5.            Entry(ThreadLocal<?> k,Object v){
  6.                super(k);
  7.                value = v;
  8.            }
  9.        }
  10.        ThreadLocalMap(ThreadLocal<?> firstKey,Object firstValue){
  11.            //构造一个Entry数组,并设置初始大小
  12.            table =newEntry[INITIAL_CAPACITY];
  13.            //计算Entry数据下标
  14.            int i = firstKey.threadLocalHashCode &(INITIAL_CAPACITY -1);
  15.            //将`firstValue`存入到指定的table下标中
  16.            table[i]=newEntry(firstKey, firstValue);
  17.            size =1;//设置节点长度为1
  18.            setThreshold(INITIAL_CAPACITY);//设置扩容的阈值
  19.        }
  20.    //...省略部分代码
  21. }
分析到这里,基本知道了ThreadLocalMap长啥样了,也知道它是如何构造的?那么我看到这里的时候仍然有疑问
  • Entry集成了 WeakReference,这个表示什么意思?
  • 在构造ThreadLocalMap的时候 newThreadLocalMap(this,firstValue);,key其实是this,this表示当前对象的引用,在当前的案例中,this指的是 ThreadLocal<Integer>local。那么多个线程对应同一个ThreadLocal实例,怎么对每一个ThreadLocal对象做区分呢?

解惑WeakReference

weakReference表示弱引用,在Java中有四种引用类型,强引用、弱引用、软引用、虚引用。 使用弱引用的对象,不会阻止它所指向的对象被垃圾回收器回收。
在Java语言中, 当一个对象o被创建时, 它被放在Heap里. 当GC运行的时候, 如果发现没有任何引用指向o, o就会被回收以腾出内存空间. 也就是说, 一个对象被回收, 必须满足两个条件:
  • 没有任何引用指向它
  • GC被运行.
这段代码中,构造了两个对象a,b,a是对象DemoA的引用,b是对象DemoB的引用,对象DemoB同时还依赖对象DemoA,那么这个时候我们认为从对象DemoB是可以到达对象DemoA的。这种称为强可达(strongly reachable)
  1. DemoA a=newDemoA();
  2. DemoB b=newDemoB(a);
如果我们增加一行代码来将a对象的引用设置为null,当一个对象不再被其他对象引用的时候,是会被GC回收的,但是对于这个场景来说,即时是a=null,也不可能被回收,因为DemoB依赖DemoA,这个时候是可能造成内存泄漏的
  1. DemoA a=newDemoA();
  2. DemoB b=newDemoB(a);
  3. a=null;
通过弱引用,有两个方法可以避免这样的问题
  1. //方法1
  2. DemoA a=newDemoA();
  3. DemoB b=newDemoB(a);
  4. a=null;
  5. b=null;
  6. //方法2
  7. DemoA a=newDemoA();
  8. WeakReference b=newWeakReference(a);
  9. a=null;
对于方法2来说,DemoA只是被弱引用依赖,假设垃圾收集器在某个时间点决定一个对象是弱可达的(weakly reachable)(也就是说当前指向它的全都是弱引用),这时垃圾收集器会清除所有指向该对象的弱引用,然后把这个弱可达对象标记为可终结(finalizable)的,这样它随后就会被回收。
试想一下如果这里没有使用弱引用,意味着ThreadLocal的生命周期和线程是强绑定,只要线程没有销毁,那么ThreadLocal一直无法回收。而使用弱引用以后,当ThreadLocal被回收时,由于Entry的key是弱引用,不会影响ThreadLocal的回收防止内存泄漏,同时,在后续的源码分析中会看到,ThreadLocalMap本身的垃圾清理会用到这一个好处,方便对无效的Entry进行回收

解惑ThreadLocalMap以this作为key

在构造ThreadLocalMap时,使用this作为key来存储,那么对于同一个ThreadLocal对象,如果同一个Thread中存储了多个值,是如何来区分存储的呢? 答案就在 firstKey.threadLocalHashCode&(INITIAL_CAPACITY-1)
  1. void createMap(Thread t, T firstValue){
  2.        t.threadLocals =newThreadLocalMap(this, firstValue);
  3. }
  4. ThreadLocalMap(ThreadLocal<?> firstKey,Object firstValue){
  5.            table =newEntry[INITIAL_CAPACITY];
  6.            int i = firstKey.threadLocalHashCode &(INITIAL_CAPACITY -1);
  7.            table[i]=newEntry(firstKey, firstValue);
  8.            size =1;
  9.            setThreshold(INITIAL_CAPACITY);
  10. }
关键点是 threadLocalHashCode,它相当于一个ThreadLocal的ID,实现的逻辑如下
  1. privatefinalint threadLocalHashCode = nextHashCode();
  2. privatestaticAtomicInteger nextHashCode =
  3.        newAtomicInteger();
  4. privatestaticfinalint HASH_INCREMENT =0x61c88647;
  5. privatestaticint nextHashCode(){
  6.    return nextHashCode.getAndAdd(HASH_INCREMENT);
  7. }
这里用到了一个非常完美的散列算法,可以简单理解为,对于同一个ThreadLocal下的多个线程来说,当任意线程调用set方法存入一个数据到Entry中的时候,其实会根据 threadLocalHashCode生成一个唯一的id标识对应这个数据,存储在Entry数据下标中。
  • threadLocalHashCode是通过
    nextHashCode.getAndAdd(HASH_INCREMENT)来实现的 
    i*HASH_INCREMENT+HASH_INCREMENT,每次新增一个元素(ThreadLocal)到Entry[],都会自增0x61c88647,目的为了让哈希码能均匀的分布在2的N次方的数组里
  • Entry[i]= hashCode & (length-1)

魔数0x61c88647

从上面的分析可以看出,它是在上一个被构造出的ThreadLocal的threadLocalHashCode的基础上加上一个魔数0x61c88647。我们来做一个实验,看看这个散列算法的运算结果
  1.    privatestaticfinalint HASH_INCREMENT =0x61c88647;
  2.    publicstaticvoid main(String[] args){
  3.        magicHash(16);
  4.        magicHash(32);
  5.    }
  6.    privatestaticvoid magicHash(int size){
  7.        int hashCode =0;
  8.        for(int i=0;i<size;i++){
  9.            hashCode = i*HASH_INCREMENT+HASH_INCREMENT;
  10.            System.out.print((hashCode &(size-1))+" ");
  11.        }
  12.        System.out.println();
  13.    }
输出结果
  1. 7145123101815613411290
  2. 714212831017243161320272916233051219261815222941118250
根据运行结果,这个算法在长度为2的N次方的数组上,确实可以完美散列,没有任何冲突, 是不是很神奇。
魔数0x61c88647的选取和斐波那契散列有关,0x61c88647对应的十进制为1640531527。而斐波那契散列的乘数可以用 (long)((1L<<31)*(Math.sqrt(5)-1)); 如果把这个值给转为带符号的int,则会得到-1640531527。也就是说(long)((1L<<31)*(Math.sqrt(5)-1));得到的结果就是1640531527,也就是魔数0x61c88647
  1. //(根号5-1)*2的31次方=(根号5-1)/2 *2的32次方=黄金分割数*2的32次方
  2. long l1 =(long)((1L<<31)*(Math.sqrt(5)-1));
  3. System.out.println("32位无符号整数: "+ l1);
  4. int i1 =(int) l1;
  5. System.out.println("32位有符号整数:   "+ i1);
总结,我们用0x61c88647作为魔数累加为每个ThreadLocal分配各自的ID也就是threadLocalHashCode再与2的幂取模,得到的结果分布很均匀。

图形分析

为了更直观的体现 set方法的实现,通过一个图形表示如下

set剩余源码分析

前面分析了set方法第一次初始化ThreadLocalMap的过程,也对ThreadLocalMap的结构有了一个全面的了解。那么接下来看一下map不为空时的执行逻辑
  1. privatevoid set(ThreadLocal<?> key,Object value){
  2.            Entry[] tab = table;
  3.            int len = tab.length;
  4.            // 根据哈希码和数组长度求元素放置的位置,即数组下标
  5.            int i = key.threadLocalHashCode &(len-1);
  6.             //从i开始往后一直遍历到数组最后一个Entry(线性探索)
  7.            for(Entry e = tab[i];
  8.                 e !=null;
  9.                 e = tab[i = nextIndex(i, len)]){
  10.                ThreadLocal<?> k = e.get();
  11.                 //如果key相等,覆盖value
  12.                if(k == key){
  13.                    e.value = value;
  14.                    return;
  15.                }
  16.                 //如果key为null,用新key、value覆盖,同时清理历史key=null的陈旧数据
  17.                if(k ==null){
  18.                    replaceStaleEntry(key, value, i);
  19.                    return;
  20.                }
  21.            }
  22.            tab[i]=newEntry(key, value);
  23.            int sz =++size;
  24.             //如果超过阀值,就需要扩容了
  25.            if(!cleanSomeSlots(i, sz)&& sz >= threshold)
  26.                rehash();
  27.        }
主要逻辑
  • 根据key的散列哈希计算Entry的数组下标
  • 通过线性探索探测从i开始往后一直遍历到数组的最后一个Entry
  • 如果map中的key和传入的key相等,表示该数据已经存在,直接覆盖
  • 如果map中的key为空,则用新的key、value覆盖,并清理key=null的数据
  • rehash扩容

replaceStaleEntry

由于Entry的key为弱引用,如果key为空,说明ThreadLocal这个对象被GC回收了。 replaceStaleEntry的作用就是把陈旧的Entry进行替换
  1. privatevoid replaceStaleEntry(ThreadLocal<?> key,Object value,
  2.                                       int staleSlot){
  3.            Entry[] tab = table;
  4.            int len = tab.length;
  5.            Entry e;
  6.           //向前扫描,查找最前一个无效的slot
  7.            int slotToExpunge = staleSlot;
  8.            for(int i = prevIndex(staleSlot, len);
  9.                 (e = tab[i])!=null;
  10.                 i = prevIndex(i, len))
  11.                if(e.get()==null)
  12.                   //通过循环遍历,可以定位到最前面一个无效的slot
  13.                    slotToExpunge = i;
  14.            //从i开始往后一直遍历到数组最后一个Entry(线性探索)
  15.            for(int i = nextIndex(staleSlot, len);
  16.                 (e = tab[i])!=null;
  17.                 i = nextIndex(i, len)){
  18.                ThreadLocal<?> k = e.get();
  19.                //找到匹配的key以后
  20.                if(k == key){
  21.                    e.value = value;//更新对应slot的value值
  22.                    //与无效的sloat进行交换
  23.                    tab[i]= tab[staleSlot];
  24.                    tab[staleSlot]= e;
  25.                    //如果最早的一个无效的slot和当前的staleSlot相等,则从i作为清理的起点
  26.                    if(slotToExpunge == staleSlot)
  27.                        slotToExpunge = i;
  28.                    //从slotToExpunge开始做一次连续的清理
  29.                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
  30.                    return;
  31.                }
  32.                //如果当前的slot已经无效,并且向前扫描过程中没有无效slot,则更新slotToExpunge为当前位置
  33.                if(k ==null&& slotToExpunge == staleSlot)
  34.                    slotToExpunge = i;
  35.            }
  36.            //如果key对应的value在entry中不存在,则直接放一个新的entry
  37.            tab[staleSlot].value =null;
  38.            tab[staleSlot]=newEntry(key, value);
  39.           //如果有任何一个无效的slot,则做一次清理
  40.            if(slotToExpunge != staleSlot)
  41.                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
  42.        }

cleanSomeSlots

这个函数有两处地方会被调用,用于清理无效的Entry
  • 插入的时候可能会被调用
  • 替换无效slot的时候可能会被调用
区别是前者传入的n为元素个数,后者为table的容量
  1. privateboolean cleanSomeSlots(int i,int n){
  2.            boolean removed =false;
  3.            Entry[] tab = table;
  4.            int len = tab.length;
  5.            do{
  6.                 // i在任何情况下自己都不会是一个无效slot,所以从下一个开始判断
  7.                i = nextIndex(i, len);
  8.                Entry e = tab[i];
  9.                if(e !=null&& e.get()==null){
  10.                    n = len;// 扩大扫描控制因子
  11.                    removed =true;
  12.                    i = expungeStaleEntry(i);// 清理一个连续段
  13.                }
  14.            }while((n >>>=1)!=0);
  15.            return removed;
  16.        }

expungeStaleEntry

执行一次全量清理
  1. privateint expungeStaleEntry(int staleSlot){
  2.            Entry[] tab = table;
  3.            int len = tab.length;
  4.            // expunge entry at staleSlot
  5.            tab[staleSlot].value =null;//删除value
  6.            tab[staleSlot]=null;//删除entry
  7.            size--;//map的size递减
  8.            // Rehash until we encounter null
  9.            Entry e;
  10.            int i;
  11.            for(i = nextIndex(staleSlot, len);// 遍历指定删除节点,所有后续节点
  12.                 (e = tab[i])!=null;
  13.                 i = nextIndex(i, len)){
  14.                ThreadLocal<?> k = e.get();
  15.                if(k ==null){//key为null,执行删除操作
  16.                    e.value =null;
  17.                    tab[i]=null;
  18.                    size--;
  19.                }else{//key不为null,重新计算下标
  20.                    int h = k.threadLocalHashCode &(len -1);
  21.                    if(h != i){//如果不在同一个位置
  22.                        tab[i]=null;//把老位置的entry置null(删除)
  23.                        // 从h开始往后遍历,一直到找到空为止,插入
  24.                        while(tab[h]!=null)
  25.                            h = nextIndex(h, len);
  26.                        tab[h]= e;
  27.                    }
  28.                }
  29.            }
  30.            return i;
  31.        }

get操作

set的逻辑分析完成以后,get的源码分析就很简单了
  1. public T get(){
  2.        Thread t =Thread.currentThread();
  3.        //从当前线程中获取ThreadLocalMap
  4.        ThreadLocalMap map = getMap(t);
  5.        if(map !=null){
  6.            //查询当前ThreadLocal变量实例对应的Entry
  7.            ThreadLocalMap.Entry e = map.getEntry(this);
  8.            if(e !=null){//获取成功,直接返回
  9.                @SuppressWarnings("unchecked")
  10.                T result =(T)e.value;
  11.                return result;
  12.            }
  13.        }
  14.        //如果map为null,即还没有初始化,走初始化方法
  15.        return setInitialValue();
  16.    }

setInitialValue

根据 initialValue()的value初始化ThreadLocalMap
  1.    private T setInitialValue(){
  2.        T value = initialValue();//protected方法,用户可以重写
  3.        Thread t =Thread.currentThread();
  4.        ThreadLocalMap map = getMap(t);
  5.        if(map !=null)
  6.            //如果map不为null,把初始化value设置进去
  7.            map.set(this, value);
  8.        else
  9.            //如果map为null,则new一个map,并把初始化value设置进去
  10.            createMap(t, value);
  11.        return value;
  12.    }
  • 从当前线程中获取ThreadLocalMap,查询当前ThreadLocal变量实例对应的Entry,如果不为null,获取value,返回
  • 如果map为null,即还没有初始化,走初始化方法

remove方法

remove的方法比较简单,从Entry[]中删除指定的key就行
  1.     publicvoid remove(){
  2.         ThreadLocalMap m = getMap(Thread.currentThread());
  3.         if(m !=null)