Commit ae941f68 authored by l_piel01's avatar l_piel01

Merge branch 'develop' into refactor/410_align-host-manager

parents 97782a00 7cc09b39
......@@ -18,4 +18,7 @@
*.py.orig
/barista.conf
/baristalog.txt
caffeVersions
/.vscode/
/sessions
import os
from backend.barista.utils.logger import LogCaller
versions = []
restart = False
def saveVersions(path=""):
"""Saves the versions to file."""
import pickle
try:
with open(os.path.join(path, "caffeVersions"), "w") as outfile:
global versions
pickle.dump(versions, outfile)
return True
except IOError as e:
return False
def loadVersions(path=""):
"""Reloads the versions from file."""
import pickle
try:
with open(os.path.join(path, "caffeVersions"), "r") as infile:
global versions
versions = pickle.load(infile)
return True
except IOError as e:
return False
def getAvailableVersions():
"""Returns all to barista available caffe versions."""
global versions
return versions
def addVersion(version, path=""):
"""Adds a caffe version to barista.
Returns True if the version was successfully added
Returns False if version could not be added"""
try:
global versions
versions.index(version)
return False
except Exception as e:
versions.append(version)
return saveVersions(path)
def removeVersion(version, path=""):
"""Removes a caffe version from barista.
Returns True if the version was successfully removed
Returns False if version could not be removed"""
try:
global versions
versions.remove(version)
return saveVersions(path)
except Exception as e:
return False
def versionCount():
"""Returns the number of stored caffe versions"""
global versions
return len(versions)
def getVersionByName(name):
"""Returns the version object depending on the given name"""
global versions
for version in versions:
if version.getName() == name:
return version
return None
def getDefaultVersion():
"""Returns the default version object"""
if versionCount() > 0:
global versions
return versions[0]
else:
return None
def setDefaultVersion(name, path=""):
"""Sets the default caffe version of barista,
implemented as beeing the first element of the list
Returns True if the default version was successfully set
Returns False if not"""
version = getVersionByName(name)
if version != None:
global versions
index = versions.index(version)
temp = versions[0]
versions[0] = versions[index]
versions[index] = temp
return saveVersions(path)
else:
return False
class caffeVersion():
"""This class represents a specific caffe version."""
def __init__(self, name, root, binary, python, proto):
self.name = name
self.root = root
self.binary = binary
self.python = python
self.proto = proto
def getName(self):
return self.name
def getRootpath(self):
return self.root
def getBinarypath(self):
return self.binary
def getPythonpath(self):
return self.python
def getProtopath(self):
return self.proto
def setName(self, name):
self.name = name
def setName(self, root):
self.root = root
def setBinarypath(self, binary):
self.binary = binary
def setPythonpath(self, python):
self.python = python
def setProtopath(self, proto):
self.proto = proto
\ No newline at end of file
......@@ -8,6 +8,7 @@ from PyQt5 import QtWidgets
from backend.barista.constraints import common
from backend.caffe.proto_info import CaffeMetaInformation
from gui.network_manager.layer_helper import LayerHelper
from backend.caffe.proto_info import UnknownLayerTypeException
def checkMinimumTrainingRequirements(session, parentGui=None, reportToUser=True):
......@@ -101,13 +102,28 @@ class MinimumTrainingRequirements:
# temporary change the current working dir to allow evaluation of relative paths
if hasattr(self._session, 'checkTraining'):
self._errorMessages.extend(self._session.checkTraining())
else:
elif self._checkLayers():
self._checkSolver()
self._checkDataLayerExistence()
self._checkDataLayerParameters()
self._checkInputLayer()
self._checkUniqueBlobNames()
def _checkLayers(self):
"""Checks ify the network contains layers that are not compatible with the current caffe-version"""
if self._stateData is not None:
layers = []
for layer in self._stateData["network"]["layers"]:
layers.append(self._stateData["network"]["layers"][layer]["parameters"]["type"])
try:
for layer in layers:
typename = CaffeMetaInformation().getLayerType(layer)
except UnknownLayerTypeException as e:
self._errorMessages.append((e._msg, "Unknown Layer"))
return False
return True
def _checkSolver(self):
"""Check whether all solver constraints are valid."""
# the base learning rate should be a positive number
......
import hashlib
import os
def hashFile(path):
"""Hashes a file and returns the hashvalue"""
m = hashlib.md5()
try:
file = open(path, "r")
for chunk in iter(lambda: file.read(4096), b""):
m.update(chunk)
return m.hexdigest()
except Exception as e:
print e
exit -1
def hashDir(path):
"""Hashes a directory and returns the hashvalue"""
hash = ""
for root, dirs, files in os.walk(path):
for names in sorted(files):
try:
filepath = os.path.join(root, names)
hash += hashFile(filepath)
except Exception as e:
print e
exit -1
return hash
\ No newline at end of file
......@@ -8,6 +8,7 @@ from datetime import datetime
from PyQt5.Qt import QObject
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtWidgets import QMessageBox
from backend.barista.constraints.permanent.project import ensureProjectDataConstraints
from backend.barista.session.session import Session
from backend.barista.session.client_session import ClientSession
......@@ -21,10 +22,11 @@ import backend.caffe.dict_helper as helper
import backend.caffe.loader as loader
import backend.caffe.proto_info as info
import backend.caffe.saver as saver
import backend.barista.caffe_versions as caffeVersions
from backend.barista.session.session_pool import SessionPool
from backend.barista.utils.logger import Log
from backend.barista.utils.logger import LogCaller
from backend.caffe import path_loader
from backend.caffe.proto_info import CaffeMetaInformation
def dirIsProject(dirname):
""" Checks if the dir with the name dirname contains a projectfile """
......@@ -93,6 +95,7 @@ class Project(QObject, LogCaller):
def __init__(self, directory):
super(Project, self).__init__()
# Make sure needed directoies exists
self.caffeVersion = caffeVersions.getDefaultVersion().getName()
self.projectRootDir = directory
self.projectId = None
self.current_sid = None
......@@ -118,6 +121,7 @@ class Project(QObject, LogCaller):
the NetworkManager.
State looks like: {
"projectid": "xyz...",
"caffeVersion": {...},
"inputdb": { .. },
"network": { .. },
"solver": {...},
......@@ -137,12 +141,14 @@ class Project(QObject, LogCaller):
res = json.load(file)
if set_settings and ("environment" in res):
self.settings = res["environment"]
#path_loader.reloadCaffeModule(self.getCaffeRoot())
#path_loader.reloadProtoModule(self.getCaffeRoot())
if "inputdb" in res:
self.__inputManagerState = res["inputdb"]
if "caffeVersion" in res:
self.caffeVersion = res["caffeVersion"]
caffeVersions.setDefaultVersion(self.caffeVersion)
if "projectdata" in res: # this is for backward compatibility, only
# Fill missing type-instances
allLayers = info.CaffeMetaInformation().availableLayerTypes()
......@@ -202,6 +208,7 @@ class Project(QObject, LogCaller):
State should looks like: {
"projectid": "xyz...",
"network": { .. },
"caffeVersion": {...},
"inputdb": { .. },
"solver": {...},
"selection": [],
......@@ -219,6 +226,7 @@ class Project(QObject, LogCaller):
# Serialize
tosave = {
"projectid": self.projectId,
"caffeVersion": self.caffeVersion,
"activeSession": self.getActiveSID(),
"inputdb": self.__inputManagerState,
"environment": self.settings,
......@@ -235,27 +243,19 @@ class Project(QObject, LogCaller):
""" Load settings with sensitive defaults.
"""
self.settings = {
'CAFFE_ROOT': "",
'plotter': {
'logFiles': {},
'checkBoxes': {}
}
}
def changeProjectCaffePath(self, path):
"""change the project caffe Path, and if a project settings jason exists, the path is saved in this"""
if not os.path.exists(self.projectConfigFileName()):
self.__loadDefaultSettings()
inputManagerState = dict()
transform = None
self.settings["CAFFE_ROOT"] = path
else:
with open(self.projectConfigFileName(), "r") as file:
res = json.load(file)
res["environment"] = self.settings
with open(self.projectConfigFileName(), "w") as file:
json.dump(res, file, sort_keys=True, indent=4)
def changeProjectCaffeVersion(self, version):
"""change the project caffe version"""
self.caffeVersion = version
def getCaffeVersion(self):
"""returns the caffe version name, saved in the project file or, if this isn't set, from Barista default settings"""
return self.caffeVersion
def deletePlotterSettings(self, logId):
"""deletes the plotter settings for a given log, e.g. if a session was reset"""
......@@ -284,16 +284,6 @@ class Project(QObject, LogCaller):
with open(self.projectConfigFileName(), "w") as file:
json.dump(res, file, sort_keys=True, indent=4)
def getCaffeRoot(self):
"""returns the caffe path, saved in the project file or, if this isn't set, from Barista default settings"""
path = None
if self.settings:
if "CAFFE_ROOT" in self.settings:
if self.settings["CAFFE_ROOT"] != "":
path = self.settings["CAFFE_ROOT"]
if not path:
path = path_loader.getCaffePath()
return path
def getCallerId(self):
""" Return the unique caller id for this project
......@@ -312,6 +302,10 @@ class Project(QObject, LogCaller):
"""
return self.projectRootDir
def getProjectId(self):
""" Return the ProjectId"""
return self.projectId
def buildSolverPrototxt(self):
""" Load the current solver dictionary and return the corresponding
message object.
......@@ -467,9 +461,26 @@ class Project(QObject, LogCaller):
def createRemoteSession(self, remote, state_dictionary=None):
"""use this only to create entirely new sessions. to load existing use the loadRemoteSession command"""
msg = {"key": Protocol.GETCAFFEVERSIONS}
reply = sendMsgToHost(remote[0], remote[1], msg)
if reply:
remoteVersions = reply["versions"]
if len(remoteVersions) <= 0:
msgBox = QMessageBox(QMessageBox.Warning, "Error", "Cannot create remote session on a host witout a caffe-version")
msgBox.addButton("Ok", QMessageBox.NoRole)
msgBox.exec_()
return None
sid = self.getNextSessionId()
msg = {"key": Protocol.CREATESESSION, "pid": self.projectId,
"sid": sid}
msg = {"key": Protocol.CREATESESSION, "pid": self.projectId, "sid": sid}
layers = []
for layer in state_dictionary["network"]["layers"]:
layers.append(state_dictionary["network"]["layers"][layer]["parameters"]["type"])
msg["layers"] = layers
ret = sendMsgToHost(remote[0], remote[1], msg)
if ret:
if ret["status"]:
......@@ -492,6 +503,24 @@ class Project(QObject, LogCaller):
def loadRemoteSession(self, remote, uid):
sid = self.getNextSessionId()
session = ClientSession(self, remote, uid, sid)
availableTypes = info.CaffeMetaInformation().availableLayerTypes()
unknownLayerTypes = []
for layerID in session.state_dictionary["network"]["layers"]:
type = session.state_dictionary["network"]["layers"][layerID]["parameters"]["type"]
if not type in availableTypes and not type in unknownLayerTypes:
unknownLayerTypes.append(type)
if len(unknownLayerTypes) > 0:
msg = "Cannot load session. The selected session contains layers unknown to the current caffe-version.\n\nUnknown layers:"
for type in unknownLayerTypes:
msg += "\n" + type
msgBox = QMessageBox(QMessageBox.Warning, "Warning", msg)
msgBox.addButton("Ok", QMessageBox.NoRole)
msgBox.exec_()
session._disconnect()
return None
self.__sessions[sid] = session
self.newSession.emit(sid)
return sid
......@@ -588,7 +617,6 @@ class Project(QObject, LogCaller):
self.__sessions[sessionID].setPretrainedWeights(newCaffemodel)
self.__sessions[sessionID].iteration = 0
self.__sessions[sessionID].max_iter = oldSession.max_iter
self.__sessions[sessionID].caffe_root = oldSession.caffe_root
except Exception as e:
Log.error('Failed to copy caffemodel to new session: '+str(e),
self.getCallerId())
......@@ -701,16 +729,6 @@ class Project(QObject, LogCaller):
else:
return 1
#def getCaffeRoot(self):
# """ Return the caffe installation directory.
# """
# if 'CAFFE_ROOT' in self.settings:
# print("bla")
# self.__loadDefaultSettings()
# else:
# print("bla2")
# return self.settings['CAFFE_ROOT']
def __ensureDirectory(self, directory):
""" Creates a directory if it does not exist.
"""
......@@ -733,4 +751,4 @@ class Project(QObject, LogCaller):
for line in f:
prefix_match = regex_prefix.search(line)
if prefix_match:
return prefix_match.group(1)
return prefix_match.group(1)
\ No newline at end of file
......@@ -279,9 +279,9 @@ class ClientSession(QObject):
def getState(self, local=True):
if local and not self.getStateFirstTime:
return self.state
return self.state
self.getStateFirstTime = False
if self._assertConnection():
if self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.GETSTATE}
self.transaction.send(msg)
ret = self.transaction.asyncRead(staging=True, attr=("subkey", SessionProtocol.GETSTATE))
......@@ -289,10 +289,11 @@ class ClientSession(QObject):
if ret["status"]:
self.setState(ret["state"], True)
return ret["state"]
else:
self._handleErrors(ret["error"])
self.setState(State.UNDEFINED, True)
return State.UNDEFINED
return State.UNDEFINED
self._handleErrors(["Failed to connect to remote session to acquire current state."])
self.setState(State.NOTCONNECTED, True)
return State.NOTCONNECTED
......@@ -350,12 +351,11 @@ class ClientSession(QObject):
self.transaction.send(msg)
ret = self.transaction.asyncRead(attr=("subkey", SessionProtocol.SETSTATEDICT))
if ret:
if ret["status"]:
self.stateDictChanged.emit(self, False)
return
self._handleErrors(ret["error"])
if not ret["status"]:
self._handleErrors(ret["error"])
else:
self._handleErrors(["SetStateDict: Failed to send StateDict"]) # TODO improve warnings
self._handleErrors(["SetStateDict: Failed to send StateDict"]) # TODO improve warnings
self.stateDictChanged.emit(self, False)
def _validateLayerOrder(self, dict):
......
This diff is collapsed.
......@@ -29,6 +29,8 @@ import backend.caffe.saver as saver
import backend.caffe.proto_info as info
import backend.caffe.dict_helper as helper
import backend.barista.caffe_versions as caffeVersions
class Session(QObject, LogCaller, ParserListener, SessionCommon):
""" A session is a caffe training process.
......@@ -44,7 +46,7 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
iterationChanged = pyqtSignal()
snapshotAdded = pyqtSignal(object)
def __init__(self, project, directory=None, sid=None, parse_old=False, caffe_root=None,
def __init__(self, project, directory=None, sid=None, parse_old=False, caffe_bin=None,
last_solverstate=None, last_caffemodel=None, state_dictionary=None):
super(Session, self).__init__()
self.caller_id = None
......@@ -70,8 +72,9 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
'provided sid is used.', self.getCallerId())
self.parse_old = parse_old
self.caffe_root = caffe_root # overrides project caffe_root if necessary, i.e. if deployed to another system
self.caffe_bin = caffe_bin # overrides project caffe_root if necessary, i.e. if deployed to another system
self.pretrainedWeights = None
self.last_solverstate = last_solverstate
self.last_caffemodel = last_caffemodel
self.state_dictionary = state_dictionary # state as saved from the network manager, such it can be restored
......@@ -170,9 +173,9 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
if os.path.exists(self.snapshot_dir) is False:
Log.error('Snapshot directory does not exists: '+self.snapshot_dir,
self.getCallerId())
if os.path.exists(self.project.getCaffeRoot()) is False:
Log.error('CAFFE_ROOT directory does not exists: ' +
self.project.getCaffeRoot(),
if os.file.exists(caffeVersions.getVersionByName(self.project.getCaffeVersion())) is False:
Log.error('Caffe binary does not exists: ' +
self.project.getCaffeVersion(),
self.getCallerId())
......@@ -237,6 +240,7 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
"""
if self.state == State.FAILED:
return self.state
if self.proc is not None:
if self.iteration == self.max_iter:
self.state = State.FINISHED
......@@ -299,17 +303,15 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
# (re-)write all session files
self.save(includeProtoTxt=True)
# check if the session has its own caffeRoot
caffeRoot = self.caffe_root
if not caffeRoot:
caffeBin = self.caffe_bin
if not caffeBin:
# else take the project's caffeRoot path
caffeRoot = self.project.getCaffeRoot()
caffeBin = caffeVersions.getVersionByName(self.project.getCaffeVersion()).getBinarypath()
try:
self.getParser().setLogging(True)
cmd = [
os.path.join(caffeRoot,
'build/tools/caffe'),
cmd = [caffeBin,
'train',
'-solver', self.getSolver()]
if solverstate:
......@@ -341,10 +343,10 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
# check if caffe root exists
Log.error('Failed to start session: '+str(e),
self.getCallerId())
if os.path.exists(self.project.getCaffeRoot()) is False:
Log.error('CAFFE_ROOT directory does not exists: ' +
caffeRoot +
'! Please set CAFFE_ROOT to run a session.',
if os.file.exists(caffeVersions.getVersionByName(self.project.getCaffeVersion()).getBinarypath()) is False:
Log.error('CAFFE_BINARY directory does not exists: ' +
caffe_bin +
'! Please set CAFFE_BINARY to run a session.',
self.getCallerId())
else:
Log.error(
......@@ -418,8 +420,8 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
self.rid += 1
try:
self.getParser().setLogging(True)
self.proc = Popen(
[os.path.join(self.project.getCaffeRoot(), 'build/tools/caffe'),
self.proc = Popen([
caffeVersions.getVersionByName(self.project.getCaffeVersion()).getBinarypath(),
'train',
'-solver', self.getSolver(),
'-snapshot', snapshot],
......@@ -444,11 +446,11 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
# check if caffe root exists
Log.error('Failed to continue session: '+str(e),
self.getCallerId())
if os.path.exists(self.project.getCaffeRoot()) is False:
Log.error('CAFFE_ROOT directory does not exists: ' +
self.project.getCaffeRoot() +
'! Please set CAFFE_ROOT to run a session.',
self.getCallerId())
if os.file.exists(caffeVersions.getVersionByName(self.project.getCaffeVersion()).getBinarypath()) is False:
Log.error('CAFFE_BINARY directory does not exists: ' +
caffe_bin +
'! Please set CAFFE_BINARY to run a session.',
self.getCallerId())
elif self.getState() in (State.FAILED, State.FINISHED):
Log.error('Could not continue a session in state ' +
str(self.getState()), self.getCallerId())
......@@ -1052,4 +1054,4 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
self.setLastModel(None)
self.setLastSnapshot(None)
self.project.resetSession.emit(self.getSessionId())
self.save()
self.save()
\ No newline at end of file
......@@ -88,7 +88,8 @@ returns a list containing the paths to all .h5 or .hdf5 files that are specified
exists
@author j_stru18
@param path : a string that is supposed to be a valid path for a HDF5TXT file
@ensure: if filePath points to a valid file, it has to a textfile
@param filePath : a string that is supposed to be a valid path for a HDF5TXT file
@return a list that contains all .h5 and .hdf5 files in the HDF5TXT
"""
def getHdf5List(filePath):
......@@ -105,6 +106,23 @@ def getHdf5List(filePath):
Log.error("HDF5 textfile contained invalid path {}".format(line), DBUTIL_LOGGER_ID)
return pathList
"""
returns a list containing all lines of the file that filePath points at
@author: j_stru18
@param filePath : a string that is supposed to be a path for a txt file
@return: a list that contains all lines of that file
"""
def getLinesAsList(filePath):
pathList = []
if os.path.exists(filePath):
pathList = [line for line in open(filePath)]
return pathList
"""
Given the name of a file, this method returns its type
......
......@@ -4,7 +4,7 @@ from os import path
import logging
from backend.networking.protocol import Protocol
def checkHardware(cafferoot, silent=False, transaction=None):
def checkHardware(binary, silent=False, transaction=None):
"""
probe caffe continuously for incrementing until missing id
structure:
......@@ -35,7 +35,7 @@ def checkHardware(cafferoot, silent=False, transaction=None):
msg = {"key": Protocol.SCANHARDWARE, "finished": False, "name": name}
transaction.send(msg)
while True:
log = _getId(gid, cafferoot)
log = _getId(gid, binary)
if not _isValid(log) or _isCpuOnly(log):
if not silent and gid is 0:
stdout.write("No GPU found, CPU mode\n")
......@@ -52,10 +52,9 @@ def checkHardware(cafferoot, silent=False, transaction=None):
return hw
def _getId(gid, cafferoot):
def _getId(gid, binary):
"""probe caffe for gpu id"""
caffe = path.join(cafferoot, "build/tools/caffe")
proc = Popen([caffe, "device_query", "-gpu", str(gid)], stdout=PIPE, stderr=STDOUT)
proc = Popen([binary, "device_query", "-gpu", str(gid)], stdout=PIPE, stderr=STDOUT)
it = iter(proc.stdout.readline, '')
log = []
for i in it:
......
......@@ -99,8 +99,8 @@ def _areParameterValid(allparams, paramdict, prefix=""):
def bareNet(name):
""" Creates a dictionary of a networks with default values where required. """
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
net = proto.NetParameter()
descr = info.ParameterGroupDescriptor(net)
params = descr.parameter().copy()
......@@ -116,8 +116,8 @@ def _bareLayer(layertype, name):
""" Creates a dictionary of the given layertype with the given name
initialized with default values if required.
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
res = {"type": layertype}
layerParamInst = proto.LayerParameter()
res["parameters"] = _extract_param(layerParamInst, info.CaffeMetaInformation().commonParameters())
......
......@@ -17,8 +17,8 @@ class ParseException(Exception):
def loadSolver(solverstring):
""" Return a dictionary which represent the caffe-solver-prototxt solverstring """
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
solver = proto.SolverParameter()
# Get DESCRIPTION for meta infos
......@@ -66,8 +66,8 @@ def loadNet(netstring):
}
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
# Load Protoclass for parsing
net = proto.NetParameter()
......@@ -172,8 +172,8 @@ def extractNetFromSolver(solverstring):
This works only, if the solver specifies a network using the "net_param" parameter. A reference to a file using the
"net" parameter can not be handled by this method.
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader