# -*- coding: utf-8 -*-
import signal
from flask import Flask, request, jsonify, json
from flask_sqlalchemy import SQLAlchemy
from flask_restful import Resource, Api
import Training as pdm_train
import Utils
from OnlinePrediction import OnlinePrediction
import os
from flask_cors import CORS
from multiprocessing import Process
import threading
import queue
import logging
import logging.config
import yaml
MYDIR = os.path.dirname(os.path.realpath(__file__))
LOGGING_CONF_FILE = os.path.join(MYDIR, "logging.yml")
def read_log_conf(yaml_file):
with open(yaml_file) as f:
logging.config.dictConfig(yaml.safe_load(f))
# temporarily added for the demo video
your_rest_server_port = 5000
# We use a Thread to join a subprocess for two reasons:
# 1) If a subprocess is not joined is considered as a zombie hence it cannot be stopped
# 2) If the parent process joins the sub-process then the execution is freezed waiting for the sub-process to return/exit.
# Hence, we use a Thread to join in order to avoid zombie process creation and to allow the parent process to continue its
# execution.
class Joiner(threading.Thread):
def __init__(self, q):
threading.Thread.__init__(self)
self.__q = q
def run(self):
while True:
child = self.__q.get()
print(child)
if child == None:
return
child.join()
# stores the Pids of the running processes
q = queue.Queue()
running_prediction_processes = []
params = {}
app = Flask(__name__)
read_log_conf(LOGGING_CONF_FILE)
app.logger = logging.getLogger("mltd-api")
app.logger.info("MLTD API is running")
# Cross-Origin Resource Sharing (CORS) - accept all origins - needed in order to communicate with the web interface
cors = CORS(app, resources={"/*": {"origins": "*"}})
api = Api(app)
# these details are used for the SQLite connection and handling
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~SQLite~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
basedir = os.path.abspath(os.path.dirname(__file__))
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" + os.path.join(
basedir, "pdm.sqlite"
)
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
db = SQLAlchemy(app)
class Model(db.Model):
__tablename__ = "models"
__table_args__ = {"extend_existing": True}
model_id = db.Column(db.Integer, primary_key=True, autoincrement=True)
description = db.Column(db.String(1000))
timedb_host = db.Column(db.String(180))
timedb_port = db.Column(db.String(5))
timedb_username = db.Column(db.String(180))
timedb_password = db.Column(db.String(180))
timedb_ssl = db.Column(db.String(180))
timedb_dbname = db.Column(db.String(80))
asset_id = db.Column(db.String(300))
timedb_adt_table = db.Column(db.String(300))
timedb_xlsiem_table = db.Column(db.String(300))
timedb_od_table = db.Column(db.String(300))
mp_thres_X = db.Column(db.Integer)
mp_thres_Y = db.Column(db.Integer)
mp_thres_Z = db.Column(db.Integer)
mp_pat_length = db.Column(db.Integer)
rf_s = db.Column(db.Float)
rf_midpoint = db.Column(db.String(5))
hours_before = db.Column(db.String(5))
time_segments = db.Column(db.String(5))
incidents = db.relationship(
"Failure_Incidents",
backref="models",
cascade="all, delete-orphan",
lazy="joined",
)
def __init__(
self,
description,
timedb_host,
timedb_port,
timedb_username,
timedb_password,
timedb_ssl,
timedb_dbname,
asset_id,
timedb_adt_table,
timedb_xlsiem_table,
timedb_od_table,
mp_thres_X,
mp_thres_Y,
mp_thres_Z,
mp_pat_length,
rf_s,
rf_midpoint,
hours_before,
time_segments,
):
self.description = description
self.timedb_host = timedb_host
self.timedb_port = timedb_port
self.timedb_username = timedb_username
self.timedb_password = timedb_password
self.timedb_ssl = timedb_ssl
self.timedb_dbname = timedb_dbname
self.asset_id = asset_id
self.timedb_adt_table = timedb_adt_table
self.timedb_xlsiem_table = timedb_xlsiem_table
self.timedb_od_table = timedb_od_table
self.mp_thres_X = mp_thres_X
self.mp_thres_Y = mp_thres_Y
self.mp_thres_Z = mp_thres_Z
self.mp_pat_length = mp_pat_length
self.rf_s = rf_s
self.rf_midpoint = rf_midpoint
self.hours_before = hours_before
self.time_segments = time_segments
def toString(self):
fed = "[" # failure_event_dates
for incident in self.incidents:
fed += '"' + incident.date + '"'
fed += ","
if len(fed) > 1:
fed = fed[: (len(fed) - 1)]
fed += "]"
return (
'{"model_id":"'
+ str(self.model_id)
+ '","description":"'
+ self.description
+ '","timedb_host":"'
+ self.timedb_host
+ '","timedb_port":"'
+ self.timedb_port
+ '","timedb_username":"'
+ self.timedb_username
+ '","timedb_password":"'
+ self.timedb_password
+ '","timedb_ssl":"'
+ self.timedb_ssl
+ '","timedb_dbname":"'
+ self.timedb_dbname
+ '","asset_id":"'
+ self.asset_id
+ '","timedb_adt_table":"'
+ self.timedb_adt_table
+ '","timedb_xlsiem_table":"'
+ self.timedb_xlsiem_table
+ '","timedb_od_table":"'
+ self.timedb_od_table
+ '","mp_thres_X":'
+ str(self.mp_thres_X)
+ ',"mp_thres_Y":'
+ str(self.mp_thres_Y)
+ ',"mp_thres_Z":'
+ str(self.mp_thres_Z)
+ ',"rf_s":'
+ str(self.rf_s)
+ ',"rf_midpoint":'
+ str(self.rf_midpoint)
+ ',"hours_before":'
+ str(self.hours_before)
+ ',"time_segments":"'
+ str(self.time_segments)
+ '","incidents":'
+ fed
+ "}"
)
def toJSON(self):
return json.loads(self.toString())
class Failure_Incidents(db.Model):
__tablename__ = "dates"
__table_args__ = {"extend_existing": True}
date_id = db.Column(db.Integer, primary_key=True, autoincrement=True)
model_id = db.Column(db.Integer, db.ForeignKey("models.model_id"), nullable=False)
date = db.Column(db.String(24))
def __init__(self, model_id, date):
self.model_id = model_id
self.date = date
db.create_all()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~end of SQLite~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Training API Endopoints~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
# endpoint to create new model
@app.route("/api/v1.0/mltd/training", methods=["POST"])
def add_model():
description = request.json["description"]
timedb_host = request.json["timedb_host"]
timedb_port = request.json["timedb_port"]
timedb_username = request.json["timedb_username"]
timedb_password = request.json["timedb_password"]
timedb_ssl = request.json["timedb_ssl"]
timedb_dbname = request.json["timedb_dbname"]
asset_id = request.json["asset_id"]
timedb_adt_table = request.json["timedb_adt_table"]
timedb_xlsiem_table = request.json["timedb_xlsiem_table"]
timedb_od_table = request.json["timedb_od_table"]
mp_thres_X = request.json["mp_thres_X"]
mp_thres_Y = request.json["mp_thres_Y"]
mp_thres_Z = request.json["mp_thres_Z"]
mp_pat_length = request.json["mp_pat_length"]
rf_s = request.json["rf_s"]
rf_midpoint = request.json["rf_midpoint"]
hours_before = request.json["hours_before"]
time_segments = request.json["time_segments"]
dates = request.json["dates"]
rre = bool(request.json["rre"])
rfe = bool(request.json["rfe"])
kofe = bool(request.json["kofe"])
mil_over = bool(request.json["mil_over"])
fs = bool(request.json["fs"])
new_model = Model(
description,
timedb_host,
timedb_port,
timedb_username,
timedb_password,
timedb_ssl,
timedb_dbname,
asset_id,
timedb_adt_table,
timedb_xlsiem_table,
timedb_od_table,
mp_thres_X,
mp_thres_Y,
mp_thres_Z,
mp_pat_length,
rf_s,
rf_midpoint,
hours_before,
time_segments,
)
db.session.add(new_model)
db.session.commit()
for date in dates:
date = Failure_Incidents(model_id=new_model.model_id, date=date)
db.session.add(date)
db.session.commit()
app.logger.info("MLTD Training is triggered")
proc = Process(
target=pdm_train.do_the_training,
args=(
new_model.model_id,
timedb_host,
timedb_port,
timedb_username,
timedb_password,
timedb_ssl,
timedb_dbname,
asset_id,
timedb_adt_table,
timedb_xlsiem_table,
timedb_od_table,
mp_thres_X,
mp_thres_Y,
mp_thres_Z,
mp_pat_length,
rf_s,
rf_midpoint,
hours_before,
time_segments,
dates,
False,
rre,
rfe,
kofe,
mil_over,
fs,
),
)
proc.start()
print(proc.pid)
q.put(proc)
j = Joiner(q)
j.start()
return jsonify({"model_id": new_model.model_id, "process_id": proc.pid})
# endpoint to check whether the training process is still runinng
@app.route("/api/v1.0/mltd/training/status/<int:pid>", methods=["GET"])
def is_running(pid):
return jsonify({"is_running": check_pid(pid)})
def check_pid(pid):
""" Check For the existence of a unix pid. """
try:
os.kill(pid, 0)
except OSError:
return False
else:
return True
# endpoint to show all models
@app.route("/api/v1.0/mltd/training", methods=["GET"])
def get_model():
all_models_list = "["
all_models = Model.query.all()
for model in all_models:
all_models_list += model.toString()
all_models_list += ","
if len(all_models_list) > 1:
all_models_list = all_models_list[: (len(all_models_list) - 1)]
all_models_list += "]"
print(all_models_list)
return jsonify(json.loads(all_models_list))
# endpoint to get model detail by id
@app.route("/api/v1.0/mltd/training/<id>", methods=["GET"])
def model_detail(id):
model = Model.query.get(id)
print(model.toString())
return jsonify(model.toJSON()) # model_schema.jsonify(model)
# endpoint to delete model
@app.route("/api/v1.0/mltd/training/<id>", methods=["DELETE"])
def model_delete(id):
model = Model.query.get(id)
db.session.delete(model)
db.session.commit()
if os.path.exists(os.path.join(basedir, "train_" + str(id) + ".dat")):
os.remove(os.path.join(basedir, "train_" + str(id) + ".dat"))
else:
app.logger.error("The file does not exist")
return jsonify(model.toJSON())
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~End of Training API Endopoints~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Threats Identification API Endopoints~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
# endpoint to start the online prediction process
@app.route("/api/v1.0/mltd/threat-identification/<trainID>/<top>", methods=["GET"])
def obtain_new_threats(trainID, top):
sql_conn = Utils.create_sqlite_connection("pdm.sqlite")
time_segments = Utils.select_model_attribute(
sql_conn, trainID, "time_segments"
)
pdm_online = OnlinePrediction()
(
pat_length,
weak_bins_mapping,
mp,
dataset_values,
regr,
feature_importance,
artificial_events_generation,
) = pdm_online.load_data("train_" + str(trainID) + ".dat")
imp_events = feature_importance[: int(top)].index.values
return jsonify({"important_events": str(list(imp_events)),"timeframe":time_segments})
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~End of Threats Identification API Endopoints~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Prediction API Endopoints~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
# endpoint to start the online prediction process
@app.route("/api/v1.0/mltd/prediction", methods=["POST"])
def run_prediction():
# get data from the JSON of the POST request
model_id = int(request.json["model_id"])
mqtt_host = request.json["mqtt_host"]
mqtt_port = int(request.json["mqtt_port"])
mqtt_topic = request.json["mqtt_topic"]
prediction_threshold = request.json["prediction_threshold"]
report_timedb_host = request.json["report_timedb_host"]
report_timedb_port = request.json["report_timedb_port"]
report_timedb_username = request.json["report_timedb_username"]
report_timedb_password = request.json["report_timedb_password"]
report_timedb_database = request.json["report_timedb_database"]
report_timedb_table = request.json["report_timedb_table"]
report_timedb_ssl = request.json["report_timedb_ssl"]
report_asset_id = request.json["report_asset_id"]
app.logger.info("MLTD Online is triggered")
# Start the online monitoring process in a new subprocess
pdm_online = OnlinePrediction()
proc = Process(
target=pdm_online.start_online_prediction_MQTT,
args=(
model_id,
mqtt_host,
mqtt_port,
mqtt_topic,
prediction_threshold,
report_timedb_host,
report_timedb_port,
report_timedb_username,
report_timedb_password,
report_timedb_database,
report_timedb_table,
report_timedb_ssl,
report_asset_id,
),
)
proc.start()
# Put the Pid (Process id) in a Queue in order to be able to handle it (stop/get status)
q.put(proc)
running_prediction_processes.append(proc.pid)
# Join the new process using a new Thread (see at be beginning of the script why)
j = Joiner(q)
j.start()
# return the process_id
return jsonify({"process_id": proc.pid})
# endpoint to get the defaults values for the online prediction process
@app.route("/api/v1.0/mltd/prediction/defaults", methods=["GET"])
def get_defaults():
global params
if len(params) == 0:
default_values = {}
default_values["model_id"] = "1"
default_values["mqtt_host"] = "mqtt-broker"
default_values["mqtt_port"] = "1883"
default_values["mqtt_topic"] = "hot-forming-press/meas"
default_values["prediction_threshold"] = "0.5"
default_values["report_dss_host"] = "http://localhost"
default_values["report_dss_port"] = "9100"
default_values["report_timedb_host"] = "https://localhost"
default_values["report_timedb_port"] = "8086"
default_values["report_timedb_username"] = ""
default_values["report_timedb_password"] = ""
default_values["report_timedb_database"] = "Axoom1"
default_values["report_timedb_table"] = "Predicted_failures"
default_values["report_timedb_ssl"] = "True"
return jsonify(default_values)
else:
return jsonify(params)
# endpoint to get the defaults values for the online prediction process
@app.route("/api/v1.0/mltd/prediction/saveParams", methods=["POST"])
def save_params():
global params
params = request.json
return jsonify("params saved")
# endpoint to stop a specific online prediction instance
@app.route("/api/v1.0/mltd/prediction/stop/<int:pid>", methods=["GET"])
def stop_prediction(pid):
if Utils.check_pid(pid):
os.kill(pid, signal.SIGTERM)
running_prediction_processes.remove(pid)
return jsonify("stopped")
# endpoint to get all the online prediction instances
@app.route("/api/v1.0/mltd/prediction/status", methods=["GET"])
def get_running_predictions():
global running_prediction_processes
pids = {}
for pid in running_prediction_processes:
pids[pid] = "/api/v1.0/mltd/prediction/stop/" + str(pid)
return jsonify(pids)
# endpoint to check whether the prediction instance is still running
@app.route("/api/v1.0/mltd/prediction/status/<int:pid>", methods=["GET"])
def is_running_prediction(pid):
return jsonify({"is_running": Utils.check_pid(pid)})
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~End of Prediction API Endopoints~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
if __name__ == "__main__":
app.run(
debug=False, host="0.0.0.0"
) # open the API to everyone (i.e. host=0.0.0.0 (unsafe)), api accessible from 5000 port