Skip to content

Instantly share code, notes, and snippets.

@snambi
Created September 18, 2023 16:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save snambi/b9080ed6382ae7495a1b467f03d877ff to your computer and use it in GitHub Desktop.
Save snambi/b9080ed6382ae7495a1b467f03d877ff to your computer and use it in GitHub Desktop.
Generate JWT Using Java
import org.apache.commons.codec.binary.Base64;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.charset.StandardCharsets;
import java.security.PrivateKey;
import java.security.SecureRandom;
import java.security.Signature;
import java.text.MessageFormat;
import org.apache.commons.codec.binary.Base64;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.GeneralSecurityException;
import java.security.KeyFactory;
import java.security.PrivateKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
public class JWTGenerator {
private static final Logger logger = LoggerFactory.getLogger(JWTGenerator.class);
public static final String JWT_HEADER = "{ \"alg\":\"RS256\" }";
public static final String JWT_CLAIM_TEMPLATE = "'{' \"iss\": \"{0}\", \"sub\": \"{1}\", \"aud\": \"{2}\", \"exp\": \"{3}\", \"jti\": \"{4}\" '}'";
public static void main(String[] args ){
String jwt = generate("3HjG9oe0cTvGocgRrFOWSLtGr0NkYA0iTTJsPZoHWbR0vhr3T2j6lOttcafh1Mmhh7QkXH.qTyvZ096qHhF2x",
"snambi@gmail.com",
"https://test.myapp.com",
60*10,
"/Users/snambi/.secrets/myapp/server.key");
System.out.println("JWT: "+ jwt);
}
public static String generate(String iss, String sub, String aud, int expiry, String pathToPrivateKey) {
String result = null;
try {
StringBuilder token = new StringBuilder();
//Encode the JWT Header and add it to our string to sign
token.append(Base64.encodeBase64URLSafeString(JWT_HEADER.getBytes(StandardCharsets.UTF_8)));
//Separate with a period
token.append(".");
//Create the JWT Claims Object
String[] claimArray = new String[5];
// ISS
claimArray[0] = iss;
// SUB
claimArray[1] = sub;
// AUD
claimArray[2] = aud; // for production replace "test" with "login"
// Expiry
claimArray[3] = Long.toString( ( System.currentTimeMillis()/1000 ) + expiry);
// JTI ( must be a unique number )
SecureRandom rand = new SecureRandom();
int randomNum = rand.nextInt();
claimArray[4]= String.valueOf(randomNum);
System.out.println("claimsTemplate: "+ JWT_CLAIM_TEMPLATE);
MessageFormat claims = new MessageFormat(JWT_CLAIM_TEMPLATE);
String payload = claims.format(claimArray);
//Add the encoded claims object
token.append(Base64.encodeBase64URLSafeString(payload.getBytes(StandardCharsets.UTF_8)));
//Load the private key from a PKCS1 key file
PrivateKey privateKey = PrivateKeyReader.loadKey(pathToPrivateKey);
//PrivateKey privateKey = loadKey("/Users/nsankar/Projects/salesforce/server.key");
//Sign the JWT Header + "." + JWT Claims Object
Signature signature = Signature.getInstance("SHA256withRSA");
signature.initSign(privateKey);
signature.update(token.toString().getBytes(StandardCharsets.UTF_8));
String signedPayload = Base64.encodeBase64URLSafeString(signature.sign());
//Separate with a period
token.append(".");
//Add the encoded signature
token.append(signedPayload);
//System.out.println( token.toString());
result = token.toString();
} catch (Exception e) {
throw new RuntimeException(e);
}
return result;
}
public static class PrivateKeyReader {
private static final Logger logger = LoggerFactory.getLogger(PrivateKeyReader.class);
private static final String PKCS_1_PEM_HEADER = "-----BEGIN RSA PRIVATE KEY-----";
private static final String PKCS_1_PEM_FOOTER = "-----END RSA PRIVATE KEY-----";
private static final String PKCS_8_PEM_HEADER = "-----BEGIN PRIVATE KEY-----";
private static final String PKCS_8_PEM_FOOTER = "-----END PRIVATE KEY-----";
public static PrivateKey loadKey(String keyFilePath) throws GeneralSecurityException, IOException {
byte[] keyDataBytes = Files.readAllBytes(Paths.get(keyFilePath));
String keyDataString = new String(keyDataBytes, StandardCharsets.UTF_8);
if (keyDataString.contains(PKCS_1_PEM_HEADER)) {
// OpenSSL / PKCS#1 Base64 PEM encoded file
keyDataString = keyDataString.replace(PKCS_1_PEM_HEADER, "");
keyDataString = keyDataString.replace(PKCS_1_PEM_FOOTER, "");
return readPkcs1PrivateKey(Base64.decodeBase64(keyDataString));
}
if (keyDataString.contains(PKCS_8_PEM_HEADER)) {
// PKCS#8 Base64 PEM encoded file
keyDataString = keyDataString.replace(PKCS_8_PEM_HEADER, "");
keyDataString = keyDataString.replace(PKCS_8_PEM_FOOTER, "");
return readPkcs8PrivateKey(Base64.decodeBase64(keyDataString));
}
// We assume it's a PKCS#8 DER encoded binary file
return readPkcs8PrivateKey(Files.readAllBytes(Paths.get(keyFilePath)));
}
private static PrivateKey readPkcs8PrivateKey(byte[] pkcs8Bytes) throws GeneralSecurityException {
KeyFactory keyFactory = KeyFactory.getInstance("RSA", "SunRsaSign");
PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(pkcs8Bytes);
try {
return keyFactory.generatePrivate(keySpec);
} catch (InvalidKeySpecException e) {
throw new IllegalArgumentException("Unexpected key format!", e);
}
}
private static PrivateKey readPkcs1PrivateKey(byte[] pkcs1Bytes) throws GeneralSecurityException {
// We can't use Java internal APIs to parse ASN.1 structures, so we build a PKCS#8 key Java can understand
int pkcs1Length = pkcs1Bytes.length;
int totalLength = pkcs1Length + 22;
byte[] pkcs8Header = new byte[] {
0x30, (byte) 0x82, (byte) ((totalLength >> 8) & 0xff), (byte) (totalLength & 0xff), // Sequence + total length
0x2, 0x1, 0x0, // Integer (0)
0x30, 0xD, 0x6, 0x9, 0x2A, (byte) 0x86, 0x48, (byte) 0x86, (byte) 0xF7, 0xD, 0x1, 0x1, 0x1, 0x5, 0x0, // Sequence: 1.2.840.113549.1.1.1, NULL
0x4, (byte) 0x82, (byte) ((pkcs1Length >> 8) & 0xff), (byte) (pkcs1Length & 0xff) // Octet string + length
};
byte[] pkcs8bytes = join(pkcs8Header, pkcs1Bytes);
return readPkcs8PrivateKey(pkcs8bytes);
}
private static byte[] join(byte[] byteArray1, byte[] byteArray2){
byte[] bytes = new byte[byteArray1.length + byteArray2.length];
System.arraycopy(byteArray1, 0, bytes, 0, byteArray1.length);
System.arraycopy(byteArray2, 0, bytes, byteArray1.length, byteArray2.length);
return bytes;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment