There's two reasonable strategies:
- symmetric crypto + database to speed up search
- relies on nonces being mostly sequential and not too many messages getting lost
- asymetric crypto
Note:both overheads include the 4 byte deviceID. That doesn't have to be included in the plaintext.
Symmetric is very very compelling due to low overhead 24 bytes (16 byte nonce+deviceID hint, 8 bit MAC) (16 total if you can find a fast 64 bit block cipher for the hint). You get the sequence number for free here.
Asymmetric has 44 bytes of overhead (32 byte point + 4 byte deviceID +8 byte MAC) but low code complexity on the base station/backend since there's no database or search code.
Symmetric system
Devices keep a sequential nonce counter which starts at zero.
we derive a k_nonce_hint
key from the main device symmetric key.
Each nonce is used to generate a nonce_hint
byte string. This can be as simple as nonce_hint(nonce)=AES_enc(nonce,k_nonce_hint)[:HINTLEN]
or a hash function.
messages sent over the air look like nonce_hint || AEAD(M,nonce)
where AEAD is some encrypt+MAC scheme.
Note:Device ID and message sequence number are sent implicitly via the nonce hint. The MAC tag provides message integrity and further validates the device ID and nonce.
At the other end
For each (device,nonce)
pair there's an associated nonce hint that can be stored in a database. Assuming not too many messages get lost, you only need to keep some small number 10
-100
for each device in the DB.
The application server after receiving a message queries the DB to find a matching (Device,Nonce) pair(s) and tries decrypting the message with that(or those) device key(s) and nonce(s). If it works, they update the list of nonce hints for that device in the DB.
Note that if the nonce hint is derived reversibly (EG:16 byte nonce length derived using AES encryption) nonce hints can be checked against all device keys as a fallback option. A lightweight 64 bit block cipher would be perfect. I suggest SPECK.
Why This Works
Symmetric crypto and databases are cheap. Adding a DB and structuring appropriately turns an O(k)
problem (try all k
device keys) into an O(1)
problem (database lookup). For sufficiently large values of k
this is worthwhile. If the hint is generated with a reversible keyed permutation, search can be very fast for the fallback case.
Code
from Crypto.Cipher import AES
from os import urandom
import struct
uint32_to_bytes=struct.Struct("!I").pack #pack 4 byte integer
bytes_to_uint=lambda s:sum(256**i * b for i,b in enumerate(bytearray(s[::-1])))
import hashlib
sha256=lambda s:hashlib.sha256(s).digest()
import hmac
sha256_hmac=lambda k,s:hmac.new(k,s,"sha256").digest()
def AES_CTR(m,k,nonce):
return AES.new(key=k,mode=AES.MODE_CTR,nonce=nonce).encrypt(m)
def nhint_make(nonce,k):
return AES.new(key=k,mode=AES.MODE_ECB).encrypt(
b"\0"*12+nonce)
class device():
def __init__(self,dev_id,dev_key,dev_key_nhint):
self.dev_id = dev_id
self.dev_key = dev_key
self.dev_key_nhint = dev_key_nhint
self.next_nonce = 0
def encrypt_message(self,m):
nb=uint32_to_bytes(self.next_nonce)
nh=nhint_make(nb, self.dev_key_nhint)
self.next_nonce+=1
ct=AES_CTR(m, self.dev_key[:16],b"\0"*4+nb)
tag=sha256_hmac(self.dev_key[16:],nh+ct)[:8]
return nh+ct+tag
class backend():
HINT_RUN_LENGTH=10
def __init__(self):
self.master_secret=urandom(32)#used to derive device keys
#FIXME:consider using a real database or explicitly designed data structure
#python dicts and strings are oversized for this
self.hint_db_meta={} #dev_id:(nonce_start,nonce_end)
self.hint_db={} #nhint:dev_ID
#in theory, nhint collisions should be handled by storing both dev_IDs in a list
#but collisions are infrequent and this saves a lot of memory
self.hint_keys_all=bytearray() #dense list of keys used for fallback route
self._next_device_ID=0
def dev_key (self,dev_id):return sha256_hmac(self.master_secret,b"dev_key:"+dev_id)
def dev_key_nhint(self,dev_id):return sha256_hmac(self.master_secret,b"nhint_key:"+dev_id)[:16]
def provision_device(self):
dev_id=uint32_to_bytes(self._next_device_ID)
self._next_device_ID+=1
self.update_hints(dev_id, 0)
self.hint_keys_all+=self.dev_key_nhint(dev_id)
return device(dev_id,
self.dev_key(dev_id),
self.dev_key_nhint(dev_id))
def update_hints(self,dev_id,start,end=None):
if end is None:end=start+self.HINT_RUN_LENGTH
assert end>=start
p_set=set(range(*self.hint_db_meta.get(dev_id,(0,0))))
self.hint_db_meta[dev_id]=(start,end)
n_set=set(range(start,end))
#derive nhint key
key_nhint = self.dev_key_nhint(dev_id)
for i in p_set-n_set:
nb=uint32_to_bytes(i)
nh=nhint_make(nb, key_nhint)
try:del self.hint_db[nh]
except KeyError:pass
for i in n_set-p_set:
nb=uint32_to_bytes(i)
nh=nhint_make(nb, key_nhint)
self.hint_db[nh]=dev_id
def decrypt(self,data):
assert len(data)>=(16+8)#nhint_make,tag
nh,ct,tag=data[:16],data[16:-8],data[-8:]
#look up the nonce hint(s)
try:hints=[self.hint_db[nh]]
except KeyError:
print("warning:failed to look up hint, trying fallback")
hints=self.fallback_nonce_candidates(nh)
for dev_id in hints:
#check the nonce_hint is well formed
block=AES.new(key=self.dev_key_nhint(dev_id),mode=AES.MODE_ECB).decrypt(nh)
if not block.startswith(b"\0"*12):continue
nb=block[12:]
#check the tag
key=self.dev_key(dev_id)
tag_correct=sha256_hmac(key[16:],nh+ct)[:8]
if hmac.compare_digest(tag,tag_correct):
m=AES_CTR(ct, key[:16],b"\0"*4+nb)
nonce=bytes_to_uint(nb)
self.update_hints(dev_id, nonce+1)
return dev_id,nonce,m
raise ValueError("couldn't decrypt the data")
def fallback_nonce_candidates(self,nh):
for i in range(0,len(self.hint_keys_all),16):
key=self.hint_keys_all[i:i+16]
block=AES.new(key=key,mode=AES.MODE_ECB).decrypt(nh)
if block.startswith(b"\0"*12):yield uint32_to_bytes(i//16)
if __name__=="__main__":
import time
def timeinc(a=[time.monotonic()]):then,a[0]=a[0],time.monotonic();return "%.6f"%(a[0]-then)
def printres(res):print("result:{devID:%r seq:%i, message:%r}"%res)
base=backend()
n=20000
print("provisioning %i devices ... "%n,end="")
base.HINT_RUN_LENGTH=0#don't populate hint db yet
lots_of_devs=[base.provision_device() for i in range(n)]
print(timeinc()+ "\nfilling hint DB ... ",end="")
import psutil
process = psutil.Process()
mem1=process.memory_info().rss
del base.HINT_RUN_LENGTH#populate afterwards
for d in lots_of_devs:
base.update_hints(d.dev_id, 0)
mem2=process.memory_info().rss
mem_used=mem2-mem1
print(timeinc()+"\nmemory used:%i (%.2f/device,%.2f/hint)"%(mem_used,mem_used/n,mem_used/n/base.HINT_RUN_LENGTH))
print("check round trip for first device")
dev1=lots_of_devs[0]
m=b"hello world!"
encd=dev1.encrypt_message(m)
#check it decrypts correctly
result=base.decrypt(encd)
printres(result)
#result:{devID:b'\x00\x00\x00\x00' seq:0, message:b'hello world!'}
print("1 message processed in "+timeinc())
assert (result[0],result[2])==(dev1.dev_id,m)
m=b"some telemetry"
dev2=lots_of_devs[n//2]
print("\ncheck round trip for device %i"%(n//2))
for i in range(1000):
encd=dev2.encrypt_message(m)
try:result=base.decrypt(encd)
except Exception:
print(i)
raise
assert result[::2]==(dev2.dev_id,m)
print("1k messages processed in "+timeinc())
print("\ncheck desynchronisation recovery")
for i in range(100):
encd=dev2.encrypt_message(m)
print("100 messages discarded to desynchronise "+str(timeinc()));
result=base.decrypt(encd)
print("decryption done "+str(timeinc()))
printres(result)
#result:{devID:b"\x00\x00'\x10" seq:1099, message:b'some telemetry'}
Performance
I'm getting 500µs for an encrypt/decrypt round trip but this is using an in memory dict as a database with 20k simulated devices * 10 hints. Memory usage is about 120B per hint (*10 hints *20k devices = 24MB). This is ~10x larger than needed.
The optimal data structure for this is a sorted list of nonce_hint || device_ID
strings. These can be stored in a radix trie allowing sorted chunks with common prefixes to lose a byte or two.
A static precomputed data structure can use sparse nonce_hint
values to index into a dense array of device_IDs
requiring only a few bytes per entry at the cost of needing to try a few dozen values per lookup.
When there's no database match for the hint, trying all device keys takes ~10µs per device. That's not going to get much faster since that's a single AES decryption and leading zeroes check on the result.
A block cipher without complex key setup could be a lot faster if vectorised. Maybe run it on a GPU? That plausibly gets you to 1B block cipher ops per second.
Time Based Nonce
If an attacker jams messages from one device, they can cause desynchronisation and subsequently sent nonce hints won't be in the database.
If the devices have a clock inside them, you could do something time based (EG:32-bit day || 32 bit sequence counter
) to ensure resynchronisation or go fully time based with ~1 minute granularity. This gives an authenticated sending time for free. For more accurate timestamps, a few bytes of the message can hold a finer grained time offset allowing messages that use "future" time values to indicate the real sending time.
Some BLE location beacons use this to restrict use to paying customers. Beacons generate time based pseudorandom values every 15s or so and end-user devices submit these values to a commercial API. The service generates every expected value for the current time (±1 min to account for clock drift) and then does a lookup for submitted values. The cost to run such a service for a few million beacons is a fraction of a CPU core + <1KB RAM per device.
Security
Message validity for random attacker supplied garbage requires:
- a well formed nonce hint
- bits of security:
32 - log2(device_count)
- unless you make >4 billion of these things you won't have to cut into your MAC security margin
- a well formed MAC tag
- bits of security:
(whatever you decide on)
- I suggest
64
bits
- total security margin is
MAC_bits + 32 - log2(device_count)
- attacker has
2^-margin
chance of forging a message by guessing randomly
- attacker has
2^-MAC_bits
chance of modifying legitimate message payload to forge a new valid message
Note that in the case devices sending 1M messages over their lifetime, that's 12MB of DB space per device, plausibly a 12TB database for 1M devices which isn't that big allows decrypting all possible messages with just a DB lookup. Optimizations are possible of course.
Asymmetric option
Cheapest secure option that meets your requirements would be unauthenticated ECDH encryption + symmetric MAC using per device keys.
Here's some python code demonstrating this option. It's not very exciting.
from Crypto.Cipher import AES
from nacl import bindings
from os import urandom
import struct
uint32_to_bytes=struct.Struct("!I").pack #pack 4 byte integer
import hashlib
sha256=lambda s:hashlib.sha256(s).digest()
import hmac
sha256_hmac=lambda k,s:hmac.new(k,s,"sha256").digest()
def AES_CTR(m,k,nonce=b"\0"*8):
assert len(nonce)==8
return AES.new(key=k,mode=AES.MODE_CTR,nonce=nonce).encrypt(m)
class device():
def __init__(self,backend_pubkey,device_ID,device_key):
self.backend_pubkey = backend_pubkey
self.device_ID = device_ID
self.device_key = device_key
def encrypt_message(self,m):
pk_eph,sk_eph=bindings.crypto_kx_keypair()
k=bindings.crypto_kx_client_session_keys(pk_eph, sk_eph, self.backend_pubkey)[0]
ct=AES_CTR(self.device_ID+m, k[:16])
tag=sha256_hmac(self.device_key,k+ct)[:8]
return pk_eph+ct+tag
class backend():
def __init__(self):
self.pk_srv,self.sk_srv=bindings.crypto_kx_keypair()
self.master_secret=urandom(32)#used to derive device keys
self._next_device_ID=0
def dev_mac_key(self,dev_id):return sha256_hmac(self.master_secret,b"device_mac_key:"+dev_id)
def provision_device(self):
dev_id=uint32_to_bytes(self._next_device_ID)
self._next_device_ID+=1
return device(self.pk_srv,dev_id,self.dev_mac_key(dev_id))
def decrypt(self,data):
assert len(data)>=(32+4+8)#pk_eph,device_id,tag
pk_eph,ct,tag=data[:32],data[32:-8],data[-8:]
k=bindings.crypto_kx_server_session_keys(self.pk_srv, self.sk_srv, pk_eph)[1]
pt=AES_CTR(ct, k[:16])
dev_id,m=pt[:4],pt[4:]
device_mac_key=self.dev_mac_key(dev_id)
correct_tag=sha256_hmac(device_mac_key,k+ct)[:8]
if not hmac.compare_digest(correct_tag,tag):
raise ValueError("bad decryption:bad MAC")
return dev_id,m
if __name__=="__main__":
base=backend()
dev1=base.provision_device()
dev2=base.provision_device()
m=b"hello world!"
encd=dev1.encrypt_message(m)
#check it decrypts correctly
result=base.decrypt(encd)
assert result==(dev1.device_ID,m)
print("result:{devID:%r, message:%r}"%result)
# (b'\x00\x00\x00\x00', b'hello world!')
m=b"some telemetry"
encd=dev2.encrypt_message(m)
result=base.decrypt(encd)
assert result==(dev2.device_ID,m)
print("result:{devID:%r, message:%r}"%result)
# (b'\x00\x00\x00\x01', b'some telemetry')
```