Last active
May 10, 2022 14:58
-
-
Save neilherbertuk/7839e6321903da3a3df08d1da306bd4f to your computer and use it in GitHub Desktop.
OpenIAM AWS SNS OTP Provider
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.text.SimpleDateFormat | |
import java.util.LinkedHashMap | |
import java.util.Map | |
import java.util.TimeZone | |
import java.util.concurrent.TimeUnit | |
import javax.crypto.Mac | |
import javax.crypto.spec.SecretKeySpec | |
import org.apache.commons.codec.binary.Hex | |
import org.apache.commons.codec.digest.DigestUtils | |
import org.apache.commons.lang3.StringUtils | |
import org.apache.commons.logging.Log | |
import org.apache.commons.logging.LogFactory | |
import org.apache.http.HttpEntity | |
import org.apache.http.HttpHeaders | |
import org.apache.http.client.methods.CloseableHttpResponse | |
import org.apache.http.client.methods.HttpGet | |
import org.apache.http.impl.client.CloseableHttpClient | |
import org.apache.http.impl.client.HttpClients | |
import org.apache.http.util.EntityUtils | |
import org.openiam.base.ws.ResponseCode | |
import org.openiam.esb.core.auth.module.AbstractOTPModule | |
import org.openiam.esb.core.otp.service.OTPProviderService | |
import org.openiam.exception.BasicDataServiceException | |
import org.openiam.http.client.OpenIAMHttpClient | |
import org.openiam.http.model.HttpClientResponseWrapper | |
import org.openiam.idm.searchbeans.OTPProviderSearchBean | |
import org.openiam.idm.srvc.auth.domain.LoginEntity | |
import org.openiam.idm.srvc.otp.dto.OTPProvider | |
import org.openiam.idm.srvc.otp.dto.OTPProviderName | |
import org.redisson.api.RMap | |
import org.redisson.api.RedissonClient | |
import org.springframework.beans.factory.annotation.Autowired | |
/** | |
* AWSSNSOTPModule - AWS SNS SMS OTP Provider | |
* | |
* Can be used to send SMS OTP using AWS SNS | |
* | |
* In the OTP Provider set the following: | |
* | |
* AWS_REGION = region to use SNS in | |
* AWS_ACCESS_KEY_ID | |
* AWS_SECRET_ACCESS_KEY | |
* GROOVY_PATH = Path of this script | |
* TEXT_MESSAGE_FORMAT = String with text you wish to send, %s will be replaced by the OTP code e.g. "You code is: %s" (default if not provided) | |
* | |
* The script will be called in the following order: | |
* - validate | |
* - getText | |
* - send | |
* - getSandboxStatus | |
* - makeAPICall | |
* - getOptOutStatus | |
* - makeAPICall | |
* - sendMessage | |
* - makeAPICall | |
* | |
* @OpenIAM: >= 4.2.1.2 | |
* @Author: Neil Herbert <info@ea3.co.uk> | |
*/ | |
class AWSSNSOTPModule extends AbstractOTPModule { | |
private static final String ALGORITHM = "HmacSHA256" | |
private static final String OTP_PROVIDER_NAME = "Text OTP by AWS SNS" // Set this to the name of your configured OTP Provider | |
private static final String DEFAULT_TEXT_MESSAGE = "Your code is %s" | |
private static final String DEFAULT_COUNTRY_CODE = "+44" | |
private static final String AWS_DEFAULT_REGION = "eu-west-2" | |
private static final String AWS_URL = "amazonaws.com" | |
private static final String AWS_SERVICE_NAME = "sns" | |
private static final Integer THROTTLE_IN_SEC = 300 // Number in seconds | |
private static final Integer THROTTLE_MAX_REQ = 5 // Maximum number of OTP requests in the THROTTLE_IN_SEC timeframe | |
private static final Log log = LogFactory.getLog("AWSSNSOTPModule") | |
@Autowired | |
protected OTPProviderService otpProviderService | |
protected OTPProvider otpProvider | |
@Autowired | |
private RedissonClient redissonClient | |
/** | |
* AWSSNSOTPModule constructor | |
* | |
* Find OTP Provider when initialised | |
*/ | |
public AWSSNSOTPModule() { | |
log.info("Initialised") | |
log.info("Finding OTP Provider") | |
OTPProviderSearchBean sb = new OTPProviderSearchBean() | |
sb.setName(OTP_PROVIDER_NAME) | |
otpProvider = otpProviderService?.find(sb, 0, 1)?.getContent()[0] | |
} | |
/** | |
* validate() | |
* | |
* Validate that the provider has been found and the required attributes are available | |
* | |
* @param String phone - Phone number to send message to | |
* @param LoginEntity login - Account to send code for | |
*/ | |
protected void validate(String phone, LoginEntity login) throws BasicDataServiceException { | |
log.info("Validate called") | |
log.debug("Phone: ${phone}") | |
log.debug("Validating OTP Provider") | |
if (otpProvider == null) { | |
throw new BasicDataServiceException(ResponseCode.AUTH_PROVIDER_NOT_FOUND) | |
} else { | |
// Check that OTP Provider is configured | |
String awsRegion = otpProvider.getValue("AWS_REGION") | |
log.debug("AWS_REGION: ${awsRegion}") | |
String awsAccessKeyId = otpProvider.getValue("AWS_ACCESS_KEY_ID") | |
String awsSecretAccessKey = otpProvider.getValue("AWS_SECRET_ACCESS_KEY") | |
if (awsRegion == null || awsAccessKeyId == null || awsSecretAccessKey == null) { | |
throw new BasicDataServiceException(ResponseCode.RESULT_INVALID_CONFIGURATION) | |
} | |
} | |
log.info("Validating phone number") | |
log.debug(getPhoneNumber(phone)) | |
def phoneNumberRegex=/\+\d{1,14}/ | |
if(!phone.matches(phoneNumberRegex)) { | |
log.warn("Phone number not valid") | |
throw new BasicDataServiceException(ResponseCode.WRONG_TYPE_IS_FOR_SMS) | |
} | |
// Check if phone number is to be throttled | |
RMap<String,Integer> map = redissonClient.getMap("otp-throttle-${getPhoneNumber(phone)}") | |
// Get count for phone number | |
Integer count = map.get("count") ?: 0 // Set to 0 if null | |
count++ // Increse count | |
map.put("count", count) // Put count | |
map.expire(THROTTLE_IN_SEC, TimeUnit.SECONDS) // Set expiry timeout | |
log.info("Number of OTP Requests in the last ${THROTTLE_IN_SEC/60} minutes: ${count}") | |
if (count > THROTTLE_MAX_REQ) { | |
log.error("There have bene too many requests to send an OTP token to ${getPhoneNumber(phone)}") | |
throw new BasicDataServiceException(ResponseCode.SMS_TOKEN_GENERATE_ERROR) | |
} | |
} | |
/** | |
* getText() | |
* | |
* Get TEXT_MESSAGE_FORMAT from OTP Provider or use a default | |
* | |
* @param String phone - Phone number to send message to | |
* @param LoginEntity login - Account to send code for | |
* @param String token - generated code | |
* @return String | |
*/ | |
protected String getText(String phone, LoginEntity login, String token) { | |
log.info("getText called") | |
log.debug("phone: ${phone}, token: ${token}") | |
String textMessageFormat = otpProvider.getValue(OTPProviderName.TEXT_MESSAGE_FORMAT) | |
textMessageFormat = StringUtils.isBlank(textMessageFormat) ? DEFAULT_TEXT_MESSAGE : textMessageFormat | |
return String.format(textMessageFormat, token) | |
} | |
/** | |
* send() | |
* | |
* Send the OTP code via AWS SNS | |
* @param String phone - Phone number to send message to | |
* @param LoginEntity login - Account to send code for | |
* @param String text - message to send | |
*/ | |
protected void send(String phone, LoginEntity login, String text) { | |
log.info("Send called") | |
log.debug("Formatting Input Phone Number") | |
String phoneNumber = getPhoneNumber(phone) | |
log.debug("Phone Number: ${phoneNumber}") | |
log.debug("Message: $text") | |
// Check if account is in sandbox | |
boolean sandboxStatus = getSandboxStatus("true") | |
if (sandboxStatus) { | |
log.info("Your AWS Account is Sandboxed. You will only be able to send messages to verified phone numbers.") | |
} | |
// Check if phone number has been opted out | |
boolean optOutStatus = getOptOutStatus(phoneNumber) | |
if (optOutStatus) { | |
log.info("This phone number has opted out of receiving messages from your AWS Account. Message cannot be sent.") | |
throw new BasicDataServiceException(ResponseCode.WRONG_TYPE_IS_FOR_SMS) | |
} | |
// Send Message | |
sendMessage(phoneNumber, text) | |
} | |
/** | |
* getSandboxStatus() | |
* | |
* Checks to see if the AWS Account assosicated with the Caller Identity if sandboxed | |
* https://docs.aws.amazon.com/sns/latest/api/API_GetSMSSandboxAccountStatus.html | |
* | |
* @param String - Not used - required to get groovy to compile/save | |
* @return Boolean | |
*/ | |
protected boolean getSandboxStatus(String check) { | |
log.info("Checking if AWS Account is Sandboxed") | |
// Query string to append to HTTP Request - must be in alphabetical order A-Za-z | |
Map<String, String> queryParams = new LinkedHashMap<>() | |
queryParams.put("Action", "GetSMSSandboxAccountStatus") | |
try { | |
String response = makeAPICall(queryParams) | |
log.debug(response) | |
if (response.contains("\"IsInSandbox\":true")) { | |
log.debug("Account is Sandboxed") | |
return true | |
} | |
} catch (Exception e) { | |
log.error(e) | |
} | |
return false | |
} | |
/** | |
* getOptOutStatus() | |
* | |
* Checks to see if the phone number has opted out of receiving messages from your AWS SNS account | |
* https://docs.aws.amazon.com/sns/latest/api/API_CheckIfPhoneNumberIsOptedOut.html | |
* | |
* @param String phoneNumber | |
* @return Boolean | |
*/ | |
protected boolean getOptOutStatus(String phoneNumber) { | |
log.info("Checking if phone number has opted out of SMS from this AWS Account") | |
// Query string to append to HTTP Request - must be in alphabetical order A-Za-z | |
Map<String, String> queryParams = new LinkedHashMap<>() | |
queryParams.put("Action", "CheckIfPhoneNumberIsOptedOut") | |
queryParams.put("phoneNumber", phoneNumber) | |
try { | |
String response = makeAPICall(queryParams) | |
log.debug(response) | |
if (response.contains("\"isOptedOut\":true")) { | |
log.debug("Number has opted out") | |
return true | |
} | |
} catch (Exception e) { | |
log.error(e) | |
} | |
return false | |
} | |
/** | |
* sendMessage() | |
* | |
* Send SMS | |
* https://docs.aws.amazon.com/sns/latest/api/API_Publish.html | |
* | |
* @param String phoneNumber - Phone number to send out to, must be in E.164 format | |
* @param String text - Message to send to user | |
*/ | |
private void sendMessage(String phoneNumber, String text) { | |
log.info("Sending Message") | |
// Query string to append to HTTP Request - must be in alphabetical order A-Za-z | |
Map<String, String> queryParams = new LinkedHashMap<>() | |
queryParams.put("Action", "Publish") | |
queryParams.put("Message", text) | |
queryParams.put("PhoneNumber", phoneNumber) | |
// Commented out while testing throttling | |
String response = makeAPICall(queryParams) | |
log.info("API Response: ${response}") | |
} | |
/** | |
* makeAPICall() | |
* | |
* Performs an API Call to the AWS API | |
* | |
* @param Map<String, String> queryParams - A list of query string parameters to send to the API | |
* @return String - JSON Response from API Call as a String | |
*/ | |
private String makeAPICall(Map<String, String> queryParams) { | |
log.info("makeAPICall Called") | |
// Get AWS Credentials from OTP Provider | |
log.debug("Getting aws credentials from OTP Provider") | |
String awsRegion = otpProvider.getValue("AWS_REGION") ?: AWS_DEFAULT_REGION | |
String awsAccessKeyId = otpProvider.getValue("AWS_ACCESS_KEY_ID") | |
String awsSecretAccessKey = otpProvider.getValue("AWS_SECRET_ACCESS_KEY") | |
// Get Date & Time | |
log.debug("Getting Date & Time") | |
SimpleDateFormat sdf = new SimpleDateFormat("yyyyMMdd'T'HHmmss'Z'", Locale.ENGLISH) | |
sdf.setTimeZone(TimeZone.getTimeZone("UTC")) | |
String dateTime = sdf.format(new Date()).trim() | |
log.debug("Date & Time: ${dateTime}") | |
String date = dateTime.substring(0,8) | |
log.debug("Date: ${date}") | |
// Build host name | |
String host = String.format("%s.%s.%s", AWS_SERVICE_NAME, awsRegion, AWS_URL) | |
log.debug("Host: ${host}") | |
// Create empty StringBuilders for AWS Signature V4 | |
StringBuilder queryString = new StringBuilder("") | |
StringBuilder canonicalRequest = new StringBuilder("") | |
StringBuilder signedHeaders = new StringBuilder("") | |
StringBuilder stringToSign = new StringBuilder("") | |
// Headers to add to HTTP request | |
Map<String, String> headers = new LinkedHashMap<>() | |
headers.put("Accept", "application/json") | |
headers.put(HttpHeaders.CONTENT_TYPE, "application/json; charset=utf-8") | |
headers.put("Host", host) | |
headers.put("X-Amz-Date", dateTime) | |
// Query string to append to HTTP Request - must be in alphabetical order A-Za-z | |
queryParams.put("X-Amz-Algorithm", "AWS4-HMAC-SHA256") | |
queryParams.put("X-Amz-Credential", String.format("%s/%s/%s/%s/aws4_request", awsAccessKeyId, date, awsRegion, AWS_SERVICE_NAME)) | |
queryParams.put("X-Amz-Date", dateTime) | |
// Add signed headers to queryParams | |
Integer count = 1 | |
if (headers != null && !headers.isEmpty()){ | |
log.debug("Adding signed headers to queryParams") | |
for (Map.Entry<String, String> entry : headers.entrySet()) { | |
log.debug("Entry: ${count} ${entry.getKey()}") | |
signedHeaders.append(entry.getKey().toLowerCase()) | |
if (count < headers.size()) { | |
signedHeaders.append(";") | |
} | |
count++ | |
} | |
} | |
queryParams.put("X-Amz-SignedHeaders", signedHeaders.toString()) | |
// Sort queryParams Alphabetically by Code Point (int value) | |
log.debug("Sorting Query Params") | |
Map<String,String> sortedQueryParams = new LinkedHashMap<>() | |
Map<String,String> UpperCaseEntries = new LinkedHashMap<>() | |
Map<String,String> LowerCaseEntries = new LinkedHashMap<>() | |
for (Map.Entry<String, String> entry : queryParams.entrySet()) { | |
char key = entry.getKey().charAt(0) | |
if (key.isLowerCase()) { | |
LowerCaseEntries.put(entry.getKey(), entry.getValue()) | |
} else { | |
UpperCaseEntries.put(entry.getKey(), entry.getValue()) | |
} | |
} | |
log.debug("Upper Case: ${UpperCaseEntries}") | |
log.debug("Lower Case: ${LowerCaseEntries}") | |
if (UpperCaseEntries != null && !UpperCaseEntries.isEmpty()) { | |
log.debug("Sorting Upper Case Entries") | |
Map<String, String> sortedUpperCaseEntries = new TreeMap<String, String>(UpperCaseEntries) | |
log.debug(sortedUpperCaseEntries) | |
sortedQueryParams.putAll(sortedUpperCaseEntries) | |
} | |
if (LowerCaseEntries != null && !LowerCaseEntries.isEmpty()) { | |
log.debug("Sorting Lower Case Entries") | |
Map<String, String> sortedLowerCaseEntries = new TreeMap<String, String>(LowerCaseEntries) | |
log.debug(sortedLowerCaseEntries) | |
sortedQueryParams.putAll(sortedLowerCaseEntries) | |
} | |
log.debug("QueryParams: ${queryParams}") | |
log.debug("Sorted QueryParams: ${sortedQueryParams}") | |
log.debug("Building Query String") | |
log.debug("QueryParams Size: ${queryParams.size()}") | |
count = 1 | |
for (Map.Entry<String, String> entry : sortedQueryParams.entrySet()) { | |
log.debug("Entry: ${count}") | |
queryString.append(encode(entry.getKey())).append("=").append(encode(entry.getValue())) | |
if (count < queryParams.size()) | |
queryString.append("&") | |
count++ | |
} | |
log.debug(queryString) | |
log.debug("Building Initial URL for HTTP Request") | |
String url = String.format("https://%s/?%s", host, queryString) | |
log.debug("URL: ${url}") | |
// AWS Signature V4 Signing Request - https://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html | |
log.info("Generating AWS Signature V4") | |
// Step 1 - Generate a canonical request | |
log.debug("Generating canonicalRequest") | |
canonicalRequest.append("GET").append("\n") // HTTP Method | |
canonicalRequest.append("/").append("\n") // HTTP Request Path | |
canonicalRequest.append(queryString).append("\n") // HTTP Query String | |
// Add HTTP Headers | |
if (headers != null && !headers.isEmpty()){ | |
log.debug("Adding headers to canonicalRequest") | |
count = 1 | |
for (Map.Entry<String, String> entry : headers.entrySet()) { | |
log.debug("Entry: ${count} ${entry.getKey()}") | |
canonicalRequest.append(entry.getKey().toLowerCase()).append(":").append(entry.getValue().trim()) | |
if (count < headers.size()) { | |
canonicalRequest.append("\n") | |
} else { | |
canonicalRequest.append("\n") | |
} | |
count++ | |
} | |
canonicalRequest.append("\n").append(signedHeaders).append("\n") | |
} | |
canonicalRequest.append(DigestUtils.sha256Hex(("").toString())) | |
log.debug("canonicalRequest: ${canonicalRequest}") | |
// Step 2 - Generate a String to Sign | |
log.debug("Generating StringToSign") | |
stringToSign.append("AWS4-HMAC-SHA256").append("\n") | |
stringToSign.append(dateTime).append("\n") | |
stringToSign.append(date).append("/").append(awsRegion).append("/").append(AWS_SERVICE_NAME).append("/").append("aws4_request").append("\n") | |
stringToSign.append(DigestUtils.sha256Hex(canonicalRequest.toString())) | |
log.debug("StringToSign: ${stringToSign}") | |
// Step 3.1 - Derive signing key | |
log.debug("Deriving Signing Key") | |
byte[] signatureKey = getSignatureKey(awsSecretAccessKey, date, awsRegion, AWS_SERVICE_NAME) | |
log.debug("SignatureKey: ${Hex.encodeHexString(signatureKey)}") | |
// Step 3.2 - Calculate Signature | |
byte[] signature = HmacSHA256(stringToSign.toString(), signatureKey) | |
log.debug("Signature: ${Hex.encodeHexString(signature)}") | |
log.info("Creating HTTP Request") | |
log.debug("Adding Signature to Query") | |
url += "&X-Amz-Signature=${Hex.encodeHexString(signature)}" | |
HttpGet request = new HttpGet(url) | |
log.debug("Adding Headers") | |
for (Map.Entry<String, String> entry : headers.entrySet()) { | |
// Do not add Host header again | |
if (entry?.getKey()?.toLowerCase() != "host") { | |
log.debug("Adding ${entry.getKey()}") | |
request.addHeader(entry.getKey(), entry.getValue()) | |
} | |
} | |
CloseableHttpClient httpClient = HttpClients.createDefault() | |
CloseableHttpResponse response = httpClient.execute(request) | |
String responseStatus = response.getStatusLine().getStatusCode() | |
log.info("Response Status Code: ${responseStatus}") | |
log.debug("Headers: ${response.getAllHeaders()}") | |
HttpEntity entity = response.getEntity() | |
String result = "" | |
if (entity != null) { | |
result = EntityUtils.toString(entity) | |
log.debug("HTTP Response: ${result}") | |
} else { | |
log.warn("HTTP Request did not return anything") | |
} | |
if (responseStatus != "200") { | |
log.error("HTTP Request returned an error: ${response}") | |
throw new BasicDataServiceException(ResponseCode.SEND_PUSH_NOTIFICATION_FAILED) | |
} | |
return result | |
} | |
/** | |
* getPhoneNumber() | |
* | |
* Take the entered phone number and convert it to required format | |
* | |
* @param String | |
* @return String | |
*/ | |
private static String getPhoneNumber(String number) { | |
return number.replaceAll("[^0-9\\+]", "") // Remove non-numeric characters apart from a + | |
.replaceAll("(.)(\\++)(.)", "${1}${3}") // Remove + from middle of string | |
.replaceAll("^\\+0", "${DEFAULT_COUNTRY_CODE}") // If number starts with +0, replace with default Country Code | |
.replaceAll("^\\+00", "+") // Convert 00 numbers to + | |
} | |
/** | |
* encode() | |
* | |
* URI Encode a string to AWS' standards | |
* | |
* @param String | |
* @return String | |
*/ | |
private static String encode(String value) { | |
char[] input = value.toCharArray() | |
StringBuilder output = new StringBuilder("") | |
for (char ch : input) { | |
// Do not encode A-Za-z0-9-_.~ | |
if (ch.isLetterOrDigit() || ch == "-" || ch == "_" || ch == "." || ch == "~") { | |
output.append(ch) | |
} else { | |
// Percent encode all other values | |
output.append("%${Integer.toHexString((int) ch).toUpperCase()}") | |
} | |
} | |
return output | |
} | |
/** | |
* HmacSHA256() | |
* | |
* Generate HMAC SHA256 hash - from https://docs.aws.amazon.com/general/latest/gr/signature-v4-examples.html#signature-v4-examples-java | |
* | |
* @param String data - input to hash | |
* @param byte[] key - key to hash data with | |
* @return byte[] | |
*/ | |
private static byte[] HmacSHA256(String data, byte[] key) throws Exception { | |
Mac mac = Mac.getInstance(ALGORITHM) | |
mac.init(new SecretKeySpec(key, mac.getAlgorithm())) | |
return mac.doFinal(data.getBytes("UTF-8")) | |
} | |
/** | |
* getSignatureKey() | |
* | |
* Create signing signature - from https://docs.aws.amazon.com/general/latest/gr/signature-v4-examples.html#signature-v4-examples-java | |
* | |
* @param String key - AWS Secret Access Key | |
* @param String dateStamp - Date in YYYMMDD format | |
* @param String regionName - AWS Region e.g. eu-west-2 | |
* @param String serviceName - Name of the AWS Service being called in lowercase e.g. sns | |
*/ | |
static byte[] getSignatureKey(String key, String dateStamp, String regionName, String serviceName) throws Exception { | |
byte[] kSecret = ("AWS4" + key).getBytes("UTF-8") | |
byte[] kDate = HmacSHA256(dateStamp, kSecret) | |
byte[] kRegion = HmacSHA256(regionName, kDate) | |
byte[] kService = HmacSHA256(serviceName, kRegion) | |
byte[] kSigning = HmacSHA256("aws4_request", kService) | |
return kSigning | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment