# -*- coding: utf-8 -*-
import time, json
import dateutil
import numpy as np
import pandas as pd
import Utils
try:
import cPickle as pickle
except:
import pickle
import PredictionKorvesis as pdm
import paho.mqtt.client as paho
import ReportTimeDB
import ReportSyslog
import logging
import logging.config
import yaml
import os
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
MYDIR = os.path.dirname(os.path.realpath(__file__))
LOGGING_CONF_FILE = os.path.join(MYDIR, "logging.yml")
class OnlinePrediction:
def __init__(self):
self.data_dates = []
self.data_values = []
self.source_ip = "10.0.2.15"
self.target_ip = "10.0.2.15"
self.read_log_conf(LOGGING_CONF_FILE)
self.logger = logging.getLogger("mltd-online")
self.logger.info("Online MLTD is running")
def read_log_conf(self, yaml_file):
with open(yaml_file) as f:
logging.config.dictConfig(yaml.safe_load(f))
def on_connect(self, client, userdata, flags, rc):
if rc == 0:
self.logger.debug("Connected to broker")
self.Connected = True # Signal connection
else:
self.logger.error("Connection failed")
def on_message(self, client, userdata, message):
"""
{
"asset_id": "string",
"timestamp": "2020-02-27T13:40:18.224Z",
"event_alarm": [
{
"event_alarm_id": "string",
"event_alarm_char": "string",
"name": "string",
"source_ip": "string",
"source_port": 0,
"destination_ip": "string",
"destination_port": 0,
"priority": 0,
"confidence": 0
}
]
}
:param client:
:param userdata:
:param message:
:return:
"""
self.logger.debug(
"Event received: " + str(json.loads(message.payload.decode("UTF-8")))
)
json.loads(message.payload)
data_dates = []
data_values = []
measDict = json.loads(message.payload.decode("UTF-8"))
if measDict["asset_id"] == self.asset_id:
for event in range(len(measDict["event_alarm"])):
data_dates.append(measDict["timestamp"])
# datetime.datetime.fromtimestamp(measDict["timestamp"]).strftime(
# "%Y-%m-%dT%H:%M:%SZ"
# )
# )
self.source_ip = measDict["event_alarm"][event]["source_ip"]
self.target_ip = measDict["event_alarm"][event]["destination_ip"]
event_alarm_id = measDict["event_alarm"][event]["event_alarm_id"]
data_values.append(event_alarm_id)
self.do_the_monitoring(data_dates, data_values)
def do_the_monitoring(self, data_dates=[], data_values=[]):
predictions = []
data_values = self.data_values + data_values
data_dates = self.data_dates + data_dates
if len(data_dates) > 0:
first_date_str = data_dates[0]
last_date_str = data_dates[-1]
first_date = dateutil.parser.parse(first_date_str)
last_date = dateutil.parser.parse(last_date_str)
duration = (last_date - first_date).total_seconds()
if duration >= self.ts_seconds:
self.data_values = []
self.data_dates = []
self.logger.info(f"Events Received: {len(data_values)}"
f" - Duration: {round(duration,2)} secs")
self.logger.info("Prediction triggered")
predictions = self.predict(data_dates, data_values)
else:
self.data_values = data_values
self.data_dates = data_dates
if max(predictions) > self.sigmoid_threshold:
timeframe = Utils.sigmoid_mins(
max(predictions),
self.rf_s,
Utils.convert_hours_to_mins(Utils.strtime_to_hours(self.rf_midpoint)),
self.hours_before,
)
self.logger.info(f"A prominent security incident is predicted"
f" - Risk level: {round(max(predictions),2)}"
f" - Expected timeframe: {round(timeframe,2)} secs")
ReportSyslog.report(
self.asset_id, max(predictions) * 100, timeframe, self.source_ip, self.target_ip
)
ReportTimeDB.report(
self.time_db_client, self.asset_id, max(predictions) * 100, timeframe
)
self.logger.info(f"The incident was reportered on TimescaleDB and Rsyslog server")
else:
self.logger.info(
f"The predicted risk {round(max(predictions),2)} is "
f"below the alarm threshold {round(self.sigmoid_threshold,2)}"
)
def predict(self, data_dates=[], data_values=[]):
dataset = pd.DataFrame({"Timestamps": data_dates, "Event_id": data_values})
predictions = pdm.predict(
self.regr, dataset, self.time_segments, self.feature_importance.index
)
self.logger.debug(f"Risk predictions: {predictions}")
return predictions
def form_dataset(self, dates_list, events_list, feature_importance):
# Create a Pandas dataframe with all the non zero event ids
# TODO handle differently the zero event ids based on some policy
loc = 0
dataset = pd.DataFrame(columns=["Timestamps", "Event_id"])
if len(events_list) != abs(sum(events_list)):
for i in range(len(events_list)):
if i < len(dates_list) and feature_importance.index.contains(
events_list[i]
):
dataset.loc[loc] = pd.Series(
{"Timestamps": dates_list[i], "Event_id": events_list[i]}
)
loc += 1
if not dataset.empty:
# dropping ALL duplicate values
dataset.drop_duplicates(subset="Timestamps", keep="first", inplace=True)
dataset.set_index(
pd.to_datetime(dataset["Timestamps"]), drop=False, inplace=True
)
self.logger.debug(f"Formed dataest: {dataset}")
return dataset
def load_data(self, filename):
infile = open(filename, "rb")
pat_length = pickle.load(infile)
weak_bins_mapping = pickle.load(infile)
mp = pickle.load(infile)
train_dataset_values = np.array(pickle.load(infile))
regr = pickle.load(infile)
feature_importance = pickle.load(infile)
artificial_events_generation = pickle.load(infile)
infile.close()
return (
pat_length,
weak_bins_mapping,
mp,
train_dataset_values,
regr,
feature_importance,
artificial_events_generation,
)
def start_online_prediction_MQTT(
self,
trainID,
broker_address,
port,
mqtt_topic,
prediction_threshold,
report_time_db_host,
report_time_db_port,
report_time_db_username,
report_time_db_password,
report_time_db_database,
report_time_db_table,
report_time_db_ssl,
report_asset_id,
):
self.sigmoid_threshold = prediction_threshold
self.time_db_host = report_time_db_host
self.time_db_port = report_time_db_port
self.time_db_username = report_time_db_username
self.time_db_password = report_time_db_password
self.time_db_database = report_time_db_database
self.time_db_table = report_time_db_table
self.time_db_ssl = report_time_db_ssl
self.asset_id = report_asset_id
sql_conn = Utils.create_sqlite_connection("pdm.sqlite")
self.time_segments = Utils.select_model_attribute(
sql_conn, trainID, "time_segments"
)
self.rf_s = Utils.select_model_attribute(sql_conn, trainID, "rf_s")
self.rf_midpoint = Utils.select_model_attribute(
sql_conn, trainID, "rf_midpoint"
)
self.hours_before = Utils.select_model_attribute(
sql_conn, trainID, "hours_before"
)
ts_hours = Utils.strtime_to_hours(self.time_segments)
self.ts_seconds = ts_hours * 3600
# load the data from pickle (binary) files - should consider to move to a database solution(?)
(
self.pat_length,
self.weak_bins_mapping,
self.mp,
self.dataset_values,
self.regr,
self.feature_importance,
self.artificial_events_generation,
) = self.load_data("train_" + str(trainID) + ".dat")
self.time_db_client = ReportTimeDB.connect(self.time_db_host, self.time_db_port,
self.time_db_database,
self.time_db_username,
self.time_db_password, self.time_db_ssl, )
self.Connected = False
client = paho.Client(
"Prediction_client" + str(time.time())
) # create new instance
client.on_connect = self.on_connect # attach function to callback
client.on_message = self.on_message # attach function to callback
client.connect(broker_address, port=port) # connect to broker
client.loop_start() # start the loop
while self.Connected != True: # Wait for connection
time.sleep(0.1)
client.subscribe(mqtt_topic)
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
self.logger.debug("exiting")
client.disconnect()
client.loop_stop()
if __name__ == "__main__":
op = OnlinePrediction()
op.start_online_prediction_MQTT(
10,
"localhost",
1884,
"auth/incidents",
0.1,
"83.212.116.5",
5432,
"postgres",
"xs?Z7HsY",
"kea",
"mltd",
False,
"server",
)