"""
Generates rules using decision trees (applied to Koalas/Spark
dataframes).
"""
import numpy as np
import iguanas.utils as utils
from iguanas.rule_generation._base_generator import BaseGenerator
from iguanas.rules import Rules
from iguanas.utils.types import KoalasDataFrame, KoalasSeries
from iguanas.utils.typing import KoalasDataFrameType, KoalasSeriesType,\
PandasDataFrameType
from datetime import date
from typing import Callable, List, Set, Tuple
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, DecisionTreeClassificationModel
from pyspark.sql import DataFrame
import warnings
[docs]class RuleGeneratorDTSpark(Rules, BaseGenerator):
"""
Generate rules by extracting the highest performing branches from a
tree ensemble model.
Parameters
----------
opt_func : Callable
A function/method which calculates the desired optimisation metric
(e.g. Fbeta score).
n_total_conditions : int
The maximum number of conditions per generated rule.
tree_ensemble : pyspark.ml.classification.RandomForestClassifier
Pyspark tree ensemble classifier object used to generated rules.
precision_threshold : float, optional
Precision threshold for the tree/branch to be used to create rules.
If the overall precision of the tree/branch is less than or equal
to this value, it will not be used in rule generation. Note that if
`bootstrap` == True in the tree_ensemble class, the precision will
be based on the bootstrapped sample used to create the tree.
Defaults to 0.
target_feat_corr_types : Union[Dict[str, List[str]], str], optional
Limits the conditions of the rules based on the target-feature
correlation (e.g. if a feature has a positive correlation with
respect to the target, then only greater than operators are used
for conditions that utilise that feature). Can be either a
dictionary specifying the list of positively correlated features
wrt the target (under the key `PositiveCorr`) and negatively
correlated features wrt the target (under the key `NegativeCorr`),
or 'Infer' (where each target-feature correlation type is inferred
from the data). Defaults to None.
verbose : int, optional
Controls the verbosity - the higher, the more messages. >0 : gives
the overall progress of the training of the ensemble model and the
extraction of the rules from the trees.
rule_name_prefix : str, optional
Prefix to use for each rule name. If None, the standard prefix is
used. Defaults to None.
Attributes
----------
rule_strings : Dict[str, str]
The generated rules, defined using the standard Iguanas string
format (values) and their names (keys).
rule_descriptions : PandasDataFrameType
A dataframe showing the logic of the rules and their performance
metrics on the given dataset.
"""
def __init__(self, opt_func: Callable,
n_total_conditions: int,
tree_ensemble: RandomForestClassifier,
precision_threshold=0,
target_feat_corr_types=None,
verbose=0, rule_name_prefix=None):
self.tree_ensemble = tree_ensemble
self.tree_ensemble.setMaxDepth(n_total_conditions)
self.tree_ensemble.setSeed(0)
self.tree_ensemble.setLabelCol('label_')
self.opt_func = opt_func
self.precision_threshold = precision_threshold
self.target_feat_corr_types = target_feat_corr_types
self.verbose = verbose
self._rule_name_counter = 0
today = date.today()
self.today = today.strftime("%Y%m%d")
self.rule_name_prefix = rule_name_prefix
Rules.__init__(self, rule_strings={}, opt_func=self.opt_func)
def __repr__(self):
if self.rule_strings:
return f'RuleGeneratorDTSpark object with {len(self.rule_strings)} rules generated'
else:
return f'RuleGeneratorDTSpark(opt_func={self.opt_func}, n_total_conditions={self.tree_ensemble.getMaxDepth}, tree_ensemble={self.tree_ensemble}, precision_threshold={self.precision_threshold}, target_feat_corr_types={self.target_feat_corr_types})'
[docs] def fit(self, X: KoalasDataFrameType, y: KoalasSeriesType,
sample_weight=None) -> KoalasDataFrameType:
"""
Generates rules by extracting the highest performing branches in a tree
ensemble model.
Parameters
----------
X : KoalasDataFrameType
The feature set used for training the model.
y : KoalasSeriesType
The target column.
sample_weight : KoalasSeriesType, optional
Record-wise weights to apply. Defaults to None.
Returns
-------
KoalasDataFrameType
The binary columns of the rules on the fitted dataset.
"""
utils.check_allowed_types(X, 'X', [KoalasDataFrame])
utils.check_allowed_types(y, 'y', [KoalasSeries])
if sample_weight is not None:
utils.check_allowed_types(
sample_weight, 'sample_weight', [KoalasSeries])
if self.target_feat_corr_types == 'Infer':
if self.verbose:
print(
'--- Calculating correlation of features with respect to the target ---')
self.target_feat_corr_types = self._calc_target_ratio_wrt_features(
X=X, y=y
)
if self.verbose:
print('--- Returning column datatypes ---')
columns_int, columns_cat, _ = utils.return_columns_types(X)
if self.verbose:
print('--- Creating Spark DataFrame for training ---')
spark_df = self._create_train_spark_df(
X=X, y=y, sample_weight=sample_weight)
if self.verbose:
print('--- Training tree ensemble ---')
if sample_weight is not None:
self.tree_ensemble.setWeightCol('sample_weight_')
trained_tree_ensemble = self.tree_ensemble.fit(spark_df)
if self.verbose:
print('--- Extracting rules from tree ensemble ---')
X_rules = self._extract_rules_from_ensemble(
X=X, y=y, sample_weight=sample_weight,
tree_ensemble=trained_tree_ensemble, columns_int=columns_int,
columns_cat=columns_cat,
)
self.rule_strings = self.rule_descriptions['Logic'].to_dict()
return X_rules
def _extract_rules_from_ensemble(self, X: KoalasDataFrameType, y: KoalasSeriesType,
sample_weight: KoalasSeriesType,
tree_ensemble: RandomForestClassifier,
columns_int: List[str],
columns_cat: List[str]) -> KoalasDataFrameType:
"""
Method for returning all of the rules from the ensemble tree-based
model.
"""
list_of_rule_string_sets = []
for i, decision_tree in enumerate(tree_ensemble.trees):
if decision_tree.depth == 0:
warnings.warn(
f'Decision Tree {i} has a depth of zero - skipping')
continue
dt_rule_strings_set = self._extract_rules_from_dt(
columns=X.columns.tolist(), decision_tree=decision_tree,
columns_cat=columns_cat, columns_int=columns_int
)
list_of_rule_string_sets.append(dt_rule_strings_set)
rule_strings_set = sorted(set().union(*list_of_rule_string_sets))
self.rule_strings = dict((self._generate_rule_name_dt(), rule_string)
for rule_string in rule_strings_set)
if not self.rule_strings:
raise Exception(
'No rules could be generated. Try changing the class parameters.')
X_rules = self.transform(X=X, y=y, sample_weight=sample_weight)
return X_rules
def _extract_rules_from_dt(self, columns: List[str],
decision_tree: DecisionTreeClassificationModel,
columns_int: List[str],
columns_cat: List[str]) -> Set[str]:
"""
Removes low precision DTs and returns the rules from the DT.
"""
left, right, features, thresholds, precisions, tree_prec = self._get_pyspark_tree_structure(
decision_tree._call_java('rootNode')
)
if tree_prec <= self.precision_threshold:
return set()
else:
return self._extract_rules_from_tree(
columns=columns, precision_threshold=self.precision_threshold,
columns_int=columns_int, columns_cat=columns_cat,
left=left, right=right, features=features,
thresholds=thresholds, precisions=precisions
)
@staticmethod
def _create_train_spark_df(X: KoalasDataFrameType, y: KoalasSeriesType,
sample_weight: KoalasSeriesType) -> DataFrame:
"""
Creates a Spark DataFrame from `X`, `y` and `sample_weight` (if
provided) for training the pyspark tree ensemble.
"""
spark_df = utils.create_spark_df(
X=X, y=y, sample_weight=sample_weight)
vectorAssembler = VectorAssembler(
inputCols=X.columns.tolist(), outputCol="features")
spark_df = vectorAssembler.transform(spark_df)
return spark_df
@staticmethod
def _get_pyspark_tree_structure(node) -> Tuple[
np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, float]:
"""
Extracts the left_children, right_children, features, thresholds,
node precisions and overall precision of the given DT (equivalent to
those attributes seen in Sklearn's DecisionTreeClassifier class)
"""
def _get_node_info(node):
"""
Appends info to the left_children, right_children, features,
thresholds, node precisions and node predictions arrays for each
node
"""
network[k[0]-1] = {}
k[0] = k[0] + 1
num_desc_left = node.leftChild().numDescendants()
num_desc_right = node.rightChild().numDescendants()
features.append(node.split().featureIndex())
thresholds.append(node.split().threshold())
node_splits.append(list(node.impurityStats().stats()))
if num_desc_left == 0:
network[k[0]-2]['Left'] = k[0]-1
t = k[0]-2
k[0] = k[0] + 1
left_children.append(-1)
node_splits.append(
list(node.leftChild().impurityStats().stats()))
features.append(-2)
thresholds.append(-2)
node_preds.append(node.leftChild().prediction())
else:
network[k[0]-2]['Left'] = k[0]-1
t = k[0]-2
left_children.append(k[0])
node_preds.append(None)
_get_node_info(node.leftChild())
if num_desc_right == 0:
network[t]['Right'] = k[0]-1
k[0] = k[0] + 1
left_children.append(-1)
node_splits.append(
list(node.rightChild().impurityStats().stats()))
features.append(-2)
thresholds.append(-2)
node_preds.append(node.rightChild().prediction())
else:
network[t]['Right'] = k[0]-1
left_children.append(k[0])
node_preds.append(None)
_get_node_info(node.rightChild())
k = [1]
left_children = [k[0]]
features = []
thresholds = []
node_preds = [None]
node_splits = []
node_precs = []
network = {}
_get_node_info(node)
left_children, features, thresholds = np.array(
left_children), np.array(features), np.array(thresholds)
right_children = np.array([network[i]['Right'] if i in network.keys(
) else -1 for i in range(len(left_children))])
node_precs = np.empty(len(node_splits))
tps_l, tps_fps_l = [], []
for i, lc in enumerate(left_children):
node_precs[i] = node_splits[i][1]/sum(node_splits[i])
if lc == -1 and node_preds[i] == 1:
tps_l.append(node_splits[i][1])
tps_fps_l.append(sum(node_splits[i]))
tps = sum(tps_l)
tps_fps = sum(tps_fps_l)
tree_prec = 0 if tps_fps == 0 else tps/tps_fps
return left_children, right_children, features, thresholds, node_precs, tree_prec