Commit 18bd646a authored by Aaron Scherzinger's avatar Aaron Scherzinger

Merge remote-tracking branch 'origin/develop'

parents 716c2409 480e7712
......@@ -16,4 +16,5 @@
*.xml
*.py.orig
/barista.conf
/baristalog.txt
\ No newline at end of file
/baristalog.txt
/.vscode/
Barista is an open-source graphical high-level interface for the Caffe deep learning framework written in Python.
**Overview**
* Barista is distributed under the MIT license (see LICENSE file).
* For more information, please refer to https://www.uni-muenster.de/PRIA/Barista/
* Please note that Barista is currently mainly developed on Linux and that Windows and Mac OS support is still experimental.
**Requirements**
* Barista uses Python 2.7
* In order to use Barista, please install the Python requirements in the requirements.txt file, e.g. using pip:
`pip install --user -r requirements.txt`
* Barista also requires a running version of Caffe. For more information on installing Caffe and obtaining the newest version, please refer to https://github.com/BVLC/caffe and http://caffe.berkeleyvision.org/
**Using Barista**
* To start Barista, run the file `main.py`
* A user manual as well as a starter's tutorial on Barista can be found in the wiki of the project repository: https://zivgitlab.uni-muenster.de/pria/Barista
......@@ -13,9 +13,20 @@ def ensureProjectDataConstraints(projectData):
This method should be called each time when the project data has been changed. It will adjust some values, which
would be valid for the prototxt-syntax in general, but underlay further constraints especially for Barista.
"""
projectData["solver"] = ensureSolverConstraints(projectData["solver"])
if not hasattr(projectData, '__getItem__'):
Log.error('Project data is empty.', logID)
return projectData
jsonKeys = ("activeSession", "environment", "inputdb", "projectid", "transform")
keyNotFound = False
for key in jsonKeys:
if key not in projectData:
Log.error("Project information is invalid! Key '"
+ key + "' is missing!", logID)
keyNotFound = True
if keyNotFound:
Log.error("Session at %s is invalid. Key %s in 'sessionstate.json' is missing", logID)
# add further constraint types like e.g.:
#projectData["network"] = ensureNetworkConstraints(projectData["network"])
return projectData
This diff is collapsed.
......@@ -39,6 +39,7 @@ class ClientSession(QObject):
self.uid = uid
self.lastStateDict = None
self.state = State.NOTCONNECTED
self.invalidErrorsList = []
self._previousState = State.WAITING
# setup timer
# see details in setStateDict()
......@@ -222,6 +223,14 @@ class ClientSession(QObject):
self._handleErrors(["Failed to connect to remote session to acquire max iteration."])
return 0
def delete(self):
if self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.DELETE}
self.transaction.send(msg)
else:
self._handleErrors(["Failed to connect to remote session to acquire current state."])
self.setState(State.NOTCONNECTED, True)
def getState(self, local=True):
if local and not self.getStateFirstTime:
return self.state
......@@ -249,10 +258,16 @@ class ClientSession(QObject):
if not silent:
self.stateChanged.emit(self.state)
def setErrorList(self, errorList):
self.invalidErrorsList = errorList
def _updateState(self):
msg = self.transaction.asyncRead(attr=("subkey", SessionProtocol.UPDATESTATE))
self.setState(msg["state"])
def getErrorList(self):
return self.invalidErrorsList
def getValidatedState(self):
# TODO the getValidPreviousState function seams broken across all sessions and makes no sense the way it is
# TODO implemented for the normal sessions
......@@ -333,13 +348,15 @@ class ClientSession(QObject):
if ret:
if ret["status"]:
self.lastStateDict = ret["statedict"]
layers = self.lastStateDict["network"]["layers"]
for id in layers:
if "parameters" in layers[id]:
if "type" in layers[id]["parameters"]:
typename = layers[id]["parameters"]["type"]
layers[id]["type"] = info.CaffeMetaInformation().availableLayerTypes()[typename]
try:
layers = self.lastStateDict["network"]["layers"]
for id in layers:
if "parameters" in layers[id]:
if "type" in layers[id]["parameters"]:
typename = layers[id]["parameters"]["type"]
layers[id]["type"] = info.CaffeMetaInformation().availableLayerTypes()[typename]
except KeyError:
pass
return self.lastStateDict
else:
self._handleErrors(ret["error"])
......
......@@ -6,9 +6,11 @@ import signal
import sys
import time
import uuid
import shutil
from collections import OrderedDict
from subprocess import Popen, PIPE, STDOUT
from threading import Lock
from PyQt5.Qt import QObject
from PyQt5.QtCore import Qt, pyqtSignal
......@@ -59,7 +61,8 @@ class ServerSession(QObject, ParserListener, SessionCommon):
SessionProtocol.TAKESNAPSHOT: self._takeSnapshot,
SessionProtocol.FETCHPARSERDATA: self._fetchParserData,
SessionProtocol.LOADINTERNALNET: self._loadInternalNet,
SessionProtocol.LOADNETPARAMETER: self._loadNetParameter}
SessionProtocol.LOADNETPARAMETER: self._loadNetParameter,
SessionProtocol.DELETE: self._delete}
self.logsig.connect(self.Log, Qt.AutoConnection)
self.parssig.connect(self.addParserRow, Qt.AutoConnection)
......@@ -276,6 +279,16 @@ class ServerSession(QObject, ParserListener, SessionCommon):
msg["check"] = req.getErrors()
self.transaction.send(msg)
def _delete(self):
self.pause() # make sure the session is not running
try:
shutil.rmtree(self.getDirectory())
except OSError as e:
sys.stderr.write(str(e)+'\n') # python docs say, that rmtree should raise an OSError code 66 if dir is not empty. However, this does not seem to happen.
self.disconnect()
return
def _getSnapshots(self):
""" Return all snapshot files, keyed by iteration number.
"""
......@@ -312,7 +325,7 @@ class ServerSession(QObject, ParserListener, SessionCommon):
def _setStateDict(self):
msg = self.transaction.asyncRead()
self.state_dictionary = msg["statedict"]
# resore lost types
# restore lost types
layers = self.state_dictionary["network"]["layers"]
for id in layers:
if "parameters" in layers[id]:
......@@ -330,15 +343,18 @@ class ServerSession(QObject, ParserListener, SessionCommon):
msg = self.transaction.asyncRead()
transferDict = copy.deepcopy(self.state_dictionary)
layers = transferDict["network"]["layers"]
for id in layers:
del layers[id]["type"]
msg["statedict"] = transferDict
msg["status"] = True
try:
layers = transferDict["network"]["layers"]
for id in layers:
del layers[id]["type"]
msg["statedict"] = transferDict
msg["status"] = True
except KeyError as e:
msg["statedict"] = transferDict
msg["status"] = True
self.transaction.send(msg)
def save(self):
def save(self, includeProtoTxt = False):
"""Saves the current session to prototxt files and session_settings json file."""
toSave = {"SessionState": self.state, "Iteration": self.iteration, "MaxIter": self.max_iter}
toSave["UID"] = self.uid
......@@ -349,14 +365,16 @@ class ServerSession(QObject, ParserListener, SessionCommon):
if self.state_dictionary:
serializedDict = copy.deepcopy(self.state_dictionary)
if "network" in serializedDict:
netDict = copy.deepcopy(self.state_dictionary["network"])
net = saver.saveNet(netdict=netDict)
with open(os.path.join(self.directory, Paths.FILE_NAME_NET_ORIGINAL), 'w') as f:
f.write(net)
if includeProtoTxt:
netDict = copy.deepcopy(self.state_dictionary["network"])
net = saver.saveNet(netdict=netDict)
with open(os.path.join(self.directory, Paths.FILE_NAME_NET_ORIGINAL), 'w') as f:
f.write(net)
layers = serializedDict["network"]["layers"]
for id in layers:
del layers[id]["type"]
if "layers" in serializedDict["network"]:
layers = serializedDict["network"]["layers"]
for id in layers:
del layers[id]["type"]
toSave["NetworkState"] = serializedDict
......@@ -496,15 +514,16 @@ class ServerSession(QObject, ParserListener, SessionCommon):
regex_rid = re.compile('([\d]+)\.([\d]+)\.log')
run_id = 0
for entry in os.listdir(self.logs):
rid_match = regex_rid.search(entry)
if rid_match:
try:
_run_id = int(rid_match.group(2))
if _run_id > run_id:
run_id = _run_id
except:
pass
if os.path.exists(self.logs):
for entry in os.listdir(self.logs):
rid_match = regex_rid.search(entry)
if rid_match:
try:
_run_id = int(rid_match.group(2))
if _run_id > run_id:
run_id = _run_id
except:
pass
self.rid = run_id
def getSnapshotDirectory(self):
......@@ -743,6 +762,14 @@ class ServerSession(QObject, ParserListener, SessionCommon):
def start(self, solverstate=None, caffemodel=None):
# TODO only one training per server
error = []
# (re-)write all session files
self.save(includeProtoTxt=True)
if not os.path.exists(self.logs):
try:
os.makedirs(self.logs)
except OSError as e:
error.extend(['Failed to create Folder: ' + str(e)])
if self.rid is None:
self._parseSessionAndRunID()
if self.state is State.WAITING:
......
This diff is collapsed.
......@@ -107,11 +107,13 @@ class Logger(QObject):
if callable(log_id):
callerId = caller.getCallerId()()
except Exception:
# TODO: Warn properly that something went wrong
print('From unknown callerID: ' + line)
pass
if callerId is None:
print('Log.log: no callerId')
return
self.appendLine(line, callerId)
elif callerId >= len(self.__callers):
print('From invalid callerID: ' + line)
else:
self.appendLine(line, callerId)
def error(self, line, callerId=None):
'''
......
import caffe.proto.caffe_pb2 as proto
from backend.caffe.loader import _extract_param
import backend.caffe.proto_info as info
import uuid
......@@ -100,6 +99,8 @@ def _areParameterValid(allparams, paramdict, prefix=""):
def bareNet(name):
""" Creates a dictionary of a networks with default values where required. """
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
net = proto.NetParameter()
descr = info.ParameterGroupDescriptor(net)
params = descr.parameter().copy()
......@@ -115,6 +116,8 @@ def _bareLayer(layertype, name):
""" Creates a dictionary of the given layertype with the given name
initialized with default values if required.
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
res = {"type": layertype}
layerParamInst = proto.LayerParameter()
res["parameters"] = _extract_param(layerParamInst, info.CaffeMetaInformation().commonParameters())
......@@ -169,7 +172,7 @@ class DictHelper:
return self._netdic["layers"][id]
def hasLayerWithId(self,id):
""" Return true iff there is a layer with the given id"""
""" Return true if there is a layer with the given id"""
return id in self._netdic["layers"]
def addLayer(self, layertype, name, idx):
......
import caffe.proto.caffe_pb2 as proto
import uuid
import h5py as h5
......@@ -18,6 +17,8 @@ class ParseException(Exception):
def loadSolver(solverstring):
""" Return a dictionary which represent the caffe-solver-prototxt solverstring """
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
solver = proto.SolverParameter()
# Get DESCRIPTION for meta infos
......@@ -65,6 +66,8 @@ def loadNet(netstring):
}
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
# Load Protoclass for parsing
net = proto.NetParameter()
......@@ -104,6 +107,10 @@ def _load_layers(layerlist):
allLayers = info.CaffeMetaInformation().availableLayerTypes()
order = []
res = {}
dicLayerTypeCounter = {}
for layer in allLayers:
dicLayerTypeCounter[layer] = 1
for layer in layerlist:
typename = layer.type
if not allLayers.has_key(typename):
......@@ -115,6 +122,21 @@ def _load_layers(layerlist):
"parameters": _extract_param(layer,layerinfo.parameters())
}
order.append(id)
for id in order:
if "name" not in res[id]["parameters"]:
typeName = res[id]["parameters"]["type"]
newName = typeName + " #" + str(dicLayerTypeCounter[typeName])
changed = True;
while changed == True:
newName = typeName + " #" + str(dicLayerTypeCounter[typeName])
changed = False
for id2 in order:
if "name" in res[id2]["parameters"] and res[id2]["parameters"]["name"] == newName:
dicLayerTypeCounter[typeName] += 1
changed = True
res[id]["parameters"]["name"] = newName
return res, order
def _extract_param(value,parameters):
......@@ -153,6 +175,8 @@ def extractNetFromSolver(solverstring):
This works only, if the solver specifies a network using the "net_param" parameter. A reference to a file using the
"net" parameter can not be handled by this method.
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
# create empty solver message
solver = proto.SolverParameter()
......@@ -171,6 +195,8 @@ def extractNetFromSolver(solverstring):
def getCaffemodelFromSolverstate(solverstate):
""" Parse the filename of the caffemodel file from the solverstate file.
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
try:
state = proto.SolverState()
......@@ -190,3 +216,17 @@ def getCaffemodelFromSolverstateHdf5(filename):
return None
except:
return None
def getIterFromSolverstate(solverstate):
""" Parse the iterations from the solverstate file.
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
try:
state = proto.SolverState()
with open(solverstate, 'rb') as f:
state.ParseFromString(f.read())
return state.iter
except Exception as e:
print(str(e))
......@@ -27,16 +27,28 @@ def checkIfFileIsCaffeRoot(fileName):
else:
return False
def importCaffe():
import importlib
import sys
def reloadCaffeModule(path):
"""reloads all "import caffe" thinks """
if checkIfFileIsCaffeRoot(path):
path = path + "/python/caffe/__init__.py"
imp.load_source('caffe', path)
def reloadProtoModule(path):
"""reloads all "import caffe.proto.caffe_pb2 as proto" thinks """
if checkIfFileIsCaffeRoot(path):
path = path + "/python/caffe/proto/caffe_pb2.py"
imp.load_source('caffe.proto.caffe_pb2', path)
path = getCaffePath()
path += "/python"
sys.path.insert(0, path) # stellt sicher, dass der angegebene Pfad als erstes durchsucht wird
try:
return importlib.import_module("caffe")
except ImportError as e:
print(e)
sys.exit(1)
def importProto():
import importlib
import sys
path = getCaffePath()
path += "/python"
sys.path.insert(0, path) # stellt sicher, dass der angegebene Pfad als erstes durchsucht wird
try:
return importlib.import_module("caffe.proto.caffe_pb2")
except ImportError as e:
print(e)
sys.exit(1)
import caffe
import os
import re
......@@ -46,7 +44,8 @@ class CaffeProtoParser:
def _readFile(self):
"""Read the content of the caffe.proto file."""
import backend.caffe.path_loader as caffeLoader
caffe = caffeLoader.importCaffe()
# first of all, try to use the environment variable CAFFE_ROOT
caffeRootPath = path_loader.getCaffePath()
......@@ -190,7 +189,7 @@ class CaffeProtoParser:
comment = comment.capitalize()
return comment
def fieldDescriptions(self):
"""Get all descriptions provided in the caffe.proto file's comments, that belong to a field inside of a message.
......@@ -231,4 +230,3 @@ class CaffeProtoParser:
description = self.messageDescriptions()[fieldMsgDefault]
return description
import caffe
import caffe.proto.caffe_pb2 as proto
from google.protobuf.descriptor import FieldDescriptor
import sys, inspect
from backend.caffe import proto_description, path_loader
......@@ -51,6 +49,8 @@ class CaffeMetaInformation:
See description of self.availableParameterGroupDescriptors().
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
current_module = sys.modules[proto.__name__]
res = {}
for (el,val) in inspect.getmembers(current_module, inspect.isclass):
......@@ -61,6 +61,8 @@ class CaffeMetaInformation:
"""Generate information about available layer types only once.
self.__initAvailableParameterGroupDescriptors() needs to be called before this method."""
import backend.caffe.path_loader as caffeLoader
caffe = caffeLoader.importCaffe()
layerNameMainParts = list(caffe.layer_type_list())
res = {}
......@@ -72,11 +74,11 @@ class CaffeMetaInformation:
commonParams = self._availableParameterGroupDescriptors["LayerParameter"].parameter() #use .parameter() on purpose
layerSpecificParameters = set()
for nameMainPart in layerNameMainParts:
specificParamsName = [nameMainPart + "Parameter"]
specificParamsName = [nameMainPart + "Parameter"]
if moreLayerNameParameter.has_key(nameMainPart):
specificParamsName.append( moreLayerNameParameter[nameMainPart])
paramsPerLayerType[nameMainPart] = {}
for key, value in commonParams.items():
for key, value in commonParams.items():
if value.isParameterGroup() and value.parameterName() in specificParamsName:
paramsPerLayerType[nameMainPart][key] = value
layerSpecificParameters.add(key)
......@@ -109,6 +111,8 @@ class CaffeMetaInformation:
"""Generate information about available solver types only once.
self.__initAvailableParameterGroupDescriptors() needs to be called before this method."""
import backend.caffe.path_loader as caffeLoader
caffe = caffeLoader.importCaffe()
# DO NOT REMOVE the following import statement, although your IDE might say it's unused. It's not!
from caffe._caffe import Solver as SolverBaseClassInfo
......@@ -474,6 +478,8 @@ def resetCaffeProtoModulesvar():
def _caffeProtobufModules():
""" Returns all available Classes of caffe_pb2 in a dictionary """
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
global _caffeprotomodulesvar
if _caffeprotomodulesvar is None:
current_module = sys.modules[proto.__name__]
......
import caffe.proto.caffe_pb2 as proto
from google.protobuf import text_format
import backend.caffe.proto_info as info
......@@ -10,6 +9,8 @@ def saveSolver(solverdict):
def _import_solver(solverdict):
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
solver = proto.SolverParameter()
for entry in solverdict:
......@@ -66,6 +67,8 @@ def saveNet(netdict):
def _import_dictionary(netdict):
"""fill the ProtoTxt-Net with data from the dictionary"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
net = proto.NetParameter()
for entry in netdict:
......
import caffe
import leveldb # pip install leveldb
import os.path
......@@ -52,6 +51,8 @@ class LeveldbInput:
def verifyConsistency(self):
'''check if all entries have the same channel number and size.
This may take a lot of time since every entry has to be checked.'''
import backend.caffe.path_loader as caffeLoader
caffe = caffeLoader.importCaffe()
if self.getDataCount() < 2:
return True
first = self._getFirstDatum()
......@@ -85,6 +86,8 @@ class LeveldbInput:
return self._db.RangeIter()
def _getFirstDatum(self):
import backend.caffe.path_loader as caffeLoader
caffe = caffeLoader.importCaffe()
iter = self._getIter()
if iter:
for key, value in iter:
......
import caffe
import os.path
import lmdb # pip install lmdb
......@@ -97,6 +96,8 @@ class LmdbInput:
return self._getCurrentDatum()
def _getCurrentDatum(self):
import backend.caffe.path_loader as caffeLoader
caffe = caffeLoader.importCaffe()
if self._cursor:
raw_datum = self._cursor.value()
datum = caffe.proto.caffe_pb2.Datum()
......
......@@ -58,3 +58,5 @@ class SessionProtocol:
LOADINTERNALNET = 50
LOADNETPARAMETER = 51
DELETE = 99
import os
import re
import sys
import json
from datetime import datetime
import logging
......@@ -16,6 +17,7 @@ from PyQt5.Qt import QObject
class ServerSessionManager(QObject):
poolEmptyJob = pyqtSignal(str)
jsonKeys = ("Iteration", "MaxIter", "NetworkState", "ProjectID", "SID", "SessionState", "UID")
def __init__(self, parent, sessionPath):
super(ServerSessionManager, self).__init__()
......@@ -155,7 +157,6 @@ class ServerSessionManager(QObject):
if os.path.exists(sdir) is False:
try:
os.makedirs(sdir)
os.makedirs(os.path.join(sdir, "logs"))
logging.info("Directory '%s' created.", sdir)
return True, sdir
except Exception as e:
......@@ -201,28 +202,33 @@ class ServerSessionManager(QObject):
print msg
def _isSession(self, dir):
"""Check if directory is a valid session directory."""
solver_file = os.path.join(dir, Paths.FILE_NAME_SOLVER)
"""Check if directory contains a valid session."""
session_json = os.path.join(dir,Paths.FILE_NAME_SESSION_JSON)
if os.path.isdir(dir) is False:
sys.stderr.write("Session directory " + dir + " is invalid! No directory!\n")
logging.error("Session directory %s is invalid. No directory", dir)
return False
if os.path.exists(solver_file):
snapshots = self._parseSnapshotPrefixFromFile(solver_file)
else:
sys.stderr.write("Session directory " + dir + " is invalid! Solver file '" + solver_file + " does not exist!\n")
logging.error("Session directory %s is invalid. Solver file does not exist!", dir)
return False
if snapshots:
snapshots = os.path.dirname(snapshots)
if snapshots and os.path.isdir(os.path.join(dir, snapshots)) is False:
sys.stderr.write("Session directory " + dir + " is invalid! Snapshot directory is not a valid directory!\n")
logging.error("Session directory %s is invalid! Snapshot directory is not a valid directory!", dir)
return False
if os.path.isdir(os.path.join(dir, "logs")) is False:
sys.stderr.write("Session directory " + dir + " is invalid! Log directory is not a valid directory!\n")
logging.error("Session directory %s is invalid! Log directory is not a valid directory!", dir)
if not os.path.exists(session_json):
sys.stderr.write("Session directory " + os.path.basename(os.path.normpath(dir)) + " is invalid!\n File 'sessionstate.json' does not exist!\n")
logging.error("Session directory %s is invalid. 'sessionstate.json' does not exist!", dir)
return False
with open(session_json,"r") as file:
dict = json.load(file)
for key in self.jsonKeys:
if key not in dict:
sys.stderr.write("Session directory "
+ os.path.basename(os.path.normpath(dir))+" is invalid!\n Key '"
+key +"' is missing in sessionstate!\n")
logging.error("Session at %s is invalid. Key %s in 'sessionstate.json' is missing.",dir,key)
return False
if "LastSnapshot" in dict:
if dict["LastSnapshot"]:
if not os.path.exists(os.path.join(dir,dict["LastSnapshot"])):
sys.stderr.write("Session directory "
+ os.path.basename(os.path.normpath(dir)) + " is invalid!\n Snapshot "
"was set but not found in directory\n")
logging.error("Session at %s is invalid. Snapshot was set but not found.", dir)
return False
logging.debug("Session directory %s is valid!", dir)
return True
......
Barista contributors (sorted alphabeticaly)
============================================
Project management:
* Aaron Scherzinger
* Dominik Drees
* Sören Klemm
Developers:
* Aleksej Matis
* Alexander Hunger
* Alexandre Stüben
......@@ -11,5 +19,20 @@ Barista contributors (sorted alphabeticaly)
* Michael Fujarski
* Robin Ortkemper
* Robin Rexeisen
* (Til Sander)[til.sander@wwu.de]
* Til Sander
* Tim Hartz
* Alexander Schulze-Schwering