一种在线程池中透传或继承ThreadLocal信息的方法

在实际的业务代码中,经常会使用到 ThreadLocal 用于跨业务代码 来获取在上游设置的值。比如,在spring mvc中 spring web mvc 中通过 RequestContextHolder 设置 HttpServletRequest,业务代码则可以在 controller 或者是 service中 通过 RequestContextHolder#getRequestAttributes 获取相应的对象. 但这种方法有一个限制,即 setValue 和 getValue 的代码必须在同一个线程内. 当然,这也是属于通过 ThreadLocal 来避免竞争的一种手法.

针对之前已经可以工作的代码,如果将相应的业务代码 迁移至一个新的线程池中运行,即封装为 1个 runnable 对象,那么相应的代码即不能正确地工作了。

如下参考所示

ThreadLocal<String> local = new ThreadLocal<>();
local.set("value1");

//打印 获取->value1
System.out.println("获取->" + local.get());

Runnable runnable = () -> {
    System.out.println("获取->" + local.get());
};
//打印 获取->null
new Thread(runnable).start();

上面的 sout 可以是任意1个业务上的 method 调用。仅仅是将 method调用 转由新线程来运行,相应的业务逻辑即不能正常工作。 至于这里将调用转由新线程来运行,可以有很多的场景。如 支持 timedout 调用(利用future.get(timed))。

本文通过反射调用读取当前Thread的信息,将值注入到新线程中的threadLocal中,以达到透传threadLocal的目的。不需要修改任何业务代码,也不需要使用InheritableThreadLocal(此类也并不用于当前场景)

主要的实现基于以下步骤

  • 提前提取当前线程ThreadLocal变量值
  • 执行时复制至新线程中
  • 执行结束之后删除未变化值

本文中的代码均基于反射调用,需要打开相应的 Accessible 属性.

本文用到的类,方法,字段如下所示

private static Class C_THREAD_LOCAL_MAP = forName("java.lang.ThreadLocal$ThreadLocalMap");
private static Class C_THREAD_LOCAL_MAP_ENTRY = forName("java.lang.ThreadLocal$ThreadLocalMap$Entry");
private static Field F_THREAD_$_THREAD_LOCALS = findField(Thread.class, "threadLocals");
private static Field F_THREAD_LOCAL_MAP_$_TABLE = findField(C_THREAD_LOCAL_MAP, "table");
private static Field F_THREAD_LOCAL_MAP_ENTRY_$_VALUE = findField(C_THREAD_LOCAL_MAP_ENTRY, "value");
private static Method M_THREAD_LOCAL_MAP_$_SET = findMethodByName(C_THREAD_LOCAL_MAP, "set");
private static Method M_THREAD_LOCAL_MAP_$_GET_ENTRY = findMethodByName(C_THREAD_LOCAL_MAP, "getEntry");
private static Method M_THREAD_LOCAL_MAP_$_REMOVE = findMethodByName(C_THREAD_LOCAL_MAP, "remove");

提前提取当前线程ThreadLocal变量值

待执行的runnable对象从添加至线程池到实际执行时,时间上差距可能较大。很有可能 调用线程中的数据已经变更了。在业务中实际是将调用线程创建 对象时的 Local变量值复制至 执行线程中,因此这时需要提前将相应的数据提取出来. 相应的代码如下参考所示:

Thread parentThread = Thread.currentThread();
Object parentThreadLocalsMap = ReflectionUtils.getField(F_THREAD_$_THREAD_LOCALS, parentThread);
WeakReference[] parentThreadLocalMapTables = (WeakReference[]) ReflectionUtils.getField(F_THREAD_LOCAL_MAP_$_TABLE, parentThreadLocalsMap);

//因为run实际调用时,可能已经很晚了,因此这里需要提前把当前threadLocal中的数据提取出来
List<Pair<ThreadLocal, Object>> copiedList = Lists.newArrayList();

for(WeakReference parentWeakRef : parentThreadLocalMapTables) {
    //... 跳过 null值 parentWeakRef
    ThreadLocal threadLocal = (ThreadLocal) parentWeakRef.get();
    //... 跳过 null值 threadLocal

    Object parentWeakRefValue = ReflectionUtils.getField(F_THREAD_LOCAL_MAP_ENTRY_$_VALUE, parentWeakRef);
    //... 跳过 null值 parentWeakRefValue
    copiedList.add(Pair.of(threadLocal, parentWeakRefValue));
}

执行时复制至新线程中

在执行线程实际执行时,因为之前已经记录要待复制的threadLocal变量,这里直接复制至新的Thread即可。不过,在处理中需要有一些额外的判定。比如,如果执行线程已经有相应的值,则不允许处理,即避免覆盖数据的情况出现. 同时,如果执行线程,还没有ThreadLocal变量,则不能获取相应的对象,这里需要进行dummy化处理. 如下参考所示:

Thread currentThread = Thread.currentThread();
Object currentThreadLocalsMap = ReflectionUtils.getField(F_THREAD_$_THREAD_LOCALS, currentThread);

//如果没有,则表示当前线程还没有 threadLocal值, 这里通过显示地set 触发 threadLocals 变量的创建
if(currentThreadLocalsMap == null) {
    DUMMY.set("");
    currentThreadLocalsMap = ReflectionUtils.getField(F_THREAD_$_THREAD_LOCALS, currentThread);
}

//这里使用新list记录实际上复制的值,与之前的待处理列表不一样
List<Pair<ThreadLocal, Object>> setedList = Lists.newArrayList();

for(Pair<ThreadLocal, Object> p : copiedList) {
    ThreadLocal threadLocal = p.getFirst();
    Object parentWeakRefValue = p.getSecond();

    Object currentThreadLocalEntry = ReflectionUtils.invokeMethod(M_THREAD_LOCAL_MAP_$_GET_ENTRY, currentThreadLocalsMap, threadLocal);
    //仅限当前线程无此值的情况下才设置
    if(currentThreadLocalEntry == null || ReflectionUtils.getField(F_THREAD_LOCAL_MAP_ENTRY_$_VALUE, currentThreadLocalEntry) == null) {
        setedList.add(p);
        ReflectionUtils.invokeMethod(M_THREAD_LOCAL_MAP_$_SET, currentThreadLocalsMap, threadLocal, parentWeakRefValue);
    }
}

执行结束之后删除未变化值

执行结束之后,一般采用 try finally 来处理,这里同样使用此种方法处理. 未变化值,即重新读取相应的数据,如果数据未发生变更,即表示在执行过程中并没有使用到或者是仅仅是读取,而没有进一步作其它处理。这些数据我们认为是可以删除,并且应该删除的。即属于额外的数据信息,这样避免线程池在执行新的线程时存在污染信息.

try{
    //执行线程
} finally {
    //... DUMMY变量处理

    for(Pair<ThreadLocal, Object> p : setedList) {
        Object currentThreadLocalEntry = ReflectionUtils.invokeMethod(M_THREAD_LOCAL_MAP_$_GET_ENTRY, currentThreadLocalsMap, p.getFirst());
        //... null 判断

        Object currentThreadLocalEntryValue = ReflectionUtils.getField(F_THREAD_LOCAL_MAP_ENTRY_$_VALUE, currentThreadLocalEntry);
        //仅当值相同时,才处理,避免当前线程内可能已经重新设置值的情况
        if(currentThreadLocalEntryValue == p.getSecond()) {
            ReflectionUtils.invokeMethod(M_THREAD_LOCAL_MAP_$_REMOVE, currentThreadLocalsMap, p.getFirst());
        }
    }
}

总结

以上的整个代码可以封装为 工具类方法,如 Runnable buildInherited(Runnable runnable), 即将一个 runnable 对象转换为另1个 runnable对象,这样经过转换之后,即拥有 ThreadLocal 值继承的能力. 如开头的代码

new Thread(runnable).start();
//变更后
new Thread(buildInherited(runnable)).start();

这样调整之后,即能成功地打印出 创建者线程中设置的 threadLocal值了。并且,如果在 执行线程中,对ThreadLocal 重新赋值,也不会改变外层线程的数据. 在一些需要线程池介入的业务代码中,通过这种手法,可以不需要对业务代码作大调整,之前的逻辑仍然可以正常地运行.

转载请标明出处:i flym
本文地址:https://www.iflym.com/index.php/code/202006020001.html

相关文章:

作者: flym

I am flym,the master of the site:)

发表评论

邮箱地址不会被公开。 必填项已用*标注