-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
464f4d1
commit 1f7ed45
Showing
8 changed files
with
1,649 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
#I "./packages" | ||
#r @"FSharp.Data.2.2.2\lib\net40\FSharp.Data.dll" | ||
open FSharp.Data | ||
|
||
type Titanic = CsvProvider<"titanic.csv"> | ||
type Passenger = Titanic.Row | ||
|
||
let dataset = Titanic.GetSample () | ||
|
||
// entropy of sample | ||
let entropy label data = | ||
let size = data |> Seq.length | ||
data | ||
|> Seq.countBy label | ||
|> Seq.map (fun (_,count) -> float count / float size) | ||
|> Seq.sumBy (fun f -> if f > 0. then - f * log f else 0.) | ||
|
||
// chart examples | ||
|
||
let ex1 = [1;1;1;2;2;2;3;3;3] |> entropy id | ||
let ex2 = [1;1;1;1;1;1;1;1;1] |> entropy id | ||
let ex3 = [1;1;1;1;1;2;2;3;3] |> entropy id | ||
let ex4 = [1;1;1;1;1;1;1;2;3] |> entropy id | ||
|
||
|
||
// average entropy if break by feature | ||
|
||
let hasData feature = feature >> Option.isSome | ||
|
||
let splitEntropy extractLabel extractFeature data = | ||
// observations with no missing values | ||
// for the selected feature | ||
let dataWithValues = | ||
data | ||
|> Seq.filter (extractFeature |> hasData) | ||
let size = dataWithValues |> Seq.length | ||
dataWithValues | ||
|> Seq.groupBy extractFeature | ||
|> Seq.sumBy (fun (_,group) -> | ||
let groupSize = group |> Seq.length | ||
let probaGroup = float groupSize / float size | ||
let groupEntropy = group |> entropy extractLabel | ||
probaGroup * groupEntropy) | ||
|
||
|
||
// compare features on entire set | ||
|
||
let survived (p:Passenger) = p.Survived | ||
|
||
let sex (p:Passenger) = Some(p.Sex) | ||
let pclass (p:Passenger) = Some(p.Pclass) | ||
let port (p:Passenger) = | ||
if p.Embarked = "" | ||
then None | ||
else Some(p.Embarked) | ||
let age (p:Passenger) = | ||
if p.Age < 12.0 | ||
then Some("Younger") | ||
else Some("Older") | ||
|
||
printfn "Comparison: most informative feature" | ||
let h = dataset.Rows |> entropy survived | ||
printfn "Base entropy %.3f" h | ||
|
||
dataset.Rows |> splitEntropy survived sex |> printfn " Sex: %.3f" | ||
dataset.Rows |> splitEntropy survived pclass |> printfn " Class: %.3f" | ||
dataset.Rows |> splitEntropy survived port |> printfn " Port: %.3f" | ||
dataset.Rows |> splitEntropy survived age |> printfn " Age: %.3f" | ||
|
||
// procedure can then be repeated... | ||
|
||
let bySex = dataset.Rows |> Seq.groupBy sex | ||
|
||
for (groupName, group) in bySex do | ||
printfn "Group: %s" groupName.Value | ||
let h = group |> entropy survived | ||
printfn "Base entropy %.3f" h | ||
|
||
group |> splitEntropy survived sex |> printfn " Sex: %.3f" | ||
group |> splitEntropy survived pclass |> printfn " Class: %.3f" | ||
group |> splitEntropy survived port |> printfn " Port: %.3f" | ||
group |> splitEntropy survived age |> printfn " Age: %.3f" | ||
|
||
|
||
// analyzing a list of features | ||
|
||
let test () = | ||
|
||
let survived (p:Passenger) = p.Survived | ||
|
||
let sex (p:Passenger) = Some(p.Sex) | ||
let pclass (p:Passenger) = Some(p.Pclass |> string) | ||
|
||
// features now have a consistent type, | ||
// so we can put them all in a list. | ||
let features = | ||
[ "Sex", sex | ||
"Class", pclass | ||
] | ||
|
||
features | ||
|> List.iter (fun (name, feat) -> | ||
dataset.Rows | ||
|> splitEntropy survived feat |> printfn "%s: %.3f" name) | ||
|
||
test () | ||
|
||
|
||
// using entropy to partition continuous features | ||
|
||
let bestAge () = | ||
|
||
let ages = dataset.Rows |> Seq.map (fun p -> p.Age) |> Seq.distinct | ||
let best = | ||
ages | ||
|> Seq.minBy (fun a -> | ||
let age (p:Passenger) = | ||
if p.Age < a then Some("Kid") else Some("Adult") | ||
dataset.Rows |> splitEntropy survived age) | ||
printfn "Best age split" | ||
printfn "Age: %.3f" best | ||
|
||
bestAge () |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
#r @".\packages\FSharp.Data.2.2.2\lib\net40\FSharp.Data.dll" | ||
open FSharp.Data | ||
|
||
type Titanic = CsvProvider<"titanic.csv"> | ||
type Passenger = Titanic.Row | ||
|
||
let dataset = Titanic.GetSample () | ||
|
||
(* | ||
Probability of surviving | ||
*) | ||
|
||
dataset.Headers.Value | ||
|> Seq.iter (printfn "%A") | ||
|
||
dataset.Rows | ||
|> Seq.countBy (fun passenger -> passenger.Survived) | ||
|> Seq.iter (printfn "%A") | ||
|
||
dataset.Rows | ||
|> Seq.averageBy (fun passenger -> | ||
if passenger.Survived then 1.0 else 0.0) | ||
|> printfn "Chances of survival: %.3f" | ||
|
||
|
||
(* | ||
Possible predictors: lead towards | ||
entropy, and bring up missing values. | ||
*) | ||
|
||
let survivalRate (passengers:Passenger seq) = | ||
let total = passengers |> Seq.length | ||
let survivors = | ||
passengers | ||
|> Seq.filter (fun p -> p.Survived) | ||
|> Seq.length | ||
100.0 * (float survivors / float total) | ||
|
||
let bySex = | ||
dataset.Rows | ||
|> Seq.groupBy(fun p -> p.Sex) | ||
|
||
bySex | ||
|> Seq.iter (fun (s,g) -> | ||
printfn "Sex %A: %f" s (survivalRate g)) | ||
|
||
let byClass = | ||
dataset.Rows | ||
|> Seq.groupBy (fun p -> p.Pclass) | ||
|
||
byClass | ||
|> Seq.iter (fun (s,g) -> | ||
printfn "Class %A: %f" s (survivalRate g)) | ||
|
||
|
||
(* | ||
Learning basic decision stumps | ||
*) | ||
|
||
let mostFrequentLabelIn group = | ||
group | ||
|> Seq.countBy snd | ||
|> Seq.maxBy snd | ||
|> fst | ||
|
||
let learn sample extractFeature extractLabel = | ||
// group together observations that have the | ||
// same value for the selected feature, and | ||
// find the most frequent label by group. | ||
let groups = | ||
sample | ||
|> Seq.map (fun obs -> extractFeature obs, extractLabel obs) | ||
|> Seq.groupBy fst | ||
|> Seq.map (fun (feat,group) -> feat, mostFrequentLabelIn group) | ||
// for an observation, find the group with | ||
// matching feature value, and predict the | ||
// most frequent label for that group. | ||
let classifier obs = | ||
let featureValue = extractFeature obs | ||
groups | ||
|> Seq.find (fun (f,_) -> f = featureValue) | ||
|> snd | ||
classifier | ||
|
||
|
||
let survived (p:Passenger) = p.Survived | ||
let sex (p:Passenger) = p.Sex | ||
|
||
let sexClassifier = survived |> learn (dataset.Rows) sex | ||
|
||
// checking predictions on a couple observations | ||
dataset.Rows | ||
|> Seq.take 10 | ||
|> Seq.iter (fun p -> | ||
printfn "Real: %A, Pred: %A" p.Survived (sexClassifier p)) | ||
|
||
printfn "Stump: classify based on passenger sex." | ||
dataset.Rows | ||
|> Seq.averageBy (fun p -> | ||
if p.Survived = sexClassifier p then 1.0 else 0.0) | ||
|
||
printfn "Stump: classify based on passenger class." | ||
let classClassifier = survived |> learn (dataset.Rows) (fun p -> p.Pclass) | ||
|
||
dataset.Rows | ||
|> Seq.averageBy (fun p -> | ||
if p.Survived = classClassifier p then 1.0 else 0.0) | ||
|
||
|
||
(* | ||
Continuous / numeric features | ||
*) | ||
|
||
let survivalByPricePaid = | ||
dataset.Rows | ||
|> Seq.groupBy (fun p -> p.Fare) | ||
|> Seq.sortBy fst | ||
|> Seq.iter (fun (price,passengers) -> | ||
printfn "%6.2F: %6.2f" price (survivalRate passengers)) | ||
|
||
// how many cases are there? | ||
dataset.Rows | ||
|> Seq.map (fun p -> p.Fare) | ||
|> Seq.distinct | ||
|> Seq.length | ||
|
||
// transforming fare into discrete "bins" | ||
|
||
let averageFare = | ||
dataset.Rows | ||
|> Seq.averageBy (fun p -> p.Fare) | ||
|
||
let fareLevel (p:Passenger) = | ||
if p.Fare < averageFare | ||
then "Cheap" | ||
else "Expensive" | ||
|
||
printfn "Stump: classify based on fare level." | ||
let fareClassifier = survived |> learn (dataset.Rows) fareLevel | ||
|
||
dataset.Rows | ||
|> Seq.averageBy (fun p -> | ||
if p.Survived = fareClassifier p then 1.0 else 0.0) | ||
|
||
|
||
|
||
(* | ||
Missing values | ||
*) | ||
|
||
let survivalByPortOfOrigin = | ||
dataset.Rows | ||
|> Seq.groupBy (fun p -> p.Embarked) | ||
|> Seq.iter (fun (port,passengers) -> | ||
printfn "%1s: %6.2f" port (survivalRate passengers)) | ||
|
||
dataset.Rows |> Seq.countBy (fun p -> p.Embarked) | ||
|
||
|
||
// Using Option<'a> to represent missing values | ||
|
||
let hasData extractFeature = extractFeature >> Option.isSome | ||
|
||
let betterLearn sample extractFeature extractLabel = | ||
let branches = | ||
sample | ||
|> Seq.filter (extractFeature |> hasData) | ||
|> Seq.map (fun obs -> extractFeature obs |> Option.get, extractLabel obs) | ||
|> Seq.groupBy fst | ||
|> Seq.map (fun (feat,group) -> feat, mostFrequentLabelIn group) | ||
|> Map.ofSeq | ||
let labelForMissingValues = | ||
sample | ||
|> Seq.countBy extractLabel | ||
|> Seq.maxBy snd | ||
|> fst | ||
let classifier obs = | ||
let featureValue = extractFeature obs | ||
match featureValue with | ||
| None -> labelForMissingValues | ||
| Some(value) -> | ||
match (branches.TryFind value) with | ||
| None -> labelForMissingValues | ||
| Some(predictedLabel) -> predictedLabel | ||
classifier | ||
|
||
let port (p:Passenger) = | ||
if p.Embarked = "" then None | ||
else Some(p.Embarked) | ||
|
||
let updatedClassifier = survived |> betterLearn (dataset.Rows) port | ||
|
||
dataset.Rows | ||
|> Seq.averageBy (fun p -> | ||
if p.Survived = updatedClassifier p then 1.0 else 0.0) |
Oops, something went wrong.