Commit ae941f68 authored by l_piel01's avatar l_piel01

Merge branch 'develop' into refactor/410_align-host-manager

parents 97782a00 7cc09b39
......@@ -18,4 +18,7 @@
*.py.orig
/barista.conf
/baristalog.txt
caffeVersions
/.vscode/
/sessions
import os
from backend.barista.utils.logger import LogCaller
versions = []
restart = False
def saveVersions(path=""):
"""Saves the versions to file."""
import pickle
try:
with open(os.path.join(path, "caffeVersions"), "w") as outfile:
global versions
pickle.dump(versions, outfile)
return True
except IOError as e:
return False
def loadVersions(path=""):
"""Reloads the versions from file."""
import pickle
try:
with open(os.path.join(path, "caffeVersions"), "r") as infile:
global versions
versions = pickle.load(infile)
return True
except IOError as e:
return False
def getAvailableVersions():
"""Returns all to barista available caffe versions."""
global versions
return versions
def addVersion(version, path=""):
"""Adds a caffe version to barista.
Returns True if the version was successfully added
Returns False if version could not be added"""
try:
global versions
versions.index(version)
return False
except Exception as e:
versions.append(version)
return saveVersions(path)
def removeVersion(version, path=""):
"""Removes a caffe version from barista.
Returns True if the version was successfully removed
Returns False if version could not be removed"""
try:
global versions
versions.remove(version)
return saveVersions(path)
except Exception as e:
return False
def versionCount():
"""Returns the number of stored caffe versions"""
global versions
return len(versions)
def getVersionByName(name):
"""Returns the version object depending on the given name"""
global versions
for version in versions:
if version.getName() == name:
return version
return None
def getDefaultVersion():
"""Returns the default version object"""
if versionCount() > 0:
global versions
return versions[0]
else:
return None
def setDefaultVersion(name, path=""):
"""Sets the default caffe version of barista,
implemented as beeing the first element of the list
Returns True if the default version was successfully set
Returns False if not"""
version = getVersionByName(name)
if version != None:
global versions
index = versions.index(version)
temp = versions[0]
versions[0] = versions[index]
versions[index] = temp
return saveVersions(path)
else:
return False
class caffeVersion():
"""This class represents a specific caffe version."""
def __init__(self, name, root, binary, python, proto):
self.name = name
self.root = root
self.binary = binary
self.python = python
self.proto = proto
def getName(self):
return self.name
def getRootpath(self):
return self.root
def getBinarypath(self):
return self.binary
def getPythonpath(self):
return self.python
def getProtopath(self):
return self.proto
def setName(self, name):
self.name = name
def setName(self, root):
self.root = root
def setBinarypath(self, binary):
self.binary = binary
def setPythonpath(self, python):
self.python = python
def setProtopath(self, proto):
self.proto = proto
\ No newline at end of file
......@@ -8,6 +8,7 @@ from PyQt5 import QtWidgets
from backend.barista.constraints import common
from backend.caffe.proto_info import CaffeMetaInformation
from gui.network_manager.layer_helper import LayerHelper
from backend.caffe.proto_info import UnknownLayerTypeException
def checkMinimumTrainingRequirements(session, parentGui=None, reportToUser=True):
......@@ -101,13 +102,28 @@ class MinimumTrainingRequirements:
# temporary change the current working dir to allow evaluation of relative paths
if hasattr(self._session, 'checkTraining'):
self._errorMessages.extend(self._session.checkTraining())
else:
elif self._checkLayers():
self._checkSolver()
self._checkDataLayerExistence()
self._checkDataLayerParameters()
self._checkInputLayer()
self._checkUniqueBlobNames()
def _checkLayers(self):
"""Checks ify the network contains layers that are not compatible with the current caffe-version"""
if self._stateData is not None:
layers = []
for layer in self._stateData["network"]["layers"]:
layers.append(self._stateData["network"]["layers"][layer]["parameters"]["type"])
try:
for layer in layers:
typename = CaffeMetaInformation().getLayerType(layer)
except UnknownLayerTypeException as e:
self._errorMessages.append((e._msg, "Unknown Layer"))
return False
return True
def _checkSolver(self):
"""Check whether all solver constraints are valid."""
# the base learning rate should be a positive number
......
import hashlib
import os
def hashFile(path):
"""Hashes a file and returns the hashvalue"""
m = hashlib.md5()
try:
file = open(path, "r")
for chunk in iter(lambda: file.read(4096), b""):
m.update(chunk)
return m.hexdigest()
except Exception as e:
print e
exit -1
def hashDir(path):
"""Hashes a directory and returns the hashvalue"""
hash = ""
for root, dirs, files in os.walk(path):
for names in sorted(files):
try:
filepath = os.path.join(root, names)
hash += hashFile(filepath)
except Exception as e:
print e
exit -1
return hash
\ No newline at end of file
......@@ -8,6 +8,7 @@ from datetime import datetime
from PyQt5.Qt import QObject
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtWidgets import QMessageBox
from backend.barista.constraints.permanent.project import ensureProjectDataConstraints
from backend.barista.session.session import Session
from backend.barista.session.client_session import ClientSession
......@@ -21,10 +22,11 @@ import backend.caffe.dict_helper as helper
import backend.caffe.loader as loader
import backend.caffe.proto_info as info
import backend.caffe.saver as saver
import backend.barista.caffe_versions as caffeVersions
from backend.barista.session.session_pool import SessionPool
from backend.barista.utils.logger import Log
from backend.barista.utils.logger import LogCaller
from backend.caffe import path_loader
from backend.caffe.proto_info import CaffeMetaInformation
def dirIsProject(dirname):
""" Checks if the dir with the name dirname contains a projectfile """
......@@ -93,6 +95,7 @@ class Project(QObject, LogCaller):
def __init__(self, directory):
super(Project, self).__init__()
# Make sure needed directoies exists
self.caffeVersion = caffeVersions.getDefaultVersion().getName()
self.projectRootDir = directory
self.projectId = None
self.current_sid = None
......@@ -118,6 +121,7 @@ class Project(QObject, LogCaller):
the NetworkManager.
State looks like: {
"projectid": "xyz...",
"caffeVersion": {...},
"inputdb": { .. },
"network": { .. },
"solver": {...},
......@@ -137,12 +141,14 @@ class Project(QObject, LogCaller):
res = json.load(file)
if set_settings and ("environment" in res):
self.settings = res["environment"]
#path_loader.reloadCaffeModule(self.getCaffeRoot())
#path_loader.reloadProtoModule(self.getCaffeRoot())
if "inputdb" in res:
self.__inputManagerState = res["inputdb"]
if "caffeVersion" in res:
self.caffeVersion = res["caffeVersion"]
caffeVersions.setDefaultVersion(self.caffeVersion)
if "projectdata" in res: # this is for backward compatibility, only
# Fill missing type-instances
allLayers = info.CaffeMetaInformation().availableLayerTypes()
......@@ -202,6 +208,7 @@ class Project(QObject, LogCaller):
State should looks like: {
"projectid": "xyz...",
"network": { .. },
"caffeVersion": {...},
"inputdb": { .. },
"solver": {...},
"selection": [],
......@@ -219,6 +226,7 @@ class Project(QObject, LogCaller):
# Serialize
tosave = {
"projectid": self.projectId,
"caffeVersion": self.caffeVersion,
"activeSession": self.getActiveSID(),
"inputdb": self.__inputManagerState,
"environment": self.settings,
......@@ -235,27 +243,19 @@ class Project(QObject, LogCaller):
""" Load settings with sensitive defaults.
"""
self.settings = {
'CAFFE_ROOT': "",
'plotter': {
'logFiles': {},
'checkBoxes': {}
}
}
def changeProjectCaffePath(self, path):
"""change the project caffe Path, and if a project settings jason exists, the path is saved in this"""
if not os.path.exists(self.projectConfigFileName()):
self.__loadDefaultSettings()
inputManagerState = dict()
transform = None
self.settings["CAFFE_ROOT"] = path
else:
with open(self.projectConfigFileName(), "r") as file:
res = json.load(file)
res["environment"] = self.settings
with open(self.projectConfigFileName(), "w") as file:
json.dump(res, file, sort_keys=True, indent=4)
def changeProjectCaffeVersion(self, version):
"""change the project caffe version"""
self.caffeVersion = version
def getCaffeVersion(self):
"""returns the caffe version name, saved in the project file or, if this isn't set, from Barista default settings"""
return self.caffeVersion
def deletePlotterSettings(self, logId):
"""deletes the plotter settings for a given log, e.g. if a session was reset"""
......@@ -284,16 +284,6 @@ class Project(QObject, LogCaller):
with open(self.projectConfigFileName(), "w") as file:
json.dump(res, file, sort_keys=True, indent=4)
def getCaffeRoot(self):
"""returns the caffe path, saved in the project file or, if this isn't set, from Barista default settings"""
path = None
if self.settings:
if "CAFFE_ROOT" in self.settings:
if self.settings["CAFFE_ROOT"] != "":
path = self.settings["CAFFE_ROOT"]
if not path:
path = path_loader.getCaffePath()
return path
def getCallerId(self):
""" Return the unique caller id for this project
......@@ -312,6 +302,10 @@ class Project(QObject, LogCaller):
"""
return self.projectRootDir
def getProjectId(self):
""" Return the ProjectId"""
return self.projectId
def buildSolverPrototxt(self):
""" Load the current solver dictionary and return the corresponding
message object.
......@@ -467,9 +461,26 @@ class Project(QObject, LogCaller):
def createRemoteSession(self, remote, state_dictionary=None):
"""use this only to create entirely new sessions. to load existing use the loadRemoteSession command"""
msg = {"key": Protocol.GETCAFFEVERSIONS}
reply = sendMsgToHost(remote[0], remote[1], msg)
if reply:
remoteVersions = reply["versions"]
if len(remoteVersions) <= 0:
msgBox = QMessageBox(QMessageBox.Warning, "Error", "Cannot create remote session on a host witout a caffe-version")
msgBox.addButton("Ok", QMessageBox.NoRole)
msgBox.exec_()
return None
sid = self.getNextSessionId()
msg = {"key": Protocol.CREATESESSION, "pid": self.projectId,
"sid": sid}
msg = {"key": Protocol.CREATESESSION, "pid": self.projectId, "sid": sid}
layers = []
for layer in state_dictionary["network"]["layers"]:
layers.append(state_dictionary["network"]["layers"][layer]["parameters"]["type"])
msg["layers"] = layers
ret = sendMsgToHost(remote[0], remote[1], msg)
if ret:
if ret["status"]:
......@@ -492,6 +503,24 @@ class Project(QObject, LogCaller):
def loadRemoteSession(self, remote, uid):
sid = self.getNextSessionId()
session = ClientSession(self, remote, uid, sid)
availableTypes = info.CaffeMetaInformation().availableLayerTypes()
unknownLayerTypes = []
for layerID in session.state_dictionary["network"]["layers"]:
type = session.state_dictionary["network"]["layers"][layerID]["parameters"]["type"]
if not type in availableTypes and not type in unknownLayerTypes:
unknownLayerTypes.append(type)
if len(unknownLayerTypes) > 0:
msg = "Cannot load session. The selected session contains layers unknown to the current caffe-version.\n\nUnknown layers:"
for type in unknownLayerTypes:
msg += "\n" + type
msgBox = QMessageBox(QMessageBox.Warning, "Warning", msg)
msgBox.addButton("Ok", QMessageBox.NoRole)
msgBox.exec_()
session._disconnect()
return None
self.__sessions[sid] = session
self.newSession.emit(sid)
return sid
......@@ -588,7 +617,6 @@ class Project(QObject, LogCaller):
self.__sessions[sessionID].setPretrainedWeights(newCaffemodel)
self.__sessions[sessionID].iteration = 0
self.__sessions[sessionID].max_iter = oldSession.max_iter
self.__sessions[sessionID].caffe_root = oldSession.caffe_root
except Exception as e:
Log.error('Failed to copy caffemodel to new session: '+str(e),
self.getCallerId())
......@@ -701,16 +729,6 @@ class Project(QObject, LogCaller):
else:
return 1
#def getCaffeRoot(self):
# """ Return the caffe installation directory.
# """
# if 'CAFFE_ROOT' in self.settings:
# print("bla")
# self.__loadDefaultSettings()
# else:
# print("bla2")
# return self.settings['CAFFE_ROOT']
def __ensureDirectory(self, directory):
""" Creates a directory if it does not exist.
"""
......
......@@ -289,6 +289,7 @@ class ClientSession(QObject):
if ret["status"]:
self.setState(ret["state"], True)
return ret["state"]
else:
self._handleErrors(ret["error"])
self.setState(State.UNDEFINED, True)
......@@ -350,12 +351,11 @@ class ClientSession(QObject):
self.transaction.send(msg)
ret = self.transaction.asyncRead(attr=("subkey", SessionProtocol.SETSTATEDICT))
if ret:
if ret["status"]:
self.stateDictChanged.emit(self, False)
return
if not ret["status"]:
self._handleErrors(ret["error"])
else:
self._handleErrors(["SetStateDict: Failed to send StateDict"]) # TODO improve warnings
self.stateDictChanged.emit(self, False)
def _validateLayerOrder(self, dict):
......
......@@ -27,7 +27,9 @@ from backend.networking.protocol import Protocol, SessionProtocol
from backend.parser.concatenator import Concatenator
from backend.parser.parser import Parser
from backend.parser.parser_listener import ParserListener
import backend.barista.caffe_versions as caffeVersions
from backend.barista.deployed_net import DeployedNet
from backend.caffe.proto_info import UnknownLayerTypeException
from PyQt5.QtCore import QTimer
from threading import Lock
......@@ -282,8 +284,8 @@ class ServerSession(QObject, ParserListener, SessionCommon):
errors.append('Log directory does not exists: ' + self.logs)
if os.path.exists(self.snapshot_dir) is False:
errors.append('Snapshot directory does not exists: ' + self.snapshot_dir)
if os.path.exists(self.manager.parent.caffePath) is False:
errors.append('CAFFE_ROOT directory does not exists: ' + self.manager.parent.caffePath)
if os.path.exists(caffeVersions.getDefaultVersion().getBinaryPath()) is False:
errors.append('CAFFE_BINARY does not exists: ' + caffeVersions.getDefaultVersion().getBinaryPath())
msg["error"] = errors
self.transaction.send(msg)
......@@ -391,9 +393,13 @@ class ServerSession(QObject, ParserListener, SessionCommon):
def _msgSetStateDict(self):
msg = self.transaction.asyncRead()
try:
self.setStateDict(msg["statedict"])
del msg["statedict"]
msg["status"] = True
except UnknownLayerTypeException as e:
msg["status"] = False
msg["error"] = [e._msg]
del msg["statedict"]
self.transaction.send(msg)
def setStateDict(self, statedict):
......@@ -408,7 +414,7 @@ class ServerSession(QObject, ParserListener, SessionCommon):
if "parameters" in layers[id]:
if "type" in layers[id]["parameters"]:
typename = layers[id]["parameters"]["type"]
layers[id]["type"] = info.CaffeMetaInformation().availableLayerTypes()[typename]
layers[id]["type"] = info.CaffeMetaInformation().getLayerType(typename)
def _msgGetStateDict(self):
msg = self.transaction.asyncRead()
......@@ -425,11 +431,22 @@ class ServerSession(QObject, ParserListener, SessionCommon):
msg["status"] = True
self.transaction.send(msg)
def save(self, includeProtoTxt = False):
def save(self, includeProtoTxt = False, errors = []):
"""Saves the current session to prototxt files and session_settings json file."""
res = self.__ensureDirectory()
if len(res) > 0:
return res
availableTypes = info.CaffeMetaInformation().availableLayerTypes()
unknownLayerTypes = []
for layerID in self.state_dictionary["network"]["layers"]:
type = self.state_dictionary["network"]["layers"][layerID]["parameters"]["type"]
if not type in availableTypes and not type in unknownLayerTypes:
unknownLayerTypes.append(type)
if len(unknownLayerTypes) > 0:
errors.extend(unknownLayerTypes)
return False
else:
toSave = {"SessionState": self.state, "Iteration": self.iteration, "MaxIter": self.max_iter}
toSave["UID"] = self.uid
toSave["SID"] = self.sid
......@@ -461,7 +478,8 @@ class ServerSession(QObject, ParserListener, SessionCommon):
with open(filename, "w") as f:
json.dump(toSave, f, sort_keys=True, indent=4)
return []
return True
def prepairInternalPrototxt(self):
error = []
......@@ -575,7 +593,7 @@ class ServerSession(QObject, ParserListener, SessionCommon):
if "parameters" in layers[id]:
if "type" in layers[id]["parameters"]:
typename = layers[id]["parameters"]["type"]
layers[id]["type"] = info.CaffeMetaInformation().availableLayerTypes()[typename]
layers[id]["type"] = info.CaffeMetaInformation().getLayerType(typename)
solver = self.state_dictionary["solver"]
if solver:
if "snapshot_prefix" in solver:
......@@ -686,9 +704,10 @@ class ServerSession(QObject, ParserListener, SessionCommon):
def _msgSave(self):
msg = self.transaction.asyncRead()
res = self.save()
msg["status"] = len(res) == 0
msg["error"] = res
error = []
ret = self.save(errors = error)
msg["status"] = ret
msg["error"] = error
self.transaction.send(msg)
def _takeSnapshot(self):
......@@ -848,7 +867,7 @@ class ServerSession(QObject, ParserListener, SessionCommon):
# TODO only one training per server
error = []
# (re-)write all session files
error.extend(self.save(includeProtoTxt=True))
self.save(includeProtoTxt=True, errors=error)
if not os.path.exists(self.logs):
try:
os.makedirs(self.logs)
......@@ -862,15 +881,12 @@ class ServerSession(QObject, ParserListener, SessionCommon):
error.extend(self.prepairInternalPrototxt())
if len(error) > 0:
return error
caffepath = self.manager.caffePath
if not os.path.exists(caffepath):
return ["No CaffePath"]
caffe_bin = caffeVersions.getDefaultVersion().getBinarypath()
try:
self.getParser().setLogging(True)
# TODO set hardware trainOnHW
cmd = [
os.path.join(caffepath, 'build/tools/caffe'), 'train', '-solver',
caffe_bin, 'train', '-solver',
os.path.join(self.directory, Paths.FILE_NAME_SOLVER)]
if solverstate is not None:
cmd.append('-snapshot')
......@@ -881,7 +897,6 @@ class ServerSession(QObject, ParserListener, SessionCommon):
if self.manager.parent.trainOnHW > 0:
cmd.append('-gpu')
cmd.append(str(self.manager.parent.trainOnHW-1))
self.proc = Popen(
cmd,
stdout=PIPE,
......@@ -955,15 +970,12 @@ class ServerSession(QObject, ParserListener, SessionCommon):
if self.state is State.PAUSED:
if snapshot is None:
snapshot = self.getLastSnapshot()
caffepath = self.manager.caffePath
if not os.path.exists(caffepath):
return ["No CaffePath"]
caffe_bin = caffeVersions.getDefaultVersion().getBinarypath()
self.rid += 1
try:
self.getParser().setLogging(True)
cmd = [
os.path.join(caffepath, 'build/tools/caffe'), 'train', '-solver',
caffe_bin, 'train', '-solver',
os.path.join(self.directory, Paths.FILE_NAME_SOLVER)]
cmd.append('-snapshot')
cmd.append(snapshot)
......
......@@ -29,6 +29,8 @@ import backend.caffe.saver as saver
import backend.caffe.proto_info as info
import backend.caffe.dict_helper as helper
import backend.barista.caffe_versions as caffeVersions
class Session(QObject, LogCaller, ParserListener, SessionCommon):
""" A session is a caffe training process.
......@@ -44,7 +46,7 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
iterationChanged = pyqtSignal()
snapshotAdded = pyqtSignal(object)
def __init__(self, project, directory=None, sid=None, parse_old=False, caffe_root=None,
def __init__(self, project, directory=None, sid=None, parse_old=False, caffe_bin=None,
last_solverstate=None, last_caffemodel=None, state_dictionary=None):
super(Session, self).__init__()
self.caller_id = None
......@@ -70,8 +72,9 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
'provided sid is used.', self.getCallerId())
self.parse_old = parse_old
self.caffe_root = caffe_root # overrides project caffe_root if necessary, i.e. if deployed to another system
self.caffe_bin = caffe_bin # overrides project caffe_root if necessary, i.e. if deployed to another system
self.pretrainedWeights = None
self.last_solverstate = last_solverstate
self.last_caffemodel = last_caffemodel
self.state_dictionary = state_dictionary # state as saved from the network manager, such it can be restored
......@@ -170,9 +173,9 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
if os.path.exists(self.snapshot_dir) is False:
Log.error('Snapshot directory does not exists: '+self.snapshot_dir,
self.getCallerId())
if os.path.exists(self.project.getCaffeRoot()) is False:
Log.error('CAFFE_ROOT directory does not exists: ' +
self.project.getCaffeRoot(),
if os.file.exists(caffeVersions.getVersionByName(self.project.getCaffeVersion())) is False:
Log.error('Caffe binary does not exists: ' +
self.project.getCaffeVersion(),
self.getCallerId())
......@@ -237,6 +240,7 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
"""
if self.state == State.FAILED:
return self.state
if self.proc is not None:
if self.iteration == self.max_iter:
self.state = State.FINISHED
......@@ -299,17 +303,15 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
# (re-)write all session files
self.save(includeProtoTxt=True)
# check if the session has its own caffeRoot
caffeRoot = self.caffe_root
if not caffeRoot:
caffeBin = self.caffe_bin
if not caffeBin:
# else take the project's caffeRoot path
caffeRoot = self.project.getCaffeRoot()
caffeBin = caffeVersions.getVersionByName(self.project.getCaffeVersion()).getBinarypath()
try:
self.getParser().setLogging(True)
cmd = [
os.path.join(caffeRoot,
'build/tools/caffe'),
cmd = [caffeBin,
'train',
'-solver', self.getSolver()]
if solverstate:
......@@ -341,10 +343,10 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
# check if caffe root exists
Log.error('Failed to start session: '+str(e),
self.getCallerId())
if os.path.exists(self.project.getCaffeRoot()) is False:
Log.error('CAFFE_ROOT directory does not exists: ' +
caffeRoot +
'! Please set CAFFE_ROOT to run a session.',
if os.file.exists(caffeVersions.getVersionByName(self.project.getCaffeVersion()).getBinarypath()) is False:
Log.error('CAFFE_BINARY directory does not exists: ' +
caffe_bin +
'! Please set CAFFE_BINARY to run a session.',
self.getCallerId())
else:
Log.error(
......@@ -418,8 +420,8 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
self.rid += 1
try:
self.getParser().setLogging(True)
self.proc = Popen(
[os.path.join(self.project.getCaffeRoot(), 'build/tools/caffe'),
self.proc = Popen([
caffeVersions.getVersionByName(self.project.getCaffeVersion()).getBinarypath(),
'train',
'-solver', self.getSolver(),
'-snapshot', snapshot],
......@@ -444,10 +446,10 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
# check if caffe root exists
Log.error('Failed to continue session: '+str(e),
self.getCallerId())
if os.path.exists(self.project.getCaffeRoot()) is False:
Log.error('CAFFE_ROOT directory does not exists: ' +
self.project.getCaffeRoot() +
'! Please set CAFFE_ROOT to run a session.',
if os.file.exists(caffeVersions.getVersionByName(self.project.getCaffeVersion()).getBinarypath()) is False:
Log.error('CAFFE_BINARY directory does not exists: ' +
caffe_bin +