Commit febc77b4 authored by Philipp Torben Jüres's avatar Philipp Torben Jüres

Merge branch 'develop' into feature/415_change-location-of-a-database

parents 004a4499 91c8cab5
......@@ -19,3 +19,4 @@
/barista.conf
/baristalog.txt
/.vscode/
/sessions
......@@ -16,7 +16,14 @@ Barista is an open-source graphical high-level interface for the Caffe deep lear
**Using Barista**
* To start Barista, run the file `main.py`
* A user manual as well as a starter's tutorial on Barista can be found in the wiki of the project repository: https://zivgitlab.uni-muenster.de/pria/Barista
* A user manual as well as a starter's tutorial on Barista can be found in the wiki of the project repository: https://zivgitlab.uni-muenster.de/pria/Barista/wikis/home
If you have any questions regarding Barista, please feel free to contact the developers at barista(at)uni-muenster.de
** Citing Barista **
If you use Barista in your research, please cite it using the following reference:
Klemm, S. and Scherzinger, A. and Drees, D. and Jiang, X. Barista - a Graphical Tool for Designing and Training Deep Neural Networks. arXiv preprint. arXiv:1802.04626. http://arxiv.org/abs/1802.04626
#! /usr/bin/env python
"""
This is a small self-contained application that demonstrates usage of networks deployed from Barista
in other applications. If net definition and weight's files obtained by training the caffe mnist example
are supplied via command line parameters, the user can draw digits [0-9] in the window, use the net
to classify the result and print the result to stdout.
- Drawing works by pressing the left mouse button and moving the mouse
- The "Return"-Button grabs the image and starts the classification
- The "Backspace"-Button clears the image.
"""
#
#
# Python package dependencies:
# pygame
# caffe
# numpy
import pygame
import caffe
import numpy as np
import argparse
# Argument parsing
parser = argparse.ArgumentParser(description='Interactively classify handwritten digits using neural nets.', epilog=__doc__, formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('-m', '--model', help='Model (.prototxt) file', type=str, required=True)
parser.add_argument('-w', '--weights', help='Weights (.caffemodel) file', type=str, required=True)
args = parser.parse_args()
#######################################################
# Setup caffe for classification
#######################################################
#load the model
net = caffe.Net(args.model, args.weights, caffe.TEST)
# load input and configure preprocessing
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2,0,1))
transformer.set_raw_scale('data', 1.0)
# Change the batch size to a single image
net.blobs['data'].reshape(1,1,28,28)
#######################################################
# Proprecessing helper functions
#######################################################
# Find the required offset (i.e. difference between center of weight and center of the image)
def find_offset(grid):
x_acc = 0;
y_acc = 0;
h = grid.shape[0]
w = grid.shape[1]
num_points = 0;
for y in np.arange(h):
for x in np.arange(w):
val = (grid[y,x] > 0)
x_acc += (x-w/2.0) * val
y_acc += (y-h/2.0) * val
if val:
num_points += 1
if num_points == 0:
return (0,0)
x_acc /= num_points
y_acc /= num_points
return (y_acc, x_acc)
# Shift and resample values in grid and thus centering the center of weight in the center of the image
def shift(grid):
offset = find_offset(grid)
h = grid.shape[0]
w = grid.shape[1]
image = np.zeros((h, w, 1))
for y in np.arange(h):
for x in np.arange(w):
x_n = int(np.round(x+offset[1]))
y_n = int(np.round(y+offset[0]))
if x_n < 0 or x_n >= w:
val = 0
elif y_n < 0 or y_n >= h:
val = 0
else:
val = grid[y_n,x_n]
image[y,x] = val
return image
# Classify a given image and output the index of the class with the highest probability according to the net and caffe
def classify(pixels):
image = np.zeros((pixels.shape[0], pixels.shape[1], 1))
image[:,:,0] = pixels[:,:]
image = np.transpose(image, (1,0,2))
image = shift(image)
data = np.asarray([transformer.preprocess('data', image)])
out = net.forward_all(data = data)
prob = out['probabilities']
cls = prob.argmax()
return cls
#######################################################
# Pygame application stuff
#######################################################
# Create screen of specified size
screen_size = (112,112)
screen = pygame.display.set_mode(screen_size)
# Global variables/constants used for drawing
currently_drawing = False
last_pos = (0, 0)
draw_color = (255, 255, 255)
clear_color = (0, 0, 0)
brush_radius = 3
# draw a line of circles from start to finish
def roundline(srf, color, start, end, radius):
dx = end[0]-start[0]
dy = end[1]-start[1]
distance = max(abs(dx), abs(dy))
for i in range(distance):
x = int( start[0]+float(i)/distance*dx)
y = int( start[1]+float(i)/distance*dy)
pygame.draw.circle(srf, color, (x, y), radius)
try:
while True:
e = pygame.event.wait()
if e.type == pygame.QUIT:
raise StopIteration
if e.type == pygame.MOUSEBUTTONDOWN:
pygame.draw.circle(screen, draw_color, e.pos, brush_radius)
currently_drawing = True
if e.type == pygame.MOUSEBUTTONUP:
currently_drawing = False
if e.type == pygame.MOUSEMOTION:
if currently_drawing:
pygame.draw.circle(screen, draw_color, e.pos, brush_radius)
roundline(screen, draw_color, e.pos, last_pos, brush_radius)
last_pos = e.pos
if e.type == pygame.KEYDOWN:
if e.key == pygame.K_RETURN:
array = pygame.PixelArray(screen)
cls = classify(array)
print("The net says says: {}".format(cls))
if e.key == pygame.K_BACKSPACE:
screen.fill(clear_color)
pygame.display.flip()
except StopIteration:
pass
pygame.quit()
......@@ -128,35 +128,49 @@ class MinimumTrainingRequirements:
def _checkUniqueBlobNames(self):
"""Checks for duplicate top blob names in all layers and emits error message if duplicates found except
if in-place is permitted (Activation Layer)"""
nonActivationTopBlobNames = []
activationTopBlobNames = []
if in-place is permitted"""
# find all 'real' sources of blobs, i.e. layers that produce a blob as output and are not in-place
blobGenerators = {} # dictionary that saves the sources for every blob name
if hasattr(self._stateData, '__getitem__'):
# inefficient double looping but activation layers can only be checked after non-activation layers
for layer_id, layer in self._stateData["network"]["layers"].iteritems():
if LayerHelper.isLayerIncludedInTrainingPhase(layer) and ("top" in layer["parameters"]):
if not layer["type"].isActivationLayer():
for name in layer["parameters"]["top"]:
if name not in nonActivationTopBlobNames:
nonActivationTopBlobNames.append(name)
else:
self._errorMessages.append((name + " is used more than once as connection name!",
"Duplicate connector name in " + layer["parameters"]["name"]))
for layer_id, layer in self._stateData["network"]["layers"].iteritems():
if LayerHelper.isLayerIncludedInTrainingPhase(layer) and ("top" in layer["parameters"]):
if layer["type"].isActivationLayer():
for name in layer["parameters"]["top"]:
if name not in activationTopBlobNames:
activationTopBlobNames.append(name)
if name in nonActivationTopBlobNames:
if not LayerHelper.isLayerInPlace(layer):
self._errorMessages.append(
(name + " is used more than once as connection name!",
"Duplicate connector name in " + layer["parameters"]["name"]))
else:
self._errorMessages.append((name + " is used more than once as connection name!",
"Duplicate connector name in " + layer["parameters"]["name"]))
parameters = layer.get("parameters", {})
tops = parameters.get("top", [])
bottoms = parameters.get("bottom", [])
# if at least one top blob is also a bottom, check if layer allows in-place
sourced = [blob for blob in tops if blob not in bottoms]
inPlace = [blob for blob in tops if blob in bottoms]
if len(inPlace) > 0 and not layer["type"].allowsInPlace():
for name in inPlace:
self._errorMessages.append((
"{} is reproduced by {}".format(name, parameters.get("name", "[NO NAME]")),
"{} does not support in-place operation".format(parameters.get("name",
"[NO NAME]"))
))
# check all blobs that are generated in one layer if they are generated only once in each phase.
for name in sourced:
phase = [] # this list can hold train, test and both
p = LayerHelper.getLayerPhase(layer)
if p == "":
phase = [LayerHelper.PHASE_TEST, LayerHelper.PHASE_TRAIN]
else:
phase.append(p)
# if a blob already exists, check if it was generated before in the same Phase
if name in blobGenerators:
found_match = False
for candidate in blobGenerators[name]:
intersection = set(phase).intersection(candidate[1])
if len(intersection) > 0:
found_match = True
self._errorMessages.append(("Sources are {} and {} in phase {}".format(
parameters.get("name", "[NO NAME]"),
candidate[0],
list(intersection)[0]),
"{} is generated by multiple layers".format(name)))
if not found_match:
blobGenerators[name].append((parameters.get("name", "[NO NAME]"), phase))
else:
blobGenerators[name] = [(parameters.get("name", "[NO NAME]"), phase)]
def _checkDataLayerExistence(self):
"""Check whether at least one data layer exists (during the training phase)."""
......
......@@ -496,82 +496,111 @@ class Project(QObject, LogCaller):
self.newSession.emit(sid)
return sid
def cloneSession(self, solverstate, old_session=None):
""" Return a new session cloned from the solverstate.
If old_session is None the solverstate is expected to be a valid path
to a solverstate.
If the old_session exists the solverstate is searched within the
sessions snapshot directory.
Copies the last snapshot to the new sessions snapshot directory.
def cloneRemoteSession(self, oldSolverstate, oldSession):
"""
if old_session:
snapshot_dir = old_session.getSnapshotDirectory()
solverstate = os.path.basename(solverstate)
old_lss = None
if os.path.isabs(snapshot_dir):
old_lss = os.path.join(snapshot_dir, solverstate)
Starts the cloning process for a remote session and creates the corr. local session upon success
oldSolverstate: solverstate produced by the snapshot from which the clone should be created
oldSession: session from which a clone should be created (type ClientSession)
"""
# validate the given session and solverstate
if oldSolverstate is None:
Log.error('Could not find solver',
self.getCallerId())
return None
if oldSession is None:
Log.error('Failed to create session!', self.getCallerId())
return None
sid = self.getNextSessionId()
# call the remote host to invoke cloning; @see cloneSession in server_session_manager.py
msg = {"key": Protocol.CLONESESSION, "pid": self.projectId,
"sid": sid, "old_uid": oldSession.uid, "old_solverstate": oldSolverstate}
ret = sendMsgToHost(oldSession.remote[0], oldSession.remote[1], msg)
# receive and validate answer
if ret:
if ret["status"]:
uid = ret["uid"]
else:
for e in ret["error"]:
Log.error(e, self.getCallerId())
return None
else:
Log.error('Failed to clone remote session! No connection to Host', self.getCallerId())
return None
# Create a corr. local session and copy (if available) the state-dictionary to maintain
# solver/net etc.
session = ClientSession(self, oldSession.remote, uid, sid)
if hasattr(oldSession, 'state_dictionary'):
session.state_dictionary = oldSession.state_dictionary
self.__sessions[sid] = session
self.newSession.emit(sid)
return sid
def cloneSession(self, oldSolverstate, oldSession):
"""
Creates a new session with the same net/solver etc. and a additional .caffemodel file with pretrained weights
oldSolverstate: solverstate produced by the snapshot from which the clone should be created
oldSession: session from which a clone should be created (type ClientSession)
"""
if type(oldSolverstate) is not str and type(oldSolverstate) is not unicode:
# if no valid solverstate is given, take the last model created for this session
oldCaffemodel = oldSession.getLastModel()
else:
snapshotDir = os.path.join(oldSession.getSnapshotDirectory(), oldSolverstate)
oldCaffemodel = loader.getCaffemodelFromSolverstate(snapshotDir)
if oldCaffemodel is None:
Log.error('Could not find model',
self.getCallerId())
return None
if oldSession:
oldSnapshotDir = oldSession.getSnapshotDirectory()
if os.path.isabs(oldSnapshotDir):
# locate caffemodel
oldModelPath = os.path.join(oldSnapshotDir, oldCaffemodel)
else:
old_sdir = old_session.getDirectory()
old_lss = os.path.join(old_sdir, snapshot_dir, solverstate)
oldSessionDir = oldSession.getDirectory()
oldModelPath = os.path.join(oldSessionDir, oldSnapshotDir, oldCaffemodel)
else:
old_lss = solverstate
Log.error('Failed to create session!', self.getCallerId())
return None
# create new session
sessionID = self.createSession()
self.__sessions[sessionID].setParserInitialized()
if sessionID is None:
Log.error('Failed to create session!', self.getCallerId())
return None
new_sdir = self.__sessions[sessionID].getSnapshotDirectory()
self.__sessions[sessionID].setParserInitialized()
# create directories for the new session
newSnapshotDir = self.__sessions[sessionID].getSnapshotDirectory()
self.__ensureDirectory(self.__sessions[sessionID].getDirectory())
self.__ensureDirectory(new_sdir)
if os.path.isdir(new_sdir):
# copy solverstate
new_lss = os.path.join(new_sdir, os.path.basename(solverstate))
try:
shutil.copy2(old_lss, new_lss)
self.__sessions[sessionID].setLastSnapshot(new_lss)
except Exception as e:
Log.error('Failed to copy solverstate to new session: '+str(e),
self.getCallerId())
# copy caffemodel
model_file = loader.getCaffemodelFromSolverstate(old_lss)
iterations = loader.getIterFromSolverstate(old_lss)
if model_file is None:
model_file = loader.getCaffemodelFromSolverstateHdf5(old_lss)
if model_file is None:
Log.error('Could not load model from solverstate '+old_lss,
self.getCallerId())
self.__sessions[sessionID].delete()
return None
model_dir = os.path.dirname(old_lss)
old_lcm = os.path.join(model_dir, model_file)
new_caffemodel = os.path.basename(model_file)
new_lcm = os.path.join(new_sdir, new_caffemodel)
self.__ensureDirectory(newSnapshotDir)
if os.path.isdir(newSnapshotDir):
newCaffemodel = 'pretrained.caffemodel'
newModelPath = os.path.join(newSnapshotDir, newCaffemodel)
try:
shutil.copy2(old_lcm, new_lcm)
self.__sessions[sessionID].setLastModel(new_lcm)
self.__sessions[sessionID].iteration = iterations
if old_session is not None:
self.__sessions[sessionID].max_iter = old_session.max_iter
self.__sessions[sessionID].caffe_root = old_session.caffe_root
# copy the old caffemodel to the new location
shutil.copy2(oldModelPath, newModelPath)
# initialize new session
self.__sessions[sessionID].setPretrainedWeights(newCaffemodel)
self.__sessions[sessionID].iteration = 0
self.__sessions[sessionID].max_iter = oldSession.max_iter
self.__sessions[sessionID].caffe_root = oldSession.caffe_root
except Exception as e:
Log.error('Failed to copy caffemodel to new session: '+str(e),
self.getCallerId())
# copy state dictionary
if old_session is not None:
#TODO: This will cause troubles if the current session was changed during training (See #270)
self.__sessions[sessionID].setStateDict(old_session.state_dictionary)
if self.__sessions[sessionID].iteration > 0:
self.__sessions[sessionID].setState(State.PAUSED)
else:
self.__sessions[sessionID].setState(State.WAITING)
# copy the old state-dict into the new session
self.__sessions[sessionID].setStateDict(oldSession.state_dictionary)
self.__sessions[sessionID].setState(State.WAITING)
self.__sessions[sessionID].save(includeProtoTxt=True)
return sessionID
else:
self.__sessions[sessionID].delete()
Log.error('Snapshot directory '+new_sdir+' does not exist!',
Log.error('Snapshot directory '+newSnapshotDir+' does not exist!',
self.getCallerId())
return None
......@@ -612,6 +641,13 @@ class Project(QObject, LogCaller):
self.__sessions.pop(sid)
del session
def closeSession(self, remoteSession):
""" Close a remote session and remove it from the available sessions"""
sid = remoteSession.getSessionId()
remoteSession.close()
self.__sessions.pop(sid)
del remoteSession
def isSession(self, directory):
""" Checks if the directory contains a valid session."""
if not os.path.exists(directory):
......
......@@ -211,8 +211,10 @@ class ClientSession(QObject):
return 0
def getMaxIteration(self):
# maybe let this be local + signal
if self._assertConnection():
if self.lastMaxIter > 0 \
and self.getState() in [State.RUNNING, State.FAILED, State.FINISHED, State.NOTCONNECTED]:
return self.lastMaxIter
elif self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.GETMAXITERATION}
self.transaction.send(msg)
ret = self.transaction.asyncRead(staging=True, attr=("subkey", SessionProtocol.GETMAXITERATION))
......@@ -220,7 +222,28 @@ class ClientSession(QObject):
self.lastMaxIter = ret["iteration"]
return self.lastMaxIter
self._handleErrors(["Failed to connect to remote session to acquire max iteration."])
return 0
return 1
def getPretrainedWeights(self):
if self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.GETPRETRAINED}
self.transaction.send(msg)
ret = self.transaction.asyncRead(staging=True, attr=("subkey", SessionProtocol.GETPRETRAINED))
if ret:
return ret.get("pretrained", None)
self._handleErrors(["Failed to connect to remote session to acquire pre-trained weights."])
return None
def setMaxIteration(self, maxIteration):
if maxIteration <= 0:
return
if self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.SETMAXITERATION, "iteration": maxIteration}
self.transaction.send(msg)
self.lastMaxIter = maxIteration
else:
self._handleErrors(["Failed to connect to remote session to update max iteration."])
self.setState(State.NOTCONNECTED, True)
def delete(self):
if self._assertConnection():
......@@ -234,6 +257,11 @@ class ClientSession(QObject):
self._handleErrors(["Failed to connect to remote session to acquire current state."])
self.setState(State.NOTCONNECTED, True)
def close(self):
if self._assertConnection():
msg = {"key": Protocol.DISCONNECTSESSION, "uid": self.uid}
self.transaction.send(msg)
def reset(self):
if self._assertConnection():
msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.RESET}
......@@ -298,6 +326,7 @@ class ClientSession(QObject):
self._handleErrors(["Could not save session."])
return False
def setStateDict(self, stateDict):
self.lastStateDict = stateDict
......@@ -393,10 +422,10 @@ class ClientSession(QObject):
return [(msg["error"], "")]
else:
self._handleErrors(["CheckTraining: Did not receive a reply from host."])
return [("CheckTraining: Did not receive a reply from host.", "No reply")]
return [("CheckTraining: Did not receive a reply from host.", "No reply from host")]
else:
self._handleErrors(["CheckTraining: Failed to connect to host to check session validity!"])
return [("CheckTraining: Failed to connect to host to check session validity!", "Connection failed")]
return [("CheckTraining: Failed to connect to host to check session validity!", "Failed to connect")]
def start(self, solverstate=None, caffemodel=None):
if self._assertConnection():
......
......@@ -52,7 +52,9 @@ class ServerSession(QObject, ParserListener, SessionCommon):
SessionProtocol.GETSNAPSHOTS: self._msgGetSnapshots,
SessionProtocol.GETITERATION: self._msgGetIteration,
SessionProtocol.GETMAXITERATION: self._msgGetMaxIteration,
SessionProtocol.SETMAXITERATION: self._msgSetMaxIteration,
SessionProtocol.GETSTATE: self._msgGetState,
SessionProtocol.GETPRETRAINED: self._msgGetPretrainedWeights,
SessionProtocol.SETSTATEDICT: self._msgSetStateDict,
SessionProtocol.GETSTATEDICT: self._msgGetStateDict,
SessionProtocol.SAVE: self._msgSave,
......@@ -90,6 +92,7 @@ class ServerSession(QObject, ParserListener, SessionCommon):
self.state = State.WAITING
self.invalidErrorsList = []
self.last_solverstate = None
self.pretrainedWeights = None
self.logs = os.path.join(directory, 'logs')
# run stuff
......@@ -334,18 +337,36 @@ class ServerSession(QObject, ParserListener, SessionCommon):
msg["iteration"] = self.getMaxIteration()
self.transaction.send(msg)
def _msgSetMaxIteration(self):
msg = self.transaction.asyncRead()
self.max_iter = msg["iteration"]
def _msgGetState(self):
msg = self.transaction.asyncRead()
msg["state"] = self._getState()
msg["status"] = True
self.transaction.send(msg)
def _msgGetPretrainedWeights(self):
msg = self.transaction.asyncRead()
msg["pretrained"] = self.getPretrainedWeights()
self.transaction.send(msg)
def getLastModel(self):
return self.last_caffemodel
def getIteration(self):
return self.iteration
def getMaxIteration(self):
return self.max_iter
def getPretrainedWeights(self):
return self.pretrainedWeights
def setPretrainedWeights(self, weights):
self.pretrainedWeights = weights
def _getState(self):
""" Return the state of the session.
"""
......@@ -370,7 +391,13 @@ class ServerSession(QObject, ParserListener, SessionCommon):
def _msgSetStateDict(self):
msg = self.transaction.asyncRead()
self.state_dictionary = msg["statedict"]
self.setStateDict(msg["statedict"])
del msg["statedict"]
msg["status"] = True
self.transaction.send(msg)
def setStateDict(self, statedict):
self.state_dictionary = statedict
self._parseSetting(self.state_dictionary)
# restore lost types
if hasattr(self.state_dictionary, '__getitem__'):
......@@ -383,10 +410,6 @@ class ServerSession(QObject, ParserListener, SessionCommon):
typename = layers[id]["parameters"]["type"]
layers[id]["type"] = info.CaffeMetaInformation().availableLayerTypes()[typename]
del msg["statedict"]
msg["status"] = True
self.transaction.send(msg)
def _msgGetStateDict(self):
msg = self.transaction.asyncRead()
......@@ -413,6 +436,8 @@ class ServerSession(QObject, ParserListener, SessionCommon):
toSave["ProjectID"] = self.pid
if self.last_solverstate:
toSave["LastSnapshot"] = self.last_solverstate
if self.pretrainedWeights:
toSave["PretrainedWeights"] = self.pretrainedWeights
if self.state_dictionary:
serializedDict = copy.deepcopy(self.state_dictionary)
if "network" in serializedDict:
......@@ -541,6 +566,8 @@ class ServerSession(QObject, ParserListener, SessionCommon):
if "LastSnapshot" in settings:
self.last_solverstate = settings["LastSnapshot"]
if "PretrainedWeights" in settings:
self.pretrainedWeights = settings["PretrainedWeights"]
if "NetworkState" in settings:
self.state_dictionary = settings["NetworkState"]
layers = self.state_dictionary["network"]["layers"]
......@@ -1058,10 +1085,11 @@ class ServerSession(QObject, ParserListener, SessionCommon):
self.logsig.emit('Failed to delete logs folder: ' + str(e), self.getCallerId(), True)
for filename in filenames:
if filename.endswith(".solverstate") or filename.endswith(".caffemodel"):
try:
os.remove(os.path.join(dirpath, filename))
except OSError as e:
self.logsig.emit('Failed to delete ' + str(filename) + ': ' + str(e), self.getCallerId(), True)
if not filename == self.getPretrainedWeights():
try:
os.remove(os.path.join(dirpath, filename))
except OSError as e:
self.logsig.emit('Failed to delete ' + str(filename) + ': ' + str(e), self.getCallerId(), True)
if filename in ["net-internal.prototxt", "net-original.prototxt", "solver.prototxt"]:
try:
os.remove(os.path.join(dirpath, filename))
......
......@@ -71,6 +71,7 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
self.parse_old = parse_old
self.caffe_root = caffe_root # overrides project caffe_root if necessary, i.e. if deployed to another system
self.pretrainedWeights = None
self.last_solverstate = last_solverstate
self.last_caffemodel = last_caffemodel
self.state_dictionary = state_dictionary # state as saved from the network manager, such it can be restored
......@@ -546,6 +547,12 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
def setLastModel(self, lcm):
self.last_caffemodel = lcm
def getPretrainedWeights(self):
return self.pretrainedWeights
def setPretrainedWeights(self, weights):
self.pretrainedWeights = weights
def setLastSnapshot(self, lss):
self.last_solverstate = lss
......@@ -687,6 +694,13 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
"""
return self.max_iter
def setMaxIteration(self, maxIteration):
""" Set the maximum training iteration of this session.
"""
if maxIteration > 0:
self.max_iter = maxIteration
self.stateChanged.emit(self.getState())
def setParserInitialized(self):
""" Should be called after the parser finished the inital parsing of
log files.
......@@ -783,6 +797,8 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
Log.log("Saving current Session status to disk.", self.getCallerId())
if self.last_solverstate:
toSave["LastSnapshot"] = self.last_solverstate
if self.getPretrainedWeights():
toSave["PretrainedWeights"] = self.getPretrainedWeights()
if self.state_dictionary:
serializedDict = copy.deepcopy(self.state_dictionary)
if includeProtoTxt:
......@@ -913,6 +929,9 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
if "LastSnapshot" in settings:
self.last_solverstate = settings["LastSnapshot"]
if "PretrainedWeights" in settings:
self.setPretrainedWeights(settings["PretrainedWeights"])
if "NetworkState" in settings:
self.state_dictionary = settings["NetworkState"]
layers = self.state_dictionary["network"]["layers"]
......@@ -920,12 +939,20 @@ class Session(QObject, LogCaller, ParserListener, SessionCommon):
if "parameters" in layers[id]:
if "type" in layers[id]["parameters"]:
typename = layers[id]["parameters"]["type"]