ThreadLocal 的实现原理

在多线程环境下,如果要访问共享变量,通常要使用同步机制来保证访问的正确性。如果变量不需要共享,可以将其移到单个线程内,这样就不需要同步操作。ThreadLocal 把线程和要使用的对象关联起来,线程自己保存一份独立的副本,从而实现了数据访问的安全性。

使用方法

先看一个简单的示例,该示例来自于「Java并发编程的艺术」这本书

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public class Profiler {
public static final ThreadLocal<Long> TIME_THREADLOCAL = new ThreadLocal<Long>();

public static final void begin() {
TIME_THREADLOCAL.set(System.currentTimeMillis());
}

public static final long end() {
return System.currentTimeMillis() - TIME_THREADLOCAL.get();
}

public static void main(String[] args) throws InterruptedException {
Profiler.begin();
TimeUnit.SECONDS.sleep(1);
System.out.println("Cost: " + Profiler.end() + " mills");
}
}

Profile 类包含 begin()end()两个方法,可以用于耗时统计,在线程内任务开始前调用 begin() 方法记录开始时间,在结束时调用 end() 方法获取耗时。

ThreadLocal 的数据保存在线程中,所以对其方法的调用不限制在一个方法中,需满足在同一个线程中即可。

原理分析

直接看源码实现也可以理解实现原理,但是容易忘记。这里换种思路,假如说要从头实现 ThreadLocal 这样一个功能,应该怎么去做呢?我们可以一步一步的去分析

保存变量

ThreadLocal 把线程和要使用的对象关联起来,线程自己保存一份独立的副本

每个线程自己保存一份变量的副本,可在线程 Thread 类里设置一个属性 value 来实现。

这样就可以调用 setValue 设置线程变量,调用 getValue 方法获取线程变量。

功能封装

通过 Thread 类操作线程变量,每次需要先调用 Thread.currentThread() 获取当前线程对象,这部分代码可以封装起来。新增一个 ThreadLocal 类,封装了线程变量的操作功能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public class ThreadLocal {
public static Object get() {
// 获取当前线程
Thread thread = Thread.currentThread();

// 获取线程变量
Object value = thread.getValue();
return value;
}

public static void set(Object value) {
// 获取当前线程
Thread thread = Thread.currentThread();

// 设置线程变量
thread.setValue(value);
}
}

这样在使用时只需要调用 ThreadLocal.get()ThreadLocal.set() 来操作线程变量,更加方便。

1
2
3
4
5
// 设置变量值
int value = 2;
ThreadLocal.set(value);
// 获取变量值
int result = ThreadLocal.get();

多个线程变量

上面实现了线程变量的功能,但是每个线程只能保存一个变量,这在使用时有很大的局限性,如果有超过一个变量需要保存就无能为力了。

接口定义

为支持多个变量,可使用 ThreadLocal 指定要保存的变量和类型,通过 ThreadLocal 变量进行操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
public class ThreadLocal<T> {
public T get() {
// 获取当前线程
Thread thread = Thread.currentThread();

// 获取线程变量
T value = getFromThread();
return value;
}

private T getFromThread(Thread thread) {
// TODO 从线程中获取变量
}

public void set(T value) {
// 获取当前线程
Thread thread = Thread.currentThread();

// 设置线程变量
setToThread(value);
}

private void setToThread(Thread thread) {
// TODO 设置线程变量
}
}

这里将 ThreadLocalgetset 方法改为了通过对象调用,根据 ThreadLocal 对象的不同确定要操作的变量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 定义两个 ThreadLocal
ThreadLocal<Integer> iThreadLocal = new ThreadLocal<Integer>();
ThreadLocal<String> sThreadLocal = new ThreadLocal<String>();

// 设置 int 变量
int number = 2;
iThreadLocal.set(number);

// 设置 String 变量
String str = "gorden5566";
sThreadLocal.set(str);

// 获取变量
System.out.println(iThreadLocal.get());
System.out.println(sThreadLocal.get());

数据结构

要保存多个变量,首先要对数据结构进行调整。再看下获取变量的代码

1
2
3
private T getFromThread(Thread thread) {
// TODO 从线程中获取变量
}

这里是从 Thread 对象中获取线程变量,具体来讲是:根据调用者 ThreadLocalThread 对象中获取对应的变量。伪代码如下:

1
2
// 根据调用者 ThreadLocal 去 Thread 对象中获取对应的变量
T t = thread.get(threadLocal);

很显然可以使用 map 保存数据,map 的 key 为 ThreadLocal 对象, map 的 value 为要保存的变量。因此定义结构如下:

将原来的 Object value 改为 Map<ThreadLocal, Object> threadLocalMap ,用于保存多个变量。

这样 getFromThread 的功能也可以实现了,setToThread 方法类似

1
2
3
4
5
6
7
8
9
private T getFromThread(Thread thread) {
// 从线程中获取变量
Map<ThreadLocal, Object> map = getMap(thread);
return map.get(this);
}

Map<ThreadLocal, Object> getMap(Thread thread) {
return thread.threadLocalMap;
}

使用

通常情况下把 ThreadLocal 变量定义为类的静态属性,通过静态方法暴露外部接口,实现工具类功能。

1
private static final ThreadLocal<Integer> INTEGER_THREADLOCAL = new ThreadLocal<Integer>();

总结

ThreadLocal 的实现原理基本上如上文所讲,但是实际代码比这要复杂很多,需要注意的是用于保存变量的 map 是一个定制化的 HashMap:ThreadLocalMap,它的 key 使用了 WeakReference 以支持垃圾回收。

引用关系图如下:

参考

程晓明,方腾飞,魏鹏. Java并发编程的艺术

0%