Commit a7596542 authored by Soeren Klemm's avatar Soeren Klemm

Merge remote-tracking branch 'origin/master' into masterzivgitlab

parents 6c45a6d7 1290f125
......@@ -33,20 +33,16 @@ class DeployedNet:
(LMDB or LEVELDB) or "HDF5Data". Otherwise, a warning will inform the user about necessary manual changes.
"""
def __init__(self, originalNetworkFilePath):
"""Create a deployment version of the given network."""
def __init__(self, netPrototxtContents):
"""Create a deployment version of the given network.
netPrototxtContents: string
The contents of the prototxt file to deploy.
"""
# create a logger id
self._logId = Log.getCallerId('deployment')
# save params
if os.path.isfile(originalNetworkFilePath):
with open(originalNetworkFilePath, 'r') as file:
networkString = file.read()
self._originalNetworkDictionary = loader.loadNet(networkString)
else:
raise ValueError("The file {} does not exist.".format(originalNetworkFilePath))
self._originalNetworkDictionary = loader.loadNet(netPrototxtContents)
# init further attributes
self._dataLayers = dict() # a dictionary containing all data layers. keys are the layer ids.
self._labelBlobNames = [] # a list containing all names of blobs, that represent any labels
......@@ -73,17 +69,10 @@ class DeployedNet:
self._insertInputLayers()
self._addSoftmax()
def saveProtoTxtFile(self, prototxtDestination, caffemodelSource, caffemodelDestination):
"""Save the content of self._deployedNetworkDictionary as a prototxt file to prototxtDestination and copy
caffemodelSource to the same folder."""
# save prototxt
with open(prototxtDestination, 'w') as file:
file.write(saver.saveNet(self._deployedNetworkDictionary))
# copy caffemodel file
copyfile(caffemodelSource, caffemodelDestination)
def getProtoTxt(self):
"""return the content of self._deployedNetworkDictionary as a prototxt string."""
Log.log("Deployment files have been saved successfully to {}.".format(prototxtDestination), self._logId)
return saver.saveNet(self._deployedNetworkDictionary)
def _searchDataLayers(self):
"""Search for all data layers in self._originalNetworkDictionary and store them in self._dataLayers.
......
......@@ -472,6 +472,9 @@ class ClientSession(QObject):
def getLogId(self):
return self.uid
def isRemote(self):
return True
def fetchParserData(self):
if self._assertConnection() and self.transaction is not None:
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.FETCHPARSERDATA}
......@@ -505,6 +508,30 @@ class ClientSession(QObject):
self._handleErrors(ret["error"])
self._handleErrors(["Failed to fetch parser data from host."])
def readInternalNetFile(self):
""" Returns the contents of the internal net prototxt file.
Wrapper around loadInternalNetFile, to provide a coherent interface
together with the Session class.
"""
return self.loadInternalNetFile()
def readDeployedNetAsString(self):
""" Returns the contents of the deployable net prototxt file.
Wrapper around loadDeployedNetFile, to provide a coherent interface
together with the Session class.
"""
return self.loadDeployedNetAsString()
def readCaffemodelFile(self, snapshot):
""" Returns the contents of the .caffemodel file that belongs to snapshot.
snapshot: string
Filename without path of the snapshot file.
"""
return self.loadCaffemodel(snapshot)
def loadInternalNetFile(self):
if self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.LOADINTERNALNET}
......@@ -518,6 +545,19 @@ class ClientSession(QObject):
self._handleErrors(["Failed to load InternalNet Prototxt"])
return ""
def loadDeployedNetAsString(self):
if self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.LOADDEPLOYEDNET}
self.transaction.send(msg)
ret = self.transaction.asyncRead(staging=True, attr=("subkey", SessionProtocol.LOADDEPLOYEDNET))
if ret:
if ret["status"]:
return ret["Net"]
else:
self._handleErrors(ret["error"])
self._handleErrors(["Failed to load Deployed Net"])
return ""
def loadNetParameter(self, snapshot):
if self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.LOADNETPARAMETER}
......@@ -540,3 +580,24 @@ class ClientSession(QObject):
self._handleErrors(ret["error"])
self._handleErrors(["Failed to load NetParameter for snapshot '" + snapshot + "'"])
return None
def getCaffemodelContents(self, snapshot):
""" Wrapper around the loadCaffemodel function to provide a coherent interface
between local and remote sessions. """
return self.loadCaffemodel(snapshot)
def loadCaffemodel(self, snapshot):
""" Loads the contents of the caffemodel file belonging to the given snapshot
from the server. """
if self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.LOADCAFFEMODEL}
msg["snapshot"] = snapshot
self.transaction.send(msg)
ret = self.transaction.asyncRead(staging=True, attr=("subkey", SessionProtocol.LOADCAFFEMODEL))
if ret:
if ret["status"]:
return ret["caffemodel"]
else:
self._handleErrors(ret["error"])
self._handleErrors(["Failed to load caffemodel contents for snapshot '" + snapshot + "'"])
return None
......@@ -27,6 +27,7 @@ from backend.networking.protocol import Protocol, SessionProtocol
from backend.parser.concatenator import Concatenator
from backend.parser.parser import Parser
from backend.parser.parser_listener import ParserListener
from backend.barista.deployed_net import DeployedNet
from PyQt5.QtCore import QTimer
from threading import Lock
......@@ -61,7 +62,9 @@ class ServerSession(QObject, ParserListener, SessionCommon):
SessionProtocol.TAKESNAPSHOT: self._takeSnapshot,
SessionProtocol.FETCHPARSERDATA: self._msgFetchParserData,
SessionProtocol.LOADINTERNALNET: self._msgLoadInternalNet,
SessionProtocol.LOADDEPLOYEDNET: self._msgLoadDeployedNet,
SessionProtocol.LOADNETPARAMETER: self._msgLoadNetParameter,
SessionProtocol.LOADCAFFEMODEL: self._msgLoadCaffemodel,
SessionProtocol.RESET: self._reset,
SessionProtocol.DELETE: self.delete}
......@@ -990,6 +993,18 @@ class ServerSession(QObject, ParserListener, SessionCommon):
msg["error"] = ["Failed to read " + path]
self.transaction.send(msg)
def _msgLoadDeployedNet(self):
msg = self.transaction.asyncRead()
try:
path = os.path.join(self.directory, Paths.FILE_NAME_NET_ORIGINAL)
file = open(path, 'r')
dn = DeployedNet(file.read())
msg["Net"] = dn.getProtoTxt()
except IOError:
msg["status"] = False
msg["error"] = ["Failed to read " + path]
self.transaction.send(msg)
def _msgLoadNetParameter(self):
msg = self.transaction.asyncRead()
msg["status"] = False
......@@ -1010,6 +1025,28 @@ class ServerSession(QObject, ParserListener, SessionCommon):
msg["error"] = ["No Snapshot provided"]
self.transaction.send(msg)
def _msgLoadCaffemodel(self):
msg = self.transaction.asyncRead()
msg["status"] = False
print("Load caffe model received.")
if "snapshot" in msg.keys():
path = msg["snapshot"]
print(path)
path = os.path.join(self.directory, path)
if os.path.exists(path):
try:
with open(path, 'rb') as f:
msg["caffemodel"] = f.read()
msg["status"] = True
except:
msg["error"] = ["Failed to load " + str(path)]
else:
msg["error"] = ["File not found " + str(path)]
else:
msg["error"] = ["No Snapshot provided"]
self.transaction.send(msg)
def reset(self):
self.pause()
for dirpath, dirnames, filenames in os.walk(self.directory, topdown=True):
......
......@@ -21,6 +21,7 @@ from backend.barista.session.session_pool import SessionPool
from backend.barista.session.session_utils import *
from backend.barista.utils.logger import Log
from backend.barista.utils.logger import LogCaller
from backend.barista.deployed_net import DeployedNet
from backend.barista.session.session_common import SessionCommon
......@@ -508,6 +509,9 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
os.makedirs(self.logs)
return self.logs
def isRemote(self):
return False
def getSnapshotDirectory(self):
""" Return the snapshot directory.
"""
......@@ -572,6 +576,30 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
Log.log("This sessions net file: " + self.__netInternalFile + " does not exist.", self.caller_id)
return self.__netInternalFile
def readInternalNetFile(self):
""" Returns the contents of the internal net prototxt file. """
path = self.getInternalNetFile()
with open(path, 'r') as f:
contents = f.read()
return contents
def readDeployedNetAsString(self):
""" Returns the contents of the deployable net prototxt file. """
path = self.getInternalNetFile()
dn = DeployedNet(open(path).read())
return dn.getProtoTxt()
def readCaffemodelFile(self, snapshot):
""" Returns the contents of the .caffemodel file that belongs to snapshot.
snapshot: string
Filename without path of the snapshot file.
"""
path = os.path.join(self.getSnapshotDirectory(), snapshot)
with open(path, 'r') as f:
contents = f.read()
return contents
def getRunLogFileName(self, basename=False):
""" Return the name of the logfile with session and run id.
"""
......
......@@ -59,7 +59,9 @@ class SessionProtocol:
TAKESNAPSHOT = 43
LOADINTERNALNET = 50
LOADNETPARAMETER = 51
LOADDEPLOYEDNET = 51
LOADNETPARAMETER = 52
LOADCAFFEMODEL = 53
RESET = 98
DELETE = 99
This diff is collapsed.
......@@ -40,9 +40,9 @@ class HostManager(QWidget):
self._buttonlayout.addWidget(self._pb_add)
# delete button
self._pb_del = QPushButton("Delete selected Host")
self._pb_del.setEnabled(False)
self._buttonlayout.addWidget(self._pb_del)
self._pb_remove = QPushButton("Remove selected Host")
self._pb_remove.setEnabled(False)
self._buttonlayout.addWidget(self._pb_remove)
# listwidget with all dbs
self._hostscroll = QListWidget()
......@@ -54,7 +54,7 @@ class HostManager(QWidget):
self.resize(800, 600)
self._pb_add.clicked.connect(self._addNewHost)
self._pb_del.clicked.connect(self.onDelete)
self._pb_remove.clicked.connect(self.onDelete)
self._hostscroll.itemSelectionChanged.connect(self._selectionChange)
self._loadFromSettings()
......@@ -103,9 +103,9 @@ class HostManager(QWidget):
def _selectionChange(self):
"""disable and enable the delete button on selection"""
if len(self._hostscroll.selectedItems()) <= 0:
self._pb_del.setEnabled(False)
self._pb_remove.setEnabled(False)
else:
self._pb_del.setEnabled(True)
self._pb_remove.setEnabled(True)
def getActiveHostList(self):
ret = []
......@@ -131,12 +131,15 @@ class HostManager(QWidget):
return ret["path"]
def onDelete(self):
"""delete selected hosts"""
"""remove selected hosts"""
ic = len(self._hostscroll.selectedItems())
if ic > 0:
gramNum = ""
if ic > 1:
gramNum = "s"
# confirm delete
ret = QMessageBox.question(self, "Delete selected Host",
"Do you really want to delete " + str(ic) + " selected host(s)?",
ret = QMessageBox.question(self, "Remove {0} selected Host{1}".format(str(ic), gramNum),
"Do you really want to delete the selected host{}?".format(gramNum),
QMessageBox.Ok, QMessageBox.No)
if ret == QMessageBox.Ok:
for item in self._hostscroll.selectedItems():
......
This diff is collapsed.
......@@ -43,7 +43,7 @@ class LayerProperties(DockElement):
""" Set the layers whose properties should be shown by its id.
ids is a list of string.
"""
self.tab.clear()
self.clearProperties()
for id in ids:
self.addTab(id)
......
......@@ -46,14 +46,15 @@ class NodeItem(QGraphicsItem):
def __updateConnectorNames(self):
""" Sets the internal connector names to correspond to the blob names in the underling data structure """
for index in range(0, len(self.__topConnectors)):
blobName = self.__layerData["parameters"]["top"][index]
self.__topConnectors[index].setBlobName(blobName)
for index in range(0, len(self.__bottomConnectors)):
blobName = self.__layerData["parameters"]["bottom"][index]
self.__bottomConnectors[index].setBlobName(blobName)
if "top" in self.__layerData["parameters"]:
for index in range(0, len(self.__topConnectors)):
blobName = self.__layerData["parameters"]["top"][index]
self.__topConnectors[index].setBlobName(blobName)
if "bottom" in self.__layerData["parameters"]:
for index in range(0, len(self.__bottomConnectors)):
blobName = self.__layerData["parameters"]["bottom"][index]
self.__bottomConnectors[index].setBlobName(blobName)
def addTopConnector(self, blobName):
""" Adds a top connector with the given name to the node item """
......
......@@ -536,10 +536,11 @@ class NetworkManager(QObject):
# TODO: Instead update in onStateUpdate
#self.nodeEditor.addBottomBlob(layerID, "")
# update property dock
self.dockLayerProperties.updateDock()
self.historyWriter.makeEntry(intern)
# update property dock
#self.dockLayerProperties.updateDock()
......@@ -556,7 +557,7 @@ class NetworkManager(QObject):
self.updateOrder()
# update property dock
# self.dockLayerProperties.updateDock()
self.dockLayerProperties.updateDock()
self.historyWriter.makeEntry(intern)
......@@ -575,7 +576,7 @@ class NetworkManager(QObject):
#self.nodeEditor.addTopBlob(layerID, blobName)
# update property dock
#self.dockLayerProperties.updateDock()
self.dockLayerProperties.updateDock()
self.historyWriter.makeEntry(intern)
......@@ -593,7 +594,7 @@ class NetworkManager(QObject):
#self.nodeEditor.renameTopBlob(layerID, blobIndex, newName)
# update property dock
#self.dockLayerProperties.updateDock()
self.dockLayerProperties.updateDock()
self.historyWriter.makeEntry(intern)
......@@ -607,8 +608,9 @@ class NetworkManager(QObject):
# TODO: Instead update in onStateUpdate
#self.nodeEditor.renameBottomBlob(layerID, blobIndex, newName)
# update property dock
#self.dockLayerProperties.updateDock()
self.dockLayerProperties.updateDock()
self.historyWriter.makeEntry(intern)
......@@ -626,7 +628,7 @@ class NetworkManager(QObject):
self.updateOrder()
# update property dock
#self.dockLayerProperties.updateDock()
self.dockLayerProperties.updateDock()
self.historyWriter.makeEntry(intern)
......@@ -647,7 +649,7 @@ class NetworkManager(QObject):
self.updateOrder()
# update property dock
#self.dockLayerProperties.updateDock()
self.dockLayerProperties.updateDock()
self.historyWriter.makeEntry(intern)
......@@ -663,7 +665,7 @@ class NetworkManager(QObject):
#self.updateNet()
self.updateOrder()
# update property dock
#self.dockLayerProperties.updateDock()
self.dockLayerProperties.updateDock()
self.historyWriter.makeEntry(intern)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment