Skip to content

Instantly share code, notes, and snippets.

@monzou
Last active June 13, 2019 07:51
Show Gist options
  • Save monzou/ee3d813a4fdb2c0d7a0ebb91fab28a78 to your computer and use it in GitHub Desktop.
Save monzou/ee3d813a4fdb2c0d7a0ebb91fab28a78 to your computer and use it in GitHub Desktop.
RedisLockManager
package sandbox;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.TimeoutException;
/**
* グローバルなロックを司るコンポーネントです。
*/
public interface LockManager {
/**
* グローバルなロックを確保します。
* ロックが確保されるまで待機します。
*
* @param lockName ロック名称
* @param expire ロックの保持期間
* @return ロック
* @throws InterruptedException ロックの確保中に割り込まれた場合にスローされます
*/
LockState lock(String lockName, Duration expire) throws InterruptedException;
/**
* グローバルなロックを確保します。
*
* @param lockName ロック名称
* @param expire ロックの保持期間
* @param timeout ロック確保のタイムアウト時間
* @return ロック
* @throws InterruptedException ロックの確保中に割り込まれた場合にスローされます
* @throws TimeoutException ロックの確保がタイムアウトした場合にスローされます
*/
LockState lock(String lockName, Duration expire, Duration timeout) throws InterruptedException, TimeoutException;
/**
* グローバルなロックの確保を試みます。
* ロックが確保できなければ即座に <code>Optional.empty()</code> を返します。
*
* @param lockName ロック名称
* @param expire ロックの保持期間
* @return ロック
*/
Optional<LockState> tryLock(String lockName, Duration expire);
}
package sandbox;
/**
* ロック状態です。
*/
public interface LockState {
/**
* ロックを解放します。
*
* @return 解放した場合は <code>true</code>
*/
boolean unlock();
}
package sandbox;
import sandbox.LockManager;
import sandbox.LockState;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
import org.springframework.data.redis.core.RedisConnectionUtils;
import org.springframework.stereotype.Component;
import redis.clients.jedis.Jedis;
import javax.inject.Inject;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;
import java.util.stream.Collectors;
@Component
public class RedisLockManager implements LockManager {
private static final int ACQUIRE_INTERVAL_MSEC = 100;
private static final String DELETE_COMMAND;
static {
List<String> command = new ArrayList<>();
command.add("if redis.call('get', KEYS[1]) == ARGV[1] then");
command.add(" return redis.call('del', KEYS[1])");
command.add("else");
command.add(" return 0");
command.add("end");
DELETE_COMMAND = command.stream().collect(Collectors.joining("\n"));
}
private final JedisConnectionFactory master;
private final Lock lock;
@Inject
public RedisLockManager(@Qualifier("jedisConnectionFactory") JedisConnectionFactory master) {
this.master = master;
this.lock = new ReentrantLock(false);
}
@Override
public LockState lock(String lockName, Duration expire) throws InterruptedException {
try {
return lock(lockName, expire, null);
} catch (TimeoutException e) {
// 有り得ない
throw new RuntimeException(e);
}
}
@Override
public LockState lock(String lockName, Duration expire, Duration timeout) throws InterruptedException, TimeoutException {
lock.lock();
try {
Instant timeoutAt = timeout == null ? null : Instant.now().plus(timeout);
do {
Optional<LockState> state = execute(jedis -> tryLock(jedis, lockName, expire));
if (state.isPresent()) {
return state.get();
}
if (timeoutAt != null && timeoutAt.isBefore(Instant.now())) {
throw new TimeoutException("Failed to acquire the lock due to timed out: " + lockName);
}
Thread.sleep(ACQUIRE_INTERVAL_MSEC);
} while (true);
} finally {
lock.unlock();
}
}
@Override
public Optional<LockState> tryLock(String lockName, Duration expire) {
lock.lock();
try {
return execute(jedis -> tryLock(jedis, lockName, expire));
} finally {
lock.unlock();
}
}
private Optional<LockState> tryLock(Jedis jedis, String lockName, Duration expire) {
String token = UUID.randomUUID().toString();
if ("OK".equals(jedis.set(lockName, token, "NX", "PX", expire.toMillis()))) {
return Optional.of(new LockStateImpl(lockName, token));
}
return Optional.empty();
}
private <T> T execute(Function<Jedis, T> action) {
RedisConnection conn = RedisConnectionUtils.getConnection(master);
try {
return action.apply((Jedis) conn.getNativeConnection());
} finally {
RedisConnectionUtils.releaseConnection(conn, master);
}
}
private class LockStateImpl implements LockState {
private final String lockName;
private final String token;
private final AtomicBoolean released;
LockStateImpl(String lockName, String token) {
this.lockName = lockName;
this.token = token;
this.released = new AtomicBoolean(false);
}
@Override
public boolean unlock() {
if (released.get()) {
return false;
}
execute(jedis -> jedis.eval(DELETE_COMMAND, Arrays.asList(lockName), Arrays.asList(token)));
return released.compareAndSet(false, true);
}
}
}
package sandbox;
import sandbox.LockState;
import junit.framework.TestCase;
import org.junit.Test;
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
import redis.clients.jedis.Jedis;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.*;
public class RedisLockManagerTest extends TestCase {
private JedisConnectionFactory redis;
private RedisLockManager subject;
/**
* ロックを取得して解放できる
*/
@Test
public void testLockAndUnlock() {
LockState lock = null;
try {
lock = subject.lock(lockName(), Duration.of(5, ChronoUnit.SECONDS));
} catch (InterruptedException e) {
fail();
} finally {
if (lock != null) {
lock.unlock();
}
}
}
/**
* ロックを確保しようとするが既に確保されているためタイムアウトする
*/
@Test
public void testWaitUntilTimeout() {
AtomicBoolean timedOut = new AtomicBoolean(false);
String lockName = lockName();
LockState lock = null;
try {
lock = subject.lock(lockName, Duration.of(5, ChronoUnit.SECONDS));
Thread later = new Thread(() -> {
LockState s = null;
try {
s = subject.lock(lockName, Duration.of(1, ChronoUnit.SECONDS), Duration.of(1, ChronoUnit.SECONDS));
} catch (InterruptedException e) {
fail();
} catch (TimeoutException e) {
timedOut.set(true);
} finally {
if (s != null) {
s.unlock();
}
}
});
later.start();
later.join();
assertTrue(timedOut.get());
} catch (InterruptedException e) {
fail();
} finally {
if (lock != null) {
lock.unlock();
}
}
}
/**
* ロックを確保しようとして待機している時に割り込まれると例外をスローする
*/
@Test
public void testExceptionOnInterrupted() {
AtomicBoolean interrupted = new AtomicBoolean(false);
String lockName = lockName();
LockState lock = null;
try {
lock = subject.lock(lockName, Duration.of(5, ChronoUnit.MINUTES));
Thread later = new Thread(() -> {
LockState s = null;
try {
s = subject.lock(lockName, Duration.of(1, ChronoUnit.SECONDS), Duration.of(1, ChronoUnit.MINUTES));
} catch (InterruptedException e) {
interrupted.set(true);
} catch (TimeoutException e) {
fail();
} finally {
if (s != null) {
s.unlock();
}
}
});
later.start();
later.interrupt();
later.join();
assertTrue(interrupted.get());
} catch (InterruptedException e) {
fail();
} finally {
if (lock != null) {
lock.unlock();
}
}
}
/**
* ロックが解放済みのため取得できる
*/
@Test
public void testLockAfterUnlock() {
Duration expire = Duration.of(5, ChronoUnit.SECONDS);
String lockName = lockName();
LockState lock = null;
try {
lock = subject.lock(lockName, expire);
} catch (InterruptedException e) {
fail();
} finally {
lock.unlock();
}
Optional<LockState> state = subject.tryLock(lockName, expire);
assertTrue(state.isPresent());
state.ifPresent(LockState::unlock);
}
/**
* ロックの解放に失敗している場合でもロックが有効期限切れしているため取得できる
*/
@Test
public void testLockAfterExpired() throws InterruptedException {
Duration expire = Duration.of(3, ChronoUnit.SECONDS);
String lockName = lockName();
subject.lock(lockName, expire); // クライアントが死ぬなどして解放されない場合
// まだ生きている
Thread.sleep(1000);
assertFalse(subject.tryLock(lockName, expire).isPresent());
Thread.sleep(1000);
assertFalse(subject.tryLock(lockName, expire).isPresent());
Thread.sleep(1000);
// ロックが期限切れ
Optional<LockState> state = subject.tryLock(lockName, expire);
assertTrue(state.isPresent());
state.ifPresent(LockState::unlock);
}
/**
* ロックの解放は一度だけ
*/
@Test
public void testUnlockOnlyOnce() {
LockState lock = null;
try {
lock = subject.lock(lockName(), Duration.of(5, ChronoUnit.SECONDS));
} catch (InterruptedException e) {
fail();
} finally {
if (lock != null) {
assertTrue(lock.unlock());
assertFalse(lock.unlock());
}
}
}
/**
* 複数スレッドがロックを取りに行っても問題ないか
*/
public void testDeadLock() {
int nThreads = 100;
AtomicInteger counter = new AtomicInteger(0);
CyclicBarrier barrier = new CyclicBarrier(nThreads + 1);
List<Thread> threads = IntStream.range(0, nThreads).mapToObj(i -> new Thread(() -> {
LockState state = null;
try {
barrier.await();
state = subject.lock(lockName(), Duration.of(10, ChronoUnit.SECONDS));
counter.incrementAndGet();
} catch (InterruptedException e) {
fail();
} catch (BrokenBarrierException e) {
fail();
} finally {
if (state != null) {
state.unlock();
}
}
}, String.format("thread-%d", i))).collect(Collectors.toList());
// 事前にスレッドを立ち上げておく
threads.forEach(Thread::start);
// どんなに競合しても 30 秒もかからないはずなので殺し屋を立てておく
Thread killer = new Thread(() -> {
try {
Thread.sleep(30000);
threads.forEach(Thread::interrupt);
} catch (InterruptedException e) {
fail();
}
}, "killer");
// 一斉にロックを取りに行く
try {
barrier.await();
// 完了を待機
threads.forEach(thread -> {
try {
thread.join();
} catch (InterruptedException e) {
fail();
}
});
// 全部取得できたことを確認
killer.interrupt();
assertThat(counter.get(), is(nThreads));
} catch (InterruptedException e) {
fail();
} catch (BrokenBarrierException e) {
fail();
}
}
@Override
protected void setUp() throws Exception {
redis = new JedisConnectionFactory();
redis.setHostName("10.101.10.214");
redis.setPort(6379);
redis.setDatabase(0);
redis.setTimeout(60000);
redis.afterPropertiesSet();
subject = new RedisLockManager(redis);
cleanup();
}
@Override
protected void tearDown() throws Exception {
cleanup();
}
private void cleanup() {
Jedis jedis = (Jedis) redis.getConnection().getNativeConnection();
jedis.del(lockName());
}
private String lockName() {
return getClass().getName();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment