Commit 9256d122 authored by Johannes Stricker's avatar Johannes Stricker

Merge branch 'develop' into feature/455-sessions

parents 12b6b671 c4edb4ec
......@@ -16,7 +16,14 @@ Barista is an open-source graphical high-level interface for the Caffe deep lear
**Using Barista**
* To start Barista, run the file `main.py`
* A user manual as well as a starter's tutorial on Barista can be found in the wiki of the project repository: https://zivgitlab.uni-muenster.de/pria/Barista
* A user manual as well as a starter's tutorial on Barista can be found in the wiki of the project repository: https://zivgitlab.uni-muenster.de/pria/Barista/wikis/home
If you have any questions regarding Barista, please feel free to contact the developers at barista(at)uni-muenster.de
** Citing Barista **
If you use Barista in your research, please cite it using the following reference:
Klemm, S. and Scherzinger, A. and Drees, D. and Jiang, X. Barista - a Graphical Tool for Designing and Training Deep Neural Networks. arXiv preprint. arXiv:1802.04626. http://arxiv.org/abs/1802.04626
......@@ -128,35 +128,49 @@ class MinimumTrainingRequirements:
def _checkUniqueBlobNames(self):
"""Checks for duplicate top blob names in all layers and emits error message if duplicates found except
if in-place is permitted (Activation Layer)"""
nonActivationTopBlobNames = []
activationTopBlobNames = []
if in-place is permitted"""
# find all 'real' sources of blobs, i.e. layers that produce a blob as output and are not in-place
blobGenerators = {} # dictionary that saves the sources for every blob name
if hasattr(self._stateData, '__getitem__'):
# inefficient double looping but activation layers can only be checked after non-activation layers
for layer_id, layer in self._stateData["network"]["layers"].iteritems():
if LayerHelper.isLayerIncludedInTrainingPhase(layer) and ("top" in layer["parameters"]):
if not layer["type"].isActivationLayer():
for name in layer["parameters"]["top"]:
if name not in nonActivationTopBlobNames:
nonActivationTopBlobNames.append(name)
else:
self._errorMessages.append((name + " is used more than once as connection name!",
"Duplicate connector name in " + layer["parameters"]["name"]))
for layer_id, layer in self._stateData["network"]["layers"].iteritems():
if LayerHelper.isLayerIncludedInTrainingPhase(layer) and ("top" in layer["parameters"]):
if layer["type"].isActivationLayer():
for name in layer["parameters"]["top"]:
if name not in activationTopBlobNames:
activationTopBlobNames.append(name)
if name in nonActivationTopBlobNames:
if not LayerHelper.isLayerInPlace(layer):
self._errorMessages.append(
(name + " is used more than once as connection name!",
"Duplicate connector name in " + layer["parameters"]["name"]))
else:
self._errorMessages.append((name + " is used more than once as connection name!",
"Duplicate connector name in " + layer["parameters"]["name"]))
parameters = layer.get("parameters", {})
tops = parameters.get("top", [])
bottoms = parameters.get("bottom", [])
# if at least one top blob is also a bottom, check if layer allows in-place
sourced = [blob for blob in tops if blob not in bottoms]
inPlace = [blob for blob in tops if blob in bottoms]
if len(inPlace) > 0 and not layer["type"].allowsInPlace():
for name in inPlace:
self._errorMessages.append((
"{} is reproduced by {}".format(name, parameters.get("name", "[NO NAME]")),
"{} does not support in-place operation".format(parameters.get("name",
"[NO NAME]"))
))
# check all blobs that are generated in one layer if they are generated only once in each phase.
for name in sourced:
phase = [] # this list can hold train, test and both
p = LayerHelper.getLayerPhase(layer)
if p == "":
phase = [LayerHelper.PHASE_TEST, LayerHelper.PHASE_TRAIN]
else:
phase.append(p)
# if a blob already exists, check if it was generated before in the same Phase
if name in blobGenerators:
found_match = False
for candidate in blobGenerators[name]:
intersection = set(phase).intersection(candidate[1])
if len(intersection) > 0:
found_match = True
self._errorMessages.append(("Sources are {} and {} in phase {}".format(
parameters.get("name", "[NO NAME]"),
candidate[0],
list(intersection)[0]),
"{} is generated by multiple layers".format(name)))
if not found_match:
blobGenerators[name].append((parameters.get("name", "[NO NAME]"), phase))
else:
blobGenerators[name] = [(parameters.get("name", "[NO NAME]"), phase)]
def _checkDataLayerExistence(self):
"""Check whether at least one data layer exists (during the training phase)."""
......
......@@ -496,82 +496,111 @@ class Project(QObject, LogCaller):
self.newSession.emit(sid)
return sid
def cloneSession(self, solverstate, old_session=None):
""" Return a new session cloned from the solverstate.
If old_session is None the solverstate is expected to be a valid path
to a solverstate.
If the old_session exists the solverstate is searched within the
sessions snapshot directory.
Copies the last snapshot to the new sessions snapshot directory.
def cloneRemoteSession(self, oldSolverstate, oldSession):
"""
if old_session:
snapshot_dir = old_session.getSnapshotDirectory()
solverstate = os.path.basename(solverstate)
old_lss = None
if os.path.isabs(snapshot_dir):
old_lss = os.path.join(snapshot_dir, solverstate)
Starts the cloning process for a remote session and creates the corr. local session upon success
oldSolverstate: solverstate produced by the snapshot from which the clone should be created
oldSession: session from which a clone should be created (type ClientSession)
"""
# validate the given session and solverstate
if oldSolverstate is None:
Log.error('Could not find solver',
self.getCallerId())
return None
if oldSession is None:
Log.error('Failed to create session!', self.getCallerId())
return None
sid = self.getNextSessionId()
# call the remote host to invoke cloning; @see cloneSession in server_session_manager.py
msg = {"key": Protocol.CLONESESSION, "pid": self.projectId,
"sid": sid, "old_uid": oldSession.uid, "old_solverstate": oldSolverstate}
ret = sendMsgToHost(oldSession.remote[0], oldSession.remote[1], msg)
# receive and validate answer
if ret:
if ret["status"]:
uid = ret["uid"]
else:
for e in ret["error"]:
Log.error(e, self.getCallerId())
return None
else:
Log.error('Failed to clone remote session! No connection to Host', self.getCallerId())
return None
# Create a corr. local session and copy (if available) the state-dictionary to maintain
# solver/net etc.
session = ClientSession(self, oldSession.remote, uid, sid)
if hasattr(oldSession, 'state_dictionary'):
session.state_dictionary = oldSession.state_dictionary
self.__sessions[sid] = session
self.newSession.emit(sid)
return sid
def cloneSession(self, oldSolverstate, oldSession):
"""
Creates a new session with the same net/solver etc. and a additional .caffemodel file with pretrained weights
oldSolverstate: solverstate produced by the snapshot from which the clone should be created
oldSession: session from which a clone should be created (type ClientSession)
"""
if type(oldSolverstate) is not str and type(oldSolverstate) is not unicode:
# if no valid solverstate is given, take the last model created for this session
oldCaffemodel = oldSession.getLastModel()
else:
snapshotDir = os.path.join(oldSession.getSnapshotDirectory(), oldSolverstate)
oldCaffemodel = loader.getCaffemodelFromSolverstate(snapshotDir)
if oldCaffemodel is None:
Log.error('Could not find model',
self.getCallerId())
return None
if oldSession:
oldSnapshotDir = oldSession.getSnapshotDirectory()
if os.path.isabs(oldSnapshotDir):
# locate caffemodel
oldModelPath = os.path.join(oldSnapshotDir, oldCaffemodel)
else:
old_sdir = old_session.getDirectory()
old_lss = os.path.join(old_sdir, snapshot_dir, solverstate)
oldSessionDir = oldSession.getDirectory()
oldModelPath = os.path.join(oldSessionDir, oldSnapshotDir, oldCaffemodel)
else:
old_lss = solverstate
Log.error('Failed to create session!', self.getCallerId())
return None
# create new session
sessionID = self.createSession()
self.__sessions[sessionID].setParserInitialized()
if sessionID is None:
Log.error('Failed to create session!', self.getCallerId())
return None
new_sdir = self.__sessions[sessionID].getSnapshotDirectory()
self.__sessions[sessionID].setParserInitialized()
# create directories for the new session
newSnapshotDir = self.__sessions[sessionID].getSnapshotDirectory()
self.__ensureDirectory(self.__sessions[sessionID].getDirectory())
self.__ensureDirectory(new_sdir)
if os.path.isdir(new_sdir):
# copy solverstate
new_lss = os.path.join(new_sdir, os.path.basename(solverstate))
try:
shutil.copy2(old_lss, new_lss)
self.__sessions[sessionID].setLastSnapshot(new_lss)
except Exception as e:
Log.error('Failed to copy solverstate to new session: '+str(e),
self.getCallerId())
# copy caffemodel
model_file = loader.getCaffemodelFromSolverstate(old_lss)
iterations = loader.getIterFromSolverstate(old_lss)
if model_file is None:
model_file = loader.getCaffemodelFromSolverstateHdf5(old_lss)
if model_file is None:
Log.error('Could not load model from solverstate '+old_lss,
self.getCallerId())
self.__sessions[sessionID].delete()
return None
model_dir = os.path.dirname(old_lss)
old_lcm = os.path.join(model_dir, model_file)
new_caffemodel = os.path.basename(model_file)
new_lcm = os.path.join(new_sdir, new_caffemodel)
self.__ensureDirectory(newSnapshotDir)
if os.path.isdir(newSnapshotDir):
newCaffemodel = 'pretrained.caffemodel'
newModelPath = os.path.join(newSnapshotDir, newCaffemodel)
try:
shutil.copy2(old_lcm, new_lcm)
self.__sessions[sessionID].setLastModel(new_lcm)
self.__sessions[sessionID].iteration = iterations
if old_session is not None:
self.__sessions[sessionID].max_iter = old_session.max_iter
self.__sessions[sessionID].caffe_root = old_session.caffe_root
# copy the old caffemodel to the new location
shutil.copy2(oldModelPath, newModelPath)
# initialize new session
self.__sessions[sessionID].setPretrainedWeights(newCaffemodel)
self.__sessions[sessionID].iteration = 0
self.__sessions[sessionID].max_iter = oldSession.max_iter
self.__sessions[sessionID].caffe_root = oldSession.caffe_root
except Exception as e:
Log.error('Failed to copy caffemodel to new session: '+str(e),
self.getCallerId())
# copy state dictionary
if old_session is not None:
#TODO: This will cause troubles if the current session was changed during training (See #270)
self.__sessions[sessionID].setStateDict(old_session.state_dictionary)
if self.__sessions[sessionID].iteration > 0:
self.__sessions[sessionID].setState(State.PAUSED)
else:
self.__sessions[sessionID].setState(State.WAITING)
# copy the old state-dict into the new session
self.__sessions[sessionID].setStateDict(oldSession.state_dictionary)
self.__sessions[sessionID].setState(State.WAITING)
self.__sessions[sessionID].save(includeProtoTxt=True)
return sessionID
else:
self.__sessions[sessionID].delete()
Log.error('Snapshot directory '+new_sdir+' does not exist!',
Log.error('Snapshot directory '+newSnapshotDir+' does not exist!',
self.getCallerId())
return None
......@@ -612,6 +641,13 @@ class Project(QObject, LogCaller):
self.__sessions.pop(sid)
del session
def closeSession(self, remoteSession):
""" Close a remote session and remove it from the available sessions"""
sid = remoteSession.getSessionId()
remoteSession.close()
self.__sessions.pop(sid)
del remoteSession
def isSession(self, directory):
""" Checks if the directory contains a valid session."""
if not os.path.exists(directory):
......
......@@ -211,8 +211,10 @@ class ClientSession(QObject):
return 0
def getMaxIteration(self):
# maybe let this be local + signal
if self._assertConnection():
if self.lastMaxIter > 0 \
and self.getState() in [State.RUNNING, State.FAILED, State.FINISHED, State.NOTCONNECTED]:
return self.lastMaxIter
elif self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.GETMAXITERATION}
self.transaction.send(msg)
ret = self.transaction.asyncRead(staging=True, attr=("subkey", SessionProtocol.GETMAXITERATION))
......@@ -220,7 +222,17 @@ class ClientSession(QObject):
self.lastMaxIter = ret["iteration"]
return self.lastMaxIter
self._handleErrors(["Failed to connect to remote session to acquire max iteration."])
return 0
return 1
def getPretrainedWeights(self):
if self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.GETPRETRAINED}
self.transaction.send(msg)
ret = self.transaction.asyncRead(staging=True, attr=("subkey", SessionProtocol.GETPRETRAINED))
if ret:
return ret.get("pretrained", None)
self._handleErrors(["Failed to connect to remote session to acquire pre-trained weights."])
return None
def setMaxIteration(self, maxIteration):
if maxIteration <= 0:
......@@ -245,6 +257,11 @@ class ClientSession(QObject):
self._handleErrors(["Failed to connect to remote session to acquire current state."])
self.setState(State.NOTCONNECTED, True)
def close(self):
if self._assertConnection():
msg = {"key": Protocol.DISCONNECTSESSION, "uid": self.uid}
self.transaction.send(msg)
def reset(self):
if self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.RESET}
......@@ -309,6 +326,7 @@ class ClientSession(QObject):
self._handleErrors(["Could not save session."])
return False
def setStateDict(self, stateDict):
self.lastStateDict = stateDict
......
......@@ -54,6 +54,7 @@ class ServerSession(QObject, ParserListener, SessionCommon):
SessionProtocol.GETMAXITERATION: self._msgGetMaxIteration,
SessionProtocol.SETMAXITERATION: self._msgSetMaxIteration,
SessionProtocol.GETSTATE: self._msgGetState,
SessionProtocol.GETPRETRAINED: self._msgGetPretrainedWeights,
SessionProtocol.SETSTATEDICT: self._msgSetStateDict,
SessionProtocol.GETSTATEDICT: self._msgGetStateDict,
SessionProtocol.SAVE: self._msgSave,
......@@ -91,6 +92,7 @@ class ServerSession(QObject, ParserListener, SessionCommon):
self.state = State.WAITING
self.invalidErrorsList = []
self.last_solverstate = None
self.pretrainedWeights = None
self.logs = os.path.join(directory, 'logs')
# run stuff
......@@ -345,12 +347,26 @@ class ServerSession(QObject, ParserListener, SessionCommon):
msg["status"] = True
self.transaction.send(msg)
def _msgGetPretrainedWeights(self):
msg = self.transaction.asyncRead()
msg["pretrained"] = self.getPretrainedWeights()
self.transaction.send(msg)
def getLastModel(self):
return self.last_caffemodel
def getIteration(self):
return self.iteration
def getMaxIteration(self):
return self.max_iter
def getPretrainedWeights(self):
return self.pretrainedWeights
def setPretrainedWeights(self, weights):
self.pretrainedWeights = weights
def _getState(self):
""" Return the state of the session.
"""
......@@ -375,7 +391,13 @@ class ServerSession(QObject, ParserListener, SessionCommon):
def _msgSetStateDict(self):
msg = self.transaction.asyncRead()
self.state_dictionary = msg["statedict"]
self.setStateDict(msg["statedict"])
del msg["statedict"]
msg["status"] = True
self.transaction.send(msg)
def setStateDict(self, statedict):
self.state_dictionary = statedict
self._parseSetting(self.state_dictionary)
# restore lost types
if hasattr(self.state_dictionary, '__getitem__'):
......@@ -388,10 +410,6 @@ class ServerSession(QObject, ParserListener, SessionCommon):
typename = layers[id]["parameters"]["type"]
layers[id]["type"] = info.CaffeMetaInformation().availableLayerTypes()[typename]
del msg["statedict"]
msg["status"] = True
self.transaction.send(msg)
def _msgGetStateDict(self):
msg = self.transaction.asyncRead()
......@@ -418,6 +436,8 @@ class ServerSession(QObject, ParserListener, SessionCommon):
toSave["ProjectID"] = self.pid
if self.last_solverstate:
toSave["LastSnapshot"] = self.last_solverstate
if self.pretrainedWeights:
toSave["PretrainedWeights"] = self.pretrainedWeights
if self.state_dictionary:
serializedDict = copy.deepcopy(self.state_dictionary)
if "network" in serializedDict:
......@@ -546,6 +566,8 @@ class ServerSession(QObject, ParserListener, SessionCommon):
if "LastSnapshot" in settings:
self.last_solverstate = settings["LastSnapshot"]
if "PretrainedWeights" in settings:
self.pretrainedWeights = settings["PretrainedWeights"]
if "NetworkState" in settings:
self.state_dictionary = settings["NetworkState"]
layers = self.state_dictionary["network"]["layers"]
......@@ -1063,10 +1085,11 @@ class ServerSession(QObject, ParserListener, SessionCommon):
self.logsig.emit('Failed to delete logs folder: ' + str(e), self.getCallerId(), True)
for filename in filenames:
if filename.endswith(".solverstate") or filename.endswith(".caffemodel"):
try:
os.remove(os.path.join(dirpath, filename))
except OSError as e:
self.logsig.emit('Failed to delete ' + str(filename) + ': ' + str(e), self.getCallerId(), True)
if not filename == self.getPretrainedWeights():
try:
os.remove(os.path.join(dirpath, filename))
except OSError as e:
self.logsig.emit('Failed to delete ' + str(filename) + ': ' + str(e), self.getCallerId(), True)
if filename in ["net-internal.prototxt", "net-original.prototxt", "solver.prototxt"]:
try:
os.remove(os.path.join(dirpath, filename))
......
......@@ -71,6 +71,7 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
self.parse_old = parse_old
self.caffe_root = caffe_root # overrides project caffe_root if necessary, i.e. if deployed to another system
self.pretrainedWeights = None
self.last_solverstate = last_solverstate
self.last_caffemodel = last_caffemodel
self.state_dictionary = state_dictionary # state as saved from the network manager, such it can be restored
......@@ -546,6 +547,12 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
def setLastModel(self, lcm):
self.last_caffemodel = lcm
def getPretrainedWeights(self):
return self.pretrainedWeights
def setPretrainedWeights(self, weights):
self.pretrainedWeights = weights
def setLastSnapshot(self, lss):
self.last_solverstate = lss
......@@ -790,6 +797,8 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
Log.log("Saving current Session status to disk.", self.getCallerId())
if self.last_solverstate:
toSave["LastSnapshot"] = self.last_solverstate
if self.getPretrainedWeights():
toSave["PretrainedWeights"] = self.getPretrainedWeights()
if self.state_dictionary:
serializedDict = copy.deepcopy(self.state_dictionary)
if includeProtoTxt:
......@@ -920,6 +929,9 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
if "LastSnapshot" in settings:
self.last_solverstate = settings["LastSnapshot"]
if "PretrainedWeights" in settings:
self.setPretrainedWeights(settings["PretrainedWeights"])
if "NetworkState" in settings:
self.state_dictionary = settings["NetworkState"]
layers = self.state_dictionary["network"]["layers"]
......@@ -927,7 +939,8 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
if "parameters" in layers[id]:
if "type" in layers[id]["parameters"]:
typename = layers[id]["parameters"]["type"]
layers[id]["type"] = info.CaffeMetaInformation().availableLayerTypes()[typename]
layers[id]["type"] = info.CaffeMetaInformation().getLayerType(typename)
# layers[id]["type"] = info.CaffeMetaInformation().availableLayerTypes()[typename]
solver = self.state_dictionary["solver"]
if solver:
if "snapshot_prefix" in solver:
......@@ -1021,10 +1034,11 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
Log.error('Failed to delete logs folder: ' + str(e), self.getCallerId())
for filename in filenames:
if filename.endswith(".solverstate") or filename.endswith(".caffemodel"):
try:
os.remove(os.path.join(dirpath, filename))
except OSError as e:
Log.error('Failed to delete ' + str(filename) + ': ' + str(e), self.getCallerId())
if not filename == self.getPretrainedWeights():
try:
os.remove(os.path.join(dirpath, filename))
except OSError as e:
Log.error('Failed to delete ' + str(filename) + ': ' + str(e), self.getCallerId())
if filename in ["net-internal.prototxt", "net-original.prototxt", "solver.prototxt"]:
try:
os.remove(os.path.join(dirpath, filename))
......
......@@ -45,6 +45,13 @@ def getMultipleHash(dbPaths):
hashValue += getHash(path)
return hashValue
"""
returns the sha256 hash value of a single string
@author j_stru18
@param inputStr : the string that has to be hashed
@return : the hashValue of the string
"""
def getStringHash(inputStr):
hashObject = hashlib.sha256()
hashObject.update(inputStr)
......@@ -81,7 +88,8 @@ returns a list containing the paths to all .h5 or .hdf5 files that are specified
exists
@author j_stru18
@param path : a string that is supposed to be a valid path for a HDF5TXT file
@ensure: if filePath points to a valid file, it has to a textfile
@param filePath : a string that is supposed to be a valid path for a HDF5TXT file
@return a list that contains all .h5 and .hdf5 files in the HDF5TXT
"""
def getHdf5List(filePath):
......@@ -98,6 +106,23 @@ def getHdf5List(filePath):
Log.error("HDF5 textfile contained invalid path {}".format(line), DBUTIL_LOGGER_ID)
return pathList
"""
returns a list containing all lines of the file that filePath points at
@author: j_stru18
@param filePath : a string that is supposed to be a path for a txt file
@return: a list that contains all lines of that file
"""
def getLinesAsList(filePath):
pathList = []
if os.path.exists(filePath):
pathList = [line for line in open(filePath)]
return pathList
"""
Given the name of a file, this method returns its type
......
......@@ -98,7 +98,6 @@ def loadNet(netstring):
res["layers"], res["layerOrder"] = _load_layers(net.layer)
res = copy.deepcopy(res)
return res
def _load_layers(layerlist):
......@@ -113,9 +112,7 @@ def _load_layers(layerlist):
for layer in layerlist:
typename = layer.type
if not allLayers.has_key(typename):
raise ParseException("Layer with type {} not available in caffe".format(typename))
layerinfo = allLayers[typename]
layerinfo = info.CaffeMetaInformation().getLayerType(typename)
id = str(uuid.uuid4())
res[id]={
"type": layerinfo,
......@@ -127,7 +124,7 @@ def _load_layers(layerlist):
if "name" not in res[id]["parameters"]:
typeName = res[id]["parameters"]["type"]
newName = typeName + " #" + str(dicLayerTypeCounter[typeName])
changed = True;
changed = True
while changed == True:
newName = typeName + " #" + str(dicLayerTypeCounter[typeName])
changed = False
......
......@@ -27,6 +27,15 @@ class Singleton(type):
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class UnknownLayerTypeException(Exception):
def __init__(self, msg):
self._msg = msg
def __str__(self):
return self._msg
class CaffeMetaInformation:
"""This class is the main interface for any usage outside of this module.
......@@ -158,6 +167,13 @@ class CaffeMetaInformation:
"""
return self._availableLayerTypes
def getLayerType(self, typename):
""" Returns the layer type for typename or throws an UnknownLayerTypeException
if a type with typename is not known. """
if typename not in self._availableLayerTypes:
raise UnknownLayerTypeException("Network contains an unknown layer type '{}'.".format(typename))
return self._availableLayerTypes[typename]
def commonParameters(self):
""" Return parameters all layers share.
It is missing specific paremeters like "inner_product_param"
......@@ -256,12 +272,16 @@ class LayerType(TopLevelEntityType):
def isActivationLayer(self):
return self._category == self.CATEGORY_ACTIVATION
def allowsInPlace(self):
inPlaceLayers = ['Dropout']
return self._name in inPlaceLayers or self.isActivationLayer()
@staticmethod
def getCategoryByName(layerTypeName):
"""Given its name, determine which category a LayerType does belong to."""
# !!! hard coded list of activation layers could cause problems when more layers are added in caffe
if layerTypeName in ["ReLU", "PReLU", "ELU", "Sigmoid", "TanH", "AbsVal", "Power",
"Exp", "Log", "BNLL", "Threshold", "Bias", "Scale" ]:
"Exp", "Log", "BNLL", "Threshold", "Bias", "Scale"]:
return LayerType.CATEGORY_ACTIVATION
elif "Loss" in layerTypeName:
return LayerType.CATEGORY_LOSS
......
......@@ -7,10 +7,11 @@ from backend.barista.utils.logger import Log
class Hdf5Input:
def __init__(self):
def __init__(self, pathOfHdf5Txt=False):
self._db = None
self._path = None
self.logid = Log.getCallerId('HDF5 Input')
self._pathOfHdf5Txt = pathOfHdf5Txt # HDF5TXT files can contain commentary lines that are no paths
self._logid = Log.getCallerId('HDF5 Input')
def __del__(self):
self.close()
......@@ -19,6 +20,10 @@ class Hdf5Input:
'''set the path of the database.'''
self._path = path
def getPath(self):
'''get the path of this HDF5Input object'''
return self._path
def open(self):
'''open the database from the set path.'''
if self._db:
......@@ -28,10 +33,10 @@ class Hdf5Input:
try:
self._db = h5.File(self._path, 'r')
except:
Log.error("File not valid HDF5: " + self._path, self.logid)
Log.error("File not valid HDF5: " + self._path, self._logid)
self._db = None
else:
Log.error("File does not exist: " + self._path, self.logid)
elif not self._pathOfHdf5Txt:
Log.error("File does not exist: " + self._path, self._logid)
def close(self):
if self._db:
......
......@@ -8,7 +8,7 @@ class Hdf5TxtInput:
def __init__(self):
self._db = []
self._path = None
self.logid = Log.getCallerId("HDF5TXT Input")
self._logid = Log.getCallerId("HDF5TXT Input")
self._projectPath = None
def __del__(self):
......@@ -23,20 +23,21 @@ class Hdf5TxtInput:
if self._path is not None:
if os.path.exists(self._path):
lines = [line.rstrip('\n') for line in open(self._path)]
hdf5Count = 0
for line in lines:
if line is not "":
if line[:1] == '.':
line = self._makepath(line)
i = len(self._db)
self._db.append(Hdf5Input())
self._db.append(Hdf5Input(pathOfHdf5Txt=True))
self._db[i].setPath(line)
self._db[i].open()
if not self._db[i].isOpen():
Log.error("File contains invalid HDF5: " + self._path, self.logid)
self._db = None
return
else:
Log.error("File does not exist: " + self._path, self.logid)
if self._db[i].isOpen():
hdf5Count += 1