]> gitweb.ps.run Git - matrix_esp_thesis/blobdiff - ext/olm/src/pk.cpp
add dependencies to repo
[matrix_esp_thesis] / ext / olm / src / pk.cpp
diff --git a/ext/olm/src/pk.cpp b/ext/olm/src/pk.cpp
new file mode 100644 (file)
index 0000000..9217c48
--- /dev/null
@@ -0,0 +1,542 @@
+/* Copyright 2018, 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "olm/pk.h"
+#include "olm/cipher.h"
+#include "olm/crypto.h"
+#include "olm/ratchet.hh"
+#include "olm/error.h"
+#include "olm/memory.hh"
+#include "olm/base64.hh"
+#include "olm/pickle_encoding.h"
+#include "olm/pickle.hh"
+
+static const std::size_t MAC_LENGTH = 8;
+
+const struct _olm_cipher_aes_sha_256 olm_pk_cipher_aes_sha256 =
+    OLM_CIPHER_INIT_AES_SHA_256("");
+const struct _olm_cipher *olm_pk_cipher =
+    OLM_CIPHER_BASE(&olm_pk_cipher_aes_sha256);
+
+extern "C" {
+
+struct OlmPkEncryption {
+    OlmErrorCode last_error;
+    _olm_curve25519_public_key recipient_key;
+};
+
+const char * olm_pk_encryption_last_error(
+    const OlmPkEncryption * encryption
+) {
+    auto error = encryption->last_error;
+    return _olm_error_to_string(error);
+}
+
+OlmErrorCode olm_pk_encryption_last_error_code(
+    const OlmPkEncryption * encryption
+) {
+    return encryption->last_error;
+}
+
+size_t olm_pk_encryption_size(void) {
+    return sizeof(OlmPkEncryption);
+}
+
+OlmPkEncryption *olm_pk_encryption(
+    void * memory
+) {
+    olm::unset(memory, sizeof(OlmPkEncryption));
+    return new(memory) OlmPkEncryption;
+}
+
+size_t olm_clear_pk_encryption(
+    OlmPkEncryption *encryption
+) {
+    /* Clear the memory backing the encryption */
+    olm::unset(encryption, sizeof(OlmPkEncryption));
+    /* Initialise a fresh encryption object in case someone tries to use it */
+    new(encryption) OlmPkEncryption();
+    return sizeof(OlmPkEncryption);
+}
+
+size_t olm_pk_encryption_set_recipient_key (
+    OlmPkEncryption *encryption,
+    void const * key, size_t key_length
+) {
+    if (key_length < olm_pk_key_length()) {
+        encryption->last_error =
+            OlmErrorCode::OLM_INPUT_BUFFER_TOO_SMALL;
+        return std::size_t(-1);
+    }
+
+    olm::decode_base64(
+        (const uint8_t*)key,
+        olm_pk_key_length(),
+        (uint8_t *)encryption->recipient_key.public_key
+    );
+
+    return 0;
+}
+
+size_t olm_pk_ciphertext_length(
+    const OlmPkEncryption *encryption,
+    size_t plaintext_length
+) {
+    return olm::encode_base64_length(
+        _olm_cipher_aes_sha_256_ops.encrypt_ciphertext_length(olm_pk_cipher, plaintext_length)
+    );
+}
+
+size_t olm_pk_mac_length(
+    const OlmPkEncryption *encryption
+) {
+    return olm::encode_base64_length(_olm_cipher_aes_sha_256_ops.mac_length(olm_pk_cipher));
+}
+
+size_t olm_pk_encrypt_random_length(
+    const OlmPkEncryption *encryption
+) {
+    return CURVE25519_KEY_LENGTH;
+}
+
+size_t olm_pk_encrypt(
+    OlmPkEncryption *encryption,
+    void const * plaintext, size_t plaintext_length,
+    void * ciphertext, size_t ciphertext_length,
+    void * mac, size_t mac_length,
+    void * ephemeral_key, size_t ephemeral_key_size,
+    const void * random, size_t random_length
+) {
+    if (ciphertext_length
+            < olm_pk_ciphertext_length(encryption, plaintext_length)
+        || mac_length
+            < _olm_cipher_aes_sha_256_ops.mac_length(olm_pk_cipher)
+        || ephemeral_key_size
+            < olm_pk_key_length()) {
+        encryption->last_error =
+            OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+        return std::size_t(-1);
+    }
+    if (random_length < olm_pk_encrypt_random_length(encryption)) {
+        encryption->last_error =
+            OlmErrorCode::OLM_NOT_ENOUGH_RANDOM;
+        return std::size_t(-1);
+    }
+
+    _olm_curve25519_key_pair ephemeral_keypair;
+    _olm_crypto_curve25519_generate_key((const uint8_t *) random, &ephemeral_keypair);
+    olm::encode_base64(
+        (const uint8_t *)ephemeral_keypair.public_key.public_key,
+        CURVE25519_KEY_LENGTH,
+        (uint8_t *)ephemeral_key
+    );
+
+    olm::SharedKey secret;
+    _olm_crypto_curve25519_shared_secret(&ephemeral_keypair, &encryption->recipient_key, secret);
+    size_t raw_ciphertext_length =
+        _olm_cipher_aes_sha_256_ops.encrypt_ciphertext_length(olm_pk_cipher, plaintext_length);
+    uint8_t *ciphertext_pos = (uint8_t *) ciphertext + ciphertext_length - raw_ciphertext_length;
+    uint8_t raw_mac[MAC_LENGTH];
+    size_t result = _olm_cipher_aes_sha_256_ops.encrypt(
+        olm_pk_cipher,
+        secret, sizeof(secret),
+        (const uint8_t *) plaintext, plaintext_length,
+        (uint8_t *) ciphertext_pos, raw_ciphertext_length,
+        (uint8_t *) raw_mac, MAC_LENGTH
+    );
+    if (result != std::size_t(-1)) {
+        olm::encode_base64(raw_mac, MAC_LENGTH, (uint8_t *)mac);
+        olm::encode_base64(ciphertext_pos, raw_ciphertext_length, (uint8_t *)ciphertext);
+    }
+    return result;
+}
+
+struct OlmPkDecryption {
+    OlmErrorCode last_error;
+    _olm_curve25519_key_pair key_pair;
+};
+
+const char * olm_pk_decryption_last_error(
+    const OlmPkDecryption * decryption
+) {
+    auto error = decryption->last_error;
+    return _olm_error_to_string(error);
+}
+
+OlmErrorCode olm_pk_decryption_last_error_code(
+    const OlmPkDecryption * decryption
+) {
+    return decryption->last_error;
+}
+
+size_t olm_pk_decryption_size(void) {
+    return sizeof(OlmPkDecryption);
+}
+
+OlmPkDecryption *olm_pk_decryption(
+    void * memory
+) {
+    olm::unset(memory, sizeof(OlmPkDecryption));
+    return new(memory) OlmPkDecryption;
+}
+
+size_t olm_clear_pk_decryption(
+    OlmPkDecryption *decryption
+) {
+    /* Clear the memory backing the decryption */
+    olm::unset(decryption, sizeof(OlmPkDecryption));
+    /* Initialise a fresh decryption object in case someone tries to use it */
+    new(decryption) OlmPkDecryption();
+    return sizeof(OlmPkDecryption);
+}
+
+size_t olm_pk_private_key_length(void) {
+    return CURVE25519_KEY_LENGTH;
+}
+
+size_t olm_pk_generate_key_random_length(void) {
+    return olm_pk_private_key_length();
+}
+
+size_t olm_pk_key_length(void) {
+    return olm::encode_base64_length(CURVE25519_KEY_LENGTH);
+}
+
+size_t olm_pk_key_from_private(
+    OlmPkDecryption * decryption,
+    void * pubkey, size_t pubkey_length,
+    const void * privkey, size_t privkey_length
+) {
+    if (pubkey_length < olm_pk_key_length()) {
+        decryption->last_error =
+            OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+        return std::size_t(-1);
+    }
+    if (privkey_length < olm_pk_private_key_length()) {
+        decryption->last_error =
+            OlmErrorCode::OLM_INPUT_BUFFER_TOO_SMALL;
+        return std::size_t(-1);
+    }
+
+    _olm_crypto_curve25519_generate_key((const uint8_t *) privkey, &decryption->key_pair);
+    olm::encode_base64(
+        (const uint8_t *)decryption->key_pair.public_key.public_key,
+        CURVE25519_KEY_LENGTH,
+        (uint8_t *)pubkey
+    );
+    return 0;
+}
+
+size_t olm_pk_generate_key(
+    OlmPkDecryption * decryption,
+    void * pubkey, size_t pubkey_length,
+    const void * privkey, size_t privkey_length
+) {
+    return olm_pk_key_from_private(decryption, pubkey, pubkey_length, privkey, privkey_length);
+}
+
+namespace {
+    static const std::uint32_t PK_DECRYPTION_PICKLE_VERSION = 1;
+
+    static std::size_t pickle_length(
+        OlmPkDecryption const & value
+    ) {
+        std::size_t length = 0;
+        length += olm::pickle_length(PK_DECRYPTION_PICKLE_VERSION);
+        length += olm::pickle_length(value.key_pair);
+        return length;
+    }
+
+
+    static std::uint8_t * pickle(
+        std::uint8_t * pos,
+        OlmPkDecryption const & value
+    ) {
+        pos = olm::pickle(pos, PK_DECRYPTION_PICKLE_VERSION);
+        pos = olm::pickle(pos, value.key_pair);
+        return pos;
+    }
+
+
+    static std::uint8_t const * unpickle(
+        std::uint8_t const * pos, std::uint8_t const * end,
+        OlmPkDecryption & value
+    ) {
+        uint32_t pickle_version;
+        pos = olm::unpickle(pos, end, pickle_version); UNPICKLE_OK(pos);
+
+        switch (pickle_version) {
+        case 1:
+            break;
+
+        default:
+            value.last_error = OlmErrorCode::OLM_UNKNOWN_PICKLE_VERSION;
+            return nullptr;
+        }
+
+        pos = olm::unpickle(pos, end, value.key_pair); UNPICKLE_OK(pos);
+
+        return pos;
+    }
+}
+
+size_t olm_pickle_pk_decryption_length(
+    const OlmPkDecryption * decryption
+) {
+    return _olm_enc_output_length(pickle_length(*decryption));
+}
+
+size_t olm_pickle_pk_decryption(
+    OlmPkDecryption * decryption,
+    void const * key, size_t key_length,
+    void *pickled, size_t pickled_length
+) {
+    OlmPkDecryption & object = *decryption;
+    std::size_t raw_length = pickle_length(object);
+    if (pickled_length < _olm_enc_output_length(raw_length)) {
+        object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+        return std::size_t(-1);
+    }
+    pickle(_olm_enc_output_pos(reinterpret_cast<std::uint8_t *>(pickled), raw_length), object);
+    return _olm_enc_output(
+        reinterpret_cast<std::uint8_t const *>(key), key_length,
+        reinterpret_cast<std::uint8_t *>(pickled), raw_length
+    );
+}
+
+size_t olm_unpickle_pk_decryption(
+    OlmPkDecryption * decryption,
+    void const * key, size_t key_length,
+    void *pickled, size_t pickled_length,
+    void *pubkey, size_t pubkey_length
+) {
+    OlmPkDecryption & object = *decryption;
+    if (pubkey != NULL && pubkey_length < olm_pk_key_length()) {
+        object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+        return std::size_t(-1);
+    }
+    std::uint8_t * const input = reinterpret_cast<std::uint8_t *>(pickled);
+    std::size_t raw_length = _olm_enc_input(
+        reinterpret_cast<std::uint8_t const *>(key), key_length,
+        input, pickled_length, &object.last_error
+    );
+    if (raw_length == std::size_t(-1)) {
+        return std::size_t(-1);
+    }
+
+    std::uint8_t const * pos = input;
+    std::uint8_t const * end = pos + raw_length;
+
+    pos = unpickle(pos, end, object);
+
+    if (!pos) {
+        /* Input was corrupted. */
+        if (object.last_error == OlmErrorCode::OLM_SUCCESS) {
+            object.last_error = OlmErrorCode::OLM_CORRUPTED_PICKLE;
+        }
+        return std::size_t(-1);
+    } else if (pos != end) {
+        /* Input was longer than expected. */
+        object.last_error = OlmErrorCode::OLM_PICKLE_EXTRA_DATA;
+        return std::size_t(-1);
+    }
+
+    if (pubkey != NULL) {
+        olm::encode_base64(
+            (const uint8_t *)object.key_pair.public_key.public_key,
+            CURVE25519_KEY_LENGTH,
+            (uint8_t *)pubkey
+        );
+    }
+
+    return pickled_length;
+}
+
+size_t olm_pk_max_plaintext_length(
+    const OlmPkDecryption * decryption,
+    size_t ciphertext_length
+) {
+    return _olm_cipher_aes_sha_256_ops.decrypt_max_plaintext_length(
+        olm_pk_cipher, olm::decode_base64_length(ciphertext_length)
+    );
+}
+
+size_t olm_pk_decrypt(
+    OlmPkDecryption * decryption,
+    void const * ephemeral_key, size_t ephemeral_key_length,
+    void const * mac, size_t mac_length,
+    void * ciphertext, size_t ciphertext_length,
+    void * plaintext, size_t max_plaintext_length
+) {
+    if (max_plaintext_length
+            < olm_pk_max_plaintext_length(decryption, ciphertext_length)) {
+        decryption->last_error =
+            OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+        return std::size_t(-1);
+    }
+
+    size_t raw_ciphertext_length = olm::decode_base64_length(ciphertext_length);
+
+    if (ephemeral_key_length != olm::encode_base64_length(CURVE25519_KEY_LENGTH)
+        || mac_length != olm::encode_base64_length(MAC_LENGTH)
+        || raw_ciphertext_length == std::size_t(-1)) {
+        decryption->last_error = OlmErrorCode::OLM_INVALID_BASE64;
+        return std::size_t(-1);
+    }
+
+    struct _olm_curve25519_public_key ephemeral;
+    olm::decode_base64(
+        (const uint8_t*)ephemeral_key,
+        olm::encode_base64_length(CURVE25519_KEY_LENGTH),
+        (uint8_t *)ephemeral.public_key
+    );
+
+    olm::SharedKey secret;
+    _olm_crypto_curve25519_shared_secret(&decryption->key_pair, &ephemeral, secret);
+
+    uint8_t raw_mac[MAC_LENGTH];
+    olm::decode_base64(
+        (const uint8_t *)mac,
+        olm::encode_base64_length(MAC_LENGTH),
+        raw_mac
+    );
+
+    olm::decode_base64(
+        (const uint8_t *)ciphertext,
+        ciphertext_length,
+        (uint8_t *)ciphertext
+    );
+
+    size_t result = _olm_cipher_aes_sha_256_ops.decrypt(
+        olm_pk_cipher,
+        secret, sizeof(secret),
+        (uint8_t *) raw_mac, MAC_LENGTH,
+        (const uint8_t *) ciphertext, raw_ciphertext_length,
+        (uint8_t *) plaintext, max_plaintext_length
+    );
+    if (result == std::size_t(-1)) {
+        // we already checked the buffer sizes, so the only error that decrypt
+        // will return is if the MAC is incorrect
+        decryption->last_error =
+            OlmErrorCode::OLM_BAD_MESSAGE_MAC;
+        return std::size_t(-1);
+    } else {
+        return result;
+    }
+}
+
+size_t olm_pk_get_private_key(
+    OlmPkDecryption * decryption,
+    void *private_key, size_t private_key_length
+) {
+    if (private_key_length < olm_pk_private_key_length()) {
+        decryption->last_error =
+            OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+        return std::size_t(-1);
+    }
+    std::memcpy(
+        private_key,
+        decryption->key_pair.private_key.private_key,
+        olm_pk_private_key_length()
+    );
+    return olm_pk_private_key_length();
+}
+
+struct OlmPkSigning {
+    OlmErrorCode last_error;
+    _olm_ed25519_key_pair key_pair;
+};
+
+size_t olm_pk_signing_size(void) {
+    return sizeof(OlmPkSigning);
+}
+
+OlmPkSigning *olm_pk_signing(void * memory) {
+    olm::unset(memory, sizeof(OlmPkSigning));
+    return new(memory) OlmPkSigning;
+}
+
+const char * olm_pk_signing_last_error(const OlmPkSigning * sign) {
+    auto error = sign->last_error;
+    return _olm_error_to_string(error);
+}
+
+OlmErrorCode olm_pk_signing_last_error_code(const OlmPkSigning * sign) {
+    return sign->last_error;
+}
+
+size_t olm_clear_pk_signing(OlmPkSigning *sign) {
+    /* Clear the memory backing the signing */
+    olm::unset(sign, sizeof(OlmPkSigning));
+    /* Initialise a fresh signing object in case someone tries to use it */
+    new(sign) OlmPkSigning();
+    return sizeof(OlmPkSigning);
+}
+
+size_t olm_pk_signing_seed_length(void) {
+    return ED25519_RANDOM_LENGTH;
+}
+
+size_t olm_pk_signing_public_key_length(void) {
+    return olm::encode_base64_length(ED25519_PUBLIC_KEY_LENGTH);
+}
+
+size_t olm_pk_signing_key_from_seed(
+    OlmPkSigning * signing,
+    void * pubkey, size_t pubkey_length,
+    const void * seed, size_t seed_length
+) {
+    if (pubkey_length < olm_pk_signing_public_key_length()) {
+        signing->last_error =
+            OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+        return std::size_t(-1);
+    }
+    if (seed_length < olm_pk_signing_seed_length()) {
+        signing->last_error =
+            OlmErrorCode::OLM_INPUT_BUFFER_TOO_SMALL;
+        return std::size_t(-1);
+    }
+
+    _olm_crypto_ed25519_generate_key((const uint8_t *) seed, &signing->key_pair);
+    olm::encode_base64(
+        (const uint8_t *)signing->key_pair.public_key.public_key,
+        ED25519_PUBLIC_KEY_LENGTH,
+        (uint8_t *)pubkey
+    );
+    return 0;
+}
+
+size_t olm_pk_signature_length(void) {
+    return olm::encode_base64_length(ED25519_SIGNATURE_LENGTH);
+}
+
+size_t olm_pk_sign(
+    OlmPkSigning *signing,
+    uint8_t const * message, size_t message_length,
+    uint8_t * signature, size_t signature_length
+) {
+    if (signature_length < olm_pk_signature_length()) {
+        signing->last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+        return std::size_t(-1);
+    }
+    uint8_t *raw_sig = signature + olm_pk_signature_length() - ED25519_SIGNATURE_LENGTH;
+    _olm_crypto_ed25519_sign(
+        &signing->key_pair,
+        message, message_length, raw_sig
+    );
+    olm::encode_base64(raw_sig, ED25519_SIGNATURE_LENGTH, signature);
+    return olm_pk_signature_length();
+}
+
+}