Last active
March 20, 2017 15:50
-
-
Save jphall663/752b1396f1a4c60468eab93af22655b1 to your computer and use it in GitHub Desktop.
h2o r gbm sample 1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # load h2o library | |
| library(h2o) | |
| ### set global constants ###################################################### | |
| ip = 'localhost' # host to connect connect to h2o server | |
| port = 54321 # port to connect connect to h2o server | |
| nthreads = - 1 # number of threads to use, -1 indicates max. threads | |
| max_mem_size = '12G' # defines amount of available memory, per node | |
| dat = NULL # input data location | |
| col.types = NULL # vector containing the column types, ['numeric', 'enum', ...] | |
| X = NULL # input variable names | |
| y = NULL # target variable name | |
| weights_column = NULL # column name for weighting variable | |
| build_tree_one_node = TRUE # set to true for smaller data | |
| seed = 12345 # random seed, increases reproducibility | |
| path = '/tmp' # folder location to save java objects | |
| ### start and connect to h2o server ########################################### | |
| h2o.init(ip = ip, | |
| port = port, | |
| nthreads = nthreads, | |
| max_mem_size = max_mem_size) | |
| ### simulate data if none is available ######################################## | |
| if (is.null(dat)) { | |
| dat <- h2o.createFrame(rows = 200000, | |
| cols = 25, | |
| categorical_fraction = 0.2, | |
| has_response = T, | |
| response_factors = 2) | |
| y = 'response' | |
| X = setdiff(names(dat), y) | |
| } else { | |
| dat <- h2o.importFile(dat, parse = TRUE, sep = '|', col.types = col.types) | |
| } | |
| ### train straightforward GBM ################################################# | |
| # set target as factor | |
| dat[[y]] <- as.factor(dat[[y]]) | |
| ### partition data | |
| splits <- h2o.splitFrame( | |
| data = dat, | |
| ratios = c(0.6,0.2), # only need to specify 2 fractions, the 3rd is implied | |
| destination_frames = c("train.hex", "valid.hex", "test.hex"), | |
| seed = seed | |
| ) | |
| train <- splits[[1]] | |
| valid <- splits[[2]] | |
| test <- splits[[3]] | |
| ### train gbm | |
| # real-time info at host:port | |
| # ~ 4 minutes for 4000 trees on mac book pro | |
| model_id <- paste('gbm', gsub(':', '-', gsub(' ', '-', Sys.time())), sep = '-') | |
| gbm1 = h2o.gbm(x = X, | |
| y = y, | |
| training_frame = train, | |
| validation_frame = valid, | |
| ntrees = 4000, | |
| model_id = model_id, | |
| max_depth = 5, | |
| learn_rate = 0.005, | |
| min_rows = nrow(dat)*0.0025, | |
| build_tree_one_node = build_tree_one_node, | |
| weights_column = weights_column, | |
| seed = seed) | |
| ### report AUC | |
| h2o.auc(h2o.performance(gbm1, newdata = train)) | |
| h2o.auc(h2o.performance(gbm1, newdata = valid)) | |
| h2o.auc(h2o.performance(gbm1, newdata = test)) | |
| ### save model binary ######################################################## | |
| h2o.saveModel(gbm1, path = path) | |
| ### save POJO (plain old ava object) ######################################### | |
| h2o.download_pojo(model = gbm1, path = path, get_jar = T) | |
| ### save MOJO (model-optimized java object) ################################## | |
| h2o.download_mojo(model = gbm1, path = path, get_genmodel_jar = T) | |
| ### shutdown h2o server ###################################################### | |
| h2o.shutdown() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment