Blame view

MLTD/src/MLTD_API.py 16.4 KB
0d8c0f816   Thanasis Naskos   initial commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
  # -*- coding: utf-8 -*-
  import signal
  
  from flask import Flask, request, jsonify, json
  from flask_sqlalchemy import SQLAlchemy
  from flask_restful import Resource, Api
  import Training as pdm_train
  import Utils
  from OnlinePrediction import OnlinePrediction
  import os
  from flask_cors import CORS
  from multiprocessing import Process
  import threading
  import queue
  import logging
  import logging.config
  import yaml
  
  MYDIR = os.path.dirname(os.path.realpath(__file__))
  LOGGING_CONF_FILE = os.path.join(MYDIR, "logging.yml")
  
  
  def read_log_conf(yaml_file):
      with open(yaml_file) as f:
          logging.config.dictConfig(yaml.safe_load(f))
  
  # temporarily added for the demo video
  your_rest_server_port = 5000
  
  # We use a Thread to join a subprocess for two reasons:
  # 1) If a subprocess is not joined is considered as a zombie hence it cannot be stopped
  # 2) If the parent process joins the sub-process then the execution is freezed waiting for the sub-process to return/exit.
  # Hence, we use a Thread to join in order to avoid zombie process creation and to allow the parent process to continue its
  # execution.
  class Joiner(threading.Thread):
      def __init__(self, q):
          threading.Thread.__init__(self)
          self.__q = q
  
      def run(self):
          while True:
              child = self.__q.get()
              print(child)
              if child == None:
                  return
              child.join()
  
  
  # stores the Pids of the running processes
  q = queue.Queue()
  running_prediction_processes = []
  params = {}
  
  app = Flask(__name__)
  
  read_log_conf(LOGGING_CONF_FILE)
  app.logger = logging.getLogger("mltd-api")
  app.logger.info("MLTD API is running")
  
  # Cross-Origin Resource Sharing (CORS) - accept all origins - needed in order to communicate with the web interface
  cors = CORS(app, resources={"/*": {"origins": "*"}})
  api = Api(app)
  
  # these details are used for the SQLite connection and handling
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~SQLite~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  basedir = os.path.abspath(os.path.dirname(__file__))
  app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" + os.path.join(
      basedir, "pdm.sqlite"
  )
  app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
  db = SQLAlchemy(app)
  
  
  class Model(db.Model):
      __tablename__ = "models"
      __table_args__ = {"extend_existing": True}
      model_id = db.Column(db.Integer, primary_key=True, autoincrement=True)
      description = db.Column(db.String(1000))
      timedb_host = db.Column(db.String(180))
      timedb_port = db.Column(db.String(5))
      timedb_username = db.Column(db.String(180))
      timedb_password = db.Column(db.String(180))
      timedb_ssl = db.Column(db.String(180))
      timedb_dbname = db.Column(db.String(80))
      asset_id = db.Column(db.String(300))
      timedb_adt_table = db.Column(db.String(300))
      timedb_xlsiem_table = db.Column(db.String(300))
      timedb_od_table = db.Column(db.String(300))
      mp_thres_X = db.Column(db.Integer)
      mp_thres_Y = db.Column(db.Integer)
      mp_thres_Z = db.Column(db.Integer)
      mp_pat_length = db.Column(db.Integer)
      rf_s = db.Column(db.Float)
      rf_midpoint = db.Column(db.String(5))
      hours_before = db.Column(db.String(5))
      time_segments = db.Column(db.String(5))
      incidents = db.relationship(
          "Failure_Incidents",
          backref="models",
          cascade="all, delete-orphan",
          lazy="joined",
      )
  
      def __init__(
          self,
          description,
          timedb_host,
          timedb_port,
          timedb_username,
          timedb_password,
          timedb_ssl,
          timedb_dbname,
          asset_id,
          timedb_adt_table,
          timedb_xlsiem_table,
          timedb_od_table,
          mp_thres_X,
          mp_thres_Y,
          mp_thres_Z,
          mp_pat_length,
          rf_s,
          rf_midpoint,
          hours_before,
          time_segments,
      ):
          self.description = description
          self.timedb_host = timedb_host
          self.timedb_port = timedb_port
          self.timedb_username = timedb_username
          self.timedb_password = timedb_password
          self.timedb_ssl = timedb_ssl
          self.timedb_dbname = timedb_dbname
          self.asset_id = asset_id
          self.timedb_adt_table = timedb_adt_table
          self.timedb_xlsiem_table = timedb_xlsiem_table
          self.timedb_od_table = timedb_od_table
          self.mp_thres_X = mp_thres_X
          self.mp_thres_Y = mp_thres_Y
          self.mp_thres_Z = mp_thres_Z
          self.mp_pat_length = mp_pat_length
          self.rf_s = rf_s
          self.rf_midpoint = rf_midpoint
          self.hours_before = hours_before
          self.time_segments = time_segments
  
      def toString(self):
          fed = "["  # failure_event_dates
          for incident in self.incidents:
              fed += '"' + incident.date + '"'
              fed += ","
          if len(fed) > 1:
              fed = fed[: (len(fed) - 1)]
          fed += "]"
          return (
              '{"model_id":"'
              + str(self.model_id)
              + '","description":"'
              + self.description
              + '","timedb_host":"'
              + self.timedb_host
              + '","timedb_port":"'
              + self.timedb_port
              + '","timedb_username":"'
              + self.timedb_username
              + '","timedb_password":"'
              + self.timedb_password
              + '","timedb_ssl":"'
              + self.timedb_ssl
              + '","timedb_dbname":"'
              + self.timedb_dbname
              + '","asset_id":"'
              + self.asset_id
              + '","timedb_adt_table":"'
              + self.timedb_adt_table
              + '","timedb_xlsiem_table":"'
              + self.timedb_xlsiem_table
              + '","timedb_od_table":"'
              + self.timedb_od_table
              + '","mp_thres_X":'
              + str(self.mp_thres_X)
              + ',"mp_thres_Y":'
              + str(self.mp_thres_Y)
              + ',"mp_thres_Z":'
              + str(self.mp_thres_Z)
              + ',"rf_s":'
              + str(self.rf_s)
              + ',"rf_midpoint":'
              + str(self.rf_midpoint)
              + ',"hours_before":'
              + str(self.hours_before)
              + ',"time_segments":"'
              + str(self.time_segments)
              + '","incidents":'
              + fed
              + "}"
          )
  
      def toJSON(self):
          return json.loads(self.toString())
  
  
  class Failure_Incidents(db.Model):
      __tablename__ = "dates"
      __table_args__ = {"extend_existing": True}
      date_id = db.Column(db.Integer, primary_key=True, autoincrement=True)
      model_id = db.Column(db.Integer, db.ForeignKey("models.model_id"), nullable=False)
      date = db.Column(db.String(24))
  
      def __init__(self, model_id, date):
          self.model_id = model_id
          self.date = date
  
  
  db.create_all()
  
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~end of SQLite~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Training API Endopoints~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  
  # endpoint to create new model
  @app.route("/api/v1.0/mltd/training", methods=["POST"])
  def add_model():
      description = request.json["description"]
      timedb_host = request.json["timedb_host"]
      timedb_port = request.json["timedb_port"]
      timedb_username = request.json["timedb_username"]
      timedb_password = request.json["timedb_password"]
      timedb_ssl = request.json["timedb_ssl"]
      timedb_dbname = request.json["timedb_dbname"]
      asset_id = request.json["asset_id"]
      timedb_adt_table = request.json["timedb_adt_table"]
      timedb_xlsiem_table = request.json["timedb_xlsiem_table"]
      timedb_od_table = request.json["timedb_od_table"]
      mp_thres_X = request.json["mp_thres_X"]
      mp_thres_Y = request.json["mp_thres_Y"]
      mp_thres_Z = request.json["mp_thres_Z"]
      mp_pat_length = request.json["mp_pat_length"]
      rf_s = request.json["rf_s"]
      rf_midpoint = request.json["rf_midpoint"]
      hours_before = request.json["hours_before"]
      time_segments = request.json["time_segments"]
      dates = request.json["dates"]
      rre = bool(request.json["rre"])
      rfe = bool(request.json["rfe"])
      kofe = bool(request.json["kofe"])
      mil_over = bool(request.json["mil_over"])
      fs = bool(request.json["fs"])
  
  
      new_model = Model(
          description,
          timedb_host,
          timedb_port,
          timedb_username,
          timedb_password,
          timedb_ssl,
          timedb_dbname,
          asset_id,
          timedb_adt_table,
          timedb_xlsiem_table,
          timedb_od_table,
          mp_thres_X,
          mp_thres_Y,
          mp_thres_Z,
          mp_pat_length,
          rf_s,
          rf_midpoint,
          hours_before,
          time_segments,
      )
      db.session.add(new_model)
      db.session.commit()
  
      for date in dates:
          date = Failure_Incidents(model_id=new_model.model_id, date=date)
          db.session.add(date)
      db.session.commit()
      app.logger.info("MLTD Training is triggered")
      proc = Process(
          target=pdm_train.do_the_training,
          args=(
              new_model.model_id,
              timedb_host,
              timedb_port,
              timedb_username,
              timedb_password,
              timedb_ssl,
              timedb_dbname,
              asset_id,
              timedb_adt_table,
              timedb_xlsiem_table,
              timedb_od_table,
              mp_thres_X,
              mp_thres_Y,
              mp_thres_Z,
              mp_pat_length,
              rf_s,
              rf_midpoint,
              hours_before,
              time_segments,
              dates,
              False,
              rre,
              rfe,
              kofe,
              mil_over,
              fs,
          ),
      )
      proc.start()
      print(proc.pid)
      q.put(proc)
      j = Joiner(q)
      j.start()
  
      return jsonify({"model_id": new_model.model_id, "process_id": proc.pid})
  
  
  # endpoint to check whether the training process is still runinng
  @app.route("/api/v1.0/mltd/training/status/<int:pid>", methods=["GET"])
  def is_running(pid):
      return jsonify({"is_running": check_pid(pid)})
  
  def check_pid(pid):
      """ Check For the existence of a unix pid. """
      try:
          os.kill(pid, 0)
      except OSError:
          return False
      else:
          return True
  
  # endpoint to show all models
  @app.route("/api/v1.0/mltd/training", methods=["GET"])
  def get_model():
      all_models_list = "["
      all_models = Model.query.all()
      for model in all_models:
          all_models_list += model.toString()
          all_models_list += ","
      if len(all_models_list) > 1:
          all_models_list = all_models_list[: (len(all_models_list) - 1)]
      all_models_list += "]"
      print(all_models_list)
      return jsonify(json.loads(all_models_list))
  
  
  # endpoint to get model detail by id
  @app.route("/api/v1.0/mltd/training/<id>", methods=["GET"])
  def model_detail(id):
      model = Model.query.get(id)
      print(model.toString())
      return jsonify(model.toJSON())  # model_schema.jsonify(model)
  
  
  # endpoint to delete model
  @app.route("/api/v1.0/mltd/training/<id>", methods=["DELETE"])
  def model_delete(id):
      model = Model.query.get(id)
      db.session.delete(model)
      db.session.commit()
      if os.path.exists(os.path.join(basedir, "train_" + str(id) + ".dat")):
          os.remove(os.path.join(basedir, "train_" + str(id) + ".dat"))
      else:
          app.logger.error("The file does not exist")
      return jsonify(model.toJSON())
  
  
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~End of Training API Endopoints~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  
  
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Threats Identification API Endopoints~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  # endpoint to start the online prediction process
  @app.route("/api/v1.0/mltd/threat-identification/<trainID>/<top>", methods=["GET"])
  def obtain_new_threats(trainID, top):
      sql_conn = Utils.create_sqlite_connection("pdm.sqlite")
      time_segments = Utils.select_model_attribute(
          sql_conn, trainID, "time_segments"
      )
      pdm_online = OnlinePrediction()
      (
          pat_length,
          weak_bins_mapping,
          mp,
          dataset_values,
          regr,
          feature_importance,
          artificial_events_generation,
      ) = pdm_online.load_data("train_" + str(trainID) + ".dat")
      imp_events = feature_importance[: int(top)].index.values
      return jsonify({"important_events": str(list(imp_events)),"timeframe":time_segments})
  
  
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~End of Threats Identification API Endopoints~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  
  
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Prediction API Endopoints~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  
  # endpoint to start the online prediction process
  @app.route("/api/v1.0/mltd/prediction", methods=["POST"])
  def run_prediction():
      # get data from the JSON of the POST request
      model_id = int(request.json["model_id"])
      mqtt_host = request.json["mqtt_host"]
      mqtt_port = int(request.json["mqtt_port"])
      mqtt_topic = request.json["mqtt_topic"]
      prediction_threshold = request.json["prediction_threshold"]
      report_timedb_host = request.json["report_timedb_host"]
      report_timedb_port = request.json["report_timedb_port"]
      report_timedb_username = request.json["report_timedb_username"]
      report_timedb_password = request.json["report_timedb_password"]
      report_timedb_database = request.json["report_timedb_database"]
      report_timedb_table = request.json["report_timedb_table"]
      report_timedb_ssl = request.json["report_timedb_ssl"]
      report_asset_id = request.json["report_asset_id"]
      app.logger.info("MLTD Online is triggered")
      # Start the online monitoring process in a new subprocess
      pdm_online = OnlinePrediction()
      proc = Process(
          target=pdm_online.start_online_prediction_MQTT,
          args=(
              model_id,
              mqtt_host,
              mqtt_port,
              mqtt_topic,
              prediction_threshold,
              report_timedb_host,
              report_timedb_port,
              report_timedb_username,
              report_timedb_password,
              report_timedb_database,
              report_timedb_table,
              report_timedb_ssl,
              report_asset_id,
          ),
      )
      proc.start()
  
      # Put the Pid (Process id) in a Queue in order to be able to handle it (stop/get status)
      q.put(proc)
      running_prediction_processes.append(proc.pid)
      # Join the new process using a new Thread (see at be beginning of the script why)
      j = Joiner(q)
      j.start()
  
      # return the process_id
      return jsonify({"process_id": proc.pid})
  
  
  # endpoint to get the defaults values for the online prediction process
  @app.route("/api/v1.0/mltd/prediction/defaults", methods=["GET"])
  def get_defaults():
      global params
      if len(params) == 0:
          default_values = {}
          default_values["model_id"] = "1"
          default_values["mqtt_host"] = "mqtt-broker"
          default_values["mqtt_port"] = "1883"
          default_values["mqtt_topic"] = "hot-forming-press/meas"
          default_values["prediction_threshold"] = "0.5"
          default_values["report_dss_host"] = "http://localhost"
          default_values["report_dss_port"] = "9100"
          default_values["report_timedb_host"] = "https://localhost"
          default_values["report_timedb_port"] = "8086"
          default_values["report_timedb_username"] = ""
          default_values["report_timedb_password"] = ""
          default_values["report_timedb_database"] = "Axoom1"
          default_values["report_timedb_table"] = "Predicted_failures"
          default_values["report_timedb_ssl"] = "True"
          return jsonify(default_values)
      else:
          return jsonify(params)
  
  
  # endpoint to get the defaults values for the online prediction process
  @app.route("/api/v1.0/mltd/prediction/saveParams", methods=["POST"])
  def save_params():
      global params
      params = request.json
      return jsonify("params saved")
  
  
  # endpoint to stop a specific online prediction instance
  @app.route("/api/v1.0/mltd/prediction/stop/<int:pid>", methods=["GET"])
  def stop_prediction(pid):
      if Utils.check_pid(pid):
          os.kill(pid, signal.SIGTERM)
          running_prediction_processes.remove(pid)
      return jsonify("stopped")
  
  
  # endpoint to get all the online prediction instances
  @app.route("/api/v1.0/mltd/prediction/status", methods=["GET"])
  def get_running_predictions():
      global running_prediction_processes
      pids = {}
      for pid in running_prediction_processes:
          pids[pid] = "/api/v1.0/mltd/prediction/stop/" + str(pid)
      return jsonify(pids)
  
  
  # endpoint to check whether the prediction instance is still running
  @app.route("/api/v1.0/mltd/prediction/status/<int:pid>", methods=["GET"])
  def is_running_prediction(pid):
      return jsonify({"is_running": Utils.check_pid(pid)})
  
  
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~End of Prediction API Endopoints~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  
  if __name__ == "__main__":
      app.run(
          debug=False, host="0.0.0.0"
      )  # open the API to everyone (i.e. host=0.0.0.0 (unsafe)), api accessible from 5000 port