]> gitweb.ps.run Git - matrix_esp_thesis/blob - ext/olm/src/ratchet.cpp
changes to olm and esp
[matrix_esp_thesis] / ext / olm / src / ratchet.cpp
1 /* Copyright 2015, 2016 OpenMarket Ltd
2  *
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include "olm/ratchet.hh"
16 #include "olm/message.hh"
17 #include "olm/memory.hh"
18 #include "olm/cipher.h"
19 #include "olm/pickle.hh"
20
21 #include <cstring>
22
23 namespace {
24
25 static const std::uint8_t PROTOCOL_VERSION = 3;
26 static const std::uint8_t MESSAGE_KEY_SEED[1] = {0x01};
27 static const std::uint8_t CHAIN_KEY_SEED[1] = {0x02};
28 static const std::size_t MAX_MESSAGE_GAP = 2000;
29
30
31 /**
32  * Advance the root key, creating a new message chain.
33  *
34  * @param root_key            previous root key R(n-1)
35  * @param our_key             our new ratchet key T(n)
36  * @param their_key           their most recent ratchet key T(n-1)
37  * @param info                table of constants for the ratchet function
38  * @param new_root_key[out]   returns the new root key R(n)
39  * @param new_chain_key[out]  returns the first chain key in the new chain
40  *                            C(n,0)
41  */
42 static void create_chain_key(
43     olm::SharedKey const & root_key,
44     _olm_curve25519_key_pair const & our_key,
45     _olm_curve25519_public_key const & their_key,
46     olm::KdfInfo const & info,
47     olm::SharedKey & new_root_key,
48     olm::ChainKey & new_chain_key
49 ) {
50     olm::SharedKey secret;
51     _olm_crypto_curve25519_shared_secret(&our_key, &their_key, secret);
52     std::uint8_t derived_secrets[2 * olm::OLM_SHARED_KEY_LENGTH];
53     _olm_crypto_hkdf_sha256(
54         secret, sizeof(secret),
55         root_key, sizeof(root_key),
56         info.ratchet_info, info.ratchet_info_length,
57         derived_secrets, sizeof(derived_secrets)
58     );
59     std::uint8_t const * pos = derived_secrets;
60     pos = olm::load_array(new_root_key, pos);
61     pos = olm::load_array(new_chain_key.key, pos);
62     new_chain_key.index = 0;
63     olm::unset(derived_secrets);
64     olm::unset(secret);
65 }
66
67
68 static void advance_chain_key(
69     olm::ChainKey const & chain_key,
70     olm::ChainKey & new_chain_key
71 ) {
72     _olm_crypto_hmac_sha256(
73         chain_key.key, sizeof(chain_key.key),
74         CHAIN_KEY_SEED, sizeof(CHAIN_KEY_SEED),
75         new_chain_key.key
76     );
77     new_chain_key.index = chain_key.index + 1;
78 }
79
80
81 static void create_message_keys(
82     olm::ChainKey const & chain_key,
83     olm::KdfInfo const & info,
84     olm::MessageKey & message_key) {
85     _olm_crypto_hmac_sha256(
86         chain_key.key, sizeof(chain_key.key),
87         MESSAGE_KEY_SEED, sizeof(MESSAGE_KEY_SEED),
88         message_key.key
89     );
90     message_key.index = chain_key.index;
91 }
92
93
94 static std::size_t verify_mac_and_decrypt(
95     _olm_cipher const *cipher,
96     olm::MessageKey const & message_key,
97     olm::MessageReader const & reader,
98     std::uint8_t * plaintext, std::size_t max_plaintext_length
99 ) {
100     return cipher->ops->decrypt(
101         cipher,
102         message_key.key, sizeof(message_key.key),
103         reader.input, reader.input_length,
104         reader.ciphertext, reader.ciphertext_length,
105         plaintext, max_plaintext_length
106     );
107 }
108
109
110 static std::size_t verify_mac_and_decrypt_for_existing_chain(
111     olm::Ratchet const & session,
112     olm::ChainKey const & chain,
113     olm::MessageReader const & reader,
114     std::uint8_t * plaintext, std::size_t max_plaintext_length
115 ) {
116     if (reader.counter < chain.index) {
117         return std::size_t(-1);
118     }
119
120     /* Limit the number of hashes we're prepared to compute */
121     if (reader.counter - chain.index > MAX_MESSAGE_GAP) {
122         return std::size_t(-1);
123     }
124
125     olm::ChainKey new_chain = chain;
126
127     while (new_chain.index < reader.counter) {
128         advance_chain_key(new_chain, new_chain);
129     }
130
131     olm::MessageKey message_key;
132     create_message_keys(new_chain, session.kdf_info, message_key);
133
134     std::size_t result = verify_mac_and_decrypt(
135         session.ratchet_cipher, message_key, reader,
136         plaintext, max_plaintext_length
137     );
138
139     olm::unset(new_chain);
140     return result;
141 }
142
143
144 static std::size_t verify_mac_and_decrypt_for_new_chain(
145     olm::Ratchet const & session,
146     olm::MessageReader const & reader,
147     std::uint8_t * plaintext, std::size_t max_plaintext_length
148 ) {
149     olm::SharedKey new_root_key;
150     olm::ReceiverChain new_chain;
151
152     /* They shouldn't move to a new chain until we've sent them a message
153      * acknowledging the last one */
154     if (session.sender_chain.empty()) {
155         return std::size_t(-1);
156     }
157
158     /* Limit the number of hashes we're prepared to compute */
159     if (reader.counter > MAX_MESSAGE_GAP) {
160         return std::size_t(-1);
161     }
162     olm::load_array(new_chain.ratchet_key.public_key, reader.ratchet_key);
163
164     create_chain_key(
165         session.root_key, session.sender_chain[0].ratchet_key,
166         new_chain.ratchet_key, session.kdf_info,
167         new_root_key, new_chain.chain_key
168     );
169     std::size_t result = verify_mac_and_decrypt_for_existing_chain(
170         session, new_chain.chain_key, reader,
171         plaintext, max_plaintext_length
172     );
173     olm::unset(new_root_key);
174     olm::unset(new_chain);
175     return result;
176 }
177
178 } // namespace
179
180
181 olm::Ratchet::Ratchet(
182     olm::KdfInfo const & kdf_info,
183     _olm_cipher const * ratchet_cipher
184 ) : kdf_info(kdf_info),
185     ratchet_cipher(ratchet_cipher),
186     last_error(OlmErrorCode::OLM_SUCCESS) {
187 }
188
189
190 void olm::Ratchet::initialise_as_bob(
191     std::uint8_t const * shared_secret, std::size_t shared_secret_length,
192     _olm_curve25519_public_key const & their_ratchet_key
193 ) {
194     std::uint8_t derived_secrets[2 * olm::OLM_SHARED_KEY_LENGTH];
195     _olm_crypto_hkdf_sha256(
196         shared_secret, shared_secret_length,
197         nullptr, 0,
198         kdf_info.root_info, kdf_info.root_info_length,
199         derived_secrets, sizeof(derived_secrets)
200     );
201     receiver_chains.insert();
202     receiver_chains[0].chain_key.index = 0;
203     std::uint8_t const * pos = derived_secrets;
204     pos = olm::load_array(root_key, pos);
205     pos = olm::load_array(receiver_chains[0].chain_key.key, pos);
206     receiver_chains[0].ratchet_key = their_ratchet_key;
207     olm::unset(derived_secrets);
208 }
209
210
211 void olm::Ratchet::initialise_as_alice(
212     std::uint8_t const * shared_secret, std::size_t shared_secret_length,
213     _olm_curve25519_key_pair const & our_ratchet_key
214 ) {
215     std::uint8_t derived_secrets[2 * olm::OLM_SHARED_KEY_LENGTH];
216     _olm_crypto_hkdf_sha256(
217         shared_secret, shared_secret_length,
218         nullptr, 0,
219         kdf_info.root_info, kdf_info.root_info_length,
220         derived_secrets, sizeof(derived_secrets)
221     );
222     sender_chain.insert();
223     sender_chain[0].chain_key.index = 0;
224     std::uint8_t const * pos = derived_secrets;
225     pos = olm::load_array(root_key, pos);
226     pos = olm::load_array(sender_chain[0].chain_key.key, pos);
227     sender_chain[0].ratchet_key = our_ratchet_key;
228     olm::unset(derived_secrets);
229 }
230
231 namespace olm {
232
233
234 static std::size_t pickle_length(
235     const olm::SharedKey & value
236 ) {
237     return olm::OLM_SHARED_KEY_LENGTH;
238 }
239
240
241 static std::uint8_t * pickle(
242     std::uint8_t * pos,
243     const olm::SharedKey & value
244 ) {
245     return olm::pickle_bytes(pos, value, olm::OLM_SHARED_KEY_LENGTH);
246 }
247
248
249 static std::uint8_t const * unpickle(
250     std::uint8_t const * pos, std::uint8_t const * end,
251     olm::SharedKey & value
252 ) {
253     return olm::unpickle_bytes(pos, end, value, olm::OLM_SHARED_KEY_LENGTH);
254 }
255
256
257 static std::size_t pickle_length(
258     const olm::SenderChain & value
259 ) {
260     std::size_t length = 0;
261     length += olm::pickle_length(value.ratchet_key);
262     length += olm::pickle_length(value.chain_key.key);
263     length += olm::pickle_length(value.chain_key.index);
264     return length;
265 }
266
267
268 static std::uint8_t * pickle(
269     std::uint8_t * pos,
270     const olm::SenderChain & value
271 ) {
272     pos = olm::pickle(pos, value.ratchet_key);
273     pos = olm::pickle(pos, value.chain_key.key);
274     pos = olm::pickle(pos, value.chain_key.index);
275     return pos;
276 }
277
278
279 static std::uint8_t const * unpickle(
280     std::uint8_t const * pos, std::uint8_t const * end,
281     olm::SenderChain & value
282 ) {
283     pos = olm::unpickle(pos, end, value.ratchet_key); UNPICKLE_OK(pos);
284     pos = olm::unpickle(pos, end, value.chain_key.key); UNPICKLE_OK(pos);
285     pos = olm::unpickle(pos, end, value.chain_key.index); UNPICKLE_OK(pos);
286     return pos;
287 }
288
289 static std::size_t pickle_length(
290     const olm::ReceiverChain & value
291 ) {
292     std::size_t length = 0;
293     length += olm::pickle_length(value.ratchet_key);
294     length += olm::pickle_length(value.chain_key.key);
295     length += olm::pickle_length(value.chain_key.index);
296     return length;
297 }
298
299
300 static std::uint8_t * pickle(
301     std::uint8_t * pos,
302     const olm::ReceiverChain & value
303 ) {
304     pos = olm::pickle(pos, value.ratchet_key);
305     pos = olm::pickle(pos, value.chain_key.key);
306     pos = olm::pickle(pos, value.chain_key.index);
307     return pos;
308 }
309
310
311 static std::uint8_t const * unpickle(
312     std::uint8_t const * pos, std::uint8_t const * end,
313     olm::ReceiverChain & value
314 ) {
315     pos = olm::unpickle(pos, end, value.ratchet_key); UNPICKLE_OK(pos);
316     pos = olm::unpickle(pos, end, value.chain_key.key); UNPICKLE_OK(pos);
317     pos = olm::unpickle(pos, end, value.chain_key.index); UNPICKLE_OK(pos);
318     return pos;
319 }
320
321
322 static std::size_t pickle_length(
323     const olm::SkippedMessageKey & value
324 ) {
325     std::size_t length = 0;
326     length += olm::pickle_length(value.ratchet_key);
327     length += olm::pickle_length(value.message_key.key);
328     length += olm::pickle_length(value.message_key.index);
329     return length;
330 }
331
332
333 static std::uint8_t * pickle(
334     std::uint8_t * pos,
335     const olm::SkippedMessageKey & value
336 ) {
337     pos = olm::pickle(pos, value.ratchet_key);
338     pos = olm::pickle(pos, value.message_key.key);
339     pos = olm::pickle(pos, value.message_key.index);
340     return pos;
341 }
342
343
344 static std::uint8_t const * unpickle(
345     std::uint8_t const * pos, std::uint8_t const * end,
346     olm::SkippedMessageKey & value
347 ) {
348     pos = olm::unpickle(pos, end, value.ratchet_key); UNPICKLE_OK(pos);
349     pos = olm::unpickle(pos, end, value.message_key.key); UNPICKLE_OK(pos);
350     pos = olm::unpickle(pos, end, value.message_key.index); UNPICKLE_OK(pos);
351     return pos;
352 }
353
354
355 } // namespace olm
356
357
358 std::size_t olm::pickle_length(
359     olm::Ratchet const & value
360 ) {
361     std::size_t length = 0;
362     length += olm::OLM_SHARED_KEY_LENGTH;
363     length += olm::pickle_length(value.sender_chain);
364     length += olm::pickle_length(value.receiver_chains);
365     length += olm::pickle_length(value.skipped_message_keys);
366     return length;
367 }
368
369 std::uint8_t * olm::pickle(
370     std::uint8_t * pos,
371     olm::Ratchet const & value
372 ) {
373     pos = pickle(pos, value.root_key);
374     pos = pickle(pos, value.sender_chain);
375     pos = pickle(pos, value.receiver_chains);
376     pos = pickle(pos, value.skipped_message_keys);
377     return pos;
378 }
379
380
381 std::uint8_t const * olm::unpickle(
382     std::uint8_t const * pos, std::uint8_t const * end,
383     olm::Ratchet & value,
384     bool includes_chain_index
385 ) {
386     pos = unpickle(pos, end, value.root_key); UNPICKLE_OK(pos);
387     pos = unpickle(pos, end, value.sender_chain); UNPICKLE_OK(pos);
388     pos = unpickle(pos, end, value.receiver_chains); UNPICKLE_OK(pos);
389     pos = unpickle(pos, end, value.skipped_message_keys); UNPICKLE_OK(pos);
390
391     // pickle v 0x80000001 includes a chain index; pickle v1 does not.
392     if (includes_chain_index) {
393         std::uint32_t dummy;
394         pos = unpickle(pos, end, dummy); UNPICKLE_OK(pos);
395     }
396     return pos;
397 }
398
399
400 std::size_t olm::Ratchet::encrypt_output_length(
401     std::size_t plaintext_length
402 ) const {
403     std::size_t counter = 0;
404     if (!sender_chain.empty()) {
405         counter = sender_chain[0].chain_key.index;
406     }
407     std::size_t padded = ratchet_cipher->ops->encrypt_ciphertext_length(
408         ratchet_cipher,
409         plaintext_length
410     );
411     return olm::encode_message_length(
412         counter, CURVE25519_KEY_LENGTH, padded, ratchet_cipher->ops->mac_length(ratchet_cipher)
413     );
414 }
415
416
417 std::size_t olm::Ratchet::encrypt_random_length() const {
418     return sender_chain.empty() ? CURVE25519_RANDOM_LENGTH : 0;
419 }
420
421
422 std::size_t olm::Ratchet::encrypt(
423     std::uint8_t const * plaintext, std::size_t plaintext_length,
424     std::uint8_t const * random, std::size_t random_length,
425     std::uint8_t * output, std::size_t max_output_length
426 ) {
427     std::size_t output_length = encrypt_output_length(plaintext_length);
428
429     if (random_length < encrypt_random_length()) {
430         last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM;
431         return std::size_t(-1);
432     }
433     if (max_output_length < output_length) {
434         last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
435         return std::size_t(-1);
436     }
437
438     if (sender_chain.empty()) {
439         sender_chain.insert();
440         _olm_crypto_curve25519_generate_key(random, &sender_chain[0].ratchet_key);
441         create_chain_key(
442             root_key,
443             sender_chain[0].ratchet_key,
444             receiver_chains[0].ratchet_key,
445             kdf_info,
446             root_key, sender_chain[0].chain_key
447         );
448     }
449
450     MessageKey keys;
451     create_message_keys(sender_chain[0].chain_key, kdf_info, keys);
452     advance_chain_key(sender_chain[0].chain_key, sender_chain[0].chain_key);
453
454     std::size_t ciphertext_length = ratchet_cipher->ops->encrypt_ciphertext_length(
455         ratchet_cipher,
456         plaintext_length
457     );
458     std::uint32_t counter = keys.index;
459     _olm_curve25519_public_key const & ratchet_key =
460         sender_chain[0].ratchet_key.public_key;
461
462     olm::MessageWriter writer;
463
464     olm::encode_message(
465         writer, PROTOCOL_VERSION, counter, CURVE25519_KEY_LENGTH,
466         ciphertext_length,
467         output
468     );
469
470     olm::store_array(writer.ratchet_key, ratchet_key.public_key);
471
472     ratchet_cipher->ops->encrypt(
473         ratchet_cipher,
474         keys.key, sizeof(keys.key),
475         plaintext, plaintext_length,
476         writer.ciphertext, ciphertext_length,
477         output, output_length
478     );
479
480     olm::unset(keys);
481     return output_length;
482 }
483
484
485 std::size_t olm::Ratchet::decrypt_max_plaintext_length(
486     std::uint8_t const * input, std::size_t input_length
487 ) {
488     olm::MessageReader reader;
489     olm::decode_message(
490         reader, input, input_length,
491         ratchet_cipher->ops->mac_length(ratchet_cipher)
492     );
493
494     if (!reader.ciphertext) {
495         last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
496         return std::size_t(-1);
497     }
498
499     return ratchet_cipher->ops->decrypt_max_plaintext_length(
500         ratchet_cipher, reader.ciphertext_length);
501 }
502
503
504 std::size_t olm::Ratchet::decrypt(
505     std::uint8_t const * input, std::size_t input_length,
506     std::uint8_t * plaintext, std::size_t max_plaintext_length
507 ) {
508     olm::MessageReader reader;
509     olm::decode_message(
510         reader, input, input_length,
511         ratchet_cipher->ops->mac_length(ratchet_cipher)
512     );
513
514     if (reader.version != PROTOCOL_VERSION) {
515         last_error = OlmErrorCode::OLM_BAD_MESSAGE_VERSION;
516         return std::size_t(-1);
517     }
518
519     if (!reader.has_counter || !reader.ratchet_key || !reader.ciphertext) {
520         last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
521         return std::size_t(-1);
522     }
523
524     std::size_t max_length = ratchet_cipher->ops->decrypt_max_plaintext_length(
525         ratchet_cipher,
526         reader.ciphertext_length
527     );
528
529     if (max_plaintext_length < max_length) {
530         last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
531         return std::size_t(-1);
532     }
533
534     if (reader.ratchet_key_length != CURVE25519_KEY_LENGTH) {
535         last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
536         return std::size_t(-1);
537     }
538
539     ReceiverChain * chain = nullptr;
540
541     for (olm::ReceiverChain & receiver_chain : receiver_chains) {
542         if (0 == std::memcmp(
543                 receiver_chain.ratchet_key.public_key, reader.ratchet_key,
544                 CURVE25519_KEY_LENGTH
545         )) {
546             chain = &receiver_chain;
547             break;
548         }
549     }
550
551     std::size_t result = std::size_t(-1);
552
553     if (!chain) {
554         result = verify_mac_and_decrypt_for_new_chain(
555             *this, reader, plaintext, max_plaintext_length
556         );
557     } else if (chain->chain_key.index > reader.counter) {
558         /* Chain already advanced beyond the key for this message
559          * Check if the message keys are in the skipped key list. */
560         for (olm::SkippedMessageKey & skipped : skipped_message_keys) {
561             if (reader.counter == skipped.message_key.index
562                     && 0 == std::memcmp(
563                         skipped.ratchet_key.public_key, reader.ratchet_key,
564                         CURVE25519_KEY_LENGTH
565                     )
566             ) {
567                 /* Found the key for this message. Check the MAC. */
568
569                 result = verify_mac_and_decrypt(
570                     ratchet_cipher, skipped.message_key, reader,
571                     plaintext, max_plaintext_length
572                 );
573
574                 if (result != std::size_t(-1)) {
575                     /* Remove the key from the skipped keys now that we've
576                      * decoded the message it corresponds to. */
577                     olm::unset(skipped);
578                     skipped_message_keys.erase(&skipped);
579                     return result;
580                 }
581             }
582         }
583     } else {
584         result = verify_mac_and_decrypt_for_existing_chain(
585             *this, chain->chain_key,
586             reader, plaintext, max_plaintext_length
587         );
588     }
589
590     if (result == std::size_t(-1)) {
591         last_error = OlmErrorCode::OLM_BAD_MESSAGE_MAC;
592         return std::size_t(-1);
593     }
594
595     if (!chain) {
596         /* They have started using a new ephemeral ratchet key.
597          * We need to derive a new set of chain keys.
598          * We can discard our previous ephemeral ratchet key.
599          * We will generate a new key when we send the next message. */
600
601         chain = receiver_chains.insert();
602         olm::load_array(chain->ratchet_key.public_key, reader.ratchet_key);
603
604         // TODO: we've already done this once, in
605         // verify_mac_and_decrypt_for_new_chain(). we could reuse the result.
606         create_chain_key(
607             root_key, sender_chain[0].ratchet_key, chain->ratchet_key,
608             kdf_info, root_key, chain->chain_key
609         );
610
611         olm::unset(sender_chain[0]);
612         sender_chain.erase(sender_chain.begin());
613     }
614
615     while (chain->chain_key.index < reader.counter) {
616         olm::SkippedMessageKey & key = *skipped_message_keys.insert();
617         create_message_keys(chain->chain_key, kdf_info, key.message_key);
618         key.ratchet_key = chain->ratchet_key;
619         advance_chain_key(chain->chain_key, chain->chain_key);
620     }
621
622     advance_chain_key(chain->chain_key, chain->chain_key);
623
624     return result;
625 }