diff --git a/Sources/EpsilonGreedy.swift b/Sources/EpsilonGreedy.swift index 26010ef..50133af 100644 --- a/Sources/EpsilonGreedy.swift +++ b/Sources/EpsilonGreedy.swift @@ -1,22 +1,22 @@ import Foundation -struct EpsilonGreedy { - let epsilon: Float - let counts: [Int] - let values: [Float] +public struct EpsilonGreedy { + public let epsilon: Float + public let counts: [Int] + public let values: [Float] - init(epsilon: Float, counts: [Int], values: [Float]) { + public init(epsilon: Float, counts: [Int], values: [Float]) { self.epsilon = epsilon self.counts = counts self.values = values } - init(epsilon: Float, nArms: Int) { + public init(epsilon: Float, nArms: Int) { self = EpsilonGreedy(epsilon: epsilon, counts: [], values: []) .initialize(nArms: nArms) } - func initialize(nArms nArms: Int) -> EpsilonGreedy { + public func initialize(nArms nArms: Int) -> EpsilonGreedy { return EpsilonGreedy( epsilon: epsilon, counts: (0.. EpsilonGreedy { + public func newEpsilon(epsilon: Float) -> EpsilonGreedy { return EpsilonGreedy(epsilon: epsilon, counts: counts, values: values) } - func indMax(values: [Float]) -> Int? { + public func indMax(values: [Float]) -> Int? { guard let max = values.maxElement(), let index = values.indexOf(max) else { return nil } return Int(index) } - func selectArm() -> Int? { + public func selectArm() -> Int? { if Float(arc4random())/Float(UInt32.max) > epsilon { return indMax(values) } return Int(arc4random_uniform(UInt32(values.count)) + 1) } - func update(chosenArm: Int, reward: Float) -> EpsilonGreedy { + public func update(chosenArm: Int, reward: Float) -> EpsilonGreedy { var newCounts = counts newCounts[chosenArm] = counts[chosenArm] + 1