Skip to content

Instantly share code, notes, and snippets.

@neilherbertuk
Last active May 10, 2022 14:58
Show Gist options
  • Save neilherbertuk/7839e6321903da3a3df08d1da306bd4f to your computer and use it in GitHub Desktop.
Save neilherbertuk/7839e6321903da3a3df08d1da306bd4f to your computer and use it in GitHub Desktop.
OpenIAM AWS SNS OTP Provider
import java.text.SimpleDateFormat
import java.util.LinkedHashMap
import java.util.Map
import java.util.TimeZone
import java.util.concurrent.TimeUnit
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import org.apache.commons.codec.binary.Hex
import org.apache.commons.codec.digest.DigestUtils
import org.apache.commons.lang3.StringUtils
import org.apache.commons.logging.Log
import org.apache.commons.logging.LogFactory
import org.apache.http.HttpEntity
import org.apache.http.HttpHeaders
import org.apache.http.client.methods.CloseableHttpResponse
import org.apache.http.client.methods.HttpGet
import org.apache.http.impl.client.CloseableHttpClient
import org.apache.http.impl.client.HttpClients
import org.apache.http.util.EntityUtils
import org.openiam.base.ws.ResponseCode
import org.openiam.esb.core.auth.module.AbstractOTPModule
import org.openiam.esb.core.otp.service.OTPProviderService
import org.openiam.exception.BasicDataServiceException
import org.openiam.http.client.OpenIAMHttpClient
import org.openiam.http.model.HttpClientResponseWrapper
import org.openiam.idm.searchbeans.OTPProviderSearchBean
import org.openiam.idm.srvc.auth.domain.LoginEntity
import org.openiam.idm.srvc.otp.dto.OTPProvider
import org.openiam.idm.srvc.otp.dto.OTPProviderName
import org.redisson.api.RMap
import org.redisson.api.RedissonClient
import org.springframework.beans.factory.annotation.Autowired
/**
* AWSSNSOTPModule - AWS SNS SMS OTP Provider
*
* Can be used to send SMS OTP using AWS SNS
*
* In the OTP Provider set the following:
*
* AWS_REGION = region to use SNS in
* AWS_ACCESS_KEY_ID
* AWS_SECRET_ACCESS_KEY
* GROOVY_PATH = Path of this script
* TEXT_MESSAGE_FORMAT = String with text you wish to send, %s will be replaced by the OTP code e.g. "You code is: %s" (default if not provided)
*
* The script will be called in the following order:
* - validate
* - getText
* - send
* - getSandboxStatus
* - makeAPICall
* - getOptOutStatus
* - makeAPICall
* - sendMessage
* - makeAPICall
*
* @OpenIAM: >= 4.2.1.2
* @Author: Neil Herbert <info@ea3.co.uk>
*/
class AWSSNSOTPModule extends AbstractOTPModule {
private static final String ALGORITHM = "HmacSHA256"
private static final String OTP_PROVIDER_NAME = "Text OTP by AWS SNS" // Set this to the name of your configured OTP Provider
private static final String DEFAULT_TEXT_MESSAGE = "Your code is %s"
private static final String DEFAULT_COUNTRY_CODE = "+44"
private static final String AWS_DEFAULT_REGION = "eu-west-2"
private static final String AWS_URL = "amazonaws.com"
private static final String AWS_SERVICE_NAME = "sns"
private static final Integer THROTTLE_IN_SEC = 300 // Number in seconds
private static final Integer THROTTLE_MAX_REQ = 5 // Maximum number of OTP requests in the THROTTLE_IN_SEC timeframe
private static final Log log = LogFactory.getLog("AWSSNSOTPModule")
@Autowired
protected OTPProviderService otpProviderService
protected OTPProvider otpProvider
@Autowired
private RedissonClient redissonClient
/**
* AWSSNSOTPModule constructor
*
* Find OTP Provider when initialised
*/
public AWSSNSOTPModule() {
log.info("Initialised")
log.info("Finding OTP Provider")
OTPProviderSearchBean sb = new OTPProviderSearchBean()
sb.setName(OTP_PROVIDER_NAME)
otpProvider = otpProviderService?.find(sb, 0, 1)?.getContent()[0]
}
/**
* validate()
*
* Validate that the provider has been found and the required attributes are available
*
* @param String phone - Phone number to send message to
* @param LoginEntity login - Account to send code for
*/
protected void validate(String phone, LoginEntity login) throws BasicDataServiceException {
log.info("Validate called")
log.debug("Phone: ${phone}")
log.debug("Validating OTP Provider")
if (otpProvider == null) {
throw new BasicDataServiceException(ResponseCode.AUTH_PROVIDER_NOT_FOUND)
} else {
// Check that OTP Provider is configured
String awsRegion = otpProvider.getValue("AWS_REGION")
log.debug("AWS_REGION: ${awsRegion}")
String awsAccessKeyId = otpProvider.getValue("AWS_ACCESS_KEY_ID")
String awsSecretAccessKey = otpProvider.getValue("AWS_SECRET_ACCESS_KEY")
if (awsRegion == null || awsAccessKeyId == null || awsSecretAccessKey == null) {
throw new BasicDataServiceException(ResponseCode.RESULT_INVALID_CONFIGURATION)
}
}
log.info("Validating phone number")
log.debug(getPhoneNumber(phone))
def phoneNumberRegex=/\+\d{1,14}/
if(!phone.matches(phoneNumberRegex)) {
log.warn("Phone number not valid")
throw new BasicDataServiceException(ResponseCode.WRONG_TYPE_IS_FOR_SMS)
}
// Check if phone number is to be throttled
RMap<String,Integer> map = redissonClient.getMap("otp-throttle-${getPhoneNumber(phone)}")
// Get count for phone number
Integer count = map.get("count") ?: 0 // Set to 0 if null
count++ // Increse count
map.put("count", count) // Put count
map.expire(THROTTLE_IN_SEC, TimeUnit.SECONDS) // Set expiry timeout
log.info("Number of OTP Requests in the last ${THROTTLE_IN_SEC/60} minutes: ${count}")
if (count > THROTTLE_MAX_REQ) {
log.error("There have bene too many requests to send an OTP token to ${getPhoneNumber(phone)}")
throw new BasicDataServiceException(ResponseCode.SMS_TOKEN_GENERATE_ERROR)
}
}
/**
* getText()
*
* Get TEXT_MESSAGE_FORMAT from OTP Provider or use a default
*
* @param String phone - Phone number to send message to
* @param LoginEntity login - Account to send code for
* @param String token - generated code
* @return String
*/
protected String getText(String phone, LoginEntity login, String token) {
log.info("getText called")
log.debug("phone: ${phone}, token: ${token}")
String textMessageFormat = otpProvider.getValue(OTPProviderName.TEXT_MESSAGE_FORMAT)
textMessageFormat = StringUtils.isBlank(textMessageFormat) ? DEFAULT_TEXT_MESSAGE : textMessageFormat
return String.format(textMessageFormat, token)
}
/**
* send()
*
* Send the OTP code via AWS SNS
* @param String phone - Phone number to send message to
* @param LoginEntity login - Account to send code for
* @param String text - message to send
*/
protected void send(String phone, LoginEntity login, String text) {
log.info("Send called")
log.debug("Formatting Input Phone Number")
String phoneNumber = getPhoneNumber(phone)
log.debug("Phone Number: ${phoneNumber}")
log.debug("Message: $text")
// Check if account is in sandbox
boolean sandboxStatus = getSandboxStatus("true")
if (sandboxStatus) {
log.info("Your AWS Account is Sandboxed. You will only be able to send messages to verified phone numbers.")
}
// Check if phone number has been opted out
boolean optOutStatus = getOptOutStatus(phoneNumber)
if (optOutStatus) {
log.info("This phone number has opted out of receiving messages from your AWS Account. Message cannot be sent.")
throw new BasicDataServiceException(ResponseCode.WRONG_TYPE_IS_FOR_SMS)
}
// Send Message
sendMessage(phoneNumber, text)
}
/**
* getSandboxStatus()
*
* Checks to see if the AWS Account assosicated with the Caller Identity if sandboxed
* https://docs.aws.amazon.com/sns/latest/api/API_GetSMSSandboxAccountStatus.html
*
* @param String - Not used - required to get groovy to compile/save
* @return Boolean
*/
protected boolean getSandboxStatus(String check) {
log.info("Checking if AWS Account is Sandboxed")
// Query string to append to HTTP Request - must be in alphabetical order A-Za-z
Map<String, String> queryParams = new LinkedHashMap<>()
queryParams.put("Action", "GetSMSSandboxAccountStatus")
try {
String response = makeAPICall(queryParams)
log.debug(response)
if (response.contains("\"IsInSandbox\":true")) {
log.debug("Account is Sandboxed")
return true
}
} catch (Exception e) {
log.error(e)
}
return false
}
/**
* getOptOutStatus()
*
* Checks to see if the phone number has opted out of receiving messages from your AWS SNS account
* https://docs.aws.amazon.com/sns/latest/api/API_CheckIfPhoneNumberIsOptedOut.html
*
* @param String phoneNumber
* @return Boolean
*/
protected boolean getOptOutStatus(String phoneNumber) {
log.info("Checking if phone number has opted out of SMS from this AWS Account")
// Query string to append to HTTP Request - must be in alphabetical order A-Za-z
Map<String, String> queryParams = new LinkedHashMap<>()
queryParams.put("Action", "CheckIfPhoneNumberIsOptedOut")
queryParams.put("phoneNumber", phoneNumber)
try {
String response = makeAPICall(queryParams)
log.debug(response)
if (response.contains("\"isOptedOut\":true")) {
log.debug("Number has opted out")
return true
}
} catch (Exception e) {
log.error(e)
}
return false
}
/**
* sendMessage()
*
* Send SMS
* https://docs.aws.amazon.com/sns/latest/api/API_Publish.html
*
* @param String phoneNumber - Phone number to send out to, must be in E.164 format
* @param String text - Message to send to user
*/
private void sendMessage(String phoneNumber, String text) {
log.info("Sending Message")
// Query string to append to HTTP Request - must be in alphabetical order A-Za-z
Map<String, String> queryParams = new LinkedHashMap<>()
queryParams.put("Action", "Publish")
queryParams.put("Message", text)
queryParams.put("PhoneNumber", phoneNumber)
// Commented out while testing throttling
String response = makeAPICall(queryParams)
log.info("API Response: ${response}")
}
/**
* makeAPICall()
*
* Performs an API Call to the AWS API
*
* @param Map<String, String> queryParams - A list of query string parameters to send to the API
* @return String - JSON Response from API Call as a String
*/
private String makeAPICall(Map<String, String> queryParams) {
log.info("makeAPICall Called")
// Get AWS Credentials from OTP Provider
log.debug("Getting aws credentials from OTP Provider")
String awsRegion = otpProvider.getValue("AWS_REGION") ?: AWS_DEFAULT_REGION
String awsAccessKeyId = otpProvider.getValue("AWS_ACCESS_KEY_ID")
String awsSecretAccessKey = otpProvider.getValue("AWS_SECRET_ACCESS_KEY")
// Get Date & Time
log.debug("Getting Date & Time")
SimpleDateFormat sdf = new SimpleDateFormat("yyyyMMdd'T'HHmmss'Z'", Locale.ENGLISH)
sdf.setTimeZone(TimeZone.getTimeZone("UTC"))
String dateTime = sdf.format(new Date()).trim()
log.debug("Date & Time: ${dateTime}")
String date = dateTime.substring(0,8)
log.debug("Date: ${date}")
// Build host name
String host = String.format("%s.%s.%s", AWS_SERVICE_NAME, awsRegion, AWS_URL)
log.debug("Host: ${host}")
// Create empty StringBuilders for AWS Signature V4
StringBuilder queryString = new StringBuilder("")
StringBuilder canonicalRequest = new StringBuilder("")
StringBuilder signedHeaders = new StringBuilder("")
StringBuilder stringToSign = new StringBuilder("")
// Headers to add to HTTP request
Map<String, String> headers = new LinkedHashMap<>()
headers.put("Accept", "application/json")
headers.put(HttpHeaders.CONTENT_TYPE, "application/json; charset=utf-8")
headers.put("Host", host)
headers.put("X-Amz-Date", dateTime)
// Query string to append to HTTP Request - must be in alphabetical order A-Za-z
queryParams.put("X-Amz-Algorithm", "AWS4-HMAC-SHA256")
queryParams.put("X-Amz-Credential", String.format("%s/%s/%s/%s/aws4_request", awsAccessKeyId, date, awsRegion, AWS_SERVICE_NAME))
queryParams.put("X-Amz-Date", dateTime)
// Add signed headers to queryParams
Integer count = 1
if (headers != null && !headers.isEmpty()){
log.debug("Adding signed headers to queryParams")
for (Map.Entry<String, String> entry : headers.entrySet()) {
log.debug("Entry: ${count} ${entry.getKey()}")
signedHeaders.append(entry.getKey().toLowerCase())
if (count < headers.size()) {
signedHeaders.append(";")
}
count++
}
}
queryParams.put("X-Amz-SignedHeaders", signedHeaders.toString())
// Sort queryParams Alphabetically by Code Point (int value)
log.debug("Sorting Query Params")
Map<String,String> sortedQueryParams = new LinkedHashMap<>()
Map<String,String> UpperCaseEntries = new LinkedHashMap<>()
Map<String,String> LowerCaseEntries = new LinkedHashMap<>()
for (Map.Entry<String, String> entry : queryParams.entrySet()) {
char key = entry.getKey().charAt(0)
if (key.isLowerCase()) {
LowerCaseEntries.put(entry.getKey(), entry.getValue())
} else {
UpperCaseEntries.put(entry.getKey(), entry.getValue())
}
}
log.debug("Upper Case: ${UpperCaseEntries}")
log.debug("Lower Case: ${LowerCaseEntries}")
if (UpperCaseEntries != null && !UpperCaseEntries.isEmpty()) {
log.debug("Sorting Upper Case Entries")
Map<String, String> sortedUpperCaseEntries = new TreeMap<String, String>(UpperCaseEntries)
log.debug(sortedUpperCaseEntries)
sortedQueryParams.putAll(sortedUpperCaseEntries)
}
if (LowerCaseEntries != null && !LowerCaseEntries.isEmpty()) {
log.debug("Sorting Lower Case Entries")
Map<String, String> sortedLowerCaseEntries = new TreeMap<String, String>(LowerCaseEntries)
log.debug(sortedLowerCaseEntries)
sortedQueryParams.putAll(sortedLowerCaseEntries)
}
log.debug("QueryParams: ${queryParams}")
log.debug("Sorted QueryParams: ${sortedQueryParams}")
log.debug("Building Query String")
log.debug("QueryParams Size: ${queryParams.size()}")
count = 1
for (Map.Entry<String, String> entry : sortedQueryParams.entrySet()) {
log.debug("Entry: ${count}")
queryString.append(encode(entry.getKey())).append("=").append(encode(entry.getValue()))
if (count < queryParams.size())
queryString.append("&")
count++
}
log.debug(queryString)
log.debug("Building Initial URL for HTTP Request")
String url = String.format("https://%s/?%s", host, queryString)
log.debug("URL: ${url}")
// AWS Signature V4 Signing Request - https://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html
log.info("Generating AWS Signature V4")
// Step 1 - Generate a canonical request
log.debug("Generating canonicalRequest")
canonicalRequest.append("GET").append("\n") // HTTP Method
canonicalRequest.append("/").append("\n") // HTTP Request Path
canonicalRequest.append(queryString).append("\n") // HTTP Query String
// Add HTTP Headers
if (headers != null && !headers.isEmpty()){
log.debug("Adding headers to canonicalRequest")
count = 1
for (Map.Entry<String, String> entry : headers.entrySet()) {
log.debug("Entry: ${count} ${entry.getKey()}")
canonicalRequest.append(entry.getKey().toLowerCase()).append(":").append(entry.getValue().trim())
if (count < headers.size()) {
canonicalRequest.append("\n")
} else {
canonicalRequest.append("\n")
}
count++
}
canonicalRequest.append("\n").append(signedHeaders).append("\n")
}
canonicalRequest.append(DigestUtils.sha256Hex(("").toString()))
log.debug("canonicalRequest: ${canonicalRequest}")
// Step 2 - Generate a String to Sign
log.debug("Generating StringToSign")
stringToSign.append("AWS4-HMAC-SHA256").append("\n")
stringToSign.append(dateTime).append("\n")
stringToSign.append(date).append("/").append(awsRegion).append("/").append(AWS_SERVICE_NAME).append("/").append("aws4_request").append("\n")
stringToSign.append(DigestUtils.sha256Hex(canonicalRequest.toString()))
log.debug("StringToSign: ${stringToSign}")
// Step 3.1 - Derive signing key
log.debug("Deriving Signing Key")
byte[] signatureKey = getSignatureKey(awsSecretAccessKey, date, awsRegion, AWS_SERVICE_NAME)
log.debug("SignatureKey: ${Hex.encodeHexString(signatureKey)}")
// Step 3.2 - Calculate Signature
byte[] signature = HmacSHA256(stringToSign.toString(), signatureKey)
log.debug("Signature: ${Hex.encodeHexString(signature)}")
log.info("Creating HTTP Request")
log.debug("Adding Signature to Query")
url += "&X-Amz-Signature=${Hex.encodeHexString(signature)}"
HttpGet request = new HttpGet(url)
log.debug("Adding Headers")
for (Map.Entry<String, String> entry : headers.entrySet()) {
// Do not add Host header again
if (entry?.getKey()?.toLowerCase() != "host") {
log.debug("Adding ${entry.getKey()}")
request.addHeader(entry.getKey(), entry.getValue())
}
}
CloseableHttpClient httpClient = HttpClients.createDefault()
CloseableHttpResponse response = httpClient.execute(request)
String responseStatus = response.getStatusLine().getStatusCode()
log.info("Response Status Code: ${responseStatus}")
log.debug("Headers: ${response.getAllHeaders()}")
HttpEntity entity = response.getEntity()
String result = ""
if (entity != null) {
result = EntityUtils.toString(entity)
log.debug("HTTP Response: ${result}")
} else {
log.warn("HTTP Request did not return anything")
}
if (responseStatus != "200") {
log.error("HTTP Request returned an error: ${response}")
throw new BasicDataServiceException(ResponseCode.SEND_PUSH_NOTIFICATION_FAILED)
}
return result
}
/**
* getPhoneNumber()
*
* Take the entered phone number and convert it to required format
*
* @param String
* @return String
*/
private static String getPhoneNumber(String number) {
return number.replaceAll("[^0-9\\+]", "") // Remove non-numeric characters apart from a +
.replaceAll("(.)(\\++)(.)", "${1}${3}") // Remove + from middle of string
.replaceAll("^\\+0", "${DEFAULT_COUNTRY_CODE}") // If number starts with +0, replace with default Country Code
.replaceAll("^\\+00", "+") // Convert 00 numbers to +
}
/**
* encode()
*
* URI Encode a string to AWS' standards
*
* @param String
* @return String
*/
private static String encode(String value) {
char[] input = value.toCharArray()
StringBuilder output = new StringBuilder("")
for (char ch : input) {
// Do not encode A-Za-z0-9-_.~
if (ch.isLetterOrDigit() || ch == "-" || ch == "_" || ch == "." || ch == "~") {
output.append(ch)
} else {
// Percent encode all other values
output.append("%${Integer.toHexString((int) ch).toUpperCase()}")
}
}
return output
}
/**
* HmacSHA256()
*
* Generate HMAC SHA256 hash - from https://docs.aws.amazon.com/general/latest/gr/signature-v4-examples.html#signature-v4-examples-java
*
* @param String data - input to hash
* @param byte[] key - key to hash data with
* @return byte[]
*/
private static byte[] HmacSHA256(String data, byte[] key) throws Exception {
Mac mac = Mac.getInstance(ALGORITHM)
mac.init(new SecretKeySpec(key, mac.getAlgorithm()))
return mac.doFinal(data.getBytes("UTF-8"))
}
/**
* getSignatureKey()
*
* Create signing signature - from https://docs.aws.amazon.com/general/latest/gr/signature-v4-examples.html#signature-v4-examples-java
*
* @param String key - AWS Secret Access Key
* @param String dateStamp - Date in YYYMMDD format
* @param String regionName - AWS Region e.g. eu-west-2
* @param String serviceName - Name of the AWS Service being called in lowercase e.g. sns
*/
static byte[] getSignatureKey(String key, String dateStamp, String regionName, String serviceName) throws Exception {
byte[] kSecret = ("AWS4" + key).getBytes("UTF-8")
byte[] kDate = HmacSHA256(dateStamp, kSecret)
byte[] kRegion = HmacSHA256(regionName, kDate)
byte[] kService = HmacSHA256(serviceName, kRegion)
byte[] kSigning = HmacSHA256("aws4_request", kService)
return kSigning
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment