-- |
-- Module      : Crypto.PubKey.RSA.PSS
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good
--
module Crypto.PubKey.RSA.PSS
    ( PSSParams(..)
    , defaultPSSParams
    , defaultPSSParamsSHA1
    -- * Sign and verify functions
    , signWithSalt
    , sign
    , signSafer
    , verify
    ) where

import Crypto.Random
import Crypto.Types.PubKey.RSA
import Data.ByteString (ByteString)
import Data.Byteable
import qualified Data.ByteString as B
import Crypto.PubKey.RSA.Prim
import Crypto.PubKey.RSA.Types
import Crypto.PubKey.RSA (generateBlinder)
import Crypto.PubKey.HashDescr
import Crypto.PubKey.MaskGenFunction
import Crypto.Hash
import Data.Bits (xor, shiftR, (.&.))
import Data.Word

-- | Parameters for PSS signature/verification.
data PSSParams = PSSParams { PSSParams -> HashFunction
pssHash         :: HashFunction     -- ^ Hash function to use
                           , PSSParams -> MaskGenAlgorithm
pssMaskGenAlg   :: MaskGenAlgorithm -- ^ Mask Gen algorithm to use
                           , PSSParams -> Int
pssSaltLength   :: Int              -- ^ Length of salt. need to be <= to hLen.
                           , PSSParams -> Word8
pssTrailerField :: Word8            -- ^ Trailer field, usually 0xbc
                           }

-- | Default Params with a specified hash function
defaultPSSParams :: HashFunction -> PSSParams
defaultPSSParams :: HashFunction -> PSSParams
defaultPSSParams hashF :: HashFunction
hashF =
    PSSParams :: HashFunction -> MaskGenAlgorithm -> Int -> Word8 -> PSSParams
PSSParams { pssHash :: HashFunction
pssHash         = HashFunction
hashF
              , pssMaskGenAlg :: MaskGenAlgorithm
pssMaskGenAlg   = MaskGenAlgorithm
mgf1
              , pssSaltLength :: Int
pssSaltLength   = ByteString -> Int
B.length (ByteString -> Int) -> ByteString -> Int
forall a b. (a -> b) -> a -> b
$ HashFunction
hashF ByteString
B.empty
              , pssTrailerField :: Word8
pssTrailerField = 0xbc
              }

-- | Default Params using SHA1 algorithm.
defaultPSSParamsSHA1 :: PSSParams
defaultPSSParamsSHA1 :: PSSParams
defaultPSSParamsSHA1 = HashFunction -> PSSParams
defaultPSSParams (Digest SHA1 -> ByteString
forall a. Byteable a => a -> ByteString
toBytes (Digest SHA1 -> ByteString)
-> (ByteString -> Digest SHA1) -> HashFunction
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Digest SHA1
forall a. HashAlgorithm a => ByteString -> Digest a
hash :: ByteString -> Digest SHA1))

-- | Sign using the PSS parameters and the salt explicitely passed as parameters.
--
-- the function ignore SaltLength from the PSS Parameters
signWithSalt :: ByteString    -- ^ Salt to use
             -> Maybe Blinder -- ^ optional blinder to use
             -> PSSParams     -- ^ PSS Parameters to use
             -> PrivateKey    -- ^ RSA Private Key
             -> ByteString    -- ^ Message to sign
             -> Either Error ByteString
signWithSalt :: ByteString
-> Maybe Blinder
-> PSSParams
-> PrivateKey
-> ByteString
-> Either Error ByteString
signWithSalt salt :: ByteString
salt blinder :: Maybe Blinder
blinder params :: PSSParams
params pk :: PrivateKey
pk m :: ByteString
m
    | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
saltLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 2 = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidParameters
    | Bool
otherwise                 = ByteString -> Either Error ByteString
forall a b. b -> Either a b
Right (ByteString -> Either Error ByteString)
-> ByteString -> Either Error ByteString
forall a b. (a -> b) -> a -> b
$ Maybe Blinder -> PrivateKey -> HashFunction
dp Maybe Blinder
blinder PrivateKey
pk ByteString
em
    where mHash :: ByteString
mHash    = (PSSParams -> HashFunction
pssHash PSSParams
params) ByteString
m
          k :: Int
k        = PrivateKey -> Int
private_size PrivateKey
pk
          dbLen :: Int
dbLen    = Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1
          saltLen :: Int
saltLen  = ByteString -> Int
B.length ByteString
salt
          hashLen :: Int
hashLen  = ByteString -> Int
B.length (HashFunction
hashF ByteString
B.empty)
          hashF :: HashFunction
hashF    = PSSParams -> HashFunction
pssHash PSSParams
params
          pubBits :: Int
pubBits  = PrivateKey -> Int
private_size PrivateKey
pk Int -> Int -> Int
forall a. Num a => a -> a -> a
* 8 -- to change if public_size is converted in bytes

          m' :: ByteString
m'       = [ByteString] -> ByteString
B.concat [Int -> Word8 -> ByteString
B.replicate 8 0,ByteString
mHash,ByteString
salt]
          h :: ByteString
h        = HashFunction
hashF ByteString
m'
          db :: ByteString
db       = [ByteString] -> ByteString
B.concat [Int -> Word8 -> ByteString
B.replicate (Int
dbLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
saltLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1) 0,Word8 -> ByteString
B.singleton 1,ByteString
salt]
          dbmask :: ByteString
dbmask   = (PSSParams -> MaskGenAlgorithm
pssMaskGenAlg PSSParams
params) HashFunction
hashF ByteString
h Int
dbLen
          maskedDB :: ByteString
maskedDB = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> [Word8] -> [Word8]
normalizeToKeySize Int
pubBits ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
db ByteString
dbmask
          em :: ByteString
em       = [ByteString] -> ByteString
B.concat [ByteString
maskedDB, ByteString
h, Word8 -> ByteString
B.singleton (PSSParams -> Word8
pssTrailerField PSSParams
params)]

-- | Sign using the PSS Parameters
sign :: CPRG g
     => g               -- ^ random generator to use to generate the salt
     -> Maybe Blinder   -- ^ optional blinder to use
     -> PSSParams       -- ^ PSS Parameters to use
     -> PrivateKey      -- ^ RSA Private Key
     -> ByteString      -- ^ Message to sign
     -> (Either Error ByteString, g)
sign :: g
-> Maybe Blinder
-> PSSParams
-> PrivateKey
-> ByteString
-> (Either Error ByteString, g)
sign rng :: g
rng blinder :: Maybe Blinder
blinder params :: PSSParams
params pk :: PrivateKey
pk m :: ByteString
m = (ByteString
-> Maybe Blinder
-> PSSParams
-> PrivateKey
-> ByteString
-> Either Error ByteString
signWithSalt ByteString
salt Maybe Blinder
blinder PSSParams
params PrivateKey
pk ByteString
m, g
rng')
    where (salt :: ByteString
salt,rng' :: g
rng') = Int -> g -> (ByteString, g)
forall gen. CPRG gen => Int -> gen -> (ByteString, gen)
cprgGenerate (PSSParams -> Int
pssSaltLength PSSParams
params) g
rng

-- | Sign using the PSS Parameters and an automatically generated blinder.
signSafer :: CPRG g
          => g          -- ^ random generator
          -> PSSParams  -- ^ PSS Parameters to use
          -> PrivateKey -- ^ private key
          -> ByteString -- ^ message to sign
          -> (Either Error ByteString, g)
signSafer :: g
-> PSSParams
-> PrivateKey
-> ByteString
-> (Either Error ByteString, g)
signSafer rng :: g
rng params :: PSSParams
params pk :: PrivateKey
pk m :: ByteString
m = g
-> Maybe Blinder
-> PSSParams
-> PrivateKey
-> ByteString
-> (Either Error ByteString, g)
forall g.
CPRG g =>
g
-> Maybe Blinder
-> PSSParams
-> PrivateKey
-> ByteString
-> (Either Error ByteString, g)
sign g
rng' (Blinder -> Maybe Blinder
forall a. a -> Maybe a
Just Blinder
blinder) PSSParams
params PrivateKey
pk ByteString
m
    where (blinder :: Blinder
blinder, rng' :: g
rng') = g -> Integer -> (Blinder, g)
forall g. CPRG g => g -> Integer -> (Blinder, g)
generateBlinder g
rng (PrivateKey -> Integer
private_n PrivateKey
pk)

-- | Verify a signature using the PSS Parameters
verify :: PSSParams  -- ^ PSS Parameters to use to verify,
                     --   this need to be identical to the parameters when signing
       -> PublicKey  -- ^ RSA Public Key
       -> ByteString -- ^ Message to verify
       -> ByteString -- ^ Signature
       -> Bool
verify :: PSSParams -> PublicKey -> ByteString -> ByteString -> Bool
verify params :: PSSParams
params pk :: PublicKey
pk m :: ByteString
m s :: ByteString
s
    | PublicKey -> Int
public_size PublicKey
pk Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Int
B.length ByteString
s        = Bool
False
    | ByteString -> Word8
B.last ByteString
em Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= PSSParams -> Word8
pssTrailerField PSSParams
params = Bool
False
    | Bool -> Bool
not ((Word8 -> Bool) -> ByteString -> Bool
B.all (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== 0) ByteString
ps0)              = Bool
False
    | ByteString
b1 ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8 -> ByteString
B.singleton 1                 = Bool
False
    | Bool
otherwise                           = ByteString
h ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
h'
        where -- parameters
              hashF :: HashFunction
hashF     = PSSParams -> HashFunction
pssHash PSSParams
params
              hashLen :: Int
hashLen   = ByteString -> Int
B.length (HashFunction
hashF ByteString
B.empty)
              dbLen :: Int
dbLen     = PublicKey -> Int
public_size PublicKey
pk Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1
              pubBits :: Int
pubBits   = PublicKey -> Int
public_size PublicKey
pk Int -> Int -> Int
forall a. Num a => a -> a -> a
* 8 -- to change if public_size is converted in bytes
              -- unmarshall fields
              em :: ByteString
em        = PublicKey -> HashFunction
ep PublicKey
pk ByteString
s
              maskedDB :: ByteString
maskedDB  = Int -> HashFunction
B.take (ByteString -> Int
B.length ByteString
em Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1) ByteString
em
              h :: ByteString
h         = Int -> HashFunction
B.take Int
hashLen HashFunction -> HashFunction
forall a b. (a -> b) -> a -> b
$ Int -> HashFunction
B.drop (ByteString -> Int
B.length ByteString
maskedDB) ByteString
em
              dbmask :: ByteString
dbmask    = (PSSParams -> MaskGenAlgorithm
pssMaskGenAlg PSSParams
params) HashFunction
hashF ByteString
h Int
dbLen
              db :: ByteString
db        = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> [Word8] -> [Word8]
normalizeToKeySize Int
pubBits ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
maskedDB ByteString
dbmask
              (ps0 :: ByteString
ps0,z :: ByteString
z)   = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
B.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== 1) ByteString
db
              (b1 :: ByteString
b1,salt :: ByteString
salt) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt 1 ByteString
z
              mHash :: ByteString
mHash     = HashFunction
hashF ByteString
m
              m' :: ByteString
m'        = [ByteString] -> ByteString
B.concat [Int -> Word8 -> ByteString
B.replicate 8 0,ByteString
mHash,ByteString
salt]
              h' :: ByteString
h'        = HashFunction
hashF ByteString
m'

normalizeToKeySize :: Int -> [Word8] -> [Word8]
normalizeToKeySize :: Int -> [Word8] -> [Word8]
normalizeToKeySize _    []     = [] -- very unlikely
normalizeToKeySize bits :: Int
bits (x :: Word8
x:xs :: [Word8]
xs) = Word8
x Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
mask Word8 -> [Word8] -> [Word8]
forall a. a -> [a] -> [a]
: [Word8]
xs
    where mask :: Word8
mask = if Int
sh Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 0 then 0xff Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shiftR` (8Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
sh) else 0xff
          sh :: Int
sh   = ((Int
bitsInt -> Int -> Int
forall a. Num a => a -> a -> a
-1) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. 0x7)