Skip to content

Instantly share code, notes, and snippets.

@edwardbeckett
Created February 11, 2013 02:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save edwardbeckett/4752143 to your computer and use it in GitHub Desktop.
Save edwardbeckett/4752143 to your computer and use it in GitHub Desktop.
Beautiful Code ...
/*
* Copyright 2011 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.social.connect.jdbc;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import org.springframework.dao.DuplicateKeyException;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.security.crypto.encrypt.TextEncryptor;
import org.springframework.social.connect.Connection;
import org.springframework.social.connect.ConnectionData;
import org.springframework.social.connect.ConnectionFactory;
import org.springframework.social.connect.ConnectionFactoryLocator;
import org.springframework.social.connect.ConnectionKey;
import org.springframework.social.connect.ConnectionRepository;
import org.springframework.social.connect.DuplicateConnectionException;
import org.springframework.social.connect.NoSuchConnectionException;
import org.springframework.social.connect.NotConnectedException;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
class JdbcConnectionRepository implements ConnectionRepository {
private final String userId;
private final JdbcTemplate jdbcTemplate;
private final ConnectionFactoryLocator connectionFactoryLocator;
private final TextEncryptor textEncryptor;
private final String tablePrefix;
public JdbcConnectionRepository(String userId, JdbcTemplate jdbcTemplate, ConnectionFactoryLocator connectionFactoryLocator, TextEncryptor textEncryptor, String tablePrefix) {
this.userId = userId;
this.jdbcTemplate = jdbcTemplate;
this.connectionFactoryLocator = connectionFactoryLocator;
this.textEncryptor = textEncryptor;
this.tablePrefix = tablePrefix;
}
public MultiValueMap<String, Connection<?>> findAllConnections() {
List<Connection<?>> resultList = jdbcTemplate.query(selectFromUserConnection() + " where userId = ? order by providerId, rank", connectionMapper, userId);
MultiValueMap<String, Connection<?>> connections = new LinkedMultiValueMap<String, Connection<?>>();
Set<String> registeredProviderIds = connectionFactoryLocator.registeredProviderIds();
for (String registeredProviderId : registeredProviderIds) {
connections.put(registeredProviderId, Collections.<Connection<?>>emptyList());
}
for (Connection<?> connection : resultList) {
String providerId = connection.getKey().getProviderId();
if (connections.get(providerId).size() == 0) {
connections.put(providerId, new LinkedList<Connection<?>>());
}
connections.add(providerId, connection);
}
return connections;
}
public List<Connection<?>> findConnections(String providerId) {
return jdbcTemplate.query(selectFromUserConnection() + " where userId = ? and providerId = ? order by rank", connectionMapper, userId, providerId);
}
@SuppressWarnings("unchecked")
public <A> List<Connection<A>> findConnections(Class<A> apiType) {
List<?> connections = findConnections(getProviderId(apiType));
return (List<Connection<A>>) connections;
}
public MultiValueMap<String, Connection<?>> findConnectionsToUsers(MultiValueMap<String, String> providerUsers) {
if (providerUsers == null || providerUsers.isEmpty()) {
throw new IllegalArgumentException("Unable to execute find: no providerUsers provided");
}
StringBuilder providerUsersCriteriaSql = new StringBuilder();
MapSqlParameterSource parameters = new MapSqlParameterSource();
parameters.addValue("userId", userId);
for (Iterator<Entry<String, List<String>>> it = providerUsers.entrySet().iterator(); it.hasNext();) {
Entry<String, List<String>> entry = it.next();
String providerId = entry.getKey();
providerUsersCriteriaSql.append("providerId = :providerId_").append(providerId).append(" and providerUserId in (:providerUserIds_").append(providerId).append(")");
parameters.addValue("providerId_" + providerId, providerId);
parameters.addValue("providerUserIds_" + providerId, entry.getValue());
if (it.hasNext()) {
providerUsersCriteriaSql.append(" or " );
}
}
List<Connection<?>> resultList = new NamedParameterJdbcTemplate(jdbcTemplate).query(selectFromUserConnection() + " where userId = :userId and " + providerUsersCriteriaSql + " order by providerId, rank", parameters, connectionMapper);
MultiValueMap<String, Connection<?>> connectionsForUsers = new LinkedMultiValueMap<String, Connection<?>>();
for (Connection<?> connection : resultList) {
String providerId = connection.getKey().getProviderId();
List<String> userIds = providerUsers.get(providerId);
List<Connection<?>> connections = connectionsForUsers.get(providerId);
if (connections == null) {
connections = new ArrayList<Connection<?>>(userIds.size());
for (int i = 0; i < userIds.size(); i++) {
connections.add(null);
}
connectionsForUsers.put(providerId, connections);
}
String providerUserId = connection.getKey().getProviderUserId();
int connectionIndex = userIds.indexOf(providerUserId);
connections.set(connectionIndex, connection);
}
return connectionsForUsers;
}
public Connection<?> getConnection(ConnectionKey connectionKey) {
try {
return jdbcTemplate.queryForObject(selectFromUserConnection() + " where userId = ? and providerId = ? and providerUserId = ?", connectionMapper, userId, connectionKey.getProviderId(), connectionKey.getProviderUserId());
} catch (EmptyResultDataAccessException e) {
throw new NoSuchConnectionException(connectionKey);
}
}
@SuppressWarnings("unchecked")
public <A> Connection<A> getConnection(Class<A> apiType, String providerUserId) {
String providerId = getProviderId(apiType);
return (Connection<A>) getConnection(new ConnectionKey(providerId, providerUserId));
}
@SuppressWarnings("unchecked")
public <A> Connection<A> getPrimaryConnection(Class<A> apiType) {
String providerId = getProviderId(apiType);
Connection<A> connection = (Connection<A>) findPrimaryConnection(providerId);
if (connection == null) {
throw new NotConnectedException(providerId);
}
return connection;
}
@SuppressWarnings("unchecked")
public <A> Connection<A> findPrimaryConnection(Class<A> apiType) {
String providerId = getProviderId(apiType);
return (Connection<A>) findPrimaryConnection(providerId);
}
@Transactional
public void addConnection(Connection<?> connection) {
try {
ConnectionData data = connection.createData();
int rank = jdbcTemplate.queryForInt("select coalesce(max(rank) + 1, 1) as rank from " + tablePrefix + "UserConnection where userId = ? and providerId = ?", userId, data.getProviderId());
jdbcTemplate.update("insert into " + tablePrefix + "UserConnection (userId, providerId, providerUserId, rank, displayName, profileUrl, imageUrl, accessToken, secret, refreshToken, expireTime) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
userId, data.getProviderId(), data.getProviderUserId(), rank, data.getDisplayName(), data.getProfileUrl(), data.getImageUrl(), encrypt(data.getAccessToken()), encrypt(data.getSecret()), encrypt(data.getRefreshToken()), data.getExpireTime());
} catch (DuplicateKeyException e) {
throw new DuplicateConnectionException(connection.getKey());
}
}
public void updateConnection(Connection<?> connection) {
ConnectionData data = connection.createData();
jdbcTemplate.update("update " + tablePrefix + "UserConnection set displayName = ?, profileUrl = ?, imageUrl = ?, accessToken = ?, secret = ?, refreshToken = ?, expireTime = ? where userId = ? and providerId = ? and providerUserId = ?",
data.getDisplayName(), data.getProfileUrl(), data.getImageUrl(), encrypt(data.getAccessToken()), encrypt(data.getSecret()), encrypt(data.getRefreshToken()), data.getExpireTime(), userId, data.getProviderId(), data.getProviderUserId());
}
public void removeConnections(String providerId) {
jdbcTemplate.update("delete from " + tablePrefix + "UserConnection where userId = ? and providerId = ?", userId, providerId);
}
public void removeConnection(ConnectionKey connectionKey) {
jdbcTemplate.update("delete from " + tablePrefix + "UserConnection where userId = ? and providerId = ? and providerUserId = ?", userId, connectionKey.getProviderId(), connectionKey.getProviderUserId());
}
// internal helpers
private String selectFromUserConnection() {
return "select userId, providerId, providerUserId, displayName, profileUrl, imageUrl, accessToken, secret, refreshToken, expireTime from " + tablePrefix + "UserConnection";
}
private Connection<?> findPrimaryConnection(String providerId) {
List<Connection<?>> connections = jdbcTemplate.query(selectFromUserConnection() + " where userId = ? and providerId = ? and rank = 1", connectionMapper, userId, providerId);
if (connections.size() > 0) {
return connections.get(0);
} else {
return null;
}
}
private final ServiceProviderConnectionMapper connectionMapper = new ServiceProviderConnectionMapper();
private final class ServiceProviderConnectionMapper implements RowMapper<Connection<?>> {
public Connection<?> mapRow(ResultSet rs, int rowNum) throws SQLException {
ConnectionData connectionData = mapConnectionData(rs);
ConnectionFactory<?> connectionFactory = connectionFactoryLocator.getConnectionFactory(connectionData.getProviderId());
return connectionFactory.createConnection(connectionData);
}
private ConnectionData mapConnectionData(ResultSet rs) throws SQLException {
return new ConnectionData(rs.getString("providerId"), rs.getString("providerUserId"), rs.getString("displayName"), rs.getString("profileUrl"), rs.getString("imageUrl"),
decrypt(rs.getString("accessToken")), decrypt(rs.getString("secret")), decrypt(rs.getString("refreshToken")), expireTime(rs.getLong("expireTime")));
}
private String decrypt(String encryptedText) {
return encryptedText != null ? textEncryptor.decrypt(encryptedText) : encryptedText;
}
private Long expireTime(long expireTime) {
return expireTime == 0 ? null : expireTime;
}
}
private <A> String getProviderId(Class<A> apiType) {
return connectionFactoryLocator.getConnectionFactory(apiType).getProviderId();
}
private String encrypt(String text) {
return text != null ? textEncryptor.encrypt(text) : text;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment