diff --git a/examples/main/iris.hs b/examples/main/iris.hs index c5dc64b2..d1485030 100644 --- a/examples/main/iris.hs +++ b/examples/main/iris.hs @@ -7,6 +7,7 @@ {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE OverloadedStrings #-} module Main where @@ -15,7 +16,7 @@ import Control.Monad import Control.Monad.Random import Data.List ( foldl' ) -import Data.Maybe ( fromMaybe , mapMaybe) +import Data.Maybe ( mapMaybe ) #if ! MIN_VERSION_base(4,13,0) import Data.Semigroup ( (<>) ) #endif @@ -32,7 +33,6 @@ import System.FilePath ( () ) import System.Random.Shuffle (shuffleM) import GHC.Generics (Generic) import GHC.Float ( float2Double) -import GHC.Int import Grenade import Grenade.Utils.OneHot @@ -45,70 +45,92 @@ import Grenade.Utils.OneHot -- This network is used to show how we can embed a Network as a layer in the larger IrisNetwork -- type. -type IrisNetwork = Network - '[FullyConnected 4 10, Relu, FullyConnected 10 3] - '[ 'D1 4, 'D1 10, 'D1 10, 'D1 3] +type IrisNetwork + = Network + '[FullyConnected 4 10, Relu, FullyConnected 10 3] + '[ 'D1 4, 'D1 10, 'D1 10, 'D1 3] -type IrisRow = (S ('D1 4), S ('D1 3)) +type IrisRow = (S ( 'D1 4), S ( 'D1 3)) randomIris :: MonadRandom m => m IrisNetwork randomIris = randomNetwork -runIris :: Int -> FilePath -> Maybe Int -> LearningParameters -> IO () +runIris :: Int -> FilePath -> Int -> LearningParameters -> IO () runIris iterations dataDir nSamples rate = do - trainRecords <- readIrisFromFile (dataDir "iris.data") - validateRecords <- readIrisFromFile (dataDir "iris.names") + records <- readIrisFromFile (dataDir "iris.data") + let numRecords = V.length records + shuffledRecords <- chooseRandomRecords records numRecords + let (trainRecords, validateRecords) = V.splitAt nSamples shuffledRecords let trainData = mapMaybe parseRecord (V.toList trainRecords) let validateData = mapMaybe parseRecord (V.toList validateRecords) - if length trainData /= length trainRecords || length validateData /= length validateRecords - then putStrLn "Parsing train data or validation data could not be fully parsed" + if length trainData + /= length trainRecords + || length validateData + /= length validateRecords + then putStrLn + "Parsing train data or validation data could not be fully parsed" else do initialNetwork <- randomIris - foldM_ (runIteration (maybe trainData (`take` trainData) nSamples) validateData) - initialNetwork [1..iterations] - where - trainEach rate' !network (i, o) = train rate' network i o - - runIteration trainRows validateRows net i = do - let trained' = foldl' - (trainEach (rate { learningRate = learningRate rate * 0.9 ^ i })) - net - trainRows - print trained' - - putStrLn "Checking..." - let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows - let res' = fmap (\(S1D label, S1D prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res - let matched = length $ filter ((==) <$> fst <*> snd) res' - let total = length res' - let matchedpc = fromIntegral matched / fromIntegral total * 100.0 :: Float - putStrLn $ "Iteration " ++ show i ++ ": matched " ++ show matched ++ " of " ++ show total ++ " (" ++ show matchedpc ++ "%)" - return trained' - -data IrisOpts = IrisOpts FilePath (Maybe Int) Int LearningParameters + foldM_ (run trainData validateData) initialNetwork [1 .. iterations] + where + + run :: [IrisRow] -> [IrisRow] -> IrisNetwork -> Int -> IO IrisNetwork + run trainData validateData network iterationNum = do + sampledData <- V.toList + <$> chooseRandomRecords (V.fromList trainData) (nSamples * 3 `div` 4) + -- Slower drop the learning rate + let rate' = rate { learningRate = learningRate rate * 0.99 ^ iterationNum } + let newNetwork = foldl' (trainRow rate') network sampledData + let labelVectors = fmap (testRow newNetwork) validateData + let labelValues = fmap getLabels labelVectors + let total = length labelValues + let correctEntries = length $ filter ((==) <$> fst <*> snd) labelValues + putStrLn $ "Iteration: " ++ show iterationNum + putStrLn $ show correctEntries ++ " correct out of: " ++ show total + return newNetwork + + trainRow :: LearningParameters -> IrisNetwork -> IrisRow -> IrisNetwork + trainRow lp network (input, output) = train lp network input output + + -- Takes a test row, returns predicted output and actual output from the network. + testRow :: IrisNetwork -> IrisRow -> (S ( 'D1 3), S ( 'D1 3)) + testRow net (rowInput, predictedOutput) = + (predictedOutput, runNet net rowInput) + + -- Goes from probability output vector to label + getLabels :: (S ( 'D1 3), S ( 'D1 3)) -> (Int, Int) + getLabels (S1D predictedLabel, S1D actualOutput) = + (maxIndex (SA.extract predictedLabel), maxIndex (SA.extract actualOutput)) + + + +data IrisOpts = IrisOpts FilePath Int Int LearningParameters iris' :: Parser IrisOpts -iris' = IrisOpts <$> argument str (metavar "DATADIR") - -- option to reduce the number of train samples used from 60,000 - -- to avoid running out of memory - <*> option (Just <$> auto) (long "limit_samples_to" <> short 'l' <> value Nothing) - <*> option auto (long "iterations" <> short 'i' <> value 15) - <*> (LearningParameters - <$> option auto (long "train_rate" <> short 'r' <> value 0.01) - <*> option auto (long "momentum" <> value 0.9) - <*> option auto (long "l2" <> value 0.0005) - ) +iris' = + IrisOpts + <$> argument str (metavar "DATADIR") + -- How many samples from the dataset should be used for training? + -- (The rest are used for validation) + <*> option auto (long "training_samples" <> short 't' <> value 100) + <*> option auto (long "iterations" <> short 'i' <> value 15) + <*> ( LearningParameters + <$> option auto (long "train_rate" <> short 'r' <> value 0.01) + <*> option auto (long "momentum" <> value 0.9) + <*> option auto (long "l2" <> value 0.0005) + ) main :: IO () main = do - IrisOpts dataDir nSamples iter rate <- execParser (info (iris' <**> helper) idm) - putStr "Training convolutional neural network with " - putStr $ maybe "all" show nSamples - putStrLn " samples..." + IrisOpts dataDir nSamples iter rate <- execParser + (info (iris' <**> helper) idm) + putStr "Training convolutional neural network with " + putStr $ show nSamples + putStrLn " samples..." - runIris iter dataDir nSamples rate + runIris iter dataDir nSamples rate data IrisClass = Setosa | Versicolor | Virginica deriving (Show, Read, Eq, Ord, Generic, Enum, Bounded) @@ -124,10 +146,10 @@ data IrisRecord = IrisRecord { instance FromRecord IrisRecord instance FromField IrisClass where - parseField "Iris-setosa" = return Setosa + parseField "Iris-setosa" = return Setosa parseField "Iris-versicolor" = return Versicolor - parseField "Iris-virginica" = return Virginica - parseField _ = fail "unknown iris class" + parseField "Iris-virginica" = return Virginica + parseField _ = fail "unknown iris class" parseRecord :: IrisRecord -> Maybe IrisRow parseRecord record = case (input, output) of @@ -142,7 +164,6 @@ parseRecord record = case (input, output) of , sepalWidth record / 8.0 , petalLength record / 8.0 , petalWidth record / 8.0 - , specie record ] output = oneHot (fromEnum $ specie record)