Skip to content

Latest commit

 

History

History
913 lines (733 loc) · 28.7 KB

9_并发实践.md

File metadata and controls

913 lines (733 loc) · 28.7 KB

并发实践

线程不安全示例

如果多个线程对同一个共享数据进行访问而不采取同步操作的话,那么操作的结果是不一致的。

以下代码演示了 1000 个线程同时对 cnt 执行自增操作,操作结束之后它的值有可能小于 1000。

public class ThreadUnsafeExample {

    private int cnt = 0;

    public void add() {
        cnt++;
    }

    public int get() {
        return cnt;
    }
}
public static void main(String[] args) throws InterruptedException {
    final int threadSize = 1000;
    ThreadUnsafeExample example = new ThreadUnsafeExample();
    final CountDownLatch countDownLatch = new CountDownLatch(threadSize);
    ExecutorService executorService = Executors.newCachedThreadPool();
    for (int i = 0; i < threadSize; i++) {
        executorService.execute(() -> {
            example.add();
            countDownLatch.countDown();
        });
    }
    countDownLatch.await();
    executorService.shutdown();
    System.out.println(example.get());
}

输出结果:

998 //如果线程是安全的最终结果就是 1000

解决线程安全问题:

  • 方案一:使用 AtomicInteger(实际上非阻塞同步)

    public class ThreadSafeExample {
        // cnt 的初始值为 0。
        private AtomicInteger cnt = new AtomicInteger(0);
    
        public void add() {
            cnt.incrementAndGet();
        }
    
        public int get() {
            return cnt.get();
        }
    
        public static void main(String[] args) throws InterruptedException {
            final int threadSize = 1000;
            ThreadSafeExample example = new ThreadSafeExample();
            final CountDownLatch countDownLatch = new CountDownLatch(threadSize);
            ExecutorService executorService = Executors.newCachedThreadPool();
            for (int i = 0; i < threadSize; i++) {
                executorService.execute(() -> {
                    example.add();
                    countDownLatch.countDown();
                });
            }
            countDownLatch.await();
            executorService.shutdown();
            System.out.println(example.get());
        }
    }

    输出结果:

    1000
  • 方案二:使用 synchronized(实际上是阻塞同步/互斥同步)

    public class ThreadSafeExample2 {
        // cnt 的初始值为 0。
        private int cnt = 0;
    
        public synchronized void add() {
            cnt++;
        }
    
        public synchronized int get() {
            return cnt;
        }
    
        public static void main(String[] args) throws InterruptedException {
            final int threadSize = 1000;
            ThreadSafeExample2 example = new ThreadSafeExample2();
            final CountDownLatch countDownLatch = new CountDownLatch(threadSize);
            ExecutorService executorService = Executors.newCachedThreadPool();
            for (int i = 0; i < threadSize; i++) {
                executorService.execute(() -> {
                    example.add();
                    countDownLatch.countDown();
                });
            }
            countDownLatch.await();
            executorService.shutdown();
            System.out.println(example.get());
        }
    }

    输出结果:

    1000
  • 方案三:使用 Reentrant(实际上是阻塞同步/互斥同步)

    public class ThreadSafeExample3 {
        // cnt 的初始值为 0。
        private int cnt = 0;
    
        private ReentrantLock lock = new ReentrantLock();
    
        public void add() {
            try{
                lock.lock();
                cnt++;
            }finally {
                lock.unlock();
            }
        }
    
        public int get() {
            try{
                lock.lock();
                return cnt;
            }finally {
                lock.unlock();
            }
        }
    
        public static void main(String[] args) throws InterruptedException {
            final int threadSize = 1000;
            ThreadSafeExample3 example = new ThreadSafeExample3();
            final CountDownLatch countDownLatch = new CountDownLatch(threadSize);
            ExecutorService executorService = Executors.newCachedThreadPool();
            for (int i = 0; i < threadSize; i++) {
                executorService.execute(() -> {
                    example.add();
                    countDownLatch.countDown();
                });
            }
            countDownLatch.await();
            executorService.shutdown();
            System.out.println(example.get());
        }
    }

    输出结果:

    1000

线程安全

多个线程不管以何种方式访问某个类,并且在主调代码中不需要进行同步,都能表现正确的行为。

Servlet 存在线程安全问题,原因如下:

Servlet 在 Tomcat 中是以单例模式存在的,也就是说,当客户端第一次请求 Servlet 时,Tomcat 会根据 web.xml 生成对应的 Servlet 实例,当客户端第二次请求的是同一个 Servlet,Tomcat 不会再生成新的TServlet。不同请求共享同一 Servlet 实例的资源,从而导致线程不安全的问题

多个线程不管以何种方式访问某个类,并且在主调代码中不需要进行同步,都能表现正确的行为。

线程安全有以下几种实现方式:

方式一:不可变

不可变(Immutable)的对象一定是线程安全的,不需要再采取任何的线程安全保障措施。只要一个不可变的对象被正确地构建出来,永远也不会看到它在多个线程之中处于不一致的状态。多线程环境下,应当尽量使对象成为不可变,来满足线程安全。

不可变的类型:

  • final 关键字修饰的基本数据类型
  • String
  • 枚举类型
  • Number 部分子类,如 Long 和 Double 等数值包装类型,BigInteger 和 BigDecimal 等大数据类型。但同为 Number 的原子类 AtomicInteger 和 AtomicLong 则是可变的。

对于集合类型,可以使用 Collections.unmodifiableXXX() 方法来获取一个不可变的集合。

public class ImmutableExample {
    public static void main(String[] args) {
        Map<String, Integer> map = new HashMap<>();
        Map<String, Integer> unmodifiableMap = Collections.unmodifiableMap(map);
        unmodifiableMap.put("a", 1);
    }
}Copy to clipboardErrorCopied
Exception in thread "main" java.lang.UnsupportedOperationException
    at java.util.Collections$UnmodifiableMap.put(Collections.java:1457)
    at ImmutableExample.main(ImmutableExample.java:9)Copy to clipboardErrorCopied

Collections.unmodifiableXXX() 先对原始的集合进行拷贝,需要对集合进行修改的方法都直接抛出异常。

public V put(K key, V value) {
    throw new UnsupportedOperationException();
}Copy to clipboardErrorCopied

方式二:互斥同步

synchronized 和 ReentrantLock。

方式三:非阻塞同步

原子操作类(Unsafe 的 CAS 操作)。

方式四:无同步方案

要保证线程安全,并不是一定就要进行同步。如果一个方法本来就不涉及共享数据,那它自然就无须任何同步措施去保证正确性。

1. 栈封闭

多个线程访问同一个方法的局部变量时,不会出现线程安全问题,因为局部变量存储在虚拟机栈中,属于线程私有的。

public class StackClosedExample {
    public void add100() {
        int cnt = 0;
        for (int i = 0; i < 100; i++) {
            cnt++;
        }
        System.out.println(cnt);
    }
}Copy to clipboardErrorCopied
public static void main(String[] args) {
    StackClosedExample example = new StackClosedExample();
    ExecutorService executorService = Executors.newCachedThreadPool();
    executorService.execute(() -> example.add100());
    executorService.execute(() -> example.add100());
    executorService.shutdown();
}Copy to clipboardErrorCopied
100
100Copy to clipboardErrorCopied

2. 线程本地存储(Thread Local Storage)

如果一段代码中所需要的数据必须与其他代码共享,那就看看这些共享数据的代码是否能保证在同一个线程中执行。如果能保证,我们就可以把共享数据的可见范围限制在同一个线程之内,这样,无须同步也能保证线程之间不出现数据争用的问题。

符合这种特点的应用并不少见,大部分使用消费队列的架构模式(如“生产者-消费者”模式)都会将产品的消费过程尽量在一个线程中消费完。其中最重要的一个应用实例就是经典 Web 交互模型中的“一个请求对应一个服务器线程”(Thread-per-Request)的处理方式,这种处理方式的广泛应用使得很多 Web 服务端应用都可以使用线程本地存储来解决线程安全问题。

可以使用 java.lang.ThreadLocal 类来实现线程本地存储功能。

对于以下代码,thread1 中设置 threadLocal 为 1,而 thread2 设置 threadLocal 为 2。过了一段时间之后,thread1 读取 threadLocal 依然是 1,不受 thread2 的影响。

public class ThreadLocalExample {
    public static void main(String[] args) {
        ThreadLocal threadLocal = new ThreadLocal();
        Thread thread1 = new Thread(() -> {
            threadLocal.set(1);
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(threadLocal.get());
            threadLocal.remove();
        });
        Thread thread2 = new Thread(() -> {
            threadLocal.set(2);
            threadLocal.remove();
        });
        thread1.start();
        thread2.start();
    }
}Copy to clipboardErrorCopied
1Copy to clipboardErrorCopied

为了理解 ThreadLocal,先看以下代码:

public class ThreadLocalExample1 {
    public static void main(String[] args) {
        ThreadLocal threadLocal1 = new ThreadLocal();
        ThreadLocal threadLocal2 = new ThreadLocal();
        Thread thread1 = new Thread(() -> {
            threadLocal1.set(1);
            threadLocal2.set(1);
        });
        Thread thread2 = new Thread(() -> {
            threadLocal1.set(2);
            threadLocal2.set(2);
        });
        thread1.start();
        thread2.start();
    }
}Copy to clipboardErrorCopied

它所对应的底层结构图为:

img

每个 Thread 都有一个 ThreadLocal.ThreadLocalMap 对象。

/* ThreadLocal values pertaining to this thread. This map is maintained
 * by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;Copy to clipboardErrorCopied

当调用一个 ThreadLocal 的 set(T value) 方法时,先得到当前线程的 ThreadLocalMap 对象,然后将 ThreadLocal->value 键值对插入到该 Map 中。

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}Copy to clipboardErrorCopied

get() 方法类似。

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}Copy to clipboardErrorCopied

ThreadLocal 从理论上讲并不是用来解决多线程并发问题的,因为根本不存在多线程竞争。

在一些场景 (尤其是使用线程池) 下,由于 ThreadLocal.ThreadLocalMap 的底层数据结构导致 ThreadLocal 有内存泄漏的情况,应该尽可能在每次使用 ThreadLocal 后手动调用 remove(),以避免出现 ThreadLocal 经典的内存泄漏甚至是造成自身业务混乱的风险。

线程间通信

当多个线程可以一起工作去解决某个问题时,如果某些部分必须在其它部分之前完成,那么就需要对线程进行协调。

join()

在线程中调用另一个线程的 join() 方法,会将当前线程挂起,而不是忙等待,直到目标线程结束。

对于以下代码,虽然 b 线程先启动,但是因为在 b 线程中调用了 a 线程的 join() 方法,b 线程会等待 a 线程结束才继续执行,因此最后能够保证 a 线程的输出先于 b 线程的输出。

public class JoinExample {

    private class A extends Thread {
        @Override
        public void run() {
            System.out.println("A");
        }
    }

    private class B extends Thread {

        private A a;

        B(A a) {
            this.a = a;
        }

        @Override
        public void run() {
            try {
                a.join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println("B");
        }
    }

    public void test() {
        A a = new A();
        B b = new B(a);
        b.start();
        a.start();
    }
}

public static void main(String[] args) {
    JoinExample example = new JoinExample();
    example.test();
}
A
B

wait() notify() notifyAll()

调用 wait() 使得线程等待某个条件满足,线程在等待时会被挂起,当其他线程的运行使得这个条件满足时,其它线程会调用 notify() 或者 notifyAll() 来唤醒挂起的线程。

它们都属于 Object 的一部分,而不属于 Thread。

只能用在同步方法或者同步控制块中使用,否则会在运行时抛出 IllegalMonitorStateException。

使用 wait() 挂起期间,线程会释放锁。这是因为,如果没有释放锁,那么其它线程就无法进入对象的同步方法或者同步控制块中,那么就无法执行 notify() 或者 notifyAll() 来唤醒挂起的线程,造成死锁。

public class WaitNotifyExample {

    public synchronized void before() {
        System.out.println("before");
        notifyAll();
    }

    public synchronized void after() {
        try {
            wait();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("after");
    }
}

public static void main(String[] args) {
    ExecutorService executorService = Executors.newCachedThreadPool();
    WaitNotifyExample example = new WaitNotifyExample();
    executorService.execute(() -> example.after());
    executorService.execute(() -> example.before());
}
before
after

await() signal() signalAll()

java.util.concurrent 类库中提供了 Condition 类来实现线程之间的协调,可以在 Condition 上调用 await() 方法使线程等待,其它线程调用 signal() 或 signalAll() 方法唤醒等待的线程。

相比于 wait() 这种等待方式,await() 可以指定等待的条件,因此更加灵活。

使用 Lock 来获取一个 Condition 对象。

public class AwaitSignalExample {

    private Lock lock = new ReentrantLock();
    private Condition condition = lock.newCondition();

    public void before() {
        lock.lock();
        try {
            System.out.println("before");
            condition.signalAll();
        } finally {
            lock.unlock();
        }
    }

    public void after() {
        lock.lock();
        try {
            condition.await();
            System.out.println("after");
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            lock.unlock();
        }
    }
}

public static void main(String[] args) {
    ExecutorService executorService = Executors.newCachedThreadPool();
    AwaitSignalExample example = new AwaitSignalExample();
    executorService.execute(() -> example.after());
    executorService.execute(() -> example.before());
}
before
after

实现生产者和消费者

1. wait() / notifyAll()

public class ProducerConsumer {
    private static class Producer implements Runnable{
        private List<Integer> list;
        private int capacity;

        public Producer(List list, int capacity) {
            this.list = list;
            this.capacity = capacity;
        }

        @Override
        public void run() {
            while (true) {
                synchronized (list) {
                    try {
                        String producer = Thread.currentThread().getName();
                        while (list.size() == capacity) {
                            System.out.println("生产者 " + producer +
                                    "  list 已达到最大容量,进行 wait");
                            list.wait();
                            System.out.println("生产者 " + producer +
                                    "  退出 wait");
                        }
                        Random random = new Random();
                        int i = random.nextInt();
                        System.out.println("生产者 " + producer +
                                " 生产数据" + i);
                        list.add(i);
                        list.notifyAll();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
    }

    private static class Consumer implements Runnable{
        private List<Integer> list;

        public Consumer(List list) {
            this.list = list;
        }

        @Override
        public void run() {
            while (true) {
                synchronized (list) {
                    try {
                        String consumer = Thread.currentThread().getName();
                        while (list.isEmpty()) {
                            System.out.println("消费者 " + consumer +
                                    "  list 为空,进行 wait");
                            list.wait();
                            System.out.println("消费者 " + consumer +
                                    "  退出wait");
                        }
                        Integer element = list.remove(0);
                        System.out.println("消费者 " + consumer +
                                "  消费数据:" + element);
                        list.notifyAll();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
    }

    public static void main(String[] args) {
        final LinkedList linkedList = new LinkedList();
        final int capacity = 5;
        Thread producer = new Thread(new Producer(linkedList,capacity),"producer");
        Thread consumer = new Thread(new Consumer(linkedList),"consumer");

        consumer.start();
        producer.start();
    }
}

输出结果:

生产者 producer 生产数据-1652445373
生产者 producer 生产数据1234295578
生产者 producer 生产数据-1885445180
生产者 producer 生产数据864400496
生产者 producer 生产数据621858426
生产者 producer  list 已达到最大容量,进行 wait
消费者 consumer  消费数据:-1652445373
消费者 consumer  消费数据:1234295578
消费者 consumer  消费数据:-1885445180
消费者 consumer  消费数据:864400496
消费者 consumer  消费数据:621858426
消费者 consumer  list 为空,进行 wait
生产者 producer  退出 wait

2. await() / sigalAll()

public class ProducerConsumer {
    private static ReentrantLock lock = new ReentrantLock();

    private static Condition full = lock.newCondition();
    private static Condition empty = lock.newCondition();

    private static class Producer implements Runnable{
        private List<Integer> list;
        private int capacity;

        public Producer(List list, int capacity) {
            this.list = list;
            this.capacity = capacity;
        }

        @Override
        public void run() {
            while (true) {
                try {
                    lock.lock();
                    String producer = Thread.currentThread().getName();
                    while (list.size() == capacity) {
                        System.out.println("生产者 " + producer +
                                           "  list 已达到最大容量,进行 wait");
                        full.await();
                        System.out.println("生产者 " + producer +
                                           "  退出 wait");
                    }
                    Random random = new Random();
                    int i = random.nextInt();
                    System.out.println("生产者 " + producer +
                                       " 生产数据" + i);
                    list.add(i);
                    empty.signalAll();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } finally {
                    lock.unlock();
                }
            }
        }
    }

    private static class Consumer implements Runnable{
        private List<Integer> list;

        public Consumer(List list) {
            this.list = list;
        }

        @Override
        public void run() {
            while (true) {
                try {
                    lock.lock();
                    String consumer = Thread.currentThread().getName();
                    while (list.isEmpty()) {
                        System.out.println("消费者 " + consumer +
                                           "  list 为空,进行 wait");
                        empty.await();
                        System.out.println("消费者 " + consumer +
                                           "  退出wait");
                    }
                    Integer element = list.remove(0);
                    System.out.println("消费者 " + consumer +
                                       "  消费数据:" + element);
                    full.signalAll();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } finally {
                    lock.unlock();
                }
            }
        }
    }

    public static void main(String[] args) {
        final LinkedList linkedList = new LinkedList();
        final int capacity = 5;
        Thread producer = new Thread(new Producer(linkedList,capacity),"producer");
        Thread consumer = new Thread(new Consumer(linkedList),"consumer");

        consumer.start();
        producer.start();
    }
}

输出结果:

生产者 producer 生产数据-1748993481
生产者 producer 生产数据-131075825
生产者 producer 生产数据-683676621
生产者 producer 生产数据1543722525
生产者 producer 生产数据804266076
生产者 producer  list 已达到最大容量,进行 wait
消费者 consumer  消费数据:-1748993481
消费者 consumer  消费数据:-131075825
消费者 consumer  消费数据:-683676621
消费者 consumer  消费数据:1543722525
消费者 consumer  消费数据:804266076
消费者 consumer  list 为空,进行 wait
生产者 producer  退出 wait

3. 阻塞队列

public class ProducerConsumer {

    private static class Producer implements Runnable{
        private BlockingQueue<Integer> queue;

        public Producer(BlockingQueue<Integer> queue) {
            this.queue = queue;
        }

        @Override
        public void run() {
            while (true) {
                try {
                    String producer = Thread.currentThread().getName();
                    Random random = new Random();
                    int i = random.nextInt();
                    System.out.println("生产者 " + producer +
                            " 生产数据" + i);
                    queue.put(i);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    private static class Consumer implements Runnable{
        private BlockingQueue<Integer> queue;

        public Consumer(BlockingQueue<Integer> queue) {
            this.queue =queue;
        }

        @Override
        public void run() {
            while (true) {
                try {
                    String consumer = Thread.currentThread().getName();
                    Integer element = queue.take();
                    System.out.println("消费者 " + consumer +
                            "  消费数据:" + element);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    public static void main(String[] args) {
        final BlockingQueue<Integer> queue = new LinkedBlockingQueue<>();
        Thread producer = new Thread(new Producer(queue),"producer");
        Thread consumer = new Thread(new Consumer(queue),"consumer");

        consumer.start();
        producer.start();
    }
}

输出结果:

生产者producer生产数据-222876564
消费者consumer正在消费数据-906876105
生产者producer生产数据-9385856
消费者consumer正在消费数据1302744938
生产者producer生产数据-177925219
生产者producer生产数据-881052378
生产者producer生产数据-841780757
生产者producer生产数据-1256703008
消费者consumer正在消费数据1900668223
消费者consumer正在消费数据2070540191
消费者consumer正在消费数据1093187
消费者consumer正在消费数据6614703
消费者consumer正在消费数据-1171326759

LeetCode 上多线程编程问题

1114. 按序打印

我们提供了一个类:

public class Foo {
  public void one() { print("one"); }
  public void two() { print("two"); }
  public void three() { print("three"); }
}

三个不同的线程将会共用一个 Foo 实例。

  • 线程 A 将会调用 one() 方法
  • 线程 B 将会调用 two() 方法
  • 线程 C 将会调用 three() 方法

请设计修改程序,以确保 two() 方法在 one() 方法之后被执行,three() 方法在 two() 方法之后被执行。

示例 1:

输入: [1,2,3]
输出: "onetwothree"
解释: 
有三个线程会被异步启动。
输入 [1,2,3] 表示线程 A 将会调用 one() 方法,线程 B 将会调用 two() 方法,线程 C 将会调用 three() 方法。
正确的输出是 "onetwothree"。

示例 2:

输入: [1,3,2]
输出: "onetwothree"
解释: 
输入 [1,3,2] 表示线程 A 将会调用 one() 方法,线程 B 将会调用 three() 方法,线程 C 将会调用 two() 方法。
正确的输出是 "onetwothree"。

注意:

尽管输入中的数字似乎暗示了顺序,但是我们并不保证线程在操作系统中的调度顺序。

你看到的输入格式主要是为了确保测试的全面性。

import java.util.concurrent.CountDownLatch;

//解法一:
class Foo {

    private CountDownLatch countDownLatchA; //等待 A 线程执行完
    private CountDownLatch countDownLatchB; //等待 B 线程执行完

    public Foo() {
        countDownLatchA = new CountDownLatch(1); //只等待一个线程,即 A 线程
        countDownLatchB = new CountDownLatch(1); //只等待一个线程,即 B 线程
    }

    public void first(Runnable printFirst) throws InterruptedException {

        // printFirst.run() outputs "first". Do not change or remove this line.
        printFirst.run();
        countDownLatchA.countDown();
        //减去 1 后,cnt =0 那些因为调用 await() 方法而在等待的线程就会被唤醒。
    }

    public void second(Runnable printSecond) throws InterruptedException {
        countDownLatchA.await();
        // printSecond.run() outputs "second". Do not change or remove this line.
        printSecond.run();
        countDownLatchB.countDown();
        //减去 1 后,cnt =0 那些因为调用 await() 方法而在等待的线程就会被唤醒。
    }

    public void third(Runnable printThird) throws InterruptedException {
        countDownLatchB.await();
        // printThird.run() outputs "third". Do not change or remove this line.
        printThird.run();
    }
}

多线程良好开发实践

  • 给线程起个有意义的名字,这样可以方便找 Bug。
  • 缩小同步范围,从而减少锁争用。例如对于 synchronized,应该尽量使用同步块而不是同步方法。
  • 多用同步工具少用 wait() 和 notify()。首先,CountDownLatch, CyclicBarrier, Semaphore 和 Exchanger 这些同步类简化了编码操作,而用 wait() 和 notify() 很难实现复杂控制流;其次,这些同步类是由最好的企业编写和维护,在后续的 JDK 中还会不断优化和完善。
  • 使用 BlockingQueue 实现生产者消费者问题。
  • 多用并发集合少用同步集合,例如应该使用 ConcurrentHashMap 而不是 Hashtable。
  • 使用本地变量和不可变类来保证线程安全。
  • 使用线程池而不是直接创建线程,这是因为创建线程代价很高,线程池可以有效地利用有限的线程来启动任务。