1"""`PrincipleMLChecker`."""
2
3import copy
4import logging
5import os
6import pickle
7from collections import Counter
8from collections.abc import Iterable
9from typing import Any, Literal, Optional, Union, overload
10
11import idstools.rule
12import xgboost
13from pandas import DataFrame, Series
14from sklearn.metrics import f1_score, make_scorer, precision_score, recall_score
15from sklearn.model_selection import (
16 GridSearchCV,
17 RepeatedStratifiedKFold,
18 cross_val_score,
19)
20from sklearn.pipeline import Pipeline
21from suricata_check.checkers.interface.checker import CheckerInterface
22from suricata_check.utils.checker import get_rule_option, get_rule_suboptions
23from suricata_check.utils.checker_typing import ISSUES_TYPE, Issue
24
25from suricata_check_design_principles._version import SURICATA_CHECK_DIR
26from suricata_check_design_principles.checkers.principle._utils import get_message
27
28_PICKLE_PATH = os.path.join(SURICATA_CHECK_DIR, "data", "principle_ml_checker.pkl")
29N_JOBS = 8
30
31
32_logger = logging.getLogger(__name__)
33
34
35COUNT_COLUMNS = (
36 "flowbits.isset.count",
37 "flowbits.isntoset.count",
38 "flowint.isset.count",
39 "flowint.isntoset.count",
40 "xbits.isset.count",
41 "xbits.uisnotset.count",
42 "http.uri.count",
43 "http.method.count",
44 "dns.query.count",
45 "content.count",
46 "pcre.count",
47 "startswith.count",
48 "bsize.count",
49 "depth.count",
50 "urilen.count",
51 "flow.from_server.count",
52 "flow.to_server.count",
53 "flow.from_client.count",
54 "flow.to_client.count",
55)
56STRING_COLUMNS = ()
57DROPDOWN_COLUMNS = (
58 "proto",
59 "threshold.type",
60)
61NUMERICAL_COLUMNS = ("threshold.count",)
62SPLITTABLE_FEATURES = (
63 "metadata",
64 "flow",
65 "threshold",
66)
67MSG_KEYWORDS = ("Suspicious", "CVE", "Vulnerability", "Response")
68MSG_COLUMNS = ("msg.contains." + keyword for keyword in MSG_KEYWORDS)
69IP_KEYWORDS = ("$HOME_NET", "$HTTP_SERVERS", "$EXTERNAL_NET", "any")
70IP_COLUMNS = tuple(
71 ["source_addr.contains." + keyword for keyword in IP_KEYWORDS]
72 + ["dest_addr.contains." + keyword for keyword in IP_KEYWORDS]
73)
74
75
76PIPELINE = Pipeline(
77 [
78 (
79 "classify",
80 xgboost.XGBClassifier(),
81 )
82 ]
83)
84# https://shengyg.github.io/repository/machine%20learning/2017/02/25/Complete-Guide-to-Parameter-Tuning-xgboost.html
85PARAM_GRID: list[dict] = [
86 {
87 # Fixed parameters for problem / desired complexity
88 "classify__n_estimators": [1000],
89 "classify__objective": ["binary:logistic"],
90 ###
91 # Parameters to optimize
92 ## Learning rate
93 "classify__eta": [0.01, 0.1, 0.3],
94 ## Tree parameters
95 "classify__subsample": [1.0],
96 "classify__colsample_bytree": [0.25, 0.5, 0.75, 1.0],
97 "classify__scale_pos_weight": [0.1, 0.25, 0.5, 1.0, 2.0, 4.0, 10.0],
98 "classify__max_depth": [1, 3],
99 "classify__min_child_weight": [1],
100 "classify__gamma": [0, 0.1],
101 ## Regularization
102 "classify__lambda": [0, 0.01, 0.1],
103 "classify__alpha": [0, 0.01, 0.1],
104 },
105]
106
107PRECISION_WEIGHT = 10
108SCORER = make_scorer(
109 lambda y, y_pred: (PRECISION_WEIGHT + 1)
110 / (
111 PRECISION_WEIGHT / (precision_score(y, y_pred, zero_division=1) + 1e-10) # type: ignore reportArgumentType
112 + 1 / (recall_score(y, y_pred, zero_division=0) + 1e-10) # type: ignore reportArgumentType
113 )
114)
115SPLITTER = RepeatedStratifiedKFold(n_splits=2, n_repeats=10)
116GRIDSEARCHCV = GridSearchCV(
117 PIPELINE, PARAM_GRID, cv=SPLITTER, scoring=SCORER, error_score="raise", n_jobs=N_JOBS, verbose=1 # type: ignore reportArgumentType
118)
119
120
[docs]
121class PrincipleMLChecker(CheckerInterface):
122 """The `PrincipleChecker` contains several checks based on the Ruling the Unruly paper and target specificity and coverage.
123
124 Codes Q000-Q009 report on non-adherence to rule design principles similar to Q000-Q009.
125 Differently, they are the result of machine learning analysis of the rules.
126 """
127
128 count_columns = COUNT_COLUMNS
129 string_columns = STRING_COLUMNS
130 dropdown_columns = DROPDOWN_COLUMNS
131 numerical_columns = NUMERICAL_COLUMNS
132 splittable_features = SPLITTABLE_FEATURES
133 msg_keywords = MSG_KEYWORDS
134 msg_columns = MSG_COLUMNS
135 ip_keywords = IP_KEYWORDS
136 ip_columns = IP_COLUMNS
137
138 codes = {
139 "Q000": {"severity": logging.INFO},
140 "Q001": {"severity": logging.INFO},
141 "Q002": {"severity": logging.INFO},
142 "Q003": {"severity": logging.INFO},
143 "Q004": {"severity": logging.INFO},
144 "Q005": {"severity": logging.INFO},
145 }
146
147 enabled_by_default = (
148 False # Since the checker is relatively slow, it is disabled by default
149 )
150
151 _dtypes: Optional[dict[str, Any]] = None
152 _models: dict[str, Pipeline] = {}
153
154 def __new__(
155 cls: type["PrincipleMLChecker"],
156 filepath: Optional[str] = _PICKLE_PATH,
157 *args: tuple,
158 **kwargs: dict,
159 ) -> "PrincipleMLChecker":
160 """Returns a new or unpickled instance of the class."""
161 if filepath:
162 if os.path.exists(filepath):
163 with open(filepath, "rb") as f:
164 inst = pickle.load(f)
165
166 if not inst.__class__.__name__ == cls.__name__:
167 _logger.error("Unpickled object is not of type %s", cls)
168 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
169 elif not hasattr(inst, "_models") or len(inst._models) == 0:
170 _logger.error("Unpickled object does not have trained models")
171 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
172 else:
173 if "include" in kwargs:
174 inst.include = kwargs["include"]
175 _logger.info("Unpickled object with trained models successfully")
176 else:
177 _logger.warning("No model found for PrincipleMLChecker at %s", filepath)
178 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
179 else:
180 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
181
182 return inst
183
[docs]
184 def __getnewargs__(
185 self: "PrincipleMLChecker",
186 ) -> tuple:
187 """Returns the arguments to be passed to the __new__ method when unpickling."""
188 return (None,)
189
190 def _check_rule(
191 self: "PrincipleMLChecker",
192 rule: idstools.rule.Rule,
193 ) -> ISSUES_TYPE:
194 issues: ISSUES_TYPE = []
195
196 if len(self._models) == 0:
197 return issues
198
199 for code, model in self._models.items():
200 if model.predict(self._get_features(rule, True))[0]:
201 issues.append(
202 Issue(
203 code=code,
204 message=get_message(code),
205 )
206 )
207
208 return issues
209
[docs]
210 def train( # noqa: C901
211 self: "PrincipleMLChecker",
212 df: DataFrame,
213 rule_col: str = "rule.rule",
214 principle_cols: dict[str, str] = {
215 "Q000": "labelled.no_proxy",
216 "Q001": "labelled.success",
217 "Q002": "labelled.thresholded",
218 "Q003": "labelled.exceptions",
219 "Q004": "labelled.generalized_match_content",
220 "Q005": "labelled.generalized_match_location",
221 },
222 reuse_models: bool = False,
223 ) -> None:
224 """Train several models for the checker to detect issues in rules.
225
226 The checker class with trained models is stored in a pickle file (`_PICKLE_PATH`).
227 """
228 self._dtypes = None
229 if not reuse_models:
230 self._models = {}
231
232 # Extract features and determine feature dtypes
233 X_train = self._get_train_df(df[rule_col]) # noqa: N806
234
235 for col in X_train.columns:
236 try:
237 X_train[col].var()
238 _logger.debug("Detected column: %s", col)
239 except:
240 _logger.error("Error with column %s", col)
241 _logger.error(X_train[col])
242 raise
243
244 # # Drop zero variance columns
245 X_train = X_train.drop( # noqa: N806
246 X_train.columns[(X_train.fillna(-1337).var(axis=0) <= 0)].to_list(), # type: ignore reportAttributeAccessIssue
247 axis=1,
248 )
249
250 # Drop columns with too few occurrences of possible values
251 for col in X_train.columns:
252 if (
253 not col.endswith(".count")
254 and not col.endswith(".num")
255 and not col.endswith(".len")
256 ):
257 if X_train[col].value_counts().min() <= 1:
258 X_train = X_train.drop( # noqa: N806
259 [col],
260 axis=1,
261 )
262
263 for col in X_train.columns:
264 try:
265 X_train[col].var()
266 _logger.info("Using column: %s", col)
267 except:
268 _logger.error("Error with column %s", col)
269 _logger.error(X_train[col])
270 raise
271
272 # Store used features and their dtypes
273 self._dtypes = X_train.dtypes.to_dict()
274 _logger.debug(self._dtypes)
275
276 # Redo feature extraction now that FE parameters are set
277 X_train = self._get_train_df(df[rule_col]) # noqa: N806
278
279 _logger.info(
280 "Training model with features: [%s]",
281 ", ".join([str(x) for x in X_train.columns]),
282 )
283
284 _logger.info(X_train)
285
286 for code, col in principle_cols.items():
287 y_true = df[col].to_numpy() == 0
288
289 if not reuse_models or code not in self._models:
290 # Train new model with grid search to find optimal parameters
291 gridsearchcv: GridSearchCV = copy.deepcopy(GRIDSEARCHCV)
292
293 gridsearchcv.fit(X_train, y_true)
294
295 _logger.info("Code %s params: %s", code, gridsearchcv.best_params_)
296 _logger.info(
297 "Code %s Weighted F1-score: %s", code, gridsearchcv.best_score_
298 )
299
300 self._models[code] = gridsearchcv.best_estimator_
301
302 precision = cross_val_score(
303 self._models[code],
304 X_train,
305 y_true,
306 scoring=make_scorer(precision_score, zero_division=0.0),
307 cv=SPLITTER,
308 n_jobs=N_JOBS,
309 ).mean()
310 recall = cross_val_score(
311 self._models[code],
312 X_train,
313 y_true,
314 scoring=make_scorer(recall_score, zero_division=0.0),
315 cv=SPLITTER,
316 n_jobs=N_JOBS,
317 ).mean()
318 f1 = cross_val_score(
319 self._models[code],
320 X_train,
321 y_true,
322 scoring=make_scorer(f1_score, zero_division=0.0),
323 cv=SPLITTER,
324 n_jobs=N_JOBS,
325 ).mean()
326 _logger.info("Code %s Precision score: %s", code, precision)
327 _logger.info("Code %s Recall score: %s", code, recall)
328 _logger.info("Code %s F1-score: %s", code, f1)
329
330 # Refit model with training data.
331 self._models[code].fit(X_train, y_true)
332
333 pickle.dump(self, open(_PICKLE_PATH, "wb"))
334
335 def _get_train_df(self: "PrincipleMLChecker", rules: Iterable[str]) -> DataFrame:
336 feature_vectors = []
337 for rule in rules:
338 parsed_rule = idstools.rule.parse(rule)
339 assert parsed_rule is not None
340 feature_vectors.append(self._get_features(parsed_rule, False))
341
342 return DataFrame(feature_vectors)
343
344 def _get_raw_features( # noqa: C901
345 self: "PrincipleMLChecker", rule: idstools.rule.Rule
346 ) -> Series:
347 d: dict[str, Optional[Union[str, int]]] = {
348 "proto": get_rule_option(rule, "proto")
349 }
350
351 options = rule["options"]
352
353 for option in options:
354 d[option["name"]] = option["value"]
355
356 counter = Counter([option["name"] for option in options])
357 for option, count in counter.items():
358 d[option + ".count"] = count
359
360 for option in options:
361 if option["name"] not in self.splittable_features:
362 continue
363
364 suboptions = [
365 {"name": k, "value": v}
366 for k, v in get_rule_suboptions(rule, option["name"], warn=False)
367 ]
368
369 if len(suboptions) == 0:
370 continue
371
372 for suboption in suboptions:
373 d[option["name"] + "." + suboption["name"]] = suboption["value"]
374
375 counter = Counter([suboption["name"] for suboption in suboptions])
376 for suboption, count in counter.items():
377 d[option["name"] + "." + suboption + ".count"] = count
378
379 msg = get_rule_option(rule, "msg")
380 assert msg is not None
381 msg = msg.lower()
382 for col, keyword in zip(self.msg_columns, self.msg_keywords):
383 d[col] = keyword.lower() in msg
384
385 source_addr = get_rule_option(rule, "source_addr")
386 assert source_addr is not None
387 source_addr = source_addr.lower()
388 for keyword in self.ip_keywords:
389 col = "source_addr.contains." + keyword
390 d[col] = keyword.lower() in source_addr
391
392 dest_addr = get_rule_option(rule, "dest_addr")
393 assert dest_addr is not None
394 dest_addr = dest_addr.lower()
395 for keyword in self.ip_keywords:
396 col = "dest_addr.contains." + keyword
397 d[col] = keyword.lower() in dest_addr
398
399 return Series(d)
400
401 def _preprocess_features(self: "PrincipleMLChecker", data: Series) -> Series:
402 original_cols: set[str] = set(data.index)
403
404 for col in self.string_columns:
405 if col not in data:
406 continue
407 data[col + ".len"] = len(data[col])
408 data = data.drop(col)
409
410 for col in self.dropdown_columns:
411 if col not in data:
412 continue
413 data[col + "." + data[col] + ".bool"] = 1
414 data = data.drop(col)
415
416 for col in self.numerical_columns:
417 if col not in data:
418 continue
419 data[col + ".num"] = float(data[col]) # type: ignore reportArgumentType
420 data = data.drop(col)
421
422 remaining_cols = (
423 original_cols
424 - set(self.count_columns)
425 - set(self.string_columns)
426 - set(self.dropdown_columns)
427 - set(self.numerical_columns)
428 - set(self.msg_columns)
429 - set(self.ip_columns)
430 )
431
432 for col in remaining_cols:
433 data = data.drop(col)
434
435 return data
436
437 @overload
438 def _get_features(
439 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[True]
440 ) -> DataFrame:
441 pass
442
443 @overload
444 def _get_features(
445 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[False]
446 ) -> Series:
447 pass
448
449 def _get_features_frame(self: "PrincipleMLChecker", features: Series) -> DataFrame:
450 features_frame = features.to_frame().transpose()
451
452 if self._dtypes is None:
453 return features_frame
454
455 for col, dtype in self._dtypes.items():
456 if features_frame.dtypes[col] != dtype:
457 features_frame[col] = features_frame[col].astype(dtype)
458
459 return features_frame
460
461 def _get_features(
462 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: bool
463 ) -> Union[Series, DataFrame]:
464 features: Series = self._get_raw_features(rule)
465 features = self._preprocess_features(features)
466
467 features["custom.negated.count"] = rule["raw"].count(':!"')
468
469 if self._dtypes is None:
470 return features
471
472 for col, dtype in self._dtypes.items():
473 if col not in features:
474 if col.endswith(".count"):
475 features[col] = 0
476 elif col.endswith(".bool"):
477 features[col] = 0
478 elif col.endswith(".num"):
479 features[col] = -1
480 else:
481 _logger.error(
482 "Unsure how to handle missing feature %s of type %s",
483 col,
484 dtype,
485 )
486
487 features = features[list(self._dtypes.keys())] # type: ignore reportAssignmentType
488
489 if not frame:
490 return features
491
492 return self._get_features_frame(features)