module Vector where

import qualified Prelude
import Feldspar
import Feldspar.Vector
import Feldspar.Matrix

-- Blake crypto

type MessageBlock = DVector Word32 -- 0..15
type Round = Data Index

type State = Matrix Word32 -- 0..3 0..3

co :: DVector Word32
co = vector [0x243F6A88,0x85A308D3,0x13198A2E,0x03707344,
             0xA4093822,0x299F31D0,0x082EFA98,0xEC4E6C89,
	     0x452821E6,0x38D01377,0xBE5466CF,0x34E90C6C,
	     0xC0AC29B7,0xC97C50DD,0x3F84D5B5,0xB5470917]

sigma :: Matrix Index
sigma = matrix
      [[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
      ,[14,10,4,8,9,15,13,6,1,12,0,2,11,7,5,3]
      ,[11,8,12,0,5,2,15,13,10,14,3,6,7,1,9,4]
      ,[7,9,3,1,13,12,11,14,2,6,5,10,4,0,15,8]
      ,[9,0,5,7,2,4,10,15,14,1,11,12,6,8,3,13]
      ,[2,12,6,10,0,11,8,3,4,13,7,5,15,14,1,9]
      ,[12,5,1,15,14,13,4,10,0,7,6,3,9,2,8,11]
      ,[13,11,7,14,12,1,3,9,5,0,15,4,8,6,2,10]
      ,[6,15,14,9,11,3,0,8,12,2,13,7,1,4,10,5]
      ,[10,2,8,4,7,6,1,5,15,11,9,14,3,12,13,0]
      ]

blakeRound :: MessageBlock -> State -> Round -> State
blakeRound m state r = 
  invDiagonals $
  zipWith (g m r) (4 ... 7) $
  diagonals $
  transpose $
  zipWith (g m r) (0 ... 3) $
  transpose $
  state

g :: MessageBlock -> Round -> Data Index -> DVector Word32 -> DVector Word32
g m r i v = fromList [a'',b'',c'',d'']
  where [a,b,c,d] = toList 4 v
        a'  = a + b + (m!(sigma!r!(2*i)) ⊕ (co!(sigma!r!(2*i+1))))
	d'  = (d ⊕ a') >> 16
	c'  = c + d'
	b'  = (b ⊕ c') >> 12
	a'' = a' + b' + (m!(sigma!r!(2*i+1)) ⊕ (co!(sigma!r!(2*i))))
	d'' = (d' ⊕ a'') >> 8
	c'' = c' + d''
	b'' = (b' ⊕ c'') >> 7

diagonals :: Type a => Matrix a -> Matrix a
diagonals m = map (diag m) (0 ... (length (head m) - 1))

diag :: Type a => Matrix a -> Data Index -> Vector (Data a)
diag m i = zipWith lookup m (i ... (l + i))
  where l = length m - 1
        lookup v i = v ! (i `mod` length v)

invDiagonals :: Type a => Matrix a -> Matrix a
invDiagonals m = zipWith shiftVectorR (0 ... (length m - 1)) (transpose m)

shiftVectorR :: Syntactic a => Data Index -> Vector a -> Vector a
shiftVectorR i v = reverse $ drop i rev ++ take i rev
  where rev = reverse v

fromList :: Type a => [Data a] -> DVector a
fromList ls = unfreezeVector (loop 1 (parallel (value len) (const (Prelude.head ls))))
  where loop i arr 
            | i Prelude.< len 
                = loop (i+1) (setIx arr (value i) (ls !! (fromIntegral i)))
            | otherwise = arr
        len  = fromIntegral (Prelude.length ls)

toList :: Type a => Index -> Vector (Data a) -> [Data a]
toList n v@(Indexed l ix _) = Prelude.map (v!) $ Prelude.map value [0..n-1]

-- DCT

-- Discrete Cosine Transform type 2 
dct2 :: (DVector Float) -> (DVector Float)
dct2 xn = mat *** xn
    where
      mat = indexedMat (length xn) (length xn) (\k l -> dct2nkl (length xn) k l)
      

-- Helper function defining all the values in the DCT-2n matrix
dct2nkl :: Data Length -> Data DefaultWord -> Data DefaultWord -> Data Float
dct2nkl n k l = cos ( (k' *(2*l' +1)*pi)/(2*n') )
  where 
    n' = i2f n
    k' = i2f k
    l' = i2f l

-- Discrete Cosine Transform type 3 
dct3 :: (DVector Float) -> (DVector Float)
dct3 xn = mat *** xn
    where
      mat = transpose $ indexedMat (length xn) (length xn) (\k l -> dct2nkl (length xn) k l)

-- Discrete Cosine Transform type 4 
dct4 :: (DVector Float) -> (DVector Float)
dct4 xn = mat *** xn
    where
      mat = indexedMat (length xn) (length xn) (\k l -> dct4nkl (length xn) k l)
      

-- Helper function defining all the values in the DCT-4n matrix
dct4nkl :: Data Length -> Data DefaultWord -> Data DefaultWord -> Data Float
dct4nkl n k l = cos ( ((2*k' +1)*(2*l' +1)*pi)/(4*n') )
  where 
    n' = i2f n
    k' = i2f k
    l' = i2f l

-- Low-pass filter

fft = error "No FFT yet"
ifft = fft

lowPassCore :: (Numeric a) => Data Index -> DVector a -> DVector a
lowPassCore k v = take k v ++ replicate (length v - k) 0

lowPass :: Data Index -> DVector Float -> DVector Float
lowPass k = frequencyTrans (lowPassCore k)

frequencyTrans :: (DVector (Complex Float) -> DVector (Complex Float)) 
               -> DVector Float 
               -> DVector Float
frequencyTrans innerFunction v = map realPart $ ifft
                                 $ innerFunction
                                 $ fft $ map (\a -> complex a 0) v
