1"""`PrincipleMLChecker`."""
2
3import copy
4import logging
5import os
6import pickle
7from collections import Counter
8from collections.abc import Iterable
9from typing import Any, Literal, overload
10
11import suricata_check
12import xgboost
13from pandas import DataFrame, Series
14from sklearn.metrics import f1_score, make_scorer, precision_score, recall_score
15from sklearn.model_selection import (
16 GridSearchCV,
17 RepeatedStratifiedKFold,
18 cross_val_score,
19)
20from sklearn.pipeline import Pipeline
21from suricata_check.checkers.interface import CheckerInterface
22from suricata_check.utils.checker import get_rule_option, get_rule_suboptions
23from suricata_check.utils.checker_typing import ISSUES_TYPE, Issue
24from suricata_check.utils.rule import Rule
25
26from suricata_check_design_principles._version import SURICATA_CHECK_DIR
27from suricata_check_design_principles.checkers.principle._utils import get_message
28
29_PICKLE_PATH = os.path.join(SURICATA_CHECK_DIR, "data", "principle_ml_checker.pkl")
30N_JOBS = 8
31
32
33_logger = logging.getLogger(__name__)
34
35
36COUNT_COLUMNS = (
37 "flowbits.isset.count",
38 "flowbits.isntoset.count",
39 "flowint.isset.count",
40 "flowint.isntoset.count",
41 "xbits.isset.count",
42 "xbits.uisnotset.count",
43 "http.uri.count",
44 "http.method.count",
45 "dns.query.count",
46 "content.count",
47 "pcre.count",
48 "startswith.count",
49 "bsize.count",
50 "depth.count",
51 "urilen.count",
52 "flow.from_server.count",
53 "flow.to_server.count",
54 "flow.from_client.count",
55 "flow.to_client.count",
56)
57STRING_COLUMNS = ()
58DROPDOWN_COLUMNS = (
59 "proto",
60 "threshold.type",
61)
62NUMERICAL_COLUMNS = ("threshold.count",)
63SPLITTABLE_FEATURES = (
64 "metadata",
65 "flow",
66 "threshold",
67)
68MSG_KEYWORDS = ("Suspicious", "CVE", "Vulnerability", "Response")
69MSG_COLUMNS = ("msg.contains." + keyword for keyword in MSG_KEYWORDS)
70IP_KEYWORDS = ("$HOME_NET", "$HTTP_SERVERS", "$EXTERNAL_NET", "any")
71IP_COLUMNS = tuple(
72 ["source_addr.contains." + keyword for keyword in IP_KEYWORDS]
73 + ["dest_addr.contains." + keyword for keyword in IP_KEYWORDS]
74)
75
76
77PIPELINE = Pipeline(
78 [
79 (
80 "classify",
81 xgboost.XGBClassifier(),
82 )
83 ]
84)
85# https://shengyg.github.io/repository/machine%20learning/2017/02/25/Complete-Guide-to-Parameter-Tuning-xgboost.html
86PARAM_GRID: list[dict] = [
87 {
88 # Fixed parameters for problem / desired complexity
89 "classify__n_estimators": [1000],
90 "classify__objective": ["binary:logistic"],
91 ###
92 # Parameters to optimize
93 ## Learning rate
94 "classify__eta": [0.01, 0.1, 0.3],
95 ## Tree parameters
96 "classify__subsample": [1.0],
97 "classify__colsample_bytree": [0.25, 0.5, 0.75, 1.0],
98 "classify__scale_pos_weight": [0.1, 0.25, 0.5, 1.0, 2.0, 4.0, 10.0],
99 "classify__max_depth": [1, 3],
100 "classify__min_child_weight": [1],
101 "classify__gamma": [0, 0.1],
102 ## Regularization
103 "classify__lambda": [0, 0.01, 0.1],
104 "classify__alpha": [0, 0.01, 0.1],
105 },
106]
107
108PRECISION_WEIGHT = 10
109SCORER = make_scorer(
110 lambda y, y_pred: (PRECISION_WEIGHT + 1)
111 / (
112 PRECISION_WEIGHT / (precision_score(y, y_pred, zero_division=1) + 1e-10) # type: ignore reportArgumentType
113 + 1 / (recall_score(y, y_pred, zero_division=0) + 1e-10) # type: ignore reportArgumentType
114 )
115)
116SPLITTER = RepeatedStratifiedKFold(n_splits=2, n_repeats=10)
117GRIDSEARCHCV = GridSearchCV(
118 PIPELINE, PARAM_GRID, cv=SPLITTER, scoring=SCORER, error_score="raise", n_jobs=N_JOBS, verbose=1 # type: ignore reportArgumentType
119)
120
121
[docs]
122class PrincipleMLChecker(CheckerInterface):
123 """The `PrincipleChecker` contains several checks based on the Ruling the Unruly paper and target specificity and coverage.
124
125 Codes Q000-Q009 report on non-adherence to rule design principles similar to Q000-Q009.
126 Differently, they are the result of machine learning analysis of the rules.
127 """
128
129 count_columns = COUNT_COLUMNS
130 string_columns = STRING_COLUMNS
131 dropdown_columns = DROPDOWN_COLUMNS
132 numerical_columns = NUMERICAL_COLUMNS
133 splittable_features = SPLITTABLE_FEATURES
134 msg_keywords = MSG_KEYWORDS
135 msg_columns = MSG_COLUMNS
136 ip_keywords = IP_KEYWORDS
137 ip_columns = IP_COLUMNS
138
139 codes = {
140 "Q000": {"severity": logging.INFO},
141 "Q001": {"severity": logging.INFO},
142 "Q002": {"severity": logging.INFO},
143 "Q003": {"severity": logging.INFO},
144 "Q004": {"severity": logging.INFO},
145 "Q005": {"severity": logging.INFO},
146 }
147
148 enabled_by_default = (
149 False # Since the checker is relatively slow, it is disabled by default
150 )
151
152 _dtypes: dict[str, Any] | None = None
153 _models: dict[str, Pipeline] = {}
154
155 def __new__(
156 cls: type["PrincipleMLChecker"],
157 filepath: str | None = _PICKLE_PATH,
158 *args: tuple,
159 **kwargs: dict,
160 ) -> "PrincipleMLChecker":
161 """Returns a new or unpickled instance of the class."""
162 if filepath:
163 if os.path.exists(filepath):
164 with open(filepath, "rb") as f:
165 inst = pickle.load(f)
166
167 if not inst.__class__.__name__ == cls.__name__:
168 _logger.error("Unpickled object is not of type %s", cls)
169 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
170 elif not hasattr(inst, "_models") or len(inst._models) == 0:
171 _logger.error("Unpickled object does not have trained models")
172 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
173 else:
174 if "include" in kwargs:
175 inst.include = kwargs["include"]
176 _logger.info("Unpickled object with trained models successfully")
177 else:
178 _logger.warning("No model found for PrincipleMLChecker at %s", filepath)
179 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
180 else:
181 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
182
183 return inst
184
[docs]
185 def __getnewargs__(
186 self: "PrincipleMLChecker",
187 ) -> tuple:
188 """Returns the arguments to be passed to the __new__ method when unpickling."""
189 return (None,)
190
191 def _check_rule(
192 self: "PrincipleMLChecker",
193 rule: Rule,
194 ) -> ISSUES_TYPE:
195 issues: ISSUES_TYPE = []
196
197 if len(self._models) == 0:
198 return issues
199
200 for code, model in self._models.items():
201 if model.predict(self._get_features(rule, True))[0]:
202 issues.append(
203 Issue(
204 code=code,
205 message=get_message(code),
206 )
207 )
208
209 return issues
210
[docs]
211 def train( # noqa: C901
212 self: "PrincipleMLChecker",
213 df: DataFrame,
214 rule_col: str = "rule.rule",
215 principle_cols: dict[str, str] = {
216 "Q000": "labelled.no_proxy",
217 "Q001": "labelled.success",
218 "Q002": "labelled.thresholded",
219 "Q003": "labelled.exceptions",
220 "Q004": "labelled.generalized_match_content",
221 "Q005": "labelled.generalized_match_location",
222 },
223 reuse_models: bool = False,
224 ) -> None:
225 """Train several models for the checker to detect issues in rules.
226
227 The checker class with trained models is stored in a pickle file (`_PICKLE_PATH`).
228 """
229 self._dtypes = None
230 if not reuse_models:
231 self._models = {}
232
233 # Extract features and determine feature dtypes
234 X_train = self._get_train_df(df[rule_col]) # noqa: N806
235
236 for col in X_train.columns:
237 try:
238 X_train[col].var()
239 _logger.debug("Detected column: %s", col)
240 except:
241 _logger.error("Error with column %s", col)
242 _logger.error(X_train[col])
243 raise
244
245 # # Drop zero variance columns
246 X_train = X_train.drop( # noqa: N806
247 X_train.columns[(X_train.fillna(-1337).var(axis=0) <= 0)].to_list(), # type: ignore reportAttributeAccessIssue
248 axis=1,
249 )
250
251 # Drop columns with too few occurrences of possible values
252 for col in X_train.columns:
253 if (
254 not col.endswith(".count")
255 and not col.endswith(".num")
256 and not col.endswith(".len")
257 ):
258 if X_train[col].value_counts().min() <= 1:
259 X_train = X_train.drop( # noqa: N806
260 [col],
261 axis=1,
262 )
263
264 for col in X_train.columns:
265 try:
266 X_train[col].var()
267 _logger.info("Using column: %s", col)
268 except:
269 _logger.error("Error with column %s", col)
270 _logger.error(X_train[col])
271 raise
272
273 # Store used features and their dtypes
274 self._dtypes = X_train.dtypes.to_dict()
275 _logger.debug(self._dtypes)
276
277 # Redo feature extraction now that FE parameters are set
278 X_train = self._get_train_df(df[rule_col]) # noqa: N806
279
280 _logger.info(
281 "Training model with features: [%s]",
282 ", ".join([str(x) for x in X_train.columns]),
283 )
284
285 _logger.info(X_train)
286
287 for code, col in principle_cols.items():
288 y_true = df[col].to_numpy() == 0
289
290 if not reuse_models or code not in self._models:
291 # Train new model with grid search to find optimal parameters
292 gridsearchcv: GridSearchCV = copy.deepcopy(GRIDSEARCHCV)
293
294 gridsearchcv.fit(X_train, y_true)
295
296 _logger.info("Code %s params: %s", code, gridsearchcv.best_params_)
297 _logger.info(
298 "Code %s Weighted F1-score: %s", code, gridsearchcv.best_score_
299 )
300
301 self._models[code] = gridsearchcv.best_estimator_
302
303 precision = cross_val_score(
304 self._models[code],
305 X_train,
306 y_true,
307 scoring=make_scorer(precision_score, zero_division=0.0),
308 cv=SPLITTER,
309 n_jobs=N_JOBS,
310 ).mean()
311 recall = cross_val_score(
312 self._models[code],
313 X_train,
314 y_true,
315 scoring=make_scorer(recall_score, zero_division=0.0),
316 cv=SPLITTER,
317 n_jobs=N_JOBS,
318 ).mean()
319 f1 = cross_val_score(
320 self._models[code],
321 X_train,
322 y_true,
323 scoring=make_scorer(f1_score, zero_division=0.0),
324 cv=SPLITTER,
325 n_jobs=N_JOBS,
326 ).mean()
327 _logger.info("Code %s Precision score: %s", code, precision)
328 _logger.info("Code %s Recall score: %s", code, recall)
329 _logger.info("Code %s F1-score: %s", code, f1)
330
331 # Refit model with training data.
332 self._models[code].fit(X_train, y_true)
333
334 pickle.dump(self, open(_PICKLE_PATH, "wb"))
335
336 def _get_train_df(self: "PrincipleMLChecker", rules: Iterable[str]) -> DataFrame:
337 feature_vectors = []
338 for rule in rules:
339 parsed_rule = suricata_check.utils.rule.parse(rule)
340 assert parsed_rule is not None
341 feature_vectors.append(self._get_features(parsed_rule, False))
342
343 return DataFrame(feature_vectors)
344
345 def _get_raw_features( # noqa: C901
346 self: "PrincipleMLChecker", rule: Rule
347 ) -> Series:
348 d: dict[str, str | int | None] = {"proto": get_rule_option(rule, "proto")}
349
350 options = rule.options
351
352 for option in options:
353 d[option.name] = option.value
354
355 counter = Counter([option.name for option in options])
356 for option, count in counter.items():
357 d[option + ".count"] = count
358
359 for option in options:
360 if option.name not in self.splittable_features:
361 continue
362
363 suboptions = [
364 {"name": k, "value": v}
365 for k, v in get_rule_suboptions(rule, option.name, warn=False)
366 ]
367
368 if len(suboptions) == 0:
369 continue
370
371 for suboption in suboptions:
372 d[option.name + "." + suboption["name"]] = suboption["value"]
373
374 counter = Counter([suboption["name"] for suboption in suboptions])
375 for suboption, count in counter.items():
376 d[option.name + "." + suboption + ".count"] = count
377
378 msg = get_rule_option(rule, "msg")
379 assert msg is not None
380 msg = msg.lower()
381 for col, keyword in zip(self.msg_columns, self.msg_keywords):
382 d[col] = keyword.lower() in msg
383
384 source_addr = get_rule_option(rule, "source_addr")
385 assert source_addr is not None
386 source_addr = source_addr.lower()
387 for keyword in self.ip_keywords:
388 col = "source_addr.contains." + keyword
389 d[col] = keyword.lower() in source_addr
390
391 dest_addr = get_rule_option(rule, "dest_addr")
392 assert dest_addr is not None
393 dest_addr = dest_addr.lower()
394 for keyword in self.ip_keywords:
395 col = "dest_addr.contains." + keyword
396 d[col] = keyword.lower() in dest_addr
397
398 return Series(d)
399
400 def _preprocess_features(self: "PrincipleMLChecker", data: Series) -> Series:
401 original_cols: set[str] = set(data.index)
402
403 for col in self.string_columns:
404 if col not in data:
405 continue
406 data[col + ".len"] = len(data[col])
407 data = data.drop(col)
408
409 for col in self.dropdown_columns:
410 if col not in data:
411 continue
412 data[col + "." + data[col] + ".bool"] = 1
413 data = data.drop(col)
414
415 for col in self.numerical_columns:
416 if col not in data:
417 continue
418 data[col + ".num"] = float(data[col]) # type: ignore reportArgumentType
419 data = data.drop(col)
420
421 remaining_cols = (
422 original_cols
423 - set(self.count_columns)
424 - set(self.string_columns)
425 - set(self.dropdown_columns)
426 - set(self.numerical_columns)
427 - set(self.msg_columns)
428 - set(self.ip_columns)
429 )
430
431 for col in remaining_cols:
432 data = data.drop(col)
433
434 return data
435
436 @overload
437 def _get_features(
438 self: "PrincipleMLChecker", rule: Rule, frame: Literal[True]
439 ) -> DataFrame:
440 pass
441
442 @overload
443 def _get_features(
444 self: "PrincipleMLChecker", rule: Rule, frame: Literal[False]
445 ) -> Series:
446 pass
447
448 def _get_features_frame(self: "PrincipleMLChecker", features: Series) -> DataFrame:
449 features_frame = features.to_frame().transpose()
450
451 if self._dtypes is None:
452 return features_frame
453
454 for col, dtype in self._dtypes.items():
455 if features_frame.dtypes[col] != dtype:
456 features_frame[col] = features_frame[col].astype(dtype)
457
458 return features_frame
459
460 def _get_features(
461 self: "PrincipleMLChecker", rule: Rule, frame: bool
462 ) -> Series | DataFrame:
463 features: Series = self._get_raw_features(rule)
464 features = self._preprocess_features(features)
465
466 features["custom.negated.count"] = rule.raw.count(':!"')
467
468 if self._dtypes is None:
469 return features
470
471 for col, dtype in self._dtypes.items():
472 if col not in features:
473 if col.endswith(".count"):
474 features[col] = 0
475 elif col.endswith(".bool"):
476 features[col] = 0
477 elif col.endswith(".num"):
478 features[col] = -1
479 else:
480 _logger.error(
481 "Unsure how to handle missing feature %s of type %s",
482 col,
483 dtype,
484 )
485
486 features = features[list(self._dtypes.keys())] # type: ignore reportAssignmentType
487
488 if not frame:
489 return features
490
491 return self._get_features_frame(features)