-- | Elliptic Curve Arithmetic.
--
-- /WARNING:/ These functions are vulnerable to timing attacks.
module Crypto.PubKey.ECC.Prim
    ( pointAdd
    , pointDouble
    , pointMul
    , isPointAtInfinity
    , isPointValid
    ) where

import Data.Maybe
import Crypto.Number.ModArithmetic
import Crypto.Number.F2m
import Crypto.Types.PubKey.ECC

--TODO: Extract helper function for `fromMaybe PointO...`

-- | Elliptic Curve point addition.
--
-- /WARNING:/ Vulnerable to timing attacks.
pointAdd :: Curve -> Point -> Point -> Point
pointAdd :: Curve -> Point -> Point -> Point
pointAdd _ PointO PointO = Point
PointO
pointAdd _ PointO q :: Point
q = Point
q
pointAdd _ p :: Point
p PointO = Point
p
pointAdd c :: Curve
c@(CurveFP (CurvePrime pr :: Integer
pr _)) p :: Point
p@(Point xp :: Integer
xp yp :: Integer
yp) q :: Point
q@(Point xq :: Integer
xq yq :: Integer
yq)
    | Point
p Point -> Point -> Bool
forall a. Eq a => a -> a -> Bool
== Integer -> Integer -> Point
Point Integer
xq (-Integer
yq) = Point
PointO
    | Point
p Point -> Point -> Bool
forall a. Eq a => a -> a -> Bool
== Point
q = Curve -> Point -> Point
pointDouble Curve
c Point
p
    | Bool
otherwise = Point -> Maybe Point -> Point
forall a. a -> Maybe a -> a
fromMaybe Point
PointO (Maybe Point -> Point) -> Maybe Point -> Point
forall a b. (a -> b) -> a -> b
$ do
                      Integer
s <- Integer -> Integer -> Integer -> Maybe Integer
divmod (Integer
yp Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
yq) (Integer
xp Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
xq) Integer
pr
                      let xr :: Integer
xr = (Integer
s Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ (2::Int) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
xp Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
xq) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
pr
                          yr :: Integer
yr = (Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
xp Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
xr) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
yp) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
pr
                      Point -> Maybe Point
forall (m :: * -> *) a. Monad m => a -> m a
return (Point -> Maybe Point) -> Point -> Maybe Point
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Point
Point Integer
xr Integer
yr
pointAdd c :: Curve
c@(CurveF2m (CurveBinary fx :: Integer
fx cc :: CurveCommon
cc)) p :: Point
p@(Point xp :: Integer
xp yp :: Integer
yp) q :: Point
q@(Point xq :: Integer
xq yq :: Integer
yq)
    | Point
p Point -> Point -> Bool
forall a. Eq a => a -> a -> Bool
== Integer -> Integer -> Point
Point Integer
xq (Integer
xq Integer -> Integer -> Integer
`addF2m` Integer
yq) = Point
PointO
    | Point
p Point -> Point -> Bool
forall a. Eq a => a -> a -> Bool
== Point
q = Curve -> Point -> Point
pointDouble Curve
c Point
p
    | Bool
otherwise = Point -> Maybe Point -> Point
forall a. a -> Maybe a -> a
fromMaybe Point
PointO (Maybe Point -> Point) -> Maybe Point -> Point
forall a b. (a -> b) -> a -> b
$ do
                     Integer
s <- Integer -> Integer -> Integer -> Maybe Integer
divF2m Integer
fx (Integer
yp Integer -> Integer -> Integer
`addF2m` Integer
yq) (Integer
xp Integer -> Integer -> Integer
`addF2m` Integer
xq)
                     let xr :: Integer
xr = Integer -> Integer -> Integer -> Integer
mulF2m Integer
fx Integer
s Integer
s Integer -> Integer -> Integer
`addF2m` Integer
s Integer -> Integer -> Integer
`addF2m` Integer
xp Integer -> Integer -> Integer
`addF2m` Integer
xq Integer -> Integer -> Integer
`addF2m` Integer
a
                         yr :: Integer
yr = Integer -> Integer -> Integer -> Integer
mulF2m Integer
fx Integer
s (Integer
xp Integer -> Integer -> Integer
`addF2m` Integer
xr) Integer -> Integer -> Integer
`addF2m` Integer
xr Integer -> Integer -> Integer
`addF2m` Integer
yp
                     Point -> Maybe Point
forall (m :: * -> *) a. Monad m => a -> m a
return (Point -> Maybe Point) -> Point -> Maybe Point
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Point
Point Integer
xr Integer
yr
  where a :: Integer
a = CurveCommon -> Integer
ecc_a CurveCommon
cc

-- | Elliptic Curve point doubling.
--
-- /WARNING:/ Vulnerable to timing attacks.
--
-- This perform the following calculation:
-- > lambda = (3 * xp ^ 2 + a) / 2 yp
-- > xr = lambda ^ 2 - 2 xp
-- > yr = lambda (xp - xr) - yp
--
-- With binary curve:
-- > xp == 0   => P = O
-- > otherwise =>
-- >    s = xp + (yp / xp)
-- >    xr = s ^ 2 + s + a
-- >    yr = xp ^ 2 + (s+1) * xr
--
pointDouble :: Curve -> Point -> Point
pointDouble :: Curve -> Point -> Point
pointDouble _ PointO = Point
PointO
pointDouble (CurveFP (CurvePrime pr :: Integer
pr cc :: CurveCommon
cc)) (Point xp :: Integer
xp yp :: Integer
yp) = Point -> Maybe Point -> Point
forall a. a -> Maybe a -> a
fromMaybe Point
PointO (Maybe Point -> Point) -> Maybe Point -> Point
forall a b. (a -> b) -> a -> b
$ do
    Integer
lambda <- Integer -> Integer -> Integer -> Maybe Integer
divmod (3 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
xp Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ (2::Int) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
a) (2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
yp) Integer
pr
    let xr :: Integer
xr = (Integer
lambda Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ (2::Int) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- 2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
xp) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
pr
        yr :: Integer
yr = (Integer
lambda Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
xp Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
xr) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
yp) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
pr
    Point -> Maybe Point
forall (m :: * -> *) a. Monad m => a -> m a
return (Point -> Maybe Point) -> Point -> Maybe Point
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Point
Point Integer
xr Integer
yr
  where a :: Integer
a = CurveCommon -> Integer
ecc_a CurveCommon
cc
pointDouble (CurveF2m (CurveBinary fx :: Integer
fx cc :: CurveCommon
cc)) (Point xp :: Integer
xp yp :: Integer
yp)
    | Integer
xp Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0   = Point
PointO
    | Bool
otherwise = Point -> Maybe Point -> Point
forall a. a -> Maybe a -> a
fromMaybe Point
PointO (Maybe Point -> Point) -> Maybe Point -> Point
forall a b. (a -> b) -> a -> b
$ do
        Integer
s <- Integer -> Maybe Integer
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> Maybe Integer)
-> (Integer -> Integer) -> Integer -> Maybe Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Integer -> Integer
addF2m Integer
xp (Integer -> Maybe Integer) -> Maybe Integer -> Maybe Integer
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Integer -> Integer -> Integer -> Maybe Integer
divF2m Integer
fx Integer
yp Integer
xp
        let xr :: Integer
xr = Integer -> Integer -> Integer -> Integer
mulF2m Integer
fx Integer
s Integer
s Integer -> Integer -> Integer
`addF2m` Integer
s Integer -> Integer -> Integer
`addF2m` Integer
a
            yr :: Integer
yr = Integer -> Integer -> Integer -> Integer
mulF2m Integer
fx Integer
xp Integer
xp Integer -> Integer -> Integer
`addF2m` Integer -> Integer -> Integer -> Integer
mulF2m Integer
fx Integer
xr (Integer
s Integer -> Integer -> Integer
`addF2m` 1)
        Point -> Maybe Point
forall (m :: * -> *) a. Monad m => a -> m a
return (Point -> Maybe Point) -> Point -> Maybe Point
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Point
Point Integer
xr Integer
yr
  where a :: Integer
a = CurveCommon -> Integer
ecc_a CurveCommon
cc

-- | Elliptic curve point multiplication (double and add algorithm).
--
-- /WARNING:/ Vulnerable to timing attacks.
pointMul :: Curve -> Integer -> Point -> Point
pointMul :: Curve -> Integer -> Point -> Point
pointMul _ _ PointO = Point
PointO
pointMul c :: Curve
c n :: Integer
n p :: Point
p@(Point xp :: Integer
xp yp :: Integer
yp)
    | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<  0 = Curve -> Integer -> Point -> Point
pointMul Curve
c (-Integer
n) (Integer -> Integer -> Point
Point Integer
xp (-Integer
yp))
    | Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0 = Point
PointO
    | Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 1 = Point
p
    | Integer -> Bool
forall a. Integral a => a -> Bool
odd Integer
n = Curve -> Point -> Point -> Point
pointAdd Curve
c Point
p (Curve -> Integer -> Point -> Point
pointMul Curve
c (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- 1) Point
p)
    | Bool
otherwise = Curve -> Integer -> Point -> Point
pointMul Curve
c (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` 2) (Curve -> Point -> Point
pointDouble Curve
c Point
p)

-- | Check if a point is the point at infinity.
isPointAtInfinity :: Point -> Bool
isPointAtInfinity :: Point -> Bool
isPointAtInfinity PointO = Bool
True
isPointAtInfinity _      = Bool
False

-- | check if a point is on specific curve
--
-- This perform three checks:
--
-- * x is not out of range
-- * y is not out of range
-- * the equation @y^2 = x^3 + a*x + b (mod p)@ holds
isPointValid :: Curve -> Point -> Bool
isPointValid :: Curve -> Point -> Bool
isPointValid _                           PointO      = Bool
True
isPointValid (CurveFP (CurvePrime p :: Integer
p cc :: CurveCommon
cc)) (Point x :: Integer
x y :: Integer
y) =
    Integer -> Bool
isValid Integer
x Bool -> Bool -> Bool
&& Integer -> Bool
isValid Integer
y Bool -> Bool -> Bool
&& (Integer
y Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ (2 :: Int)) Integer -> Integer -> Bool
`eqModP` (Integer
x Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ (3 :: Int) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
a Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
b)
  where a :: Integer
a  = CurveCommon -> Integer
ecc_a CurveCommon
cc
        b :: Integer
b  = CurveCommon -> Integer
ecc_b CurveCommon
cc
        eqModP :: Integer -> Integer -> Bool
eqModP z1 :: Integer
z1 z2 :: Integer
z2 = (Integer
z1 Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p) Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== (Integer
z2 Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p)
        isValid :: Integer -> Bool
isValid e :: Integer
e = Integer
e Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= 0 Bool -> Bool -> Bool
&& Integer
e Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
p
isPointValid curve :: Curve
curve@(CurveF2m (CurveBinary fx :: Integer
fx cc :: CurveCommon
cc)) pt :: Point
pt@(Point x :: Integer
x y :: Integer
y) =
    [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [ Integer -> Bool
isValid Integer
x
        , Integer -> Bool
isValid Integer
y
        , ((((Integer
x Integer -> Integer -> Integer
`add` Integer
a) Integer -> Integer -> Integer
`mul` Integer
x Integer -> Integer -> Integer
`add` Integer
y) Integer -> Integer -> Integer
`mul` Integer
x) Integer -> Integer -> Integer
`add` Integer
b Integer -> Integer -> Integer
`add` (Integer -> Integer -> Integer
squareF2m Integer
fx Integer
y)) Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0
        ]
  where a :: Integer
a  = CurveCommon -> Integer
ecc_a CurveCommon
cc
        b :: Integer
b  = CurveCommon -> Integer
ecc_b CurveCommon
cc
        add :: Integer -> Integer -> Integer
add = Integer -> Integer -> Integer
addF2m
        mul :: Integer -> Integer -> Integer
mul = Integer -> Integer -> Integer -> Integer
mulF2m Integer
fx
        isValid :: Integer -> Bool
isValid e :: Integer
e = Integer -> Integer -> Integer
modF2m Integer
fx Integer
e Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
e

-- | div and mod
divmod :: Integer -> Integer -> Integer -> Maybe Integer
divmod :: Integer -> Integer -> Integer -> Maybe Integer
divmod y :: Integer
y x :: Integer
x m :: Integer
m = do
    Integer
i <- Integer -> Integer -> Maybe Integer
inverse (Integer
x Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
m) Integer
m
    Integer -> Maybe Integer
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> Maybe Integer) -> Integer -> Maybe Integer
forall a b. (a -> b) -> a -> b
$ Integer
y Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
i Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
m