Source code for suricata_check_design_principles.checkers.principle.ml

  1"""`PrincipleMLChecker`."""
  2
  3import copy
  4import logging
  5import os
  6import pickle
  7from collections import Counter
  8from collections.abc import Iterable
  9from typing import Any, Literal, overload
 10
 11import suricata_check
 12import xgboost
 13from pandas import DataFrame, Series
 14from sklearn.metrics import f1_score, make_scorer, precision_score, recall_score
 15from sklearn.model_selection import (
 16    GridSearchCV,
 17    RepeatedStratifiedKFold,
 18    cross_val_score,
 19)
 20from sklearn.pipeline import Pipeline
 21from suricata_check.checkers.interface import CheckerInterface
 22from suricata_check.utils.checker import get_rule_option, get_rule_suboptions
 23from suricata_check.utils.checker_typing import ISSUES_TYPE, Issue
 24from suricata_check.utils.rule import Rule
 25
 26from suricata_check_design_principles._version import SURICATA_CHECK_DIR
 27from suricata_check_design_principles.checkers.principle._utils import get_message
 28
 29_PICKLE_PATH = os.path.join(SURICATA_CHECK_DIR, "data", "principle_ml_checker.pkl")
 30N_JOBS = 8
 31
 32
 33_logger = logging.getLogger(__name__)
 34
 35
 36COUNT_COLUMNS = (
 37    "flowbits.isset.count",
 38    "flowbits.isntoset.count",
 39    "flowint.isset.count",
 40    "flowint.isntoset.count",
 41    "xbits.isset.count",
 42    "xbits.uisnotset.count",
 43    "http.uri.count",
 44    "http.method.count",
 45    "dns.query.count",
 46    "content.count",
 47    "pcre.count",
 48    "startswith.count",
 49    "bsize.count",
 50    "depth.count",
 51    "urilen.count",
 52    "flow.from_server.count",
 53    "flow.to_server.count",
 54    "flow.from_client.count",
 55    "flow.to_client.count",
 56)
 57STRING_COLUMNS = ()
 58DROPDOWN_COLUMNS = (
 59    "proto",
 60    "threshold.type",
 61)
 62NUMERICAL_COLUMNS = ("threshold.count",)
 63SPLITTABLE_FEATURES = (
 64    "metadata",
 65    "flow",
 66    "threshold",
 67)
 68MSG_KEYWORDS = ("Suspicious", "CVE", "Vulnerability", "Response")
 69MSG_COLUMNS = ("msg.contains." + keyword for keyword in MSG_KEYWORDS)
 70IP_KEYWORDS = ("$HOME_NET", "$HTTP_SERVERS", "$EXTERNAL_NET", "any")
 71IP_COLUMNS = tuple(
 72    ["source_addr.contains." + keyword for keyword in IP_KEYWORDS]
 73    + ["dest_addr.contains." + keyword for keyword in IP_KEYWORDS]
 74)
 75
 76
 77PIPELINE = Pipeline(
 78    [
 79        (
 80            "classify",
 81            xgboost.XGBClassifier(),
 82        )
 83    ]
 84)
 85# https://shengyg.github.io/repository/machine%20learning/2017/02/25/Complete-Guide-to-Parameter-Tuning-xgboost.html
 86PARAM_GRID: list[dict] = [
 87    {
 88        # Fixed parameters for problem / desired complexity
 89        "classify__n_estimators": [1000],
 90        "classify__objective": ["binary:logistic"],
 91        ###
 92        # Parameters to optimize
 93        ## Learning rate
 94        "classify__eta": [0.01, 0.1, 0.3],
 95        ## Tree parameters
 96        "classify__subsample": [1.0],
 97        "classify__colsample_bytree": [0.25, 0.5, 0.75, 1.0],
 98        "classify__scale_pos_weight": [0.1, 0.25, 0.5, 1.0, 2.0, 4.0, 10.0],
 99        "classify__max_depth": [1, 3],
100        "classify__min_child_weight": [1],
101        "classify__gamma": [0, 0.1],
102        ## Regularization
103        "classify__lambda": [0, 0.01, 0.1],
104        "classify__alpha": [0, 0.01, 0.1],
105    },
106]
107
108PRECISION_WEIGHT = 10
109SCORER = make_scorer(
110    lambda y, y_pred: (PRECISION_WEIGHT + 1)
111    / (
112        PRECISION_WEIGHT / (precision_score(y, y_pred, zero_division=1) + 1e-10)  # type: ignore reportArgumentType
113        + 1 / (recall_score(y, y_pred, zero_division=0) + 1e-10)  # type: ignore reportArgumentType
114    )
115)
116SPLITTER = RepeatedStratifiedKFold(n_splits=2, n_repeats=10)
117GRIDSEARCHCV = GridSearchCV(
118    PIPELINE, PARAM_GRID, cv=SPLITTER, scoring=SCORER, error_score="raise", n_jobs=N_JOBS, verbose=1  # type: ignore reportArgumentType
119)
120
121
[docs] 122class PrincipleMLChecker(CheckerInterface): 123 """The `PrincipleChecker` contains several checks based on the Ruling the Unruly paper and target specificity and coverage. 124 125 Codes Q000-Q009 report on non-adherence to rule design principles similar to Q000-Q009. 126 Differently, they are the result of machine learning analysis of the rules. 127 """ 128 129 count_columns = COUNT_COLUMNS 130 string_columns = STRING_COLUMNS 131 dropdown_columns = DROPDOWN_COLUMNS 132 numerical_columns = NUMERICAL_COLUMNS 133 splittable_features = SPLITTABLE_FEATURES 134 msg_keywords = MSG_KEYWORDS 135 msg_columns = MSG_COLUMNS 136 ip_keywords = IP_KEYWORDS 137 ip_columns = IP_COLUMNS 138 139 codes = { 140 "Q000": {"severity": logging.INFO}, 141 "Q001": {"severity": logging.INFO}, 142 "Q002": {"severity": logging.INFO}, 143 "Q003": {"severity": logging.INFO}, 144 "Q004": {"severity": logging.INFO}, 145 "Q005": {"severity": logging.INFO}, 146 } 147 148 enabled_by_default = ( 149 False # Since the checker is relatively slow, it is disabled by default 150 ) 151 152 _dtypes: dict[str, Any] | None = None 153 _models: dict[str, Pipeline] = {} 154 155 def __new__( 156 cls: type["PrincipleMLChecker"], 157 filepath: str | None = _PICKLE_PATH, 158 *args: tuple, 159 **kwargs: dict, 160 ) -> "PrincipleMLChecker": 161 """Returns a new or unpickled instance of the class.""" 162 if filepath: 163 if os.path.exists(filepath): 164 with open(filepath, "rb") as f: 165 inst = pickle.load(f) 166 167 if not inst.__class__.__name__ == cls.__name__: 168 _logger.error("Unpickled object is not of type %s", cls) 169 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 170 elif not hasattr(inst, "_models") or len(inst._models) == 0: 171 _logger.error("Unpickled object does not have trained models") 172 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 173 else: 174 if "include" in kwargs: 175 inst.include = kwargs["include"] 176 _logger.info("Unpickled object with trained models successfully") 177 else: 178 _logger.warning("No model found for PrincipleMLChecker at %s", filepath) 179 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 180 else: 181 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 182 183 return inst 184
[docs] 185 def __getnewargs__( 186 self: "PrincipleMLChecker", 187 ) -> tuple: 188 """Returns the arguments to be passed to the __new__ method when unpickling.""" 189 return (None,)
190 191 def _check_rule( 192 self: "PrincipleMLChecker", 193 rule: Rule, 194 ) -> ISSUES_TYPE: 195 issues: ISSUES_TYPE = [] 196 197 if len(self._models) == 0: 198 return issues 199 200 for code, model in self._models.items(): 201 if model.predict(self._get_features(rule, True))[0]: 202 issues.append( 203 Issue( 204 code=code, 205 message=get_message(code), 206 ) 207 ) 208 209 return issues 210
[docs] 211 def train( # noqa: C901 212 self: "PrincipleMLChecker", 213 df: DataFrame, 214 rule_col: str = "rule.rule", 215 principle_cols: dict[str, str] = { 216 "Q000": "labelled.no_proxy", 217 "Q001": "labelled.success", 218 "Q002": "labelled.thresholded", 219 "Q003": "labelled.exceptions", 220 "Q004": "labelled.generalized_match_content", 221 "Q005": "labelled.generalized_match_location", 222 }, 223 reuse_models: bool = False, 224 ) -> None: 225 """Train several models for the checker to detect issues in rules. 226 227 The checker class with trained models is stored in a pickle file (`_PICKLE_PATH`). 228 """ 229 self._dtypes = None 230 if not reuse_models: 231 self._models = {} 232 233 # Extract features and determine feature dtypes 234 X_train = self._get_train_df(df[rule_col]) # noqa: N806 235 236 for col in X_train.columns: 237 try: 238 X_train[col].var() 239 _logger.debug("Detected column: %s", col) 240 except: 241 _logger.error("Error with column %s", col) 242 _logger.error(X_train[col]) 243 raise 244 245 # # Drop zero variance columns 246 X_train = X_train.drop( # noqa: N806 247 X_train.columns[(X_train.fillna(-1337).var(axis=0) <= 0)].to_list(), # type: ignore reportAttributeAccessIssue 248 axis=1, 249 ) 250 251 # Drop columns with too few occurrences of possible values 252 for col in X_train.columns: 253 if ( 254 not col.endswith(".count") 255 and not col.endswith(".num") 256 and not col.endswith(".len") 257 ): 258 if X_train[col].value_counts().min() <= 1: 259 X_train = X_train.drop( # noqa: N806 260 [col], 261 axis=1, 262 ) 263 264 for col in X_train.columns: 265 try: 266 X_train[col].var() 267 _logger.info("Using column: %s", col) 268 except: 269 _logger.error("Error with column %s", col) 270 _logger.error(X_train[col]) 271 raise 272 273 # Store used features and their dtypes 274 self._dtypes = X_train.dtypes.to_dict() 275 _logger.debug(self._dtypes) 276 277 # Redo feature extraction now that FE parameters are set 278 X_train = self._get_train_df(df[rule_col]) # noqa: N806 279 280 _logger.info( 281 "Training model with features: [%s]", 282 ", ".join([str(x) for x in X_train.columns]), 283 ) 284 285 _logger.info(X_train) 286 287 for code, col in principle_cols.items(): 288 y_true = df[col].to_numpy() == 0 289 290 if not reuse_models or code not in self._models: 291 # Train new model with grid search to find optimal parameters 292 gridsearchcv: GridSearchCV = copy.deepcopy(GRIDSEARCHCV) 293 294 gridsearchcv.fit(X_train, y_true) 295 296 _logger.info("Code %s params: %s", code, gridsearchcv.best_params_) 297 _logger.info( 298 "Code %s Weighted F1-score: %s", code, gridsearchcv.best_score_ 299 ) 300 301 self._models[code] = gridsearchcv.best_estimator_ 302 303 precision = cross_val_score( 304 self._models[code], 305 X_train, 306 y_true, 307 scoring=make_scorer(precision_score, zero_division=0.0), 308 cv=SPLITTER, 309 n_jobs=N_JOBS, 310 ).mean() 311 recall = cross_val_score( 312 self._models[code], 313 X_train, 314 y_true, 315 scoring=make_scorer(recall_score, zero_division=0.0), 316 cv=SPLITTER, 317 n_jobs=N_JOBS, 318 ).mean() 319 f1 = cross_val_score( 320 self._models[code], 321 X_train, 322 y_true, 323 scoring=make_scorer(f1_score, zero_division=0.0), 324 cv=SPLITTER, 325 n_jobs=N_JOBS, 326 ).mean() 327 _logger.info("Code %s Precision score: %s", code, precision) 328 _logger.info("Code %s Recall score: %s", code, recall) 329 _logger.info("Code %s F1-score: %s", code, f1) 330 331 # Refit model with training data. 332 self._models[code].fit(X_train, y_true) 333 334 pickle.dump(self, open(_PICKLE_PATH, "wb"))
335 336 def _get_train_df(self: "PrincipleMLChecker", rules: Iterable[str]) -> DataFrame: 337 feature_vectors = [] 338 for rule in rules: 339 parsed_rule = suricata_check.utils.rule.parse(rule) 340 assert parsed_rule is not None 341 feature_vectors.append(self._get_features(parsed_rule, False)) 342 343 return DataFrame(feature_vectors) 344 345 def _get_raw_features( # noqa: C901 346 self: "PrincipleMLChecker", rule: Rule 347 ) -> Series: 348 d: dict[str, str | int | None] = {"proto": get_rule_option(rule, "proto")} 349 350 options = rule.options 351 352 for option in options: 353 d[option.name] = option.value 354 355 counter = Counter([option.name for option in options]) 356 for option, count in counter.items(): 357 d[option + ".count"] = count 358 359 for option in options: 360 if option.name not in self.splittable_features: 361 continue 362 363 suboptions = [ 364 {"name": k, "value": v} 365 for k, v in get_rule_suboptions(rule, option.name, warn=False) 366 ] 367 368 if len(suboptions) == 0: 369 continue 370 371 for suboption in suboptions: 372 d[option.name + "." + suboption["name"]] = suboption["value"] 373 374 counter = Counter([suboption["name"] for suboption in suboptions]) 375 for suboption, count in counter.items(): 376 d[option.name + "." + suboption + ".count"] = count 377 378 msg = get_rule_option(rule, "msg") 379 assert msg is not None 380 msg = msg.lower() 381 for col, keyword in zip(self.msg_columns, self.msg_keywords): 382 d[col] = keyword.lower() in msg 383 384 source_addr = get_rule_option(rule, "source_addr") 385 assert source_addr is not None 386 source_addr = source_addr.lower() 387 for keyword in self.ip_keywords: 388 col = "source_addr.contains." + keyword 389 d[col] = keyword.lower() in source_addr 390 391 dest_addr = get_rule_option(rule, "dest_addr") 392 assert dest_addr is not None 393 dest_addr = dest_addr.lower() 394 for keyword in self.ip_keywords: 395 col = "dest_addr.contains." + keyword 396 d[col] = keyword.lower() in dest_addr 397 398 return Series(d) 399 400 def _preprocess_features(self: "PrincipleMLChecker", data: Series) -> Series: 401 original_cols: set[str] = set(data.index) 402 403 for col in self.string_columns: 404 if col not in data: 405 continue 406 data[col + ".len"] = len(data[col]) 407 data = data.drop(col) 408 409 for col in self.dropdown_columns: 410 if col not in data: 411 continue 412 data[col + "." + data[col] + ".bool"] = 1 413 data = data.drop(col) 414 415 for col in self.numerical_columns: 416 if col not in data: 417 continue 418 data[col + ".num"] = float(data[col]) # type: ignore reportArgumentType 419 data = data.drop(col) 420 421 remaining_cols = ( 422 original_cols 423 - set(self.count_columns) 424 - set(self.string_columns) 425 - set(self.dropdown_columns) 426 - set(self.numerical_columns) 427 - set(self.msg_columns) 428 - set(self.ip_columns) 429 ) 430 431 for col in remaining_cols: 432 data = data.drop(col) 433 434 return data 435 436 @overload 437 def _get_features( 438 self: "PrincipleMLChecker", rule: Rule, frame: Literal[True] 439 ) -> DataFrame: 440 pass 441 442 @overload 443 def _get_features( 444 self: "PrincipleMLChecker", rule: Rule, frame: Literal[False] 445 ) -> Series: 446 pass 447 448 def _get_features_frame(self: "PrincipleMLChecker", features: Series) -> DataFrame: 449 features_frame = features.to_frame().transpose() 450 451 if self._dtypes is None: 452 return features_frame 453 454 for col, dtype in self._dtypes.items(): 455 if features_frame.dtypes[col] != dtype: 456 features_frame[col] = features_frame[col].astype(dtype) 457 458 return features_frame 459 460 def _get_features( 461 self: "PrincipleMLChecker", rule: Rule, frame: bool 462 ) -> Series | DataFrame: 463 features: Series = self._get_raw_features(rule) 464 features = self._preprocess_features(features) 465 466 features["custom.negated.count"] = rule.raw.count(':!"') 467 468 if self._dtypes is None: 469 return features 470 471 for col, dtype in self._dtypes.items(): 472 if col not in features: 473 if col.endswith(".count"): 474 features[col] = 0 475 elif col.endswith(".bool"): 476 features[col] = 0 477 elif col.endswith(".num"): 478 features[col] = -1 479 else: 480 _logger.error( 481 "Unsure how to handle missing feature %s of type %s", 482 col, 483 dtype, 484 ) 485 486 features = features[list(self._dtypes.keys())] # type: ignore reportAssignmentType 487 488 if not frame: 489 return features 490 491 return self._get_features_frame(features)