-
-
Save briandilley/3eae82be989cd392bf8f56a8975f3d03 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.streamhub.core.concurrency | |
import com.streamhub.core.config.StreamHubCoreProperties | |
import com.streamhub.core.jdbc.AbstractJdbcDao | |
import com.streamhub.core.util.getLogger | |
import org.codejargon.fluentjdbc.api.FluentJdbc | |
import org.springframework.context.annotation.Bean | |
import org.springframework.context.annotation.Configuration | |
import java.sql.Connection | |
import java.util.concurrent.TimeUnit | |
import java.util.concurrent.atomic.AtomicInteger | |
import java.util.concurrent.locks.Condition | |
import java.util.concurrent.locks.Lock | |
import javax.sql.DataSource | |
@Configuration | |
class JdbcDistributedLockFactoryConfig( | |
val jdbc: FluentJdbc, | |
val dataSource: DataSource, | |
var props: StreamHubCoreProperties) { | |
@Bean | |
fun jdbcDistributedLockFactory(): JdbcDistributedLockFactory { | |
return JdbcDistributedLockFactory(jdbc, dataSource, props.defaultLockTimeoutMillis) | |
} | |
} | |
class JdbcDistributedLockFactory( | |
jdbc: FluentJdbc, | |
val dataSource: DataSource, | |
val defaultLockTimeoutMillis: Long = 1_000) : AbstractJdbcDao(jdbc), DistributedLockFactory { | |
override fun createLock(id: Long): Lock = JdbcPostgresAdvisoryLock(id, defaultLockTimeoutMillis, jdbc, dataSource) | |
} | |
class JdbcPostgresAdvisoryLock( | |
val id: Long, | |
private val defaultLockTimeoutMillis: Long, | |
private val jdbc: FluentJdbc, | |
private val dataSource: DataSource) : Lock { | |
companion object { | |
private val LOGGER = getLogger(JdbcPostgresAdvisoryLock::class) | |
} | |
private var _connection: Connection? = null | |
private val _lockCount: AtomicInteger = AtomicInteger(0) | |
private fun acquireConnection(): Connection { | |
if (_connection == null) { | |
_connection = dataSource.connection | |
LOGGER.trace("Acquired connection: ${_connection.hashCode()} for lock $id") | |
} | |
return _connection!! | |
} | |
private fun maybeReleaseConnection() { | |
if (_lockCount.get() <= 0 && _connection != null) { | |
LOGGER.trace("Releasing connection: ${_connection.hashCode()} for lock $id") | |
_connection?.close() | |
_connection = null | |
} | |
} | |
private fun incrementLockCount() { | |
_lockCount.incrementAndGet() | |
} | |
private fun decrementLockCount() { | |
_lockCount.decrementAndGet() | |
maybeReleaseConnection() | |
} | |
val lockCount: Int get() = _lockCount.get() | |
val connection: Connection? get() = _connection | |
override fun lock() { | |
var locked = false | |
try { | |
locked = jdbc.queryOn(acquireConnection()) | |
.select("select acquire_lock(:lock_id, :timeout) as lock_result;") | |
.namedParam("lock_id", id) | |
.namedParam("timeout", defaultLockTimeoutMillis) | |
.singleResult { it.getBoolean("lock_result") } | |
if (!locked) { | |
throw IllegalStateException("Unable to acquire lock: $id, deadlocked?") | |
} | |
} finally { | |
if (locked) { | |
incrementLockCount() | |
} | |
maybeReleaseConnection() | |
} | |
} | |
override fun tryLock(): Boolean { | |
var locked = false | |
try { | |
locked = jdbc.queryOn(acquireConnection()) | |
.select("select acquire_lock(:lock_id, :timeout) as lock_result;") | |
.namedParam("lock_id", id) | |
.namedParam("timeout", defaultLockTimeoutMillis) | |
.singleResult { it.getBoolean("lock_result") } | |
} finally { | |
if (locked) { | |
incrementLockCount() | |
} | |
maybeReleaseConnection() | |
} | |
return locked | |
} | |
override fun tryLock(time: Long, unit: TimeUnit): Boolean { | |
var locked = false | |
try { | |
locked = jdbc.queryOn(acquireConnection()) | |
.select("select acquire_lock(:lock_id, :timeout) as lock_result;") | |
.namedParam("lock_id", id) | |
.namedParam("timeout", unit.toMillis(time)) | |
.singleResult { it.getBoolean("lock_result") } | |
} finally { | |
if (locked) { | |
incrementLockCount() | |
} | |
maybeReleaseConnection() | |
} | |
return locked | |
} | |
override fun unlock() { | |
var unlocked = false | |
try { | |
unlocked = jdbc.queryOn(acquireConnection()) | |
.select("select release_lock(:lock_id) as lock_result;") | |
.namedParam("lock_id", id) | |
.singleResult { it.getBoolean("lock_result") } | |
if (!unlocked) { | |
throw IllegalStateException("Unable to release lock: $id, not the owner?") | |
} | |
} finally { | |
if (unlocked) { | |
decrementLockCount() | |
} | |
maybeReleaseConnection() | |
} | |
} | |
override fun lockInterruptibly() = lock() | |
override fun newCondition(): Condition = throw UnsupportedOperationException() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment