-- hkt.hs: hOpenPGP key tool
-- Copyright © 2013-2014  Clint Adams
--
-- vim: softtabstop=4:shiftwidth=4:expandtab
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU Affero General Public License for more details.
--
-- You should have received a copy of the GNU Affero General Public License
-- along with this program.  If not, see <http://www.gnu.org/licenses/>.

import HOpenPGP.Tools.Common (banner, versioner, warranty, keyMatchesFingerprint, keyMatchesEightOctetKeyId, keyMatchesUIDSubString)
import HOpenPGP.Tools.ExpressionParsing (pPE)
import Codec.Encryption.OpenPGP.Fingerprint (fingerprint, eightOctetKeyID)
import Codec.Encryption.OpenPGP.KeyInfo (pubkeySize, pkalgoAbbrev)
import Codec.Encryption.OpenPGP.KeySelection (parseEightOctetKeyId, parseFingerprint)
import Codec.Encryption.OpenPGP.Serialize ()
import Codec.Encryption.OpenPGP.Signatures (verifyTKWith, verifySigWith, verifyAgainstKeyring)
import Codec.Encryption.OpenPGP.Types
import Control.Applicative ((<$>),(<*>), optional, (<|>), pure)
import Control.Arrow ((&&&))
import Control.Lens ((^.), _1, _2, (^..))
import Control.Monad.Trans.Resource (runResourceT, MonadResource)
import qualified Control.Monad.Trans.State.Lazy as S
import Control.Monad.Trans.Writer.Lazy (execWriter, tell)
import qualified Data.Attoparsec.Text as A
import qualified Data.ByteString as B
import Data.Conduit (($=),($$), Source)
import qualified Data.Conduit.Binary as CB
import Data.Conduit.Cereal (conduitGet)
import qualified Data.Conduit.List as CL
import Data.Conduit.OpenPGP.Filter (Expr(..), PKPPredicate(..), PKPOp(..), PKPVar(..), PKPValue(..))
import Data.Conduit.OpenPGP.Keyring (conduitToTKsDropping, sinkKeyringMap)
import Data.Data.Lens (biplate)
import Data.Either (rights)
import qualified Data.IxSet as IxSet
import Data.Graph.Inductive.Graph (Graph(mkGraph), emap)
import Data.Graph.Inductive.PatriciaTree (Gr)
import Data.Graph.Inductive.Query.SP (sp)
import Data.GraphViz (graphToDot, nonClusteredParams)
import Data.GraphViz.Types (printDotGraph)
import Data.HashMap.Lazy (HashMap)
import qualified Data.HashMap.Lazy as HashMap
import Data.List (nub, sort)
import Data.Maybe (fromMaybe, mapMaybe, listToMaybe)
import Data.Monoid ((<>))
import Data.Serialize (get, put, runPut)
import qualified Data.Text as T
import qualified Data.Text.Lazy.IO as TLIO
import Data.Time.Clock.POSIX (getPOSIXTime, posixSecondsToUTCTime)
import Data.Traversable (traverse)
import Data.Tuple (swap)
import System.Directory (getHomeDirectory)

import Options.Applicative.Builder (argument, command, footer, header, help, info, long, metavar, option, prefs, progDesc, showDefault, showHelpOnError, str, strOption, subparser, switch, value)
import Options.Applicative.Extra (customExecParser, helper)
import Options.Applicative.Types (Parser)

import System.IO (Handle, hFlush, hPutStrLn, stderr, hSetBuffering, BufferMode(..))

grabMatchingKeysConduit :: MonadResource m => FilePath -> Bool -> String -> Source m TK
grabMatchingKeysConduit fp filt srch = CB.sourceFile fp $= conduitGet get $= conduitToTKsDropping $= CL.filter (if filt then filterMatch else matchAny)
    where
        matchAny tk = either (const False) id $ fmap (keyMatchesFingerprint True tk) efp <|> fmap (keyMatchesEightOctetKeyId True tk . Right) eeok <|> return (keyMatchesUIDSubString srch tk)
        filterMatch tk = eval pkpEval (either error id (A.parseOnly pPE (T.pack srch))) (tk^.tkKey._1)
        efp = parseFingerprint . T.pack $ srch
        eeok = parseEightOctetKeyId . T.pack $ srch

grabMatchingKeys :: FilePath -> Bool -> String -> IO [TK]
grabMatchingKeys fp filt srch = runResourceT $ grabMatchingKeysConduit fp filt srch $$ CL.consume

grabMatchingKeysKeyring :: FilePath -> Bool -> String -> IO Keyring
grabMatchingKeysKeyring fp filt srch = runResourceT $ grabMatchingKeysConduit fp filt srch $$ sinkKeyringMap

showKey :: TK -> IO ()
showKey key = putStrLn . unlines . execWriter $ do
    tell [ "pub   " ++ either (const "unknown") show (pubkeySize (key^.tkKey._1.pubkey)) ++ pkalgoAbbrev (key^.tkKey._1.pkalgo) ++ "/0x" ++ (either (const "unknown") show . eightOctetKeyID $ key^.tkKey._1 ) ]
    tell $ map (\(x,_) -> "uid                            " ++ x) (key^.tkUIDs)
    tell $ map (\(PublicSubkeyPkt x,_) -> "sub   " ++ either (const "unknown") show (pubkeySize (x^.pubkey)) ++ pkalgoAbbrev (x^.pkalgo) ++ "/0x" ++ (either (const "unknown") show . eightOctetKeyID $ x)) (key^.tkSubs)

data Options = Options {
    keyring :: String
  , graphOutputFormat :: GraphOutputFormat
  , targetIsFilter :: Bool
  , target1 :: String
  , target2 :: String
  , target3 :: String
}

data Command = CmdList Options | CmdExportPubkeys Options | CmdGraph Options | CmdFindPaths Options

data GraphOutputFormat = GraphViz
    deriving (Eq, Read, Show)

listO :: String -> Parser Options
listO homedir = Options
    <$> (fromMaybe (homedir ++ "/.gnupg/pubring.gpg") <$> optional (strOption
        ( long "keyring"
       <> metavar "FILE"
       <> help "file containing keyring" )))
    <*> pure GraphViz -- unused
    <*> switch ( long "filter" <> help "treat target as filter" )
    <*> argument str ( metavar "TARGET" )
    <*> pure ""
    <*> pure ""

graphO :: String -> Parser Options
graphO homedir = Options
    <$> (fromMaybe (homedir ++ "/.gnupg/pubring.gpg") <$> optional (strOption
        ( long "keyring"
       <> metavar "FILE"
       <> help "file containing keyring" )))
    <*> option
        ( long "output-format"
       <> metavar "FORMAT"
       <> value GraphViz
       <> showDefault
       <> help "output format" )
    <*> switch ( long "filter" <> help "treat target as filter" )
    <*> argument str ( metavar "TARGET" )
    <*> pure ""
    <*> pure ""

findPathsO :: String -> Parser Options
findPathsO homedir = Options
    <$> (fromMaybe (homedir ++ "/.gnupg/pubring.gpg") <$> optional (strOption
        ( long "keyring"
       <> metavar "FILE"
       <> help "file containing keyring" )))
    <*> option
        ( long "output-format"
       <> metavar "FORMAT"
       <> value GraphViz
       <> showDefault
       <> help "output format" )
    <*> switch ( long "filter" <> help "treat targets as filter" )
    <*> argument str ( metavar "TARGET-SET" )
    <*> argument str ( metavar "FROM-KEYS" )
    <*> argument str ( metavar "TO-KEYS" )

dispatch :: Command -> IO ()
dispatch (CmdList o) = banner' stderr >> hFlush stderr >> doList o
dispatch (CmdExportPubkeys o) = banner' stderr >> hFlush stderr >> doExportPubkeys o
dispatch (CmdGraph o) = banner' stderr >> hFlush stderr >> doGraph o
dispatch (CmdFindPaths o) = banner' stderr >> hFlush stderr >> doFindPaths o

main :: IO ()
main = do
    hSetBuffering stderr LineBuffering
    homedir <- getHomeDirectory
    customExecParser (prefs showHelpOnError) (info (helper <*> versioner <*> cmd homedir) (header (banner "hkt") <> progDesc "hOpenPGP Keyring Tool" <> footer (warranty "hkt"))) >>= dispatch

cmd :: String -> Parser Command
cmd homedir = subparser
    ( command "list" (info ( CmdList <$> listO homedir) ( progDesc "list matching keys" ))
   <> command "export-pubkeys" (info ( CmdExportPubkeys <$> listO homedir) ( progDesc "export matching keys to stdout" ))
   <> command "graph" (info ( CmdGraph <$> graphO homedir) ( progDesc "graph certifications" ))
   <> command "findpaths" (info ( CmdFindPaths <$> findPathsO homedir) ( progDesc "find short paths between keys" )))

banner' :: Handle -> IO ()
banner' h = hPutStrLn h (banner "hkt" ++ "\n" ++ warranty "hkt")

doList :: Options -> IO ()
doList o = do
    keys <- grabMatchingKeys (keyring o) (targetIsFilter o) (target1 o)
    mapM_ showKey keys

doExportPubkeys :: Options -> IO ()
doExportPubkeys o = do
    keys <- grabMatchingKeys (keyring o) (targetIsFilter o) (target1 o)
    mapM_ (B.putStr . putTK') keys
    where
        putTK' key = runPut $ do
            put (PublicKey (key^.tkKey._1))
            mapM_ (put . Signature) (_tkRevs key)
            mapM_ putUid' (_tkUIDs key)
            mapM_ putUat' (_tkUAts key)
            mapM_ putSub' (_tkSubs key)
        putUid' (u, sps) = put (UserId u) >> mapM_ (put . Signature) sps
        putUat' (us, sps) = put (UserAttribute us) >> mapM_ (put . Signature) sps
        putSub' (p, sps) = put p >> mapM_ (put . Signature) sps

doGraph :: Options -> IO ()
doGraph o = do
    cpt <- getPOSIXTime
    kr <- grabMatchingKeysKeyring (keyring o) (targetIsFilter o) (target1 o)
    TLIO.putStrLn . printDotGraph $ graphToDot nonClusteredParams (buildKeyGraph ((buildMaps &&& id) (rights (map (verifyTKWith (verifySigWith (verifyAgainstKeyring kr)) (Just (posixSecondsToUTCTime cpt))) (IxSet.toList kr)))))

buildMaps :: [TK] -> (KeyMaps, Int)
buildMaps ks = S.execState (mapM_ mapsInsertions ks) (KeyMaps HashMap.empty HashMap.empty HashMap.empty, 0)

-- FIXME: this presumes no keyID collisions in the input
data KeyMaps = KeyMaps {
    _k2f :: HashMap EightOctetKeyId TwentyOctetFingerprint
  , _f2i :: HashMap TwentyOctetFingerprint Int
  , _i2f :: HashMap Int TwentyOctetFingerprint
}

mapsInsertions :: TK -> S.State (KeyMaps, Int) ()
mapsInsertions tk = do
    (KeyMaps k2f f2i i2f, i) <- S.get
    let fp = fingerprint (tk^.tkKey._1)
        keyids = rights . map eightOctetKeyID $ (tk ^.. biplate :: [PKPayload])
        i' = i + 1
        k2f' = foldr (\k m -> HashMap.insert k fp m) k2f keyids
        f2i' = HashMap.insert fp i' f2i
        i2f' = HashMap.insert i' fp i2f
    S.put (KeyMaps k2f' f2i' i2f', i')

buildKeyGraph :: ((KeyMaps, Int), [TK]) -> Gr TwentyOctetFingerprint HashAlgorithm
buildKeyGraph ((KeyMaps k2f f2i _, _), ks) = mkGraph nodes edges
    where
        nodes = map swap . HashMap.toList $ f2i
        edges = filter (not . samesies) . nub . sort . concatMap tkToEdges $ ks
        tkToEdges tk = map (\(ha, i) -> (source i, target tk, ha)) (mapMaybe (fakejoin . (hashAlgo &&& sigissuer)) (sigs tk))
        target tk = fromMaybe (error "Epic fail") (HashMap.lookup (fingerprint (tk^.tkKey._1)) f2i)
        source i = fromMaybe (-1) (HashMap.lookup i k2f >>= flip HashMap.lookup f2i)
        fakejoin (x, y) = fmap ((,) x) y
        sigs tk = concat ((tk^..tkUIDs.traverse._2) ++ (tk^..tkUAts.traverse._2))
        samesies (x,y,_) = x == y

doFindPaths :: Options -> IO ()
doFindPaths o = do
    cpt <- getPOSIXTime
    kr <- grabMatchingKeysKeyring (keyring o) (targetIsFilter o) (target1 o)
    let keys1 = filter (if targetIsFilter o then filterMatch (target2 o) else matchAny (target2 o)) (IxSet.toList kr)
        keys2 = filter (if targetIsFilter o then filterMatch (target3 o) else matchAny (target3 o)) (IxSet.toList kr)
        ((KeyMaps k2f f2i i2f, i), ks) = (buildMaps &&& id) (rights (map (verifyTKWith (verifySigWith (verifyAgainstKeyring kr)) (Just (posixSecondsToUTCTime cpt))) (IxSet.toList kr)))
        keygraph = buildKeyGraph ((KeyMaps k2f f2i i2f, i), ks)
        keysToIs = mapMaybe (\x -> HashMap.lookup (fingerprint (x^.tkKey._1)) f2i)
        froms = keysToIs keys1
        tos = keysToIs keys2
        combos = froms >>= \f -> tos >>= \t -> return (f,t)
        paths = map (\(x,y) -> sp x y (emap (const (1.0 :: Double)) keygraph)) combos
    print paths
    putStrLn . unlines $ map (\x -> maybe (show x) show $ HashMap.lookup x i2f >>= \y -> return (x, y)) (nub (sort (concat paths)))
    where  -- FIXME: deduplicate this
        matchAny srch tk = either (const False) id $ fmap (keyMatchesFingerprint True tk) (efp srch) <|> fmap (keyMatchesEightOctetKeyId True tk . Right) (eeok srch) <|> return (keyMatchesUIDSubString srch tk)
        filterMatch srch tk = eval pkpEval (either error id (A.parseOnly pPE (T.pack srch))) (tk^.tkKey._1)
        efp srch = parseFingerprint . T.pack $ srch
        eeok srch = parseEightOctetKeyId . T.pack $ srch

-- FIXME: deduplicate the following code
eval :: (a -> v -> Bool) -> Expr a -> v -> Bool
eval t e v = ev e
  where
        ev EAny = True
        ev (EAnd e1 e2) = ev e1 && ev e2
        ev (EOr e1 e2) =  ev e1 || ev e2
        ev (ENot e1) = (not . ev) e1
        ev (E e') = t e' v

pkpEval :: PKPPredicate -> PKPayload -> Bool
pkpEval (PKPPredicate lhs o rhs) pkp = uncurry (opreduce o) (vreduce (lhs,pkp),rhs)
    where
        opreduce PKEquals = (==)
        opreduce PKLessThan = (<)
        opreduce PKGreaterThan = (>)
        vreduce (PKPVVersion, p) = PKPInt (kv (_keyVersion p))
        vreduce (PKPVPKA, p) = PKPPKA (_pkalgo p)
        vreduce (PKPVKeysize, p) = PKPInt (either (const 0) id . pubkeySize . _pubkey $ p) -- FIXME: this should be smarter
        vreduce (PKPVTimestamp, p) = PKPInt (fromIntegral (_timestamp p))
        kv DeprecatedV3 = 3
        kv V4 = 4

sigissuer :: SignaturePayload -> Maybe EightOctetKeyId
getIssuer :: SigSubPacketPayload -> Maybe EightOctetKeyId
hashAlgo :: SignaturePayload -> HashAlgorithm

sigissuer (SigVOther 2 _) = Nothing
sigissuer (SigV3 {}) = Nothing
sigissuer (SigV4 _ _ _ ys xs _ _) = listToMaybe . mapMaybe (getIssuer . _sspPayload) $ (ys++xs) -- FIXME: what should this be if there are multiple matches?
sigissuer (SigVOther _ _) = error "We're in the future." -- FIXME

getIssuer (Issuer i) = Just i
getIssuer _ = Nothing

hashAlgo (SigV4 _ _ x _ _ _ _) = x
hashAlgo _ = error "V3 sig not supported here"

