Created
July 15, 2009 21:07
-
-
Save michaelmelanson/147988 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| -- Stochatistic Gradient Descent in Haskell | |
| -- | |
| -- Based on Mark Reid's implementation in Clojure: | |
| -- | |
| -- http://github.com/mreid/injuce | |
| -- | |
| -- | |
| -- Profiling | |
| -- --------- | |
| -- | |
| -- On my computer, about 70% of the execution time is spent parsing | |
| -- the input, and 15% is spent looking up elements in the feature | |
| -- maps. I do not know how to speed up either of these operations | |
| -- (except possibly by using regular expressions rather than Parsec | |
| -- for parsing), and so I am releasing the code as-is. Here are the | |
| -- execution times: | |
| -- | |
| -- With no optimization: | |
| -- | |
| -- 14.42 real 13.82 user 0.32 sys | |
| -- | |
| -- And with optimization (-O3): | |
| -- | |
| -- 13.41 real 12.86 user 0.28 sys | |
| -- 13.49 real 13.03 user 0.26 sys | |
| -- | |
| -- Changing to Data.IntMap: | |
| -- | |
| -- 11.94 real 11.34 user 0.24 sys | |
| -- 11.64 real 11.23 user 0.23 sys | |
| -- | |
| -- Replaced parser with lazy ByteString spitting (built with -O3, | |
| -- running with -H15M): | |
| -- | |
| -- 9.36 real 9.11 user 0.12 sys | |
| -- 9.75 real 9.19 user 0.14 sys | |
| -- | |
| -- With -H300M: | |
| -- | |
| -- 8.89 real 7.98 user 0.54 sys | |
| -- 8.92 real 7.98 user 0.55 sys | |
| -- | |
| import Control.Monad | |
| import qualified Data.IntMap as Map | |
| import qualified Data.ByteString.Lazy.Char8 as L | |
| import Data.Word | |
| import Data.Maybe | |
| type Features = Map.IntMap Double | |
| data Example = Example { label :: Double, | |
| features :: Features } | |
| deriving (Show) | |
| data Model = Model { lambda :: Double, | |
| step :: Int, | |
| w :: Features, | |
| errors :: Int | |
| } | |
| -- Sparse vector operations | |
| add x y = Map.unionWith (+) x y | |
| inner :: Features -> Features -> Double | |
| inner x y = {-# SCC "sum" #-} sum terms | |
| where terms = {-# SCC "terms" #-} map computeTerm (Map.keys y) | |
| computeTerm k = {-# SCC "computeTerm" #-} | |
| let lhs = {-# SCC "lhs" #-} case {-# SCC "lookup" #-} Map.lookup k x of | |
| Just a -> a | |
| Nothing -> 0 | |
| rhs = {-# SCC "rhs" #-} case {-# SCC "lookup" #-} Map.lookup k y of | |
| Just a -> a | |
| Nothing -> 0 | |
| in {-# SCC "multiply" #-} (lhs * rhs) | |
| norm x = sqrt (inner x x) | |
| scale :: Double -> Features -> Features | |
| scale a x = Map.map (* a) x | |
| project w r = scale (min (r / (norm w)) 1) w | |
| -- Parsing routines | |
| parse_examples :: L.ByteString -> [Example] | |
| parse_examples content = let examplesstr = filter (/= L.pack "") $ L.split '\n' content | |
| in map parse_example examplesstr | |
| where parse_example :: L.ByteString -> Example | |
| parse_example examplestr = | |
| let exampleFields = L.split ' ' examplestr | |
| label = head exampleFields | |
| featureFields = tail exampleFields | |
| featurePairs = map (L.split ':') featureFields | |
| features = map (\x -> (fst $ fromJust $ L.readInt (x!!0), read $ L.unpack (x!!1))) featurePairs | |
| in Example { label = fromIntegral $ fst $ fromJust $ L.readInt label, | |
| features = Map.fromList features } | |
| parse_input :: IO [Example] | |
| parse_input = do input <- L.getContents | |
| examples <- return $ parse_examples input | |
| -- putStrLn $ show examples | |
| return examples | |
| -- Okay, now for the meat and potatoes | |
| hinge_loss w example = max 0 loss | |
| where loss = 1 - ((label example) * (inner w (features example))) | |
| correct :: Features -> Example -> Int -> Double -> Features | |
| correct w example t lambda = | |
| let x = features example | |
| y = label example | |
| w1 = scale (1.0 - (1.0 / fromIntegral t)) w | |
| eta = 1.0 / (lambda * fromIntegral t) | |
| r = 1.0 / (sqrt lambda) | |
| in project (add w1 (scale (eta * y) x)) r | |
| report model interval = | |
| when (step model `mod` interval == 0) $ do | |
| let t = step model | |
| size = length $ Map.keys (w model) | |
| e = errors model | |
| putStrLn ("Step: " ++ show t ++ | |
| "\t Features in w = " ++ show size ++ | |
| "\t Errors = " ++ show e ++ | |
| "\t Accuracy = " ++ show (((fromIntegral e) / | |
| (fromIntegral t)) :: Double)) | |
| update :: Model -> Example -> IO Model | |
| update model example = | |
| let l = lambda model | |
| t = step model | |
| w1 = w model | |
| e = errors model | |
| error = (hinge_loss w1 example) > 0 | |
| in do report model 100 | |
| return $ Model { | |
| w = (if error then correct w1 example t l else w1), | |
| lambda = l, | |
| step = t + 1, | |
| errors = (if error then (e + 1) else e) | |
| } | |
| train initial examples = foldM update initial examples | |
| main = do | |
| start <- return $ Model { | |
| lambda = 0.0001, | |
| step = 1, | |
| w = Map.empty, | |
| errors = 0 | |
| } | |
| examples <- parse_input | |
| model <- train start examples | |
| putStrLn $ show (w model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment