事件在pipeline中传播时如何跳过非必须的handler?

Netty 中通过在 pipeline 上添加各种 handler 组合来实现不同的逻辑,handler 又可以分为 ChannelInboundHandlerChannelOutboundHandler,它们分别用于处理入站事件和出站事件。

下图所示是 ChannelInboundHandler 的方法列表,对应它支持处理的各种入站事件,例如 channelRegistered 用于处理 registered 事件

同样地,ChannelOutboundHandler 提供了用于处理出站事件的方法,如下图所示,例如 bind 方法用于处理 channel 的绑定事件

因为每个 handler 有自己的职责,它可能并不关心所有的事件,而只关心自己感兴趣的事件。

当一个 pipeline 上添加的 handler 变多时,调用链路也会相应变长,这时就会引起一些问题。一方面调用栈会比较深,事件处理过程中占用内存增多;另一方面调用耗时也会增加。如果 handler 不关心某些事件,只是做向后或向前传播,这种情况下如果能跳过这些handler,则会使得实际的调用链路变短,起到很好地优化效果。

Netty 中使用 @Skip 注解标志是否要跳过该事件的处理,例如 ChannelInboundHandlerAdapter 中对 ChannelInboundHandler 的实现均添加了 @Skip 注解,如下是其 channelRegistered 方法的实现

1
2
3
4
5
6
7
8
9
10
11
/**
* Calls {@link ChannelHandlerContext#fireChannelRegistered()} to forward
* to the next {@link ChannelInboundHandler} in the {@link ChannelPipeline}.
*
* Sub-classes may override this method to change behavior.
*/
@Skip
@Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
ctx.fireChannelRegistered();
}

如果自定义的 handler 继承了 ChannelInboundHandlerAdapter ,并且重写了 ChannelInboundHandler 提供的某个事件处理方法。当把它添加到 pipeline 上后,它只会处理重写过的方法对应的事件,不会参与其他事件的处理。

Skip注解是如何生效的呢?

ChannelRegistered 事件为例,当 channel 完成注册之后,它会调用 DefaultChannelPipelinefireChannelRegistered 方法传播 ChannelRegistered 事件。

fireChannelRegistered的调用路径

最终会调用到 head.fireChannelRegistered() 方法,这里可以分成两步来看

  1. findContextInbound(MASK_CHANNEL_REGISTERED) 用于找到下一个支持处理 channelRegistered 事件的 InboundHandler

  2. invokeChannelRegistered 用于调用该 handler 的 channelRegistered 方法

findContextInbound的处理逻辑

findContextInbound 用于向后查找 pipeline 上的下一个符合条件的 InboundHandler,它的代码如下

1
2
3
4
5
6
7
8
9
// AbstractChannelHandlerContext#findContextInbound
private AbstractChannelHandlerContext findContextInbound(int mask) {
AbstractChannelHandlerContext ctx = this;
EventExecutor currentExecutor = executor();
do {
ctx = ctx.next;
} while (skipContext(ctx, currentExecutor, mask, MASK_ONLY_INBOUND));
return ctx;
}

它的入参是一个 mask,用于代表 ChannelInboundHandlerChannelOutboundHandler 的不同方法,例如这里的 MASK_CHANNEL_REGISTERED 代表的是 ChannelInboundHandlerchannelRegistered 方法

skipContext 用于判断是否要跳过 handler 所在的 handlerContext,它的代码如下

1
2
3
4
5
6
7
8
9
10
11
// AbstractChannelHandlerContext#skipContext
private static boolean skipContext(
AbstractChannelHandlerContext ctx, EventExecutor currentExecutor, int mask, int onlyMask) {
// Ensure we correctly handle MASK_EXCEPTION_CAUGHT which is not included in the MASK_EXCEPTION_CAUGHT
return (ctx.executionMask & (onlyMask | mask)) == 0 ||
// We can only skip if the EventExecutor is the same as otherwise we need to ensure we offload
// everything to preserve ordering.
//
// See https://github.com/netty/netty/issues/10067
(ctx.executor() == currentExecutor && (ctx.executionMask & mask) == 0);
}

注意 (ctx.executionMask & mask) == 0 这段代码,其中 ctx.executionMask 表示对应的 handler 的 mask 值, mask = MASK_CHANNEL_REGISTERED 表示 channelRegistered 方法。如果两者进行与运算结果为0,就说明该 handler 需要跳过 mask 对应的方法,即跳过 channelRegistered 方法。

HeadContext 调用 findContextInbound(MASK_CHANNEL_REGISTERED) 时,它的作用是从当前 ChannelHandlerContext 也就是 HeadContext 开始,往后找下一个不需要跳过 channelRegistered 方法的 ChannelHandlerContext

ctx.executionMask的计算

直接看代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
AbstractChannelHandlerContext(DefaultChannelPipeline pipeline, EventExecutor executor,
String name, Class<? extends ChannelHandler> handlerClass) {
this.name = ObjectUtil.checkNotNull(name, "name");
this.pipeline = pipeline;
this.executor = executor;
// 计算mask的值
this.executionMask = mask(handlerClass);
// Its ordered if its driven by the EventLoop or the given Executor is an instanceof OrderedEventExecutor.
ordered = executor == null || executor instanceof OrderedEventExecutor;
}

// ChannelHandlerMask#mask
static int mask(Class<? extends ChannelHandler> clazz) {
// Try to obtain the mask from the cache first. If this fails calculate it and put it in the cache for fast
// lookup in the future.
Map<Class<? extends ChannelHandler>, Integer> cache = MASKS.get();
Integer mask = cache.get(clazz);
if (mask == null) {
// 是在这里计算的
mask = mask0(clazz);
cache.put(clazz, mask);
}
return mask;
}

可以看到是通过 ChannelHandlerMaskmask0 方法计算的,它的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
private static int mask0(Class<? extends ChannelHandler> handlerType) {
int mask = MASK_EXCEPTION_CAUGHT;
try {
if (ChannelInboundHandler.class.isAssignableFrom(handlerType)) {
mask |= MASK_ALL_INBOUND;

if (isSkippable(handlerType, "channelRegistered", ChannelHandlerContext.class)) {
mask &= ~MASK_CHANNEL_REGISTERED;
}
if (isSkippable(handlerType, "channelUnregistered", ChannelHandlerContext.class)) {
mask &= ~MASK_CHANNEL_UNREGISTERED;
}
if (isSkippable(handlerType, "channelActive", ChannelHandlerContext.class)) {
mask &= ~MASK_CHANNEL_ACTIVE;
}
if (isSkippable(handlerType, "channelInactive", ChannelHandlerContext.class)) {
mask &= ~MASK_CHANNEL_INACTIVE;
}
if (isSkippable(handlerType, "channelRead", ChannelHandlerContext.class, Object.class)) {
mask &= ~MASK_CHANNEL_READ;
}
if (isSkippable(handlerType, "channelReadComplete", ChannelHandlerContext.class)) {
mask &= ~MASK_CHANNEL_READ_COMPLETE;
}
if (isSkippable(handlerType, "channelWritabilityChanged", ChannelHandlerContext.class)) {
mask &= ~MASK_CHANNEL_WRITABILITY_CHANGED;
}
if (isSkippable(handlerType, "userEventTriggered", ChannelHandlerContext.class, Object.class)) {
mask &= ~MASK_USER_EVENT_TRIGGERED;
}
}

if (ChannelOutboundHandler.class.isAssignableFrom(handlerType)) {
mask |= MASK_ALL_OUTBOUND;

if (isSkippable(handlerType, "bind", ChannelHandlerContext.class,
SocketAddress.class, ChannelPromise.class)) {
mask &= ~MASK_BIND;
}
if (isSkippable(handlerType, "connect", ChannelHandlerContext.class, SocketAddress.class,
SocketAddress.class, ChannelPromise.class)) {
mask &= ~MASK_CONNECT;
}
if (isSkippable(handlerType, "disconnect", ChannelHandlerContext.class, ChannelPromise.class)) {
mask &= ~MASK_DISCONNECT;
}
if (isSkippable(handlerType, "close", ChannelHandlerContext.class, ChannelPromise.class)) {
mask &= ~MASK_CLOSE;
}
if (isSkippable(handlerType, "deregister", ChannelHandlerContext.class, ChannelPromise.class)) {
mask &= ~MASK_DEREGISTER;
}
if (isSkippable(handlerType, "read", ChannelHandlerContext.class)) {
mask &= ~MASK_READ;
}
if (isSkippable(handlerType, "write", ChannelHandlerContext.class,
Object.class, ChannelPromise.class)) {
mask &= ~MASK_WRITE;
}
if (isSkippable(handlerType, "flush", ChannelHandlerContext.class)) {
mask &= ~MASK_FLUSH;
}
}

if (isSkippable(handlerType, "exceptionCaught", ChannelHandlerContext.class, Throwable.class)) {
mask &= ~MASK_EXCEPTION_CAUGHT;
}
} catch (Exception e) {
// Should never reach here.
PlatformDependent.throwException(e);
}

return mask;
}

ChannelInboundHandlerchannelRegistered 处理为例,mask 的初始值为 MASK_EXCEPTION_CAUGHT |= MASK_ALL_INBOUND,表示所有的 InboundHandler 方法,如果需要跳过某个方法,则将对应的二进制位置位0

1
mask &= ~MASK_CHANNEL_REGISTERED

~MASK_CHANNEL_REGISTERED 用于对 MASK_CHANNEL_REGISTERED 的二进制位取反,也就是说将非目标位全部置位1。然后和 mask 值进行按位与运算,其结果就是将目标位置位0

再来看下 isSkippable 的代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
private static boolean isSkippable(
final Class<?> handlerType, final String methodName, final Class<?>... paramTypes) throws Exception {
return AccessController.doPrivileged(new PrivilegedExceptionAction<Boolean>() {
@Override
public Boolean run() throws Exception {
Method m;
try {
// 查找对应的方法
m = handlerType.getMethod(methodName, paramTypes);
} catch (NoSuchMethodException e) {
if (logger.isDebugEnabled()) {
logger.debug(
"Class {} missing method {}, assume we can not skip execution", handlerType, methodName, e);
}
return false;
}

// 判断是否包含@Skip注解
return m != null && m.isAnnotationPresent(Skip.class);
}
});
}

它会查找 channelHandler 上的指定方法,判断是否包含 @Skip 注解,若包含该注解,则说明需要跳过。

mask0 对所有的 InboundHandler 方法和 OutboundHandler 方法进行处理,得出一个最终结果,作为该 handler 的 mask 值。