Skip to content

Commit

Permalink
Add Iris example
Browse files Browse the repository at this point in the history
  • Loading branch information
jrp2014 committed Feb 2, 2020
1 parent f7b53b0 commit 6ab97a4
Showing 1 changed file with 73 additions and 52 deletions.
125 changes: 73 additions & 52 deletions examples/main/iris.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}

module Main where

Expand All @@ -15,7 +16,7 @@ import Control.Monad
import Control.Monad.Random

import Data.List ( foldl' )
import Data.Maybe ( fromMaybe , mapMaybe)
import Data.Maybe ( mapMaybe )
#if ! MIN_VERSION_base(4,13,0)
import Data.Semigroup ( (<>) )
#endif
Expand All @@ -32,7 +33,6 @@ import System.FilePath ( (</>) )
import System.Random.Shuffle (shuffleM)
import GHC.Generics (Generic)
import GHC.Float ( float2Double)
import GHC.Int

import Grenade
import Grenade.Utils.OneHot
Expand All @@ -45,70 +45,92 @@ import Grenade.Utils.OneHot
-- This network is used to show how we can embed a Network as a layer in the larger IrisNetwork
-- type.

type IrisNetwork = Network
'[FullyConnected 4 10, Relu, FullyConnected 10 3]
'[ 'D1 4, 'D1 10, 'D1 10, 'D1 3]
type IrisNetwork
= Network
'[FullyConnected 4 10, Relu, FullyConnected 10 3]
'[ 'D1 4, 'D1 10, 'D1 10, 'D1 3]

type IrisRow = (S ('D1 4), S ('D1 3))
type IrisRow = (S ( 'D1 4), S ( 'D1 3))

randomIris :: MonadRandom m => m IrisNetwork
randomIris = randomNetwork

runIris :: Int -> FilePath -> Maybe Int -> LearningParameters -> IO ()
runIris :: Int -> FilePath -> Int -> LearningParameters -> IO ()
runIris iterations dataDir nSamples rate = do
trainRecords <- readIrisFromFile (dataDir </> "iris.data")
validateRecords <- readIrisFromFile (dataDir </> "iris.names")
records <- readIrisFromFile (dataDir </> "iris.data")
let numRecords = V.length records
shuffledRecords <- chooseRandomRecords records numRecords
let (trainRecords, validateRecords) = V.splitAt nSamples shuffledRecords

let trainData = mapMaybe parseRecord (V.toList trainRecords)
let validateData = mapMaybe parseRecord (V.toList validateRecords)

if length trainData /= length trainRecords || length validateData /= length validateRecords
then putStrLn "Parsing train data or validation data could not be fully parsed"
if length trainData
/= length trainRecords
|| length validateData
/= length validateRecords
then putStrLn
"Parsing train data or validation data could not be fully parsed"
else do
initialNetwork <- randomIris
foldM_ (runIteration (maybe trainData (`take` trainData) nSamples) validateData)
initialNetwork [1..iterations]
where
trainEach rate' !network (i, o) = train rate' network i o

runIteration trainRows validateRows net i = do
let trained' = foldl'
(trainEach (rate { learningRate = learningRate rate * 0.9 ^ i }))
net
trainRows
print trained'

putStrLn "Checking..."
let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows
let res' = fmap (\(S1D label, S1D prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
let matched = length $ filter ((==) <$> fst <*> snd) res'
let total = length res'
let matchedpc = fromIntegral matched / fromIntegral total * 100.0 :: Float
putStrLn $ "Iteration " ++ show i ++ ": matched " ++ show matched ++ " of " ++ show total ++ " (" ++ show matchedpc ++ "%)"
return trained'

data IrisOpts = IrisOpts FilePath (Maybe Int) Int LearningParameters
foldM_ (run trainData validateData) initialNetwork [1 .. iterations]
where

run :: [IrisRow] -> [IrisRow] -> IrisNetwork -> Int -> IO IrisNetwork
run trainData validateData network iterationNum = do
sampledData <- V.toList
<$> chooseRandomRecords (V.fromList trainData) (nSamples * 3 `div` 4)
-- Slower drop the learning rate
let rate' = rate { learningRate = learningRate rate * 0.99 ^ iterationNum }
let newNetwork = foldl' (trainRow rate') network sampledData
let labelVectors = fmap (testRow newNetwork) validateData
let labelValues = fmap getLabels labelVectors
let total = length labelValues
let correctEntries = length $ filter ((==) <$> fst <*> snd) labelValues
putStrLn $ "Iteration: " ++ show iterationNum
putStrLn $ show correctEntries ++ " correct out of: " ++ show total
return newNetwork

trainRow :: LearningParameters -> IrisNetwork -> IrisRow -> IrisNetwork
trainRow lp network (input, output) = train lp network input output

-- Takes a test row, returns predicted output and actual output from the network.
testRow :: IrisNetwork -> IrisRow -> (S ( 'D1 3), S ( 'D1 3))
testRow net (rowInput, predictedOutput) =
(predictedOutput, runNet net rowInput)

-- Goes from probability output vector to label
getLabels :: (S ( 'D1 3), S ( 'D1 3)) -> (Int, Int)
getLabels (S1D predictedLabel, S1D actualOutput) =
(maxIndex (SA.extract predictedLabel), maxIndex (SA.extract actualOutput))



data IrisOpts = IrisOpts FilePath Int Int LearningParameters

iris' :: Parser IrisOpts
iris' = IrisOpts <$> argument str (metavar "DATADIR")
-- option to reduce the number of train samples used from 60,000
-- to avoid running out of memory
<*> option (Just <$> auto) (long "limit_samples_to" <> short 'l' <> value Nothing)
<*> option auto (long "iterations" <> short 'i' <> value 15)
<*> (LearningParameters
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
<*> option auto (long "momentum" <> value 0.9)
<*> option auto (long "l2" <> value 0.0005)
)
iris' =
IrisOpts
<$> argument str (metavar "DATADIR")
-- How many samples from the dataset should be used for training?
-- (The rest are used for validation)
<*> option auto (long "training_samples" <> short 't' <> value 100)
<*> option auto (long "iterations" <> short 'i' <> value 15)
<*> ( LearningParameters
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
<*> option auto (long "momentum" <> value 0.9)
<*> option auto (long "l2" <> value 0.0005)
)

main :: IO ()
main = do
IrisOpts dataDir nSamples iter rate <- execParser (info (iris' <**> helper) idm)
putStr "Training convolutional neural network with "
putStr $ maybe "all" show nSamples
putStrLn " samples..."
IrisOpts dataDir nSamples iter rate <- execParser
(info (iris' <**> helper) idm)
putStr "Training convolutional neural network with "
putStr $ show nSamples
putStrLn " samples..."

runIris iter dataDir nSamples rate
runIris iter dataDir nSamples rate

data IrisClass = Setosa | Versicolor | Virginica
deriving (Show, Read, Eq, Ord, Generic, Enum, Bounded)
Expand All @@ -124,10 +146,10 @@ data IrisRecord = IrisRecord {
instance FromRecord IrisRecord

instance FromField IrisClass where
parseField "Iris-setosa" = return Setosa
parseField "Iris-setosa" = return Setosa
parseField "Iris-versicolor" = return Versicolor
parseField "Iris-virginica" = return Virginica
parseField _ = fail "unknown iris class"
parseField "Iris-virginica" = return Virginica
parseField _ = fail "unknown iris class"

parseRecord :: IrisRecord -> Maybe IrisRow
parseRecord record = case (input, output) of
Expand All @@ -142,7 +164,6 @@ parseRecord record = case (input, output) of
, sepalWidth record / 8.0
, petalLength record / 8.0
, petalWidth record / 8.0
, specie record
]
output = oneHot (fromEnum $ specie record)

Expand Down

0 comments on commit 6ab97a4

Please sign in to comment.