Skip to content

Add ability to use custom IV in AES/GCM #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion cryptography-core/src/commonMain/kotlin/algorithms/AES.kt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public interface AES<K : AES.Key> : CryptographyAlgorithm {

@SubclassOptInRequired(CryptographyProviderApi::class)
public interface Key : AES.Key {
public fun cipher(tagSize: BinarySize = 128.bits): AuthenticatedCipher
public fun cipher(tagSize: BinarySize = 128.bits): IvAuthenticatedCipher
}
}

Expand Down Expand Up @@ -119,4 +119,41 @@ public interface AES<K : AES.Key> : CryptographyAlgorithm {
public fun decryptBlocking(iv: ByteString, ciphertext: ByteString): ByteString =
decryptBlocking(iv.asByteArray(), ciphertext.asByteArray()).asByteString()
}

@SubclassOptInRequired(CryptographyProviderApi::class)
public interface IvAuthenticatedCipher : IvAuthenticatedEncryptor, IvAuthenticatedDecryptor

@SubclassOptInRequired(CryptographyProviderApi::class)
public interface IvAuthenticatedEncryptor : AuthenticatedEncryptor {
@DelicateCryptographyApi
public suspend fun encrypt(iv: ByteArray, plaintext: ByteArray, associatedData: ByteArray? = null): ByteArray = encryptBlocking(iv, plaintext, associatedData)

@DelicateCryptographyApi
public fun encryptBlocking(iv: ByteArray, plaintext: ByteArray, associatedData: ByteArray? = null): ByteArray

@DelicateCryptographyApi
public suspend fun encrypt(iv: ByteString, plaintext: ByteString, associatedData: ByteString? = null): ByteString =
encrypt(iv.asByteArray(), plaintext.asByteArray(), associatedData?.toByteArray()).asByteString()

@DelicateCryptographyApi
public fun encryptBlocking(iv: ByteString, plaintext: ByteString, associatedData: ByteString? = null): ByteString =
encryptBlocking(iv.asByteArray(), plaintext.asByteArray(), associatedData?.toByteArray()).asByteString()
}

@SubclassOptInRequired(CryptographyProviderApi::class)
public interface IvAuthenticatedDecryptor : AuthenticatedDecryptor {
@DelicateCryptographyApi
public suspend fun decrypt(iv: ByteArray, ciphertext: ByteArray, associatedData: ByteArray? = null): ByteArray = decryptBlocking(iv, ciphertext, associatedData)

@DelicateCryptographyApi
public fun decryptBlocking(iv: ByteArray, ciphertext: ByteArray, associatedData: ByteArray? = null): ByteArray

@DelicateCryptographyApi
public suspend fun decrypt(iv: ByteString, ciphertext: ByteString, associatedData: ByteString? = null): ByteString =
decrypt(iv.asByteArray(), ciphertext.asByteArray(), associatedData?.toByteArray()).asByteString()

@DelicateCryptographyApi
public fun decryptBlocking(iv: ByteString, ciphertext: ByteString, associatedData: ByteString? = null): ByteString =
decryptBlocking(iv.asByteArray(), ciphertext.asByteArray(), associatedData?.toByteArray()).asByteString()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ abstract class AesGcmCompatibilityTest(provider: CryptographyProvider) :
AesBasedCompatibilityTest<AES.GCM.Key, AES.GCM>(AES.GCM, provider) {

@Serializable
private data class CipherParameters(val tagSizeBits: Int) : TestParameters
private data class CipherParameters(
val tagSizeBits: Int,
val iv: ByteStringAsString?,
) : TestParameters {
override fun toString(): String = "CipherParameters(tagSizeBits=${tagSizeBits}, iv.size=${iv?.size})"
}

override suspend fun CompatibilityTestScope<AES.GCM>.generate(isStressTest: Boolean) {
val associatedDataIterations = when {
Expand All @@ -31,17 +36,28 @@ abstract class AesGcmCompatibilityTest(provider: CryptographyProvider) :
isStressTest -> 10
else -> 5
}
val ivIterations = when {
isStressTest -> 10
else -> 5
}

val tagSizes = listOf(96, 128)

val tagSizes = listOf(96, 128).map { tagSizeBits ->
val id = api.ciphers.saveParameters(CipherParameters(tagSizeBits))
id to tagSizeBits.bits
val parametersList = buildList {
tagSizes.forEach { tagSize ->
// size of IV = 12
(List(ivIterations) { ByteString(CryptographyRandom.nextBytes(12)) } + listOf(null)).forEach { iv ->
val parameters = CipherParameters(tagSize, iv)
val id = api.ciphers.saveParameters(parameters)
add(id to parameters)
}
}
}

generateKeys(isStressTest) { key, keyReference, _ ->
tagSizes.forEach { (cipherParametersId, tagSize) ->
logger.log { "tagSize = $tagSize" }
val cipher = key.cipher(tagSize)
parametersList.forEach { (cipherParametersId, parameters) ->
logger.log { "parameters = $parameters" }
val cipher = key.cipher(parameters.tagSizeBits.bits)
repeat(associatedDataIterations) { adIndex ->
val associatedDataSize = if (adIndex == 0) null else CryptographyRandom.nextInt(maxAssociatedDataSize)
logger.log { "associatedData.size = $associatedDataSize" }
Expand All @@ -50,10 +66,21 @@ abstract class AesGcmCompatibilityTest(provider: CryptographyProvider) :
val plaintextSize = CryptographyRandom.nextInt(maxPlaintextSize)
logger.log { "plaintext.size = $plaintextSize" }
val plaintext = ByteString(CryptographyRandom.nextBytes(plaintextSize))
val ciphertext = cipher.encrypt(plaintext, associatedData)
logger.log { "ciphertext.size = ${ciphertext.size}" }

assertContentEquals(plaintext, cipher.decrypt(ciphertext, associatedData), "Initial Decrypt")
val ciphertext = when (val iv = parameters.iv) {
null -> {
val ciphertext = cipher.encrypt(plaintext, associatedData)
logger.log { "ciphertext.size = ${ciphertext.size}" }
assertContentEquals(plaintext, cipher.decrypt(ciphertext, associatedData), "Initial Decrypt")
ciphertext
}
else -> {
val ciphertext = cipher.encrypt(iv, plaintext, associatedData)
logger.log { "ciphertext.size = ${ciphertext.size}" }
assertContentEquals(plaintext, cipher.decrypt(iv, ciphertext, associatedData), "Initial Decrypt")
ciphertext
}
}

api.ciphers.saveData(
cipherParametersId,
Expand All @@ -68,16 +95,29 @@ abstract class AesGcmCompatibilityTest(provider: CryptographyProvider) :
override suspend fun CompatibilityTestScope<AES.GCM>.validate() {
val keys = validateKeys()

api.ciphers.getParameters<CipherParameters> { (tagSize), parametersId, _ ->
api.ciphers.getParameters<CipherParameters> { (tagSize, iv), parametersId, _ ->
api.ciphers.getData<AuthenticatedCipherData>(parametersId) { (keyReference, associatedData, plaintext, ciphertext), _, _ ->
keys[keyReference]?.forEach { key ->
val cipher = key.cipher(tagSize.bits)
assertContentEquals(plaintext, cipher.decrypt(ciphertext, associatedData), "Decrypt")
assertContentEquals(
plaintext,
cipher.decrypt(cipher.encrypt(plaintext, associatedData), associatedData),
"Encrypt-Decrypt"
)

when (iv) {
null -> {
assertContentEquals(plaintext, cipher.decrypt(ciphertext, associatedData), "Decrypt")
assertContentEquals(
plaintext,
cipher.decrypt(cipher.encrypt(plaintext, associatedData), associatedData),
"Encrypt-Decrypt"
)
}
else -> {
assertContentEquals(plaintext, cipher.decrypt(iv, ciphertext, associatedData), "Decrypt")
assertContentEquals(
plaintext,
cipher.decrypt(iv, cipher.encrypt(iv, plaintext, associatedData), associatedData),
"Encrypt-Decrypt"
)
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,55 @@ internal class JdkAesGcm(
) : AES.GCM {
private val keyWrapper: (JSecretKey) -> AES.GCM.Key = { key ->
object : AES.GCM.Key, JdkEncodableKey<AES.Key.Format>(key) {
override fun cipher(tagSize: BinarySize): AuthenticatedCipher = AesGcmCipher(state, key, tagSize)
override fun cipher(tagSize: BinarySize): AES.IvAuthenticatedCipher = AesGcmCipher(state, key, tagSize)

override fun encodeToByteArrayBlocking(format: AES.Key.Format): ByteArray = when (format) {
AES.Key.Format.JWK -> error("$format is not supported")
AES.Key.Format.RAW -> encodeToRaw()
}
}
}

private val keyDecoder = JdkSecretKeyDecoder<AES.Key.Format, _>("AES", keyWrapper)

override fun keyDecoder(): KeyDecoder<AES.Key.Format, AES.GCM.Key> = keyDecoder

override fun keyGenerator(keySize: BinarySize): KeyGenerator<AES.GCM.Key> = JdkSecretKeyGenerator(state, "AES", keyWrapper) {
init(keySize.inBits, state.secureRandom)
}
}

private const val ivSizeBytes = 12 //bytes for GCM
private const val ivSizeBytes = 12 // bytes for GCM

private class AesGcmCipher(
private val state: JdkCryptographyState,
private val key: JSecretKey,
private val tagSize: BinarySize,
) : AuthenticatedCipher {
) : AES.IvAuthenticatedCipher {
private val cipher = state.cipher("AES/GCM/NoPadding")

override fun encryptBlocking(plaintext: ByteArray, associatedData: ByteArray?): ByteArray = cipher.use { cipher ->
override fun encryptBlocking(plaintext: ByteArray, associatedData: ByteArray?): ByteArray {
val iv = ByteArray(ivSizeBytes).also(state.secureRandom::nextBytes)
return iv + encryptBlocking(iv, plaintext, associatedData)
}

@DelicateCryptographyApi
override fun encryptBlocking(iv: ByteArray, plaintext: ByteArray, associatedData: ByteArray?): ByteArray = cipher.use { cipher ->
cipher.init(JCipher.ENCRYPT_MODE, key, GCMParameterSpec(tagSize.inBits, iv), state.secureRandom)
associatedData?.let(cipher::updateAAD)
iv + cipher.doFinal(plaintext)
cipher.doFinal(plaintext)
}

override fun decryptBlocking(ciphertext: ByteArray, associatedData: ByteArray?): ByteArray = cipher.use { cipher ->
cipher.init(JCipher.DECRYPT_MODE, key, GCMParameterSpec(tagSize.inBits, ciphertext, 0, ivSizeBytes), state.secureRandom)
associatedData?.let(cipher::updateAAD)
cipher.doFinal(ciphertext, ivSizeBytes, ciphertext.size - ivSizeBytes)
}

@DelicateCryptographyApi
override fun decryptBlocking(iv: ByteArray, ciphertext: ByteArray, associatedData: ByteArray?): ByteArray = cipher.use { cipher ->
cipher.init(JCipher.DECRYPT_MODE, key, GCMParameterSpec(tagSize.inBits, iv), state.secureRandom)
associatedData?.let(cipher::updateAAD)
cipher.doFinal(ciphertext)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package dev.whyoleg.cryptography.providers.openssl3.algorithms

import dev.whyoleg.cryptography.*
import dev.whyoleg.cryptography.algorithms.*
import dev.whyoleg.cryptography.operations.*
import dev.whyoleg.cryptography.providers.openssl3.internal.*
import dev.whyoleg.cryptography.providers.openssl3.internal.cinterop.*
import dev.whyoleg.cryptography.random.*
Expand All @@ -25,28 +24,28 @@ internal object Openssl3AesGcm : AES.GCM, Openssl3Aes<AES.GCM.Key>() {
else -> error("Unsupported key size")
}

override fun cipher(tagSize: BinarySize): AuthenticatedCipher = AesGcmCipher(algorithm, key, tagSize)
override fun cipher(tagSize: BinarySize): AES.IvAuthenticatedCipher = AesGcmCipher(algorithm, key, tagSize)
}
}

private const val ivSizeBytes = 12 //bytes for CBC
private const val ivSizeBytes = 12 //bytes for GCM

private class AesGcmCipher(
algorithm: String,
private val key: ByteArray,
private val tagSize: BinarySize,
) : AuthenticatedCipher {
) : AES.IvAuthenticatedCipher {

private val cipher = EVP_CIPHER_fetch(null, algorithm, null)

@OptIn(ExperimentalNativeApi::class)
private val cleaner = createCleaner(cipher, ::EVP_CIPHER_free)

override fun encryptBlocking(plaintext: ByteArray, associatedData: ByteArray?): ByteArray = memScoped {
@DelicateCryptographyApi
override fun encryptBlocking(iv: ByteArray, plaintext: ByteArray, associatedData: ByteArray?): ByteArray = memScoped {
require(iv.size == ivSizeBytes) { "IV size is wrong" }
val context = EVP_CIPHER_CTX_new()
try {
val iv = ByteArray(ivSizeBytes).also { CryptographyRandom.nextBytes(it) }

checkError(
EVP_EncryptInit_ex2(
ctx = context,
Expand Down Expand Up @@ -101,22 +100,50 @@ private class AesGcmCipher(
)
)
val produced = producedWithFinal + tagSize.inBytes
iv + ciphertextOutput.ensureSizeExactly(produced)
ciphertextOutput.ensureSizeExactly(produced)
} finally {
EVP_CIPHER_CTX_free(context)
}
}

override fun decryptBlocking(ciphertext: ByteArray, associatedData: ByteArray?): ByteArray = memScoped {
override fun encryptBlocking(plaintext: ByteArray, associatedData: ByteArray?): ByteArray {
val iv = ByteArray(ivSizeBytes).also { CryptographyRandom.nextBytes(it) }
return iv + encryptBlocking(iv, plaintext, associatedData)
}

override fun decryptBlocking(ciphertext: ByteArray, associatedData: ByteArray?): ByteArray {
require(ciphertext.size >= ivSizeBytes + tagSize.inBytes) { "Ciphertext is too short" }

return decrypt(
iv = ciphertext,
ciphertext = ciphertext,
ciphertextStartIndex = ivSizeBytes,
associatedData = associatedData,
)
}

@DelicateCryptographyApi
override fun decryptBlocking(iv: ByteArray, ciphertext: ByteArray, associatedData: ByteArray?): ByteArray {
require(iv.size == ivSizeBytes) { "IV size is wrong" }
require(ciphertext.size >= tagSize.inBytes) { "Ciphertext is too short" }

return decrypt(
iv = iv,
ciphertext = ciphertext,
ciphertextStartIndex = 0,
associatedData = associatedData,
)
}

private fun decrypt(iv: ByteArray, ciphertext: ByteArray, ciphertextStartIndex: Int, associatedData: ByteArray?): ByteArray = memScoped {
val context = EVP_CIPHER_CTX_new()
try {
checkError(
EVP_DecryptInit_ex2(
ctx = context,
cipher = cipher,
key = key.refToU(0),
iv = ciphertext.refToU(0),
iv = iv.refToU(0),
params = null
)
)
Expand All @@ -135,14 +162,14 @@ private class AesGcmCipher(
check(outl.value == ad.size) { "Unexpected output length: got ${outl.value} expected ${ad.size}" }
}

val plaintextOutput = ByteArray(ciphertext.size - ivSizeBytes - tagSize.inBytes)
val plaintextOutput = ByteArray(ciphertext.size - ciphertextStartIndex - tagSize.inBytes)
checkError(
EVP_DecryptUpdate(
ctx = context,
out = plaintextOutput.safeRefToU(0),
outl = outl.ptr,
`in` = ciphertext.refToU(ivSizeBytes),
inl = ciphertext.size - ivSizeBytes - tagSize.inBytes
`in` = ciphertext.refToU(ciphertextStartIndex),
inl = ciphertext.size - ciphertextStartIndex - tagSize.inBytes
)
)
if (plaintextOutput.isEmpty()) check(outl.value == 0) { "Unexpected output length: got ${outl.value} expected 0" }
Expand Down
Loading