Skip to content

Instantly share code, notes, and snippets.

@lynxluna
Created November 24, 2021 02:39
Show Gist options
  • Save lynxluna/be439d2f1f1652f464ec128edea76280 to your computer and use it in GitHub Desktop.
Save lynxluna/be439d2f1f1652f464ec128edea76280 to your computer and use it in GitHub Desktop.
AES-CBC-256 Crypto with Go
package main
import (
"bufio"
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"os"
"golang.org/x/crypto/hkdf"
)
const (
keySizeBit = 256
keySizeByte = 256 / 8
)
// arbitrary 'info'
const (
info = "this key is secret"
)
var (
ErrPasswordEmpty = errors.New("password cannot be empty")
ErrKeyTooShort = errors.New("key too short")
ErrIVTooShort = errors.New("iv too short")
ErrUnpadded = errors.New("unpadded input")
)
func generateKeyByte(password []byte) (salt []byte, key []byte, err error) {
salt = make([]byte, keySizeByte)
key = make([]byte, keySizeByte)
if len(password) == 0 {
return nil, nil, ErrPasswordEmpty
}
if _, err = rand.Read(salt); err != nil {
return
}
key, err = generateKeyWithSalt(salt, password)
if err != nil {
return nil, nil, err
}
return
}
func generateKeyWithSalt(salt, secret []byte) (key []byte, err error) {
key = make([]byte, keySizeByte)
if len(secret) == 0 {
return nil, ErrPasswordEmpty
}
h := hkdf.New(sha256.New, secret, salt, []byte(info))
_, err = h.Read(key)
if err != nil {
return nil, err
}
return
}
func generateKeyStr(password string) (salt []byte, key []byte, err error) {
return generateKeyByte([]byte(password))
}
func PKCS5Padding(ciphertext []byte, blockSize int) []byte {
padding := blockSize - len(ciphertext)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padtext...)
}
func PKCS5Trimming(encrypt []byte) []byte {
padding := encrypt[len(encrypt)-1]
return encrypt[:len(encrypt)-int(padding)]
}
func encrypt(key []byte, plaintext []byte) ([]byte, error) {
if len(plaintext)%aes.BlockSize > 0 {
return nil, ErrUnpadded
}
// ciphertext length = iv + ciphertext
// iv length = 1 block aesBlock
ciphertext := make([]byte, aes.BlockSize+len(plaintext))
iv := ciphertext[:aes.BlockSize]
n, err := rand.Read(iv)
if err != nil {
return nil, err
}
if n < aes.BlockSize {
return nil, ErrIVTooShort
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(ciphertext[aes.BlockSize:], plaintext)
return ciphertext, nil
}
func decrypt(key []byte, ciphertext []byte) ([]byte, error) {
if len(ciphertext)%aes.BlockSize > 0 {
return nil, ErrUnpadded
}
// extract IV nya
iv := ciphertext[:aes.BlockSize]
// extract ciphertext content
ct := ciphertext[aes.BlockSize:]
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
plaintext := make([]byte, len(ciphertext)-aes.BlockSize)
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(plaintext, ct)
return plaintext, nil
}
func main() {
checkbytes := []byte{0xCA, 0xFE, 0xBA, 0xBE}
reader := bufio.NewReader(os.Stdin)
fmt.Print("Enter plaintext: ")
ptb, err := reader.ReadBytes('\n')
if err != nil {
panic(err)
}
fmt.Print("Enter password: ")
rtb, err := reader.ReadBytes('\n')
if err != nil {
panic(err)
}
// generate key from password
// save the salt for that session
// salt can be constant if you want
salt, key, err := generateKeyByte(rtb[:len(rtb)-1])
if err != nil {
panic(err)
}
content := make([]byte, len(checkbytes)+len(ptb)-1)
copy(content, checkbytes)
copy(content[len(checkbytes):], ptb[:len(ptb)-1])
content = PKCS5Padding(content, aes.BlockSize)
ct, err := encrypt(key, content)
if err != nil {
panic(err)
}
fmt.Printf("\n\nEncryption\n-----\nSalt: %s\nKey: %s\nCiphertext: %s\n",
hex.EncodeToString(salt), hex.EncodeToString(key), hex.EncodeToString(ct))
fmt.Print("\nDecryption\n-----\nEnter previous password: ")
rtb2, err := reader.ReadBytes('\n')
// use salt from previous generation
// or just use constant salt
key2, err := generateKeyWithSalt(salt, rtb2[:len(rtb2)-1])
pt, err := decrypt(key2, ct)
if err != nil {
panic(err)
}
if !bytes.Equal(checkbytes, pt[:len(checkbytes)]) {
fmt.Println("invalid key")
return
}
pt = PKCS5Trimming(pt)
fmt.Printf("Plaintext: %s\n", pt)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment