ThreadLocal 的实现原理

在多线程环境下,如果要访问共享变量,通常要使用同步机制来保证访问的正确性。如果变量不需要共享,可以将其移到单个线程内,这样就不需要同步操作。ThreadLocal 把线程和要使用的对象关联起来,线程自己保存一份独立的副本,从而实现了数据访问的安全性。

使用方法

先看一个简单的示例,该示例来自于「Java并发编程的艺术」这本书

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public class Profiler {
public static final ThreadLocal<Long> TIME_THREADLOCAL = new ThreadLocal<Long>();

public static final void begin() {
TIME_THREADLOCAL.set(System.currentTimeMillis());
}

public static final long end() {
return System.currentTimeMillis() - TIME_THREADLOCAL.get();
}

public static void main(String[] args) throws InterruptedException {
Profiler.begin();
TimeUnit.SECONDS.sleep(1);
System.out.println("Cost: " + Profiler.end() + " mills");
}
}

Profile 类包含 begin()end()两个方法,可以用于耗时统计,在线程内任务开始前调用 begin() 方法记录开始时间,在结束时调用 end() 方法获取耗时。

ThreadLocal 的数据保存在线程中,所以对其方法的调用不限制在一个方法中,需满足在同一个线程中即可。

原理分析

直接看源码实现也可以理解实现原理,但是容易忘记。这里换种思路,假如说要从头实现 ThreadLocal 这样一个功能,应该怎么去做呢?我们可以一步一步的去分析

保存变量

ThreadLocal 把线程和要使用的对象关联起来,线程自己保存一份独立的副本

每个线程自己保存一份变量的副本,可在线程 Thread 类里设置一个属性 value 来实现。

Thread-Object value;+void setValue(Object value);+Object getValue();

这样就可以调用 setValue 设置线程变量,调用 getValue 方法获取线程变量。

功能封装

通过 Thread 类操作线程变量,每次需要先调用 Thread.currentThread() 获取当前线程对象,这部分代码可以封装起来。新增一个 ThreadLocal 类,封装了线程变量的操作功能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public class ThreadLocal {
public static Object get() {
// 获取当前线程
Thread thread = Thread.currentThread();

// 获取线程变量
Object value = thread.getValue();
return value;
}

public static void set(Object value) {
// 获取当前线程
Thread thread = Thread.currentThread();

// 设置线程变量
thread.setValue(value);
}
}

这样在使用时只需要调用 ThreadLocal.get()ThreadLocal.set() 来操作线程变量,更加方便。

1
2
3
4
5
// 设置变量值
int value = 2;
ThreadLocal.set(value);
// 获取变量值
int result = ThreadLocal.get();

多个线程变量

上面实现了线程变量的功能,但是每个线程只能保存一个变量,这在使用时有很大的局限性,如果有超过一个变量需要保存就无能为力了。

接口定义

为支持多个变量,可使用 ThreadLocal 指定要保存的变量和类型,通过 ThreadLocal 变量进行操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
public class ThreadLocal<T> {
public T get() {
// 获取当前线程
Thread thread = Thread.currentThread();

// 获取线程变量
T value = getFromThread();
return value;
}

private T getFromThread(Thread thread) {
// TODO 从线程中获取变量
}

public void set(T value) {
// 获取当前线程
Thread thread = Thread.currentThread();

// 设置线程变量
setToThread(value);
}

private void setToThread(Thread thread) {
// TODO 设置线程变量
}
}

这里将 ThreadLocalgetset 方法改为了通过对象调用,根据 ThreadLocal 对象的不同确定要操作的变量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 定义两个 ThreadLocal
ThreadLocal<Integer> iThreadLocal = new ThreadLocal<Integer>();
ThreadLocal<String> sThreadLocal = new ThreadLocal<String>();

// 设置 int 变量
int number = 2;
iThreadLocal.set(number);

// 设置 String 变量
String str = "gorden5566";
sThreadLocal.set(str);

// 获取变量
System.out.println(iThreadLocal.get());
System.out.println(sThreadLocal.get());

数据结构

要保存多个变量,首先要对数据结构进行调整。再看下获取变量的代码

1
2
3
private T getFromThread(Thread thread) {
// TODO 从线程中获取变量
}

这里是从 Thread 对象中获取线程变量,具体来讲是:根据调用者 ThreadLocalThread 对象中获取对应的变量。伪代码如下:

1
2
// 根据调用者 ThreadLocal 去 Thread 对象中获取对应的变量
T t = thread.get(threadLocal);

很显然可以使用 map 保存数据,map 的 key 为 ThreadLocal 对象, map 的 value 为要保存的变量。因此定义结构如下:

Thread#Map<ThreadLocal, Object> threadLocalMap;

将原来的 Object value 改为 Map<ThreadLocal, Object> threadLocalMap ,用于保存多个变量。

这样 getFromThread 的功能也可以实现了,setToThread 方法类似

1
2
3
4
5
6
7
8
9
private T getFromThread(Thread thread) {
// 从线程中获取变量
Map<ThreadLocal, Object> map = getMap(thread);
return map.get(this);
}

Map<ThreadLocal, Object> getMap(Thread thread) {
return thread.threadLocalMap;
}

使用

通常情况下把 ThreadLocal 变量定义为类的静态属性,通过静态方法暴露外部接口,实现工具类功能。

1
private static final ThreadLocal<Integer> INTEGER_THREADLOCAL = new ThreadLocal<Integer>();

总结

ThreadLocal 的实现原理基本上如上文所讲,但是实际代码比这要复杂很多,需要注意的是用于保存变量的 map 是一个定制化的 HashMap:ThreadLocalMap,它的 key 使用了 WeakReference 以支持垃圾回收。

引用关系图如下:

ObjectThreadEntryThreadLocal key;Object value;ThreadLocalMapEntry[] table;ThreadLocal

参考

程晓明,方腾飞,魏鹏. Java并发编程的艺术