Commit 26bf521c authored by Soeren Klemm's avatar Soeren Klemm

Merge remote-tracking branch 'origin/develop'

parents e7a5c599 8984a428
.idea
/nbproject
/gui/nbproject/private/
/nbproject/private/
/.idea/
*.pyc
baristalog.txt
/gui/nbproject/
*.caffemodel
*.solverstate
/test/examples/mnist/sessions/*
/test/examples/mnist/logs/*
/test/examples/mnist/snapshots/*
/test/examples/mnist/mnist_test_lmdb/*.mdb
/test/examples/mnist/mnist_train_lmdb/*.mdb
/view_settings.txt
/test/examples/mnist/data/*
*.xml
*.py.orig
/barista.conf
/baristalog.txt
\ No newline at end of file
"""
This package groups some additional constraints required for the Barista program.
Those constraints might not be required to create valid caffe protoxt files, but they are required for the Barista
program due to whatever.
Depending on the subpackages, the included constraints need to be ensured permanently or only at specific times during
runtime.
"""
\ No newline at end of file
"""Common (helper) functions to check constraints."""
def isNumber(input):
"""Check whether input is a valid (float) number."""
try:
float(input)
return True
except ValueError:
return False
def isPositiveNumber(input):
"""Check whether input is a valid (float) number > 0."""
return isNumber(input) and input > 0
def isPositiveInteger(input):
"""Check whether input is a valid integer > 0."""
return isPositiveNumber(input) and float(input).is_integer()
\ No newline at end of file
"""
This subpackage contains all constraints that need to be ensured permanently.
So the given methods need to be called each time one of the handled constraints might get broken, which is usually
identically to all situations in which the user could edit the raw prototxt file or equivalent data (as in the project
file).
"""
\ No newline at end of file
"""
This module defines permanent constraints for a Barista project.
"""
from backend.barista.utils.logger import Log
from backend.barista.constraints.permanent.solver import ensureSolverConstraints
logID = Log.getCallerId('constraints/project')
def ensureProjectDataConstraints(projectData):
"""Take the current state of projectData and manipulate it to handle some special cases.
This method should be called each time when the project data has been changed. It will adjust some values, which
would be valid for the prototxt-syntax in general, but underlay further constraints especially for Barista.
"""
projectData["solver"] = ensureSolverConstraints(projectData["solver"])
# add further constraint types like e.g.:
#projectData["network"] = ensureNetworkConstraints(projectData["network"])
return projectData
"""
This module defines permanent solver constraints.
"""
import os
import backend
from backend.barista.utils.logger import Log
logID = Log.getCallerId('constraints/solver')
def ensureSolverConstraints(solverDictionary):
"""Ensure that all constraints for the given solverDictionary are valid.
Sets static values and removes invalid values.
"""
# The file names inside of a session are static and must not be altered by the user
if "net" not in solverDictionary or solverDictionary["net"] != backend.barista.session.session_utils.Paths.FILE_NAME_NET_INTERNAL:
Log.log("The solver property 'net' must point to the generated network file. "
"Value has been changed from '{}' to '{}'.".format(
solverDictionary["net"] if "net" in solverDictionary else "None",
backend.barista.session.session_utils.Paths.FILE_NAME_NET_INTERNAL
), logID)
solverDictionary["net"] = backend.barista.session.session_utils.Paths.FILE_NAME_NET_INTERNAL
# An additional net definition inside of the solver would be inconsistent to the separately handled network
if "net_param" in solverDictionary:
Log.log("The solver property 'net_param' is not supported as it would be inconsistent with the separately "
"handled network. Property has been removed.", logID)
del solverDictionary["net_param"]
# a snapshot_prefix containing a path is not supported either
if "snapshot_prefix" in solverDictionary:
head, tail = os.path.split(solverDictionary["snapshot_prefix"])
if len(head) > 0:
Log.log("The solver property 'snapshot_prefix' contained an unsupported path. "
"Property was shortened from '{}' to '{}'.".format(
solverDictionary["snapshot_prefix"],
tail
), logID)
solverDictionary["snapshot_prefix"] = tail
return solverDictionary
"""
This subpackage contains all constraints that need to be ensured before/during a new session run.
So the given methods need to be called each time a new session is about to be started.
"""
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import os
import re
from abc import abstractmethod
class SessionCommon():
def __init__(self):
self.lastSolverState = None
self.directory = None
@abstractmethod
def getLogs(self):
pass
@abstractmethod
def getSnapshotDirectory(self):
pass
def getLastSnapshot(self):
""" Return the last snapshot/solverstate for this session.
The last snapshot name is searched in the log files.
"""
if self.last_solverstate:
if os.path.isfile(os.path.join(self.directory, self.last_solverstate)):
return self.last_solverstate
self.last_solverstate = self._getLastSnapshotFromLogFiles()
if self.last_solverstate:
if os.path.isfile(os.path.join(self.directory, self.last_solverstate)):
return self.last_solverstate
self.last_solverstate = self._getLastSnapshotFromSnapshotDirectory()
if self.last_solverstate:
if os.path.isfile(os.path.join(self.directory, self.last_solverstate)):
return self.last_solverstate
return None
def _getLastSnapshotFromLogFiles(self):
""" Try to find the last snapshot name in the log file.
Return the name of the solverstate file if it was found.
"""
# get all log files
log_files = {}
regex_filename = re.compile('[\d]+\.([\d]+)\.log$')
for entry in os.listdir(self.getLogs()):
filename_match = regex_filename.search(entry)
if filename_match:
# key files by run id
try:
run_id = int(filename_match.group(1))
log_files[run_id] = entry
except:
pass
last_solverstate = None
for run_id in reversed(sorted(log_files.keys())):
with open(os.path.join(self.getLogs(), log_files[run_id])) as f:
# find the last snapshot in the file
regex_snapshot = re.compile(
'Snapshotting solver state to (?:binary proto|HDF5) file (.+\.solverstate[\.\w-]*)')
for line in f:
snapshot_match = regex_snapshot.search(line)
if snapshot_match:
last_solverstate = snapshot_match.group(1)
if last_solverstate:
return last_solverstate
def _getLastSnapshotFromSnapshotDirectory(self, basename=False):
""" Try to find the last snapshot in the snapshot directory.
Return the name of the solverstate file if it was found.
"""
solver_state = None
max_iter = -1
regex_iter = re.compile('iter_([\d]+)\.solverstate[\.\w-]*$')
for entry in os.listdir(self.getSnapshotDirectory()):
iter_match = regex_iter.search(entry)
if iter_match:
try:
iter_id = int(iter_match.group(1))
if iter_id > max_iter:
max_iter = iter_id
solver_state = entry
except:
pass
if solver_state:
if basename:
return solver_state
return os.path.join(self.getSnapshotDirectory(), solver_state)
import multiprocessing
import sys
if sys.version_info[0] == 2:
from Queue import PriorityQueue, Empty
else:
from queue import PriorityQueue, Empty
from threading import Thread
class Singleton(type):
"""This metaclass is used to provide the singleton pattern in a generic way.
See http://stackoverflow.com/a/6798042 for source and further explanation.
TODO outsource this class to use the same pattern in the complete project?
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class SessionPool:
""" A simple thread pool implementation for parsing log files of sessions.
A job is added by adding a session.
Jobs are ordered by the session_id(priority), so new sessions have a higher
priority.
The pool allocates and starts threads on demand. It allocates no more then
CPU_COUNT - 2 threads, but at least 1.
"""
__metaclass__ = Singleton
def __init__(self):
self.sessions = PriorityQueue()
self.pool = []
self.__start_jobs = False
self.MAX_THREADS = multiprocessing.cpu_count() - 2
if self.MAX_THREADS <= 0:
self.MAX_THREADS = 1
def addSession(self, session):
""" Add a session to to queue and start a thread.
"""
self.sessions.put(session)
if self.__start_jobs:
self.__startJob()
def activate(self, emptyJob):
""" Activates this pool by starting threads.
"""
self.__start_jobs = True
self.emptyJob = emptyJob
qs = self.sessions.qsize()
starts = min(qs, self.MAX_THREADS)
for i in range(0, starts):
self.__startJob()
# private methods
def __startJob(self):
""" Remove inactive threads from the pool and create a new for the new
job.
"""
if self.__start_jobs is False:
return
# clean thread pool
jc = len(self.pool)
for ri in range(jc, 0, -1):
i = ri - 1
thread = self.pool[i]
if thread is not None:
if thread.is_alive() is False:
self.pool.pop(i)
# create a new thread
jc = len(self.pool)
if jc < self.MAX_THREADS:
t = Thread(target=self.__executeJob)
t.start()
self.pool.append(t)
def __executeJob(self):
""" The target method of a thread.
Poll for sessions and execute their parser.
"""
try:
while True:
session = self.sessions.get(True, 1)
if session:
session.parseOldLogs()
parser = session.getParser()
parser.parseLog()
self.sessions.task_done()
else:
break
except Empty:
if self.emptyJob:
self.emptyJob()
return
except Exception as e:
print(str(e))
class State:
""" Pseudo Enumeration for the states a session could be in.
"""
UNDEFINED = 0
WAITING = 1
RUNNING = 2
PAUSED = 3
FAILED = 5
FINISHED = 6
INVALID = 7
NOTCONNECTED = 8
def baristaSessionFile(directory):
import os
""" Returns the filename of the config-json file in the given directory """
return os.path.join(directory, "sessionstate.json")
class Paths:
# class "constants" defining common file names inside of a session
FILE_NAME_SOLVER = "solver.prototxt"
FILE_NAME_NET_ORIGINAL = "net-original.prototxt"
FILE_NAME_NET_INTERNAL = "net-internal.prototxt"
FILE_NAME_SESSION_JSON = "sessionstate.json"
class Events:
import re
events = {
# exceptions
'FileNotFound': re.compile('Check failed: mdb_status == 0 \(([\d]+) vs\. 0\) No such file or directory'),
'OutOfGPU': re.compile('Check failed: error == cudaSuccess \(([\d]+) vs\. 0\) out of memory'),
'NoSnapshotPrefix': re.compile('Check failed: param_\.has_snapshot_prefix\(\) In solver params, snapshot is specified but snapshot_prefix is not'),
# session finished
'OptimizationDone': re.compile('Optimization Done'),
# iterations
'max_iter': re.compile('max_iter:[\s]+([\d]+)'),
# snapshots
'state_snapshot': re.compile('Snapshotting solver state to (?:binary proto|HDF5) file (.+\.solverstate[\.\w-]*)'),
'model_snapshot': re.compile('Snapshotting to (?:binary proto|HDF5) file (.+\.caffemodel[\.\w-]*)')
}
\ No newline at end of file
from PyQt5.QtCore import Qt
from PyQt5.Qt import QObject
import PyQt5.QtCore as QtCore
import inspect
import datetime
import time as tim
'''
Model/Helper-Classes
'''
class MessageType:
def __init__(self, id, description):
self.typeId = id
if(id == -1): # DockElementConsole.ALL
self.color = Logger.COLORS[Logger.TEXT.typeId] # not existiing now
else:
self.color = Logger.COLORS[id]
self.description = description
class LogLine:
def __init__(self, line, caller, msgType, time):
self.line = line
self.caller = caller
self.msgType = msgType
self.time = time
class Caller:
def __init__(self, id, description, customColor):
# if no id is given, default to 0
if not id:
id = 0
self.callerId = id
id += 2 # color offset for reserved colors black and red
if id >= len(Logger.COLORS) or not customColor:
self.color = Logger.COLORS[
Logger.TEXT.typeId]
else:
self.color = Logger.COLORS[id]
self.description = description
self.used = False
def setUsed(self):
self.used = True
class LogCaller:
""" A interface for users of the Logger class.
If a class implements this interface, it can use the Logger methods log and
error without delivering the callerId
"""
def getCallerId(self):
pass
class Logger(QObject):
"""
TODO: Selection is removed on insertion, fix needed?
TODO: Logfile schreiben
"""
newLine = QtCore.pyqtSignal(object)
sigRefreshGui = QtCore.pyqtSignal(object)
sigRefreshCallers = QtCore.pyqtSignal(object)
# Color Constants
COLORS = [Qt.red, Qt.black, Qt.blue, Qt.green, Qt.yellow]
# Type Constants, used for color and combobox index
# initalized at the bottom of this module
ERROR = None
TEXT = None
ALL = None
def __init__(self):
super(Logger, self).__init__()
self.__loglines = []
self.__callerIdCount = 0
self.__removedCallers = []
self.__callers = []
self.__guiConsole = None
self.__filePath = None
def log(self, line, callerId=None):
'''
use to log simple text
'''
# inspired by http://stackoverflow.com/a/7272464/2129327
if callerId is None:
try:
caller = inspect.currentframe().f_back.f_locals['self']
if caller:
log_id = getattr(caller, "getCallerId()", None)
if callable(log_id):
callerId = caller.getCallerId()()
except Exception:
pass
if callerId is None:
print('Log.log: no callerId')
return
self.appendLine(line, callerId)
def error(self, line, callerId=None):
'''
use to log errors
'''
# inspired by http://stackoverflow.com/a/7272464/2129327
if callerId is None:
try:
caller = inspect.currentframe().f_back.f_locals['self']
if caller:
log_id = getattr(caller, "getCallerId()", None)
if callable(log_id):
callerId = caller.getCallerId()()
except Exception:
pass
if callerId is None:
print('Log.log: no callerId')
return
self.appendLine(line, callerId, Logger.ERROR)
def appendLine(self, line, callerId, msgType=None):
'''
appends a sing line to the logger and if existing to console
and file
line: String , the string to append
callerId: Int, the id created by getCallerId
msgType: MessageType, An Object indicating which Type this logline is
'''
if(msgType is None):
msgType = Logger.TEXT
ts = tim.time()
st = datetime.datetime.fromtimestamp(ts).strftime("%H:%M:%S")
logline = LogLine(line, self.__callers[callerId], msgType, st)
self.__callers[callerId].setUsed()
self.__loglines.append(logline)
# notify all consoles etc about new line
self.newLine.emit(logline)
if self.__filePath is not None:
self.__guiConsole.appendLineToConsole(logline)
def appendLines(self, lines, callerId):
'''
TODO: Possible optimization for big amount of data, like logfiles
'''
for line in lines:
self.appendLine(line, callerId)
def getCallerId(self, description, customColor=False):
'''
creates a new Caller id which identifies the
using object to group their log messages
'''
if(self.__removedCallers): # list not empty
callerId = self.__removedCallers.pop()
else:
callerId = self.__callerIdCount
self.__callerIdCount += 1
caller = Caller(callerId, description, customColor)
self.__callers.append(caller)
self.refreshCallers()
return callerId
def removeCallerId(self, callerId, keepLines=True):
if(callerId in self.__removedCallers):
return
self.__removedCallers.append(callerId)
filterCallers = lambda caller: caller.callerId != callerId
self.__callers = list(filter(filterCallers, self.__callers))
if(keepLines == False):
filterLines = lambda logline: logline.caller.callerId != callerId
self.__loglines = list(filter(filterLines, self.__loglines))
self.refreshGui()
else:
self.refreshCallers()
def setFile(self, filePath):
self.__filePath = filePath
def refreshGui(self):
'''
refreshs/validates the whole console and filter values
'''
self.refreshConsole()
self.refreshCallers()
def refreshConsole(self):
'''
refreshs/validates the whole console
'''
self.sigRefreshGui.emit(self.__loglines)
def refreshCallers(self):
'''
refreshs/validates the filter values/callers
'''
self.sigRefreshCallers.emit(self.__callers)
def __getPrefix(self, caller, time):
'''
the prefix that have to be added
'''
prefix = ""
prefix = "[" + time + ", " + caller.description + "]"
return prefix
# "Singelton" instance
Log = Logger()
# Add Constant Values to Logger
Logger.ERROR = MessageType(0, "ERROR")
Logger.TEXT = MessageType(1, "TEXT")
Logger.ALL = MessageType(-1, "ALL")
from PyQt5.QtCore import QSettings
def applicationQSetting():
""" Returns the default QSettings-Object for this application """
# return QSettings(QSettings.IniFormat,QSettings.UserScope, "wwu", "Barista")
return QSettings("wwu", "Barista")
from subprocess import Popen, PIPE, STDOUT
from sys import stdout
from os import path
import logging
from backend.networking.protocol import Protocol