"""Filters rules based on performance metrics."""
import iguanas.utils as utils
import warnings
from typing import Dict, Union, List
from iguanas.utils.typing import PandasDataFrameType
[docs]class SimpleFilter:
"""
Filter rules based on performance metrics.
Parameters
----------
filters : Dict[str, Dict[str, Union[str, float]]]
Gives the filtering metric (keys) and the filtering conditions
(values). The filtering conditions are another dictionary
containing the keys 'Operator' (which specifies the filter
operator) and 'Value' (which specifies the value to filter by).
rule_descriptions : PandasDataFrameType, optional
The standard performance metrics dataframe associated with the
rules (if available). If not given, it will be calculated from
`X_rules`. Defaults to None.
opt_func : Callable, optional
The custom method/function to be applied to the rules (e.g. Fbeta
score) if `rule_descriptions` is not given. Use the filtering
metric key 'OptMetric' in the filters parameter if you need to
filter by this metric. Defaults to None.
Attributes
----------
rules_to_keep : List[str]
List of rules which remain after the filters have been applied.
"""
def __init__(self, filters: Dict[str, Dict[str, Union[str, float]]],
rule_descriptions=None, opt_func=None):
self.filters = filters
self.rule_descriptions = rule_descriptions
self.opt_func = opt_func
[docs] def fit(self, X_rules: PandasDataFrameType, y=None, sample_weight=None) -> None:
"""
Calculates the rules remaining after the filters have been applied.
Parameters
----------
X_rules : PandasDataFrameType
The binary columns of the rules applied to a dataset.
y : PandasSeriesType, optional
The binary target column. Not required if `rule_descriptions` is
given. Defaults to None.
sample_weight : PandasSeriesType, optional
Row-wise weights to apply. Defaults to None.
"""
if self.rule_descriptions is None:
if self.opt_func is None and 'OptMetric' in self.filters.keys():
raise ValueError(
'Must provide `opt_func` when `rule_descriptions` is None and "OptMetric" is included in filters.')
self.rule_descriptions = utils.return_rule_descriptions_from_X_rules(
X_rules=X_rules, X_rules_cols=X_rules.columns, y_true=y,
sample_weight=sample_weight, opt_func=self.opt_func)
self.rules_to_keep = self._iterate_rule_descriptions(
rule_descriptions=self.rule_descriptions, filters=self.filters)
@staticmethod
def _iterate_rule_descriptions(rule_descriptions: PandasDataFrameType,
filters: Dict[str, Dict[str, Union[str, float]]]) -> List[str]:
"""
Iterates through rule_descriptions and applies filters, returning
the rules which meet the filter requirements
"""
for metric, operator_value in filters.items():
if metric not in rule_descriptions.columns:
raise ValueError(
f'{metric} is not in the rule_descriptions dataframe')
operator = operator_value['Operator']
value = operator_value['Value']
mask = eval(f'rule_descriptions["{metric}"] {operator} {value}')
rule_descriptions = rule_descriptions[mask]
rules_to_keep = rule_descriptions.index.tolist()
if not rules_to_keep:
warnings.warn('No rules remaining after filtering')
return rules_to_keep