Skip to content

Instantly share code, notes, and snippets.

@KONFeature
Created July 12, 2024 23:16
Show Gist options
  • Save KONFeature/a69b789404b3a069e7c4707dbe35e8d8 to your computer and use it in GitHub Desktop.
Save KONFeature/a69b789404b3a069e7c4707dbe35e8d8 to your computer and use it in GitHub Desktop.
Viem KMS signer
import { KMSClient } from "@aws-sdk/client-kms";
import { logger } from "@frak-backend/core";
import {
Hex,
hashMessage,
hashTypedData,
keccak256,
serializeTransaction,
signatureToHex,
} from "viem";
import { toAccount } from "viem/accounts";
import { getKmsAddress } from "./address";
import { getKmsSignature } from "./signature";
/**
* Our KMS client
*/
const kmsClient = new KMSClient({
region: process.env.AWS_REGION,
});
/**
* Build a kms viem account
* details: https://ethereum.stackexchange.com/a/73371/5093
* @param string The key id of the KMS to use
*/
export const getKmsAccount = async ({ keyId }: { keyId: string }) => {
// Get base data & methods for our accounts
const address = await getKmsAddress({ keyId, client: kmsClient });
const signMsg = async (msg: Hex) =>
getKmsSignature({ keyId, address, msg, client: kmsClient });
// Build the account
return toAccount({
address,
async signTransaction(
transaction,
{ serializer = serializeTransaction } = {}
) {
logger.debug(
{ serialisedTransaction: serializer(transaction) },
"Signing a new transaction"
);
const signature = await signMsg(keccak256(serializer(transaction)));
return serializer(transaction, signature);
},
async signTypedData(typedData) {
logger.debug({ typedData }, "Signing a new typed data");
return signatureToHex(await signMsg(hashTypedData(typedData)));
},
async signMessage({ message }) {
logger.debug({ message }, "Signing a new message");
return signatureToHex(await signMsg(hashMessage(message)));
},
});
};
import { GetPublicKeyCommand, KMSClient } from "@aws-sdk/client-kms";
import { BitString, ObjectIdentifier, Sequence, verifySchema } from "asn1js";
import { getAddress, keccak256 } from "viem";
/**
* ASN1 schema to parse the public key
*/
const EcdsaPubKey = new Sequence({
name: "EcdsaPubKey",
value: [
new Sequence({
name: "algo",
value: [
new ObjectIdentifier({ name: "a" }),
new ObjectIdentifier({ name: "b" }),
],
}),
new BitString({ name: "pubKey" }),
],
});
/**
* Get the ethereum address from the given KMS key id
* @param keyId
* @param client
*/
export async function getKmsAddress({
keyId,
client,
}: {
keyId: string;
client: KMSClient;
}) {
const res = await client.send(
new GetPublicKeyCommand({
KeyId: keyId,
})
);
if (!res.PublicKey) {
throw new Error("Missing public key");
}
const publicKey = Buffer.from(res.PublicKey);
// Ensure the signature match our expected format
const decodedPubKey = verifySchema(publicKey, EcdsaPubKey);
if (!decodedPubKey.verified) {
throw new Error("Invalid public key");
}
// Get the raw public key, and then build our address buffer
const rawPubKey = (
decodedPubKey.result.valueBlock as any
).value[1].valueBlock.valueHexView.slice(1);
const addressBuf = Buffer.from(keccak256(rawPubKey).slice(2), "hex");
// Finally, map the address buffer to an ethereum address
return getAddress(`0x${addressBuf.subarray(-20).toString("hex")}`);
}
import { KMSClient, SignCommand } from "@aws-sdk/client-kms";
import { Integer, Sequence, verifySchema } from "asn1js";
import { tryit } from "radash";
import {
Address,
Hex,
Signature,
isAddressEqual,
recoverAddress,
signatureToHex,
toBytes,
toHex,
} from "viem";
/**
* ASN1 schema to parse the signature from KMS
*/
const EcdsaSigAsnParse = new Sequence({
name: "EcdsaSig",
value: [new Integer({ name: "r" }), new Integer({ name: "s" })],
});
/**
* Get a KMS signature for the given message
* @param keyId
* @param address
* @param msg
* @param client
*/
export async function getKmsSignature({
keyId,
address,
msg,
client,
}: {
keyId: string;
address: Address;
msg: Hex;
client: KMSClient;
}) {
const msgHash = Buffer.from(toBytes(msg));
const command = new SignCommand({
KeyId: keyId,
Message: msgHash,
SigningAlgorithm: "ECDSA_SHA_256",
MessageType: "DIGEST",
});
const res = await client.send(command);
if (!res.Signature) {
throw new Error("Missing signature");
}
const signature = Buffer.from(res.Signature);
return extractSignature({ signature, msgHash, address });
}
/**
* Extract the signature from the given buffer
* @param signature
* @param msgHash
* @param address
*/
async function extractSignature({
signature,
msgHash,
address,
}: {
signature: Buffer;
msgHash: Buffer;
address: Address;
}): Promise<Signature> {
const baseSignature = getSigRs(signature);
const { v } = await getSigV({ msgHash, address, baseSignature });
return { ...baseSignature, v };
}
/**
* Extract 'r' and 's' from the signature
* @param signature
*/
function getSigRs(signature: Buffer) {
// Ensurethe signature match our expected format
const decodedSignature = verifySchema(signature, EcdsaSigAsnParse);
if (!decodedSignature.verified) {
throw new Error("Invalid signature");
}
// Extract R & S from the signature
// We are using hex here since the value overflow int
const rawR = toHex(
(decodedSignature.result.valueBlock as any).value[0].valueBlock
.valueHexView
);
const rawS = toHex(
(decodedSignature.result.valueBlock as any).value[1].valueBlock
.valueHexView
);
const r = BigInt(rawR);
let s = BigInt(rawS);
// Reput the signature on a valid secp256 formt
const secp256k1N = BigInt(
"0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"
);
const secp256k1halfN = secp256k1N / 2n;
if (s > secp256k1halfN) {
s = secp256k1N - s;
}
return { r: toHex(r), s: toHex(s) };
}
/**
* Find the right 'v' value for the signature to match the given 'address'
* @param msgHash
* @param address
* @param baseSignature
*/
async function getSigV({
msgHash,
address,
baseSignature,
}: {
msgHash: Buffer;
address: Address;
baseSignature: { r: Hex; s: Hex };
}) {
let v = 27n;
let signature = signatureToHex({
r: baseSignature.r,
s: baseSignature.s,
v,
});
let [, recovered] = await tryit(recoverAddress)({
hash: msgHash,
signature,
});
if (!recovered || !isAddressEqual(recovered, address)) {
v = 28n;
signature = signatureToHex({
r: baseSignature.r,
s: baseSignature.s,
v,
});
[, recovered] = await tryit(recoverAddress)({
hash: msgHash,
signature,
});
}
if (!recovered || !isAddressEqual(recovered, address)) {
throw new Error(
"signature is invalid. recovered address does not match"
);
}
return { v };
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment