Skip to content

sieber-lab/custom_gbm

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Custom_gbm

JAX-based implementation of custom loss functions for LightGBM. The package currently includes:

More will come in the future. Additionally, the package provides a wrapper for facilitating the addition of new loss functions, e.g. by handling automatic differentiation, initial score calculation, interfacing with LightGBM Train API and so forth.

Installation

The package requires LightGBM 3.5.5, Scipy 1.10.1 and Jax 0.4.12. Alternatively, you can create a virtual environment using the environment.yml file in the Github repository.

git clone https://github.com/dahvida/custom_gbm
conda env create --name customgbm --file=environment.yml
conda activate customgbm

Usage

Here is a minimal example for training a model with CustomGBM. Additional tutorials are provided in the notebook folder.

# import packages
from focal import Focal_loss
from custom_gbm import CustomGBM

# define parameters
loss_fn = Focal_loss
loss_params = {"gamma": 2.0}
booster_params = {"num_boost_round": 100}
gbm = CustomGBM(loss_fn, loss_params, booster_params)

# fit model
gbm.fit(x_train, y_train)

# get predictions
predictions = gbm.predict(x_test)

About

JAX-based implementation of custom loss functions for LightGBM

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 75.0%
  • Python 25.0%