• 解析CountDownLatch源码

    CountDownLatch「AQS共享模式」下的典型实现,这里贴出之前AQS共享模式的一张图:

    私有静态内部类Sync是一个AQS的实现:

    public class CountDownLatch {

        private static final class Sync extends AbstractQueuedSynchronizer {

            private static final long serialVersionUID = 4982264981922014374L;
            
            // 设置AQS的state的值为输入的计数值
            Sync(int count) {
                setState(count);
            }
            
            // 获取AQS中的state属性
            int getCount() {
                return getState();
            }
            
            // 共享模式下获取资源,这里无视共享模式下需要获取的资源数,只判断当前的state值是否为0,为0的时候,意味资源获取成功,闭锁已经释放,所有等待线程需要解除阻塞
            // 如果state当前已经为0,那么线程完全不会加入AQS同步队列中等待,表现为直接运行
            protected int tryAcquireShared(int acquires) {
                return (getState() == 0) ? 1 : -1;
            }
            
            // 共享模式下释放资源,这里也无视共享模式下需要释放的资源数,每次让状态值通过CAS减少1,当减少到0的时候,返回true
            protected boolean tryReleaseShared(int releases) {
                // Decrement count; signal when transition to zero
                for (;;) {
                    int c = getState();
                    // 这种情况下说明了当前state为0,从tryAcquireShared方法来看,线程不会加入AQS同步队列进行阻塞,所以也无须释放
                    if (c == 0)
                        return false;
                    // state的快照值减少1,并且通过CAS设置快照值更新为state,如果state减少为0则返回true,意味着需要唤醒阻塞线程
                    int nextc = c - 1;
                    if (compareAndSetState(c, nextc))
                        return nextc == 0;
                }
            }
        }
         
        // 输入的计数值不能小于0,意味着AQS的state属性必须大于等于0
        public CountDownLatch(int count) {
            if (count < 0throw new IllegalArgumentException("count < 0");
            this.sync = new Sync(count);
        }

        public void await() throws InterruptedException {
            // 共享模式下获取资源,响应中断
            sync.acquireSharedInterruptibly(1);
        }

        public boolean await(long timeout, TimeUnit unit)
            throws InterruptedException 
    {
            // 共享模式下获取资源,响应中断,带超时期限
            return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
        }

        public void countDown() {
            // 共享模式下释放资源
            sync.releaseShared(1);
        }

        public long getCount() {
            // 获取当前state的值
            return sync.getCount();
        }

        public String toString() {
            return super.toString() + "[Count = " + sync.getCount() + "]";
        }
    }

    接下来再分步解析每一个方法。先看构造函数:

    // 构造函数,其实就是对AQS的state进行赋值
    public CountDownLatch(int count) {
        if (count < 0throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    // 私有静态内部类Sync中
    private static final class Sync extends AbstractQueuedSynchronizer {

        Sync(int count) {
            setState(count);
        }

        // ......
    }

    // AbstractQueuedSynchronizer中
    public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {

        // ......
        
        // volatile修饰的状态值,变更会强制写回主内存,以便多线程环境下可见
        private volatile int state;

        // 调用的是这个父类方法
        protected final void setState(int newState) {
            state = newState;
        }

        // ......
    }

    由于AQS的头尾节点都是「懒创建」的,所以只初始化了state的情况下,AQS是"空的"。接着看await()方法:

    // CountDownLatch中
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    ↓↓↓↓↓↓↓↓↓↓↓↓
    // 私有静态内部类Sync中
    private static final class Sync extends AbstractQueuedSynchronizer {
        
        // state等于0的时候返回1,大于0的时候返回-1
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        // ......
    }

    ↓↓↓↓↓↓↓↓↓↓↓↓
    // AbstractQueuedSynchronizer中
    public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {

        // ......
        
        // 共享模式下获取资源,响应中断
        public final void acquireSharedInterruptibly(int arg)
                throws InterruptedException 
    {
            // 如果线程已经处于中断状态,则清空中断状态位,抛出InterruptedException
            if (Thread.interrupted())
                throw new InterruptedException();
            // 尝试获取资源,此方法由子类CountDownLatch中的Sync实现,小于0的时候,说明state > 0
            if (tryAcquireShared(arg) < 0)
                doAcquireSharedInterruptibly(arg);
        }

        // ......

        //
        private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {
            // 基于当前线程新建一个标记为共享的新节点
            final Node node = addWaiter(Node.SHARED);
            try {
                for (;;) {
                    // 获取新入队节点的前驱节点
                    final Node p = node.predecessor();
                    // 前驱节点为头节点
                    if (p == head) {
                        // 并且尝试获取资源成功,也就是每一轮循环都会调用tryAcquireShared尝试获取资源(r >= 0意味获取成功),除非阻塞或者跳出循环
                        // 由前文可知,CountDownLatch中只有当state = 0的情况下,r才会大于等于0
                        int r = tryAcquireShared(arg);
                        if (r >= 0) {
                            // 设置头结点,并且传播获取资源成功的状态,这个方法的作用是确保唤醒状态传播到所有的后继节点
                            // 然后任意一个节点晋升为头节点都会唤醒其第一个有效的后继节点,起到一个链式释放和解除阻塞的动作
                            setHeadAndPropagate(node, r);
                            // 由于节点晋升,原来的位置需要断开,置为NULL便于GC
                            p.next = null// help GC
                            return;
                        }
                    }
                    // shouldParkAfterFailedAcquire ->  判断获取资源失败是否需要阻塞,这里会把前驱节点的等待状态CAS更新为Node.SIGNAL
                    // parkAndCheckInterrupt -> 判断到获取资源失败并且需要阻塞,调用LockSupport.park()阻塞节点中的线程实例,(解除阻塞后)清空中断状态位并且返回该中断状态
                    if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt())
                        throw new InterruptedException();
                }
            } catch (Throwable t) {
                cancelAcquire(node);
                throw t;
            }
        }

        // ......
    }

    因为是工程化的代码,并且引入了死循环避免竞态条件下的异常,代码看起来比较复杂,其实做了下面几件事:

    接着看countDown()方法:

    // CountDownLatch中
    public void countDown() {
        sync.releaseShared(1);
    }

    ↓↓↓↓↓↓↓↓↓↓↓↓
    // 私有静态内部类Sync中
    private static final class Sync extends AbstractQueuedSynchronizer {
        
        // 共享模式下释放资源,这里也无视共享模式下需要释放的资源数,每次让状态值通过CAS减少1,当减少到0的时候,返回true
        protected boolean tryReleaseShared(int releases) {
            // 减少计数值state,直到变为0,则进行释放
            for (;;) {
                int c = getState();
                // 如果已经为0,直接返回false,不能再递减到小于0,返回false也意味着不会进入AQS的doReleaseShared()逻辑
                if (c == 0)
                    return false;
                int nextc = c - 1;
                // CAS原子更新state = state - 1
                if (compareAndSetState(c, nextc))
                    // 如果此次递减为0则返回true
                    return nextc == 0;
            }
        }

        // ......
    }

    ↓↓↓↓↓↓↓↓↓↓↓↓
    // AbstractQueuedSynchronizer中
    public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {
         
        // ......
        
        // 共享模式下,释放arg个资源
        public final boolean releaseShared(int arg) {
            // 从上面的分析来看,这里只有一种可能返回true并且进入doReleaseShared()方法,就是state由1递减为0的时候
            if (tryReleaseShared(arg)) {
                doReleaseShared();
                return true;
            }
            return false;
        }
        
        // 共享模式下的释放操作
        private void doReleaseShared() {
            // 死循环是避免因为新节点入队产生影响,CAS做状态设置被放在死循环中失败了会在下一轮循环中重试
            for (;;) {
                Node h = head;
                // 头不等于尾,也就是AQS同步等待队列不为空
                // h == NULL,说明AQS同步等待队列刚进行了初始化,并未有持有线程实例的节点
                if (h != null && h != tail) {
                    int ws = h.waitStatus;
                    // 头节点为Node.SIGNAL(-1),也就是后继节点需要唤醒,CAS设置头节点状态-1 -> 0,并且唤醒头节点的后继节点(也就是紧挨着头节点后的第一个节点)
                    if (ws == Node.SIGNAL) {
                        // 这个if分支是对于Node.SIGNAL状态的头节点,这种情况下,说明
                        // 这里使用CAS的原因是setHeadAndPropagate()方法和releaseXX()方法都会调用此doReleaseShared()方法,CAS也是并发控制的一种手段
                        // 如果CAS失败,很大可能是头节点发生了变更,需要进入下一轮循环更变头节点的引用再进行判断
                        // 该状态一定是由后继节点为当前节点设置的,具体见shouldParkAfterFailedAcquire()方法
                        if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
                            continue;            // loop to recheck cases
                        // 唤醒后继节点,如果有后继节点被唤醒,则后继节点会调用setHeadAndPropagate()方法,更变头节点和转播唤醒状态
                        unparkSuccessor(h);
                    }
                    // 头节点状态为0,说明头节点的后继节点未设置前驱节点的waitStatus为SIGNAL,代表无需唤醒
                    // CAS更新它的状态0 -> Node.PROPAGATE(-3),这个标识目的是为了把节点状态设置为跟Node.SIGNAL(-1)一样的负数值,
                    // 便于某个后继节点解除阻塞后,在一轮doAcquireSharedInterruptibly()循环中调用shouldParkAfterFailedAcquire()方法返回false,实现"链式唤醒"
                    else if (ws == 0 && !h.compareAndSetWaitStatus(0, Node.PROPAGATE))  
                        continue;                // loop on failed CAS
                }
                // 如果头节点未发生变化,则代表当前没有其他线程获取到资源,晋升为头节点,直接退出循环
                // 如果头节点已经发生变化,代表已经有线程(后继节点)获取到资源,
                if (h == head)                   // loop if head changed
                    break;
            }
        }

        // 解除传入节点的第一个后继节点的阻塞状态,当前处理节点的等待状态会被CAS更新为0
        private void unparkSuccessor(Node node) {
            // 当前处理的节点状态小于0则直接CAS更新为0
            int ws = node.waitStatus;
            if (ws < 0)
                node.compareAndSetWaitStatus(ws, 0);
            // 如果节点的第一个后继节点为null或者等待状态大于0(取消),则从等待队列的尾节点向前遍历,
            // 找到最后一个(这里指的是队列尾部->队列头部搜索路径的最后一个满足的节点,一般是传入的node节点的next节点)不为null,并且等待状态小于等于0的节点
            Node s = node.next;
            if (s == null || s.waitStatus > 0) {
                s = null;
                for (Node p = tail; p != node && p != null; p = p.prev)
                    if (p.waitStatus <= 0)
                        s = p;
            }
            // 解除传入节点的后继节点的阻塞状态,唤醒后继节点所存放的线程
            if (s != null)
                LockSupport.unpark(s.thread);
        }
        // ......
    }

    这里一定要注意,unparkSuccessor()只是会「唤醒当前传入节点参数的"正常的"后继节点,并不是唤醒同步队列中的所有阻塞节点」。头节点的后继节点被唤醒之后,该节点所在的线程会解除阻塞,在doAcquireSharedInterruptibly()方法中被唤醒,解除阻塞后进入下一轮循环,然后调用setHeadAndPropagate()唤醒后继节点,把状态标记为Node.PROPAGATE(-3),这个过程简单理解为"链式反应唤醒"。

    复杂的setHeadAndPropagate方法

    这里再重点分析一下setHeadAndPropagate()方法的实现,个人认为这是AQS里面的一个比较"烧脑"的方法,复杂不在于它自身的逻辑,而在于它需要结合doAcquireSharedInterruptibly()方法中的死循环和doReleaseShared()方法中的死循环来推演。先看看setHeadAndPropagate()方法的源码 :

    private void setHeadAndPropagate(Node node, int propagate) {
        // 这里的临时变量h存放了旧的头节点引用
        Node h = head; // Record old head for check below
        // 这里的输入参数node基本上就是原来旧头节点的后继节点,而propagate的值来源于tryAcquireShared(),由图中可知propagate >= 0 恒成立 
        setHead(node);
        /*
            * Try to signal next queued node if:
            *   Propagation was indicated by caller,
            *     or was recorded (as h.waitStatus either before
            *     or after setHead) by a previous operation
            *     (note: this uses sign-check of waitStatus because
            *      PROPAGATE status may transition to SIGNAL.)
            * and
            *   The next node is waiting in shared mode,
            *     or we don't know, because it appears null
            *
            * The conservatism in both of these checks may cause
            * unnecessary wake-ups, but only when there are multiple
            * racing acquires/releases, so most need signals now or soon
            * anyway.
            */

        // 这里是一个很复杂的IF条件,下文一个一个条件看
        if (propagate > 0 || h == null || h.waitStatus < 0 ||
            (h = head) == null || h.waitStatus < 0) {
            Node s = node.next;
            if (s == null || s.isShared())
                doReleaseShared();
        }
    }

    // 设置头节点,输入节点的前驱节点和持有线程实例都会置空,因为它持有的线程实例已经从shouldParkAfterFailedAcquire()中解除阻塞
    private void setHead(Node node) {
        head = node;
        node.thread = null;
        node.prev = null;
    }

    // 判断当获取资源失败的时候是否应该阻塞当前处理的节点中的线程实例
    // node为当前处理的节点或者新入队的节点
    // pred则为node的前驱节点
    private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
        // 前驱节点的状态值
        int ws = pred.waitStatus;
        // 前驱节点处于Node.SIGNAL(-1)状态,说明当前节点可以唤醒,返回true以便调用在下一轮循环进入setHeadAndPropagate()方法
        if (ws == Node.SIGNAL)
            /*
                * This node has already set status asking a release
                * to signal it, so it can safely park.
                */

            return true;
        // 状态值大于0,说明当前处理的节点的前驱节点处于取消状态,则需要跳过这些取消状态的前驱节点
        if (ws > 0) {
            /*
                * Predecessor was cancelled. Skip over predecessors and
                * indicate retry.
                */

            do {
                node.prev = pred = pred.prev;
            } while (pred.waitStatus > 0);
            pred.next = node;
        } else {
            /*
                * waitStatus must be 0 or PROPAGATE.  Indicate that we
                * need a signal, but don't park yet.  Caller will need to
                * retry to make sure it cannot acquire before parking.
                */

            // 剩下的就是其他情况,初始化状态0或者无条件传播状态Node.PROPAGATE(-3),这两种情况把前驱节点状态CAS更新为Node.SIGNAL(-1),表明当前节点可以被唤醒
            pred.compareAndSetWaitStatus(ws, Node.SIGNAL);
        }
        return false;
    }

    setHeadAndPropagate()中有一个极度复杂的IF分支判断,紧记propagate代表的是tryAcquireShared()的返回值:

    补充一下:共享模式下(其实独占模式下也是这样)在不满足唤醒条件的前提下,一个全新的节点加入到AQS的同步等待队列中,在doAcquireSharedInterruptibly()需要经历两轮循环才会成功阻塞:

    如果至此,还有不理解"PROPAGATE"这个词的含义的话,可以看看下图的状态推演,笔者把节点状态变更的细节全部标明:

    最后,带超时阻塞的doAcquireSharedNanos()方法实现思路其实差不多,只是在获取资源的循环体中会判断阻塞是否超越了输入的期限,超时的节点应用cancelAcquire(node)更变状态为取消,然后阻塞线程的方法使用了LockSupport.parkNanos()

    唤醒顺序的问题

    基于上面的源码分析,总结一下CountDownLatch对于阻塞线程唤醒的顺序,如果达到唤醒条件后:

    AQS的数据结构是CLH锁队列的变体,毕竟是队列数据结构,所以阻塞节点的出队(解除阻塞)也遵循于「FIFO」的特性。节点持有线程解除阻塞后的执行顺序,有可能会和预期不一样,这是因为很多时候线程解除阻塞之后,会参与其他类型的锁竞争,例如System.out.println()方法,本质是一个同步方法,后解除阻塞的线程有可能先获取到锁并且先执行。这里可以取巧用其他手段监测一下这些阻塞线程的解除阻塞的顺序,例如在LockSupport.unpark()方法做一个埋点,可以应用Instrumentation和字节码增强工具。先引入Javassist

    <dependency>
        <groupId>org.javassist</groupId>
        <artifactId>javassist</artifactId>
        <version>3.27.0-GA</version>
    </dependency>

    编写一个Agent

    public class AqsAgent {

        private static final byte[] NO_TRANSFORM = null;

        public static void premain(final String agentArgs, @NonNull final Instrumentation inst) {
            inst.addTransformer(new LockSupportClassFileTransformer(), true);
        }

        private static class LockSupportClassFileTransformer implements ClassFileTransformer {

            @Override
            public byte[] transform(ClassLoader loader,
                                    String classFileName,
                                    Class<?> classBeingRedefined,
                                    ProtectionDomain protectionDomain,
                                    byte[] classfileBuffer) throws IllegalClassFormatException {
                String className = toClassName(classFileName);
                if (className.contains("concurrent")) {
                    System.out.println("正在处理:" + className);
                }
                if (className.equals("java.util.concurrent.locks.AbstractQueuedSynchronizer")) {
                    return processTransform(loader, classfileBuffer);
                }
                return NO_TRANSFORM;
            }
        }

        private static byte[] processTransform(ClassLoader loader, byte[] classfileBuffer) {
            try {
                final ClassPool classPool = new ClassPool(true);
                if (loader == null) {
                    classPool.appendClassPath(new LoaderClassPath(ClassLoader.getSystemClassLoader()));
                } else {
                    classPool.appendClassPath(new LoaderClassPath(loader));
                }
                final CtClass clazz = classPool.makeClass(new ByteArrayInputStream(classfileBuffer), false);
                clazz.defrost();
                final CtClass paramClass = clazz.getClassPool().get("java.util.concurrent.locks.AbstractQueuedSynchronizer$Node");
                final CtMethod unparkMethod = clazz.getDeclaredMethod("unparkSuccessor"new CtClass[]{paramClass});
                unparkMethod.insertBefore("{java.lang.Object x = $1;\n" +
                        "            java.lang.reflect.Field nextField = Class.forName(\"java.util.concurrent.locks.AbstractQueuedSynchronizer$Node\").getDeclaredField(\"next\");\n" +
                        "            java.lang.reflect.Field threadField = Class.forName(\"java.util.concurrent.locks.AbstractQueuedSynchronizer$Node\").getDeclaredField(\"thread\");\n" +
                        "            nextField.setAccessible(true);\n" +
                        "            threadField.setAccessible(true);\n" +
                        "            java.lang.Object next = nextField.get(x);\n" +
                        "            if (null != next){" +
                        "java.lang.Object thread = threadField.get(next);\n" +
                        "System.out.println(\"当前解除阻塞的线程名称为:\"+ thread);\n" +
                        "}\n" +
                        "}");
                return clazz.toBytecode();
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }

        private static String toClassName(@NonNull final String classFileName) {
            return classFileName.replace('/''.');
        }
    }

    文章开头提到的例子里面,在VM参数添加-javaagent:I:\J-Projects\lock-support-agent\target\lock-support-agent.jar引入这个做好的Agent,再运行一次,结果如下:

    线程thirdThread准备调用await方法......
    线程firstThread准备调用await方法......
    线程secondThread准备调用await方法......
    main线程释放CountDownLatch......
    当前解除阻塞的线程名称为:Thread[thirdThread,5,main]
    当前解除阻塞的线程名称为:Thread[firstThread,5,main]
    当前解除阻塞的线程名称为:Thread[secondThread,5,main]
    线程thirdThread解除阻塞继续运行......
    线程firstThread解除阻塞继续运行......
    线程secondThread解除阻塞继续运行......

    可见唤醒和阻塞的顺序是完全一致的,印证了前面的源码分析过程。

    CountDownLatch实战

    除了一些个人项目或者Demo,笔者在生产环境中只在少量场景应用过CountDownLatch,其中一个场景是数据迁移的「生产者-消费者线程模型控制模块」。这个方案在前公司迁移一张存量一亿多的图片信息记录表,拆分到单库128张新建的图片信息表中,过程一共用了几个小时(主要瓶颈在带宽和写操作,因为当时数据库机器的磁盘为机械硬盘,带宽有限并且需要保证业务稳定的情况下只能尽可能减少写线程的数量)。伪代码如下:

    public class PhotoMigration {
        
        // 毒丸对象,参考自《Java并发编程实战》的7.2.3一节
        static final List<Photo> POISON = Collections.emptyList();
        static final int WORKER_COUNT = 10;
        static final BlockingQueue<List<Photo>> QUEUE = new ArrayBlockingQueue<>(WORKER_COUNT * 10);

        public void process() throws Exception {
            CountDownLatch latch = new CountDownLatch(WORKER_COUNT);
            ThreadPoolExecutor executor = new ThreadPoolExecutor(WORKER_COUNT, WORKER_COUNT, 0, TimeUnit.SECONDS, new LinkedBlockingQueue<>(), new ThreadFactory() {

                private final AtomicInteger counter = new AtomicInteger();

                @Override
                public Thread newThread(@NotNull Runnable r) {
                    Thread thread = new Thread(r);
                    thread.setDaemon(true);
                    thread.setName("MigrationWorker-" + counter.getAndIncrement());
                    return thread;
                }
            });
            long id = 0;
            for (; ; ) {
                List<Photo> photos = photoDao.selectByIdScrollingPagination(id, 500);
                if (photos.isEmpty()) {
                    for(int i = 0; i < WORKER_COUNT; i++){
                        QUEUE.put(POISON);
                    }
                    break;
                } else {
                    QUEUE.put(photos);
                    id = photos.stream().map(Photo::getId).max(Long::compareTo).orElse(Long.MAX_VALUE);
                }
            }
            // 解除主线程阻塞
            latch.await();
            EventBus.send([发布迁移完成事件]);
            // 关闭线程池
            executor.shutdown();
        }

        @RequiredArgsConstructor
        private static class Task implements Runnable {

            private final CountDownLatch latch;

            @Override
            public void run() {
                for (; ; ) {
                    try {
                        List<Photo> photos = QUEUE.take();
                        if (POISON == photos) {
                            latch.countDown();
                            break;
                        }
                        // 执行迁移逻辑和入库
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                    }
                }
            }
        }

        @Data
        private static class Photo {

            private Long id;
        }
    }

    上面的伪代码基本可以使用在任意的数据迁移场景,可以动态调整查询线程和写线程的数量(这里只是简单描述了一个查询线程多个写线程的伪代码,实践上应用了多线程查和多线程写,笔者之前为了简单起见,提前对需要迁移的数据进行了ID分段,查询线程已经提前规划好需要查询的ID段),如果做得更加完善,甚至可以控制每批次处理的数量、异常记录和恢复等等,「方案的灵感来源于Java并发编程圣经《Java并发编程实战》中的第五章」。里面的组件可以做改良,升级为多进程版本,例如保证单进程做查询的前提下,把队列替换为RabbitMQ中的"队列"概念,把Worker调整为RabbitMQ的消费者即可。原则上,存在"等待某些操作完成之后执行其他操作"的场景,可以考虑使用CountDownLatch

    小结

    CountDownLatch算是JUC包中实现思路相对简单的一个组件,不过在使用的时候需要注意几个事项:

    至此,CountDownLatch的基本使用和源码分析基本结束,AQS的状态变更和状态判断,大量的死循环和复杂的条件判断看起来真是让人觉得烧脑,但毫无疑问这是一个比较优秀的并发工具,推荐使用于下面的类似场景:

    这里给出文中编写好的Agent的仓库:

    (本文完 c-5-d e-a-20200831 耗费了大量时间作图、DEGUG和编写Agent)


    本文分享自微信公众号 - Throwable文摘(throwable-doge)。
    如有侵权,请联系 support@oschina.cn 删除。
    本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。

    09-03 13:26