Skip to content

Instantly share code, notes, and snippets.

@jonyt
Last active September 12, 2021 08:23
Show Gist options
  • Save jonyt/2ac61baea2bb4396e47aa29f379039bc to your computer and use it in GitHub Desktop.
Save jonyt/2ac61baea2bb4396e47aa29f379039bc to your computer and use it in GitHub Desktop.
import com.codahale.metrics.Histogram;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import com.codahale.metrics.jmx.JmxReporter;
import org.apache.catalina.Context;
import org.apache.catalina.Globals;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.Session;
import org.apache.catalina.session.StandardSession;
import org.apache.catalina.session.StoreBase;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import java.io.*;
import java.net.URI;
import java.nio.charset.Charset;
import java.time.Duration;
import java.util.Arrays;
import java.util.Set;
import java.util.concurrent.TimeUnit;
public class JedisStore extends StoreBase {
private static final String KEY_PREFIX = "TOMCAT_SESSION";
private static final Charset UTF8 = Charset.forName("UTF-8");
private static final MetricRegistry REGISTRY = new MetricRegistry();
private static final Timer TIMER = REGISTRY.timer(MetricRegistry.name(JedisStore.class, "storeRate"));
private static final Histogram HISTOGRAM = REGISTRY.histogram(MetricRegistry.name(JedisStore.class, "sessionSizes"));
private static final JmxReporter REPORTER = JmxReporter
.forRegistry(REGISTRY)
.convertDurationsTo(TimeUnit.SECONDS)
.convertDurationsTo(TimeUnit.SECONDS)
.build();
private final Log logger = LogFactory.getLog(this.getClass());
private JedisPool jedisPool;
private int expirationInSeconds = (int) Duration.ofDays(1).getSeconds();
@Override
public int getSize() {
return keys().length;
}
@Override
public String[] keys() {
try (Jedis jedis = getConnection()) {
Set < String > keys = jedis.keys(KEY_PREFIX + "*");
return keys == null ? new String[0] : keys.toArray(new String[keys.size()]);
}
}
@Override
public Session load(String sessionId) {
logger.debug("Going to load session " + sessionId);
String key = getKey(sessionId);
byte[] bytes;
try (Jedis jedis = getConnection()) {
bytes = jedis.get(key.getBytes(UTF8));
if (bytes == null || bytes.length == 0) {
logger.warn("Couldn't find session " + sessionId);
return null;
}
} catch (Exception e) {
logger.error("Error getting session " + sessionId + " from Redis", e);
return null;
}
Context context = manager.getContext();
ClassLoader oldThreadContextCL = context.bind(Globals.IS_SECURITY_ENABLED, null);
try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
ObjectInputStream ois = getObjectInputStream(bis)) {
StandardSession session = (StandardSession) manager.createEmptySession();
session.readObjectData(ois);
session.setManager(manager);
logger.info("Successfully loaded session " + sessionId);
return session;
} catch (Exception e) {
logger.error("Error deserializing session " + sessionId, e);
return null;
} finally {
context.unbind(Globals.IS_SECURITY_ENABLED, oldThreadContextCL);
}
}
@Override
public void remove(String sessionId) {
String key = getKey(sessionId);
try (Jedis jedis = getConnection()) {
jedis.del(key.getBytes(UTF8));
logger.info("Deleted session " + sessionId);
} catch (Exception e) {
logger.error("Failed to delete session " + sessionId, e);
}
}
@Override
public void clear() {
logger.debug("Going to delete all sessions");
String[] keys = keys();
if (keys.length > 0) {
try (Jedis jedis = getConnection()) {
jedis.del(keys);
logger.info("Deleted all sessions");
} catch (Exception e) {
logger.error("Failed to delete all sessions");
}
}
}
@Override
public void save(Session session) throws IOException {
logger.debug("Saving session " + session.getId());
String key = getKey(session.getId());
try (Timer.Context context = TIMER.time();
ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(bos)) {
((StandardSession) session).writeObjectData(oos);
bos.flush();
byte[] bytes = bos.toByteArray();
try (Jedis jedis = getConnection()) {
jedis.set(key.getBytes(UTF8), bytes);
jedis.expire(key.getBytes(UTF8), expirationInSeconds);
HISTOGRAM.update(bytes.length);
logger.info(
String.format("Session %s saved. Will expire in %d seconds unless saved again", session.getId(), expiration)
);
}
} catch (Exception e) {
logger.error("Failed to serialize session " + session.getId(), e);
}
}
@Override
protected synchronized void startInternal() throws LifecycleException {
super.startInternal();
REPORTER.start();
}
@Override
protected synchronized void stopInternal() throws LifecycleException {
super.stopInternal();
if (jedisPool != null && !jedisPool.isClosed())
jedisPool.close();
REPORTER.stop();
REPORTER.close();
}
public void setRedisAddress(String redisAddress) {
try {
URI uri = new URI(redisAddress);
jedisPool = new JedisPool(uri.getHost(), uri.getPort());
logger.info("Set Redis address to " + redisAddress);
} catch (Exception e) {
logger.error("Failed to set Redis address to " + redisAddress, e);
}
}
public void setExpirationInSeconds(int expirationInSeconds) {
this.expirationInSeconds = expirationInSeconds;
}
public synchronized static String getKey(String sessionId) {
return sessionId.startsWith(KEY_PREFIX) ? sessionId : String.format("%s:%s", KEY_PREFIX, sessionId);
}
private Jedis getConnection() {
if (jedisPool == null)
throw new NullPointerException("Jedis pool is uninitialized");
return jedisPool.getResource();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment