Source code for suricata_check_design_principles.checkers.principle.ml

  1"""`PrincipleMLChecker`."""
  2
  3import copy
  4import logging
  5import os
  6import pickle
  7from collections import Counter
  8from collections.abc import Iterable
  9from typing import Any, Literal, Optional, Union, overload
 10
 11import idstools.rule
 12import xgboost
 13from pandas import DataFrame, Series
 14from sklearn.metrics import f1_score, make_scorer, precision_score, recall_score
 15from sklearn.model_selection import (
 16    GridSearchCV,
 17    RepeatedStratifiedKFold,
 18    cross_val_score,
 19)
 20from sklearn.pipeline import Pipeline
 21from suricata_check.checkers.interface.checker import CheckerInterface
 22from suricata_check.utils.checker import get_rule_option, get_rule_suboptions
 23from suricata_check.utils.checker_typing import ISSUES_TYPE, Issue
 24
 25from suricata_check_design_principles._version import SURICATA_CHECK_DIR
 26from suricata_check_design_principles.checkers.principle._utils import get_message
 27
 28_PICKLE_PATH = os.path.join(SURICATA_CHECK_DIR, "data", "principle_ml_checker.pkl")
 29N_JOBS = 8
 30
 31
 32_logger = logging.getLogger(__name__)
 33
 34
 35COUNT_COLUMNS = (
 36    "flowbits.isset.count",
 37    "flowbits.isntoset.count",
 38    "flowint.isset.count",
 39    "flowint.isntoset.count",
 40    "xbits.isset.count",
 41    "xbits.uisnotset.count",
 42    "http.uri.count",
 43    "http.method.count",
 44    "dns.query.count",
 45    "content.count",
 46    "pcre.count",
 47    "startswith.count",
 48    "bsize.count",
 49    "depth.count",
 50    "urilen.count",
 51    "flow.from_server.count",
 52    "flow.to_server.count",
 53    "flow.from_client.count",
 54    "flow.to_client.count",
 55)
 56STRING_COLUMNS = ()
 57DROPDOWN_COLUMNS = (
 58    "proto",
 59    "threshold.type",
 60)
 61NUMERICAL_COLUMNS = ("threshold.count",)
 62SPLITTABLE_FEATURES = (
 63    "metadata",
 64    "flow",
 65    "threshold",
 66)
 67MSG_KEYWORDS = ("Suspicious", "CVE", "Vulnerability", "Response")
 68MSG_COLUMNS = ("msg.contains." + keyword for keyword in MSG_KEYWORDS)
 69IP_KEYWORDS = ("$HOME_NET", "$HTTP_SERVERS", "$EXTERNAL_NET", "any")
 70IP_COLUMNS = tuple(
 71    ["source_addr.contains." + keyword for keyword in IP_KEYWORDS]
 72    + ["dest_addr.contains." + keyword for keyword in IP_KEYWORDS]
 73)
 74
 75
 76PIPELINE = Pipeline(
 77    [
 78        (
 79            "classify",
 80            xgboost.XGBClassifier(),
 81        )
 82    ]
 83)
 84# https://shengyg.github.io/repository/machine%20learning/2017/02/25/Complete-Guide-to-Parameter-Tuning-xgboost.html
 85PARAM_GRID: list[dict] = [
 86    {
 87        # Fixed parameters for problem / desired complexity
 88        "classify__n_estimators": [1000],
 89        "classify__objective": ["binary:logistic"],
 90        ###
 91        # Parameters to optimize
 92        ## Learning rate
 93        "classify__eta": [0.01, 0.1, 0.3],
 94        ## Tree parameters
 95        "classify__subsample": [1.0],
 96        "classify__colsample_bytree": [0.25, 0.5, 0.75, 1.0],
 97        "classify__scale_pos_weight": [0.1, 0.25, 0.5, 1.0, 2.0, 4.0, 10.0],
 98        "classify__max_depth": [1, 3],
 99        "classify__min_child_weight": [1],
100        "classify__gamma": [0, 0.1],
101        ## Regularization
102        "classify__lambda": [0, 0.01, 0.1],
103        "classify__alpha": [0, 0.01, 0.1],
104    },
105]
106
107PRECISION_WEIGHT = 10
108SCORER = make_scorer(
109    lambda y, y_pred: (PRECISION_WEIGHT + 1)
110    / (
111        PRECISION_WEIGHT / (precision_score(y, y_pred, zero_division=1) + 1e-10)  # type: ignore reportArgumentType
112        + 1 / (recall_score(y, y_pred, zero_division=0) + 1e-10)  # type: ignore reportArgumentType
113    )
114)
115SPLITTER = RepeatedStratifiedKFold(n_splits=2, n_repeats=10)
116GRIDSEARCHCV = GridSearchCV(
117    PIPELINE, PARAM_GRID, cv=SPLITTER, scoring=SCORER, error_score="raise", n_jobs=N_JOBS, verbose=1  # type: ignore reportArgumentType
118)
119
120
[docs] 121class PrincipleMLChecker(CheckerInterface): 122 """The `PrincipleChecker` contains several checks based on the Ruling the Unruly paper and target specificity and coverage. 123 124 Codes Q000-Q009 report on non-adherence to rule design principles similar to Q000-Q009. 125 Differently, they are the result of machine learning analysis of the rules. 126 """ 127 128 count_columns = COUNT_COLUMNS 129 string_columns = STRING_COLUMNS 130 dropdown_columns = DROPDOWN_COLUMNS 131 numerical_columns = NUMERICAL_COLUMNS 132 splittable_features = SPLITTABLE_FEATURES 133 msg_keywords = MSG_KEYWORDS 134 msg_columns = MSG_COLUMNS 135 ip_keywords = IP_KEYWORDS 136 ip_columns = IP_COLUMNS 137 138 codes = { 139 "Q000": {"severity": logging.INFO}, 140 "Q001": {"severity": logging.INFO}, 141 "Q002": {"severity": logging.INFO}, 142 "Q003": {"severity": logging.INFO}, 143 "Q004": {"severity": logging.INFO}, 144 "Q005": {"severity": logging.INFO}, 145 } 146 147 enabled_by_default = ( 148 False # Since the checker is relatively slow, it is disabled by default 149 ) 150 151 _dtypes: Optional[dict[str, Any]] = None 152 _models: dict[str, Pipeline] = {} 153 154 def __new__( 155 cls: type["PrincipleMLChecker"], 156 filepath: Optional[str] = _PICKLE_PATH, 157 *args: tuple, 158 **kwargs: dict, 159 ) -> "PrincipleMLChecker": 160 """Returns a new or unpickled instance of the class.""" 161 if filepath: 162 if os.path.exists(filepath): 163 with open(filepath, "rb") as f: 164 inst = pickle.load(f) 165 166 if not inst.__class__.__name__ == cls.__name__: 167 _logger.error("Unpickled object is not of type %s", cls) 168 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 169 elif not hasattr(inst, "_models") or len(inst._models) == 0: 170 _logger.error("Unpickled object does not have trained models") 171 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 172 else: 173 if "include" in kwargs: 174 inst.include = kwargs["include"] 175 _logger.info("Unpickled object with trained models successfully") 176 else: 177 _logger.warning("No model found for PrincipleMLChecker at %s", filepath) 178 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 179 else: 180 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 181 182 return inst 183
[docs] 184 def __getnewargs__( 185 self: "PrincipleMLChecker", 186 ) -> tuple: 187 """Returns the arguments to be passed to the __new__ method when unpickling.""" 188 return (None,)
189 190 def _check_rule( 191 self: "PrincipleMLChecker", 192 rule: idstools.rule.Rule, 193 ) -> ISSUES_TYPE: 194 issues: ISSUES_TYPE = [] 195 196 if len(self._models) == 0: 197 return issues 198 199 for code, model in self._models.items(): 200 if model.predict(self._get_features(rule, True))[0]: 201 issues.append( 202 Issue( 203 code=code, 204 message=get_message(code), 205 ) 206 ) 207 208 return issues 209
[docs] 210 def train( # noqa: C901 211 self: "PrincipleMLChecker", 212 df: DataFrame, 213 rule_col: str = "rule.rule", 214 principle_cols: dict[str, str] = { 215 "Q000": "labelled.no_proxy", 216 "Q001": "labelled.success", 217 "Q002": "labelled.thresholded", 218 "Q003": "labelled.exceptions", 219 "Q004": "labelled.generalized_match_content", 220 "Q005": "labelled.generalized_match_location", 221 }, 222 reuse_models: bool = False, 223 ) -> None: 224 """Train several models for the checker to detect issues in rules. 225 226 The checker class with trained models is stored in a pickle file (`_PICKLE_PATH`). 227 """ 228 self._dtypes = None 229 if not reuse_models: 230 self._models = {} 231 232 # Extract features and determine feature dtypes 233 X_train = self._get_train_df(df[rule_col]) # noqa: N806 234 235 for col in X_train.columns: 236 try: 237 X_train[col].var() 238 _logger.debug("Detected column: %s", col) 239 except: 240 _logger.error("Error with column %s", col) 241 _logger.error(X_train[col]) 242 raise 243 244 # # Drop zero variance columns 245 X_train = X_train.drop( # noqa: N806 246 X_train.columns[(X_train.fillna(-1337).var(axis=0) <= 0)].to_list(), # type: ignore reportAttributeAccessIssue 247 axis=1, 248 ) 249 250 # Drop columns with too few occurrences of possible values 251 for col in X_train.columns: 252 if ( 253 not col.endswith(".count") 254 and not col.endswith(".num") 255 and not col.endswith(".len") 256 ): 257 if X_train[col].value_counts().min() <= 1: 258 X_train = X_train.drop( # noqa: N806 259 [col], 260 axis=1, 261 ) 262 263 for col in X_train.columns: 264 try: 265 X_train[col].var() 266 _logger.info("Using column: %s", col) 267 except: 268 _logger.error("Error with column %s", col) 269 _logger.error(X_train[col]) 270 raise 271 272 # Store used features and their dtypes 273 self._dtypes = X_train.dtypes.to_dict() 274 _logger.debug(self._dtypes) 275 276 # Redo feature extraction now that FE parameters are set 277 X_train = self._get_train_df(df[rule_col]) # noqa: N806 278 279 _logger.info( 280 "Training model with features: [%s]", 281 ", ".join([str(x) for x in X_train.columns]), 282 ) 283 284 _logger.info(X_train) 285 286 for code, col in principle_cols.items(): 287 y_true = df[col].to_numpy() == 0 288 289 if not reuse_models or code not in self._models: 290 # Train new model with grid search to find optimal parameters 291 gridsearchcv: GridSearchCV = copy.deepcopy(GRIDSEARCHCV) 292 293 gridsearchcv.fit(X_train, y_true) 294 295 _logger.info("Code %s params: %s", code, gridsearchcv.best_params_) 296 _logger.info( 297 "Code %s Weighted F1-score: %s", code, gridsearchcv.best_score_ 298 ) 299 300 self._models[code] = gridsearchcv.best_estimator_ 301 302 precision = cross_val_score( 303 self._models[code], 304 X_train, 305 y_true, 306 scoring=make_scorer(precision_score, zero_division=0.0), 307 cv=SPLITTER, 308 n_jobs=N_JOBS, 309 ).mean() 310 recall = cross_val_score( 311 self._models[code], 312 X_train, 313 y_true, 314 scoring=make_scorer(recall_score, zero_division=0.0), 315 cv=SPLITTER, 316 n_jobs=N_JOBS, 317 ).mean() 318 f1 = cross_val_score( 319 self._models[code], 320 X_train, 321 y_true, 322 scoring=make_scorer(f1_score, zero_division=0.0), 323 cv=SPLITTER, 324 n_jobs=N_JOBS, 325 ).mean() 326 _logger.info("Code %s Precision score: %s", code, precision) 327 _logger.info("Code %s Recall score: %s", code, recall) 328 _logger.info("Code %s F1-score: %s", code, f1) 329 330 # Refit model with training data. 331 self._models[code].fit(X_train, y_true) 332 333 pickle.dump(self, open(_PICKLE_PATH, "wb"))
334 335 def _get_train_df(self: "PrincipleMLChecker", rules: Iterable[str]) -> DataFrame: 336 feature_vectors = [] 337 for rule in rules: 338 parsed_rule = idstools.rule.parse(rule) 339 assert parsed_rule is not None 340 feature_vectors.append(self._get_features(parsed_rule, False)) 341 342 return DataFrame(feature_vectors) 343 344 def _get_raw_features( # noqa: C901 345 self: "PrincipleMLChecker", rule: idstools.rule.Rule 346 ) -> Series: 347 d: dict[str, Optional[Union[str, int]]] = { 348 "proto": get_rule_option(rule, "proto") 349 } 350 351 options = rule["options"] 352 353 for option in options: 354 d[option["name"]] = option["value"] 355 356 counter = Counter([option["name"] for option in options]) 357 for option, count in counter.items(): 358 d[option + ".count"] = count 359 360 for option in options: 361 if option["name"] not in self.splittable_features: 362 continue 363 364 suboptions = [ 365 {"name": k, "value": v} 366 for k, v in get_rule_suboptions(rule, option["name"], warn=False) 367 ] 368 369 if len(suboptions) == 0: 370 continue 371 372 for suboption in suboptions: 373 d[option["name"] + "." + suboption["name"]] = suboption["value"] 374 375 counter = Counter([suboption["name"] for suboption in suboptions]) 376 for suboption, count in counter.items(): 377 d[option["name"] + "." + suboption + ".count"] = count 378 379 msg = get_rule_option(rule, "msg") 380 assert msg is not None 381 msg = msg.lower() 382 for col, keyword in zip(self.msg_columns, self.msg_keywords): 383 d[col] = keyword.lower() in msg 384 385 source_addr = get_rule_option(rule, "source_addr") 386 assert source_addr is not None 387 source_addr = source_addr.lower() 388 for keyword in self.ip_keywords: 389 col = "source_addr.contains." + keyword 390 d[col] = keyword.lower() in source_addr 391 392 dest_addr = get_rule_option(rule, "dest_addr") 393 assert dest_addr is not None 394 dest_addr = dest_addr.lower() 395 for keyword in self.ip_keywords: 396 col = "dest_addr.contains." + keyword 397 d[col] = keyword.lower() in dest_addr 398 399 return Series(d) 400 401 def _preprocess_features(self: "PrincipleMLChecker", data: Series) -> Series: 402 original_cols: set[str] = set(data.index) 403 404 for col in self.string_columns: 405 if col not in data: 406 continue 407 data[col + ".len"] = len(data[col]) 408 data = data.drop(col) 409 410 for col in self.dropdown_columns: 411 if col not in data: 412 continue 413 data[col + "." + data[col] + ".bool"] = 1 414 data = data.drop(col) 415 416 for col in self.numerical_columns: 417 if col not in data: 418 continue 419 data[col + ".num"] = float(data[col]) # type: ignore reportArgumentType 420 data = data.drop(col) 421 422 remaining_cols = ( 423 original_cols 424 - set(self.count_columns) 425 - set(self.string_columns) 426 - set(self.dropdown_columns) 427 - set(self.numerical_columns) 428 - set(self.msg_columns) 429 - set(self.ip_columns) 430 ) 431 432 for col in remaining_cols: 433 data = data.drop(col) 434 435 return data 436 437 @overload 438 def _get_features( 439 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[True] 440 ) -> DataFrame: 441 pass 442 443 @overload 444 def _get_features( 445 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[False] 446 ) -> Series: 447 pass 448 449 def _get_features_frame(self: "PrincipleMLChecker", features: Series) -> DataFrame: 450 features_frame = features.to_frame().transpose() 451 452 if self._dtypes is None: 453 return features_frame 454 455 for col, dtype in self._dtypes.items(): 456 if features_frame.dtypes[col] != dtype: 457 features_frame[col] = features_frame[col].astype(dtype) 458 459 return features_frame 460 461 def _get_features( 462 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: bool 463 ) -> Union[Series, DataFrame]: 464 features: Series = self._get_raw_features(rule) 465 features = self._preprocess_features(features) 466 467 features["custom.negated.count"] = rule["raw"].count(':!"') 468 469 if self._dtypes is None: 470 return features 471 472 for col, dtype in self._dtypes.items(): 473 if col not in features: 474 if col.endswith(".count"): 475 features[col] = 0 476 elif col.endswith(".bool"): 477 features[col] = 0 478 elif col.endswith(".num"): 479 features[col] = -1 480 else: 481 _logger.error( 482 "Unsure how to handle missing feature %s of type %s", 483 col, 484 dtype, 485 ) 486 487 features = features[list(self._dtypes.keys())] # type: ignore reportAssignmentType 488 489 if not frame: 490 return features 491 492 return self._get_features_frame(features)