Source code for scripts.train.trainers.xgboost_line_classifier_trainer

from typing import List

import xgbfir
from xgboost import XGBClassifier

from scripts.train.trainers.base_sklearn_line_classifier import BaseSklearnLineClassifierTrainer


[docs]class XGBoostLineClassifierTrainer(BaseSklearnLineClassifierTrainer): """ Trainer of XGBoost line classifier. See documentation of `XGBClassifier <https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.XGBClassifier>`_ to get more details. """
[docs] def _get_classifier(self) -> XGBClassifier: """ Initialize the XGBClassifier. :return: XGBClassifier instance for training """ return XGBClassifier(random_state=self.random_seed, **self.classifier_parameters)
[docs] def _save_features_importances(self, cls: XGBClassifier, feature_names: List[str]) -> None: """ Save information about most important features for XGBClassifier using `xgbfir <https://github.com/limexp/xgbfir>`_ library. :param cls: XGBClassifier trained on the features with names `feature_names` :param feature_names: column names of the feature matrix, that was used for classifier training """ xgbfir.saveXgbFI(cls, feature_names=feature_names, OutputXlsxFile=self.path_features_importances)