diff --git a/Data_process2.py b/Data_process2.py new file mode 100644 index 0000000..9ef932d --- /dev/null +++ b/Data_process2.py @@ -0,0 +1,135 @@ +import scipy +import gzip,cPickle +import correlation +import os,pdb,glob +import theano +import skimage +import sklearn +import PIL.Image +import pylab + +import scipy.ndimage as ndi +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.cm as cm +import matplotlib.font_manager as font_manager + +from sklearn import preprocessing, cross_validation +from sklearn.feature_extraction import image +from sklearn.decomposition import PCA +from scipy import io as sp_io +from numpy.random import RandomState +from theano import tensor as T +from scipy import misc +from PIL import Image, ImageEnhance +from pylab import * +from theano.tensor.nnet import conv +from scipy.misc import lena +from itertools import product + +################################################# +# Generating fully overlapping patches # +################################################# + +def overlapping_patches(path, patch_size): + + Ols_images = Image.open (path).convert('L') + height, width = np.asarray(Ols_images).shape + + if height < 512: + Ols_images = Ols_images.resize((512, 512), Image.BICUBIC) + print '... image resized to 512,512' + elif height > 512: + Ols_images = Ols_images.resize((512, 512), Image.ANTIALIAS) + print '... image resized to 512,512' + + Ols_images = np.asarray(Ols_images,dtype = 'float') + Ols_images = correlation.normalizeArray(np.asarray(Ols_images)) + + image_height = np.asarray(Ols_images).shape[0] + image_width = np.asarray(Ols_images).shape[1] + + Ols_patche = image.extract_patches_2d(image=Ols_images, patch_size=patch_size,max_patches=None) + Ols_patches = np.reshape(Ols_patche,(Ols_patche.shape[0], -1)) + n_patches, nvis = Ols_patches.shape + rval = (Ols_patches, image_height, image_width) + return rval + +################################################# +# Generating overlapping patches with strides # +################################################# + +def overlapping_patches_strides(path, patch_size, strides): + + Ols_images = Image.open (path).convert('L') + height, width = np.asarray(Ols_images).shape + + '''if height < 512: + Ols_images = Ols_images.resize((512, 512), Image.BICUBIC) + print '... image resized to 512,512' + elif height > 512: + Ols_images = Ols_images.resize((512, 512), Image.ANTIALIAS) + print '... image resized to 512,512' + ''' + + #nrow = 512 + #ncol = 512 #767 #768 for 28x281 767 for 17x17 + + # ROC Pic + nrow = height*1 + ncol = width*1 + + #print ' ... Initial image dimensions: ', nrow, ncol + + Up = (nrow-patch_size[0]-strides[0])/strides[0] + Vp = (ncol-patch_size[1]-strides[1])/strides[1] + + #print ' ... Initial patches: ', '%.2f'%(Up), '%.2f'%(Vp) + + Up = np.floor((nrow-patch_size[0]-strides[0])/strides[0]) + Vp = np.floor((ncol-patch_size[1]-strides[1])/strides[1]) + + #print ' ... Generated patches: ', '%.2f'%(Up), '%.2f'%(Vp) + + nrow = np.int(Up*strides[0] + strides[0] + patch_size[0]) + ncol = np.int(Vp*strides[1] + strides[1] + patch_size[1]) + + #print ' ... Resized image dimensions: ', nrow, ncol + + Ols_images = Ols_images.resize((ncol, nrow), Image.BICUBIC) + + Ols_images = np.asarray(Ols_images,dtype = 'float') + Ols_images = correlation.normalizeArray(np.asarray(Ols_images)) + U = (nrow-patch_size[0]-strides[0])/strides[0] + V = (ncol-patch_size[1]-strides[1])/strides[1] + + image_height = np.asarray(Ols_images).shape[0] + image_width = np.asarray(Ols_images).shape[1] + + Ols_patche = image.extract_patches(Ols_images, patch_shape=patch_size, extraction_step=strides) + Ols_patches = np.reshape(Ols_patche,(Ols_patche.shape[0]*Ols_patche.shape[1], -1)) + + n_patches, nvis = Ols_patches.shape + rval = (Ols_patches, image_height, image_width) + return rval + +################################################# +# Reconstructing pathes with strides # +################################################# + +def reconstruct_from_patches_with_strides_2d(patches, image_size, strides): + + i_stride = strides[0] + j_stride = strides[1] + i_h, i_w = image_size[:2] + p_h, p_w = patches.shape[1:3] + img = np.zeros(image_size) + img1 = np.zeros(image_size) + n_h = int((i_h - p_h + i_stride)/i_stride) + n_w = int((i_w - p_w + j_stride)/j_stride) + for p, (i, j) in zip(patches, product(range(n_h), range(n_w))): + img[i*i_stride:i*i_stride + p_h, j*j_stride:j*j_stride + p_w] +=p + img1[i*i_stride:i*i_stride + p_h, j*j_stride:j*j_stride + p_w] +=np.ones(p.shape) + return img/img1 + + diff --git a/README.md b/README.md index c637a8d..6b4f3cf 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,53 @@ -# LLNet -A low light image enhancement with deep learning +# LLNet: Low-light Image Enhancement with Deep Learning # + +This repository includes the codes and modules used for running LLNet via a Graphical User Interface. Users can choose to train the network from scratch, or to enhance multiple images using a specific trained model. + +NOTE: Trained model 17x17 is included in this repo. + +## How do I run the program? ## + +Open the terminal and navigate to this directory. Type: + +``` +#!bash +python llnet.py +``` + +to launch the program with GUI. For command-line only interface, you type the following command in the terminal. + +To train a new model, enter: + +``` +#!bash +python llnet.py train [TRAINING_DATA] +``` + +To enhance an image, enter: + +``` +#!bash +python llnet.py test [IMAGE_FILENAME] [MODEL_FILENAME] +``` + +For example, you may type: + +``` +#!bash +python llnet.py train datafolder/yourdataset.mat +python llnet.py test somefolder/darkpicture.png models/save_model.obj +``` + +where file names do not need to be in quotes. + +Datasets need to be saved as .MAT file with the '-v7.3' tag in MATLAB. The saved variables are: + +``` +train_set_x (N x wh) Noisy, darkened training data +train_set_y (N x wh) Clean, bright training data +valid_set_x (N x wh) Noisy, darkened validation data +valid_set_y (N x wh) Clean, bright validation data +test_set_x (N x wh) Noisy, darkened test data +test_set_y (N x wh) Clean, bright test data +``` + +Where N is the number of examples and w, h are the width and height of the patches, respectively. Test data are mostly used to plot the test patches; in actual applications we are interested to enhance a single image. Use the test command instead. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..c9990ce --- /dev/null +++ b/__init__.py @@ -0,0 +1,3 @@ +from easygui import * +pass +__all__ = easygui.__all__ \ No newline at end of file diff --git a/correlation.py b/correlation.py new file mode 100644 index 0000000..dccec30 --- /dev/null +++ b/correlation.py @@ -0,0 +1,154 @@ +""" +correlation.py +Compute the correlation between two, single-channel, grayscale input images. +The second image must be smaller than the first. + +Author: Brad Montgomery + http://bradmontgomery.net + +This code has been placed in the public domain by the author. + +USAGE: python correlation + +""" +#import Image +import numpy +import math +import sys +import timeit + +def normalizeArray(a): + """ + Normalize the given array to values between 0 and 1. + Return a numpy array of floats (of the same shape as given) + """ + w,h = a.shape + minval = a.min() + if minval < 0: # shift to positive... + a = a + abs(minval) + maxval = a.max() # THEN, get max value! + new_a = numpy.zeros(a.shape, 'd') + for x in range(0,w): + for y in range(0,h): + new_a[x,y] = float(a[x,y])/maxval + return new_a + +def pil2array(im): + """ Convert a 1-channel grayscale PIL image to a numpy ndarray """ + data = list(im.getdata()) + w,h = im.size + A = numpy.zeros((w*h), 'd') + i=0 + for val in data: + A[i] = val + i=i+1 + A=A.reshape(w,h) + return A + +def array2pil(A,mode='L'): + """ + Convert a numpy ndarray to a PIL image. + Only grayscale images (PIL mode 'L') are supported. + """ + w,h = A.shape + # make sure the array only contains values from 0-255 + # if not... fix them. + if A.max() > 255 or A.min() < 0: + A = normalizeArray(A) # normalize between 0-1 + A = A * 255 # shift values to range 0-255 + if A.min() >= 0.0 and A.max() <= 1.0: # values are already between 0-1 + A = A * 255 # shift values to range 0-255 + A = A.flatten() + data = [] + for val in A: + if val is numpy.nan: val = 0 + data.append(int(val)) # make sure they're all int's + im = Image.new(mode, (w,h)) + im.putdata(data) + return im + +def correlation(input, match): + """ + Calculate the correlation coefficients between the given pixel arrays. + + input - an input (numpy) matrix representing an image + match - the (numpy) matrix representing the image for which we are looking + + """ + t = timeit.Timer() + assert match.shape < input.shape, "Match Template must be Smaller than the input" + c = numpy.zeros(input.shape) # store the coefficients... + mfmean = match.mean() + iw, ih = input.shape # get input image width and height + mw, mh = match.shape # get match image width and height + + print "Computing Correleation Coefficients..." + start_time = t.timer() + + for i in range(0, iw): + for j in range(0, ih): + + # find the left, right, top + # and bottom of the sub-image + if i-mw/2 <= 0: + left = 0 + elif iw - i < mw: + left = iw - mw + else: + left = i + + right = left + mw + + if j - mh/2 <= 0: + top = 0 + elif ih - j < mh: + top = ih - mh + else: + top = j + + bottom = top + mh + + # take a slice of the input image as a sub image + sub = input[left:right, top:bottom] + assert sub.shape == match.shape, "SubImages must be same size!" + localmean = sub.mean() + temp = (sub - localmean) * (match - mfmean) + s1 = temp.sum() + temp = (sub - localmean) * (sub - localmean) + s2 = temp.sum() + temp = (match - mfmean) * (match - mfmean) + s3 = temp.sum() + denom = s2*s3 + if denom == 0: + temp = 0 + else: + temp = s1 / math.sqrt(denom) + + c[i,j] = temp + + end_time = t.timer() + print "=> Correlation computed in: ", end_time - start_time + print '\tMax: ', c.max() + print '\tMin: ', c.min() + print '\tMean: ', c.mean() + return c + +def main(f1, f2, output_file="CORRELATION.jpg"): + """ open the image files, and compute their correlation """ + im1 = f1.convert('L') + im2 = f2.convert('L') + # Better way to do PIL-Numpy conversion + f = numpy.asarray(im1) # was f = pil2array(im1) + w = numpy.asarray(im2) # was w = pil2array(im2) + corr = correlation(f,w) # was c = array2pil(correlation(f,w)) + c = Image.fromarray(numpy.uint8(normalizeArray(corr) * 255)) + + print "Saving as: %s" % output_file + c.save(output_file) + +if __name__ == "__main__": + if len(sys.argv) == 3: + main(sys.argv[1], sys.argv[2]) + else: + print 'USAGE: python correlation ' + diff --git a/dA.py b/dA.py new file mode 100644 index 0000000..f1d62e8 --- /dev/null +++ b/dA.py @@ -0,0 +1,435 @@ +""" + This tutorial introduces denoising auto-encoders (dA) using Theano. + + Denoising autoencoders are the building blocks for SdA. + They are based on auto-encoders as the ones used in Bengio et al. 2007. + An autoencoder takes an input x and first maps it to a hidden representation + y = f_{\theta}(x) = s(Wx+b), parameterized by \theta={W,b}. The resulting + latent representation y is then mapped back to a "reconstructed" vector + z \in [0,1]^d in input space z = g_{\theta'}(y) = s(W'y + b'). The weight + matrix W' can optionally be constrained such that W' = W^T, in which case + the autoencoder is said to have tied weights. The network is trained such + that to minimize the reconstruction error (the error between x and z). + + For the denosing autoencoder, during training, first x is corrupted into + \tilde{x}, where \tilde{x} is a partially destroyed version of x by means + of a stochastic mapping. Afterwards y is computed as before (using + \tilde{x}), y = s(W\tilde{x} + b) and z as s(W'y + b'). The reconstruction + error is now measured between z and the uncorrupted input x, which is + computed as the cross-entropy : + - \sum_{k=1}^d[ x_k \log z_k + (1-x_k) \log( 1-z_k)] + + + References : + - P. Vincent, H. Larochelle, Y. Bengio, P.A. Manzagol: Extracting and + Composing Robust Features with Denoising Autoencoders, ICML'08, 1096-1103, + 2008 + - Y. Bengio, P. Lamblin, D. Popovici, H. Larochelle: Greedy Layer-Wise + Training of Deep Networks, Advances in Neural Information Processing + Systems 19, 2007 + +""" + +import os +import sys +import time +import numpy + +import theano +import theano.tensor as T +import theano.sandbox.linalg as TL +from theano.tensor.shared_randomstreams import RandomStreams + +from logistic_sgd import load_data +from utils import tile_raster_images + +try: + import PIL.Image as Image +except ImportError: + import Image + + +# start-snippet-1 +class dA(object): + """Denoising Auto-Encoder class (dA) + + A denoising autoencoders tries to reconstruct the input from a corrupted + version of it by projecting it first in a latent space and reprojecting + it afterwards back in the input space. Please refer to Vincent et al.,2008 + for more details. If x is the input then equation (1) computes a partially + destroyed version of x by means of a stochastic mapping q_D. Equation (2) + computes the projection of the input into the latent space. Equation (3) + computes the reconstruction of the input, while equation (4) computes the + reconstruction error. + + .. math:: + + \tilde{x} ~ q_D(\tilde{x}|x) (1) + + y = s(W \tilde{x} + b) (2) + + x = s(W' y + b') (3) + + L(x,z) = -sum_{k=1}^d [x_k \log z_k + (1-x_k) \log( 1-z_k)] (4) + + """ + + def __init__( + self, + numpy_rng, + theano_rng=None, + input=None, + n_visible=784, + n_hidden=500, + W=None, + bhid=None, + bvis=None + ): + """ + Initialize the dA class by specifying the number of visible units (the + dimension d of the input ), the number of hidden units ( the dimension + d' of the latent or hidden space ) and the corruption level. The + constructor also receives symbolic variables for the input, weights and + bias. Such a symbolic variables are useful when, for example the input + is the result of some computations, or when weights are shared between + the dA and an MLP layer. When dealing with SdAs this always happens, + the dA on layer 2 gets as input the output of the dA on layer 1, + and the weights of the dA are used in the second stage of training + to construct an MLP. + + :type numpy_rng: numpy.random.RandomState + :param numpy_rng: number random generator used to generate weights + + :type theano_rng: theano.tensor.shared_randomstreams.RandomStreams + :param theano_rng: Theano random generator; if None is given one is + generated based on a seed drawn from `rng` + + :type input: theano.tensor.TensorType + :param input: a symbolic description of the input or None for + standalone dA + + :type n_visible: int + :param n_visible: number of visible units + + :type n_hidden: int + :param n_hidden: number of hidden units + + :type W: theano.tensor.TensorType + :param W: Theano variable pointing to a set of weights that should be + shared belong the dA and another architecture; if dA should + be standalone set this to None + + :type bhid: theano.tensor.TensorType + :param bhid: Theano variable pointing to a set of biases values (for + hidden units) that should be shared belong dA and another + architecture; if dA should be standalone set this to None + + :type bvis: theano.tensor.TensorType + :param bvis: Theano variable pointing to a set of biases values (for + visible units) that should be shared belong dA and another + architecture; if dA should be standalone set this to None + + + """ + self.n_visible = n_visible + self.n_hidden = n_hidden + + # create a Theano random generator that gives symbolic random values + if not theano_rng: + theano_rng = RandomStreams(numpy_rng.randint(2 ** 30)) + + # note : W' was written as `W_prime` and b' as `b_prime` + if not W: + # W is initialized with `initial_W` which is uniformely sampled + # from -4*sqrt(6./(n_visible+n_hidden)) and + # 4*sqrt(6./(n_hidden+n_visible))the output of uniform if + # converted using asarray to dtype + # theano.config.floatX so that the code is runable on GPU + initial_W = numpy.asarray( + numpy_rng.uniform( + low=-4 * numpy.sqrt(6. / (n_hidden + n_visible)), + high=4 * numpy.sqrt(6. / (n_hidden + n_visible)), + size=(n_visible, n_hidden) + ), + dtype=theano.config.floatX + ) + W = theano.shared(value=initial_W, name='W', borrow=True) + + if not bvis: + bvis = theano.shared( + value=numpy.zeros( + n_visible, + dtype=theano.config.floatX + ), + borrow=True + ) + + if not bhid: + bhid = theano.shared( + value=numpy.zeros( + n_hidden, + dtype=theano.config.floatX + ), + name='b', + borrow=True + ) + + self.W = W + # b corresponds to the bias of the hidden + self.b = bhid + # b_prime corresponds to the bias of the visible + self.b_prime = bvis + # tied weights, therefore W_prime is W transpose + self.W_prime = self.W.T + self.theano_rng = theano_rng + # if no input is given, generate a variable representing the input + if input is None: + # we use a matrix because we expect a minibatch of several + # examples, each example being a row + self.x = T.dmatrix(name='input') + else: + self.x = input + + self.params = [self.W, self.b, self.b_prime] + # end-snippet-1 + + def get_corrupted_input(self, input, corruption_level): + """This function keeps ``1-corruption_level`` entries of the inputs the + same and zero-out randomly selected subset of size ``coruption_level`` + Note : first argument of theano.rng.binomial is the shape(size) of + random numbers that it should produce + second argument is the number of trials + third argument is the probability of success of any trial + + this will produce an array of 0s and 1s where 1 has a + probability of 1 - ``corruption_level`` and 0 with + ``corruption_level`` + + The binomial function return int64 data type by + default. int64 multiplicated by the input + type(floatX) always return float64. To keep all data + in floatX when floatX is float32, we set the dtype of + the binomial to floatX. As in our case the value of + the binomial is always 0 or 1, this don't change the + result. This is needed to allow the gpu to work + correctly as it only support float32 for now. + + """ + # Corruption Noise + #noise = self.theano_rng.binomial(size=input.shape, n=1, + # p=1 - corruption_level, + # dtype=theano.config.floatX) * input + + #noise = self.theano_rng.normal(size=input.shape)*(corruption_level) + input + noise = self.theano_rng.normal(avg=0.0,std=corruption_level,size=input.shape) + input + + return T.clip(noise,0,1) + + def get_hidden_values(self, input): + """ Computes the values of the hidden layer """ + return T.nnet.sigmoid(T.dot(input, self.W) + self.b) + + def get_reconstructed_input(self, hidden): + """Computes the reconstructed input given the values of the + hidden layer + + """ + return T.nnet.sigmoid(T.dot(hidden, self.W_prime) + self.b_prime) + + def get_cost_updates(self, corruption_level, learning_rate): + """ This function computes the cost and the updates for one trainng + step of the dA """ + + tilde_x = self.get_corrupted_input(self.x, corruption_level) + y = self.get_hidden_values(tilde_x) + z = self.get_reconstructed_input(y) + + # note : we sum over the size of a datapoint; if we are using + # minibatches, L will be a vector, with one entry per + # example in minibatch + #L = - T.sum(self.x * T.log(z) + (1 - self.x) * T.log(1 - z), axis=1) + + # Implementation of cost function from the paper + lambda_reg = 0.00001 + beta = 0.01 + rho = 0.2 + + # ---- Error Term + l2norm = T.sqrt(((self.x-z)**2).sum(axis=0,keepdims=False))**2 + errorterm = T.mean(l2norm) + + # ---- KL Divergence Term + rho_j = T.mean(y,axis=0,keepdims=False) #Mean activation of hidden units based on hidden layer, results in 1 x HU matrix/vector + kl = rho*T.log(rho/rho_j) + (1-rho)*T.log((1-rho)/(1-rho_j)) + kl = T.sum(kl) + #T.sum((rho_expression),keepdims=False) + + # ---- Regularization Term + regterm = (T.sqrt((self.W ** 2).sum())**2) + (T.sqrt((self.W_prime ** 2).sum())**2) + + # ---- Final Loss Function + cost = errorterm + beta*kl + lambda_reg/2 * regterm + + # compute the gradients of the cost of the `dA` with respect + # to its parameters + gparams = T.grad(cost, self.params) + # generate the list of updates + updates = [ + (param, param - learning_rate * gparam) + for param, gparam in zip(self.params, gparams) + ] + + return (cost, updates) + + +def test_dA(learning_rate=0.1, training_epochs=15, + dataset='mnist.pkl.gz', + batch_size=20, output_folder='dA_plots'): + + """ + This demo is tested on MNIST + + :type learning_rate: float + :param learning_rate: learning rate used for training the DeNosing + AutoEncoder + + :type training_epochs: int + :param training_epochs: number of epochs used for training + + :type dataset: string + :param dataset: path to the picked dataset + + """ + datasets = load_data(dataset) + train_set_x, train_set_y = datasets[0] + + # compute number of minibatches for training, validation and testing + n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size + + # allocate symbolic variables for the data + index = T.lscalar() # index to a [mini]batch + x = T.matrix('x') # the data is presented as rasterized images + + if not os.path.isdir(output_folder): + os.makedirs(output_folder) + os.chdir(output_folder) + #################################### + # BUILDING THE MODEL NO CORRUPTION # + #################################### + + rng = numpy.random.RandomState(123) + theano_rng = RandomStreams(rng.randint(2 ** 30)) + + da = dA( + numpy_rng=rng, + theano_rng=theano_rng, + input=x, + n_visible=28 * 28, + n_hidden=500 + ) + + cost, updates = da.get_cost_updates( + corruption_level=0., + learning_rate=learning_rate + ) + + train_da = theano.function( + [index], + cost, + updates=updates, + givens={ + x: train_set_x[index * batch_size: (index + 1) * batch_size] + } + ) + + start_time = time.clock() + + ############ + # TRAINING # + ############ + + # go through training epochs + for epoch in xrange(training_epochs): + # go through trainng set + c = [] + for batch_index in xrange(n_train_batches): + c.append(train_da(batch_index)) + + print 'Training epoch %d, cost ' % epoch, numpy.mean(c) + + end_time = time.clock() + + training_time = (end_time - start_time) + + print >> sys.stderr, ('The no corruption code for file ' + + os.path.split(__file__)[1] + + ' ran for %.2fm' % ((training_time) / 60.)) + image = Image.fromarray( + tile_raster_images(X=da.W.get_value(borrow=True).T, + img_shape=(28, 28), tile_shape=(10, 10), + tile_spacing=(1, 1))) + image.save('filters_corruption_0.png') + + ##################################### + # BUILDING THE MODEL CORRUPTION 30% # + ##################################### + + rng = numpy.random.RandomState(123) + theano_rng = RandomStreams(rng.randint(2 ** 30)) + + da = dA( + numpy_rng=rng, + theano_rng=theano_rng, + input=x, + n_visible=28 * 28, + n_hidden=500 + ) + + cost, updates = da.get_cost_updates( + corruption_level=0.3, + learning_rate=learning_rate + ) + + train_da = theano.function( + [index], + cost, + updates=updates, + givens={ + x: train_set_x[index * batch_size: (index + 1) * batch_size] + } + ) + + start_time = time.clock() + + ############ + # TRAINING # + ############ + + # go through training epochs + for epoch in xrange(training_epochs): + # go through trainng set + c = [] + for batch_index in xrange(n_train_batches): + c.append(train_da(batch_index)) + + print 'Training epoch %d, cost ' % epoch, numpy.mean(c) + + end_time = time.clock() + + training_time = (end_time - start_time) + + print >> sys.stderr, ('The 30% corruption code for file ' + + os.path.split(__file__)[1] + + ' ran for %.2fm' % (training_time / 60.)) + + image = Image.fromarray(tile_raster_images( + X=da.W.get_value(borrow=True).T, + img_shape=(28, 28), tile_shape=(10, 10), + tile_spacing=(1, 1))) + image.save('filters_corruption_30.png') + + os.chdir('../') + + +if __name__ == '__main__': + test_dA() diff --git a/easygui.py b/easygui.py new file mode 100644 index 0000000..26fc1f5 --- /dev/null +++ b/easygui.py @@ -0,0 +1,2764 @@ +""" + +.. moduleauthor:: Stephen Raymond Ferg +.. default-domain:: py +.. highlight:: python + +Version |release| + +ABOUT EASYGUI +============= + +EasyGui provides an easy-to-use interface for simple GUI interaction +with a user. It does not require the programmer to know anything about +tkinter, frames, widgets, callbacks or lambda. All GUI interactions are +invoked by simple function calls that return results. + +.. warning:: Using EasyGui with IDLE + + You may encounter problems using IDLE to run programs that use EasyGui. Try it + and find out. EasyGui is a collection of Tkinter routines that run their own + event loops. IDLE is also a Tkinter application, with its own event loop. The + two may conflict, with unpredictable results. If you find that you have + problems, try running your EasyGui program outside of IDLE. + +.. note:: EasyGui requires Tk release 8.0 or greater. + +LICENSE INFORMATION +=================== +EasyGui version |version| + +Copyright (c) 2014, Stephen Raymond Ferg + +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. + + 3. The name of the author may not be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +ABOUT THE EASYGUI LICENSE +------------------------- +| This license is what is generally known as the "modified BSD license", +| aka "revised BSD", "new BSD", "3-clause BSD". +| See http://www.opensource.org/licenses/bsd-license.php +| +| This license is GPL-compatible. +| See ``_ +| See http://www.gnu.org/licenses/license-list.html#GPLCompatibleLicenses +| +| The BSD License is less restrictive than GPL. +| It allows software released under the license to be incorporated into proprietary products. +| Works based on the software may be released under a proprietary license or as closed source software. +| ``_ + +API +=== +""" +eg_version = __doc__.split()[1] + +__all__ = [ + 'ynbox' + , 'ccbox' + , 'boolbox' + , 'indexbox' + , 'msgbox' + , 'buttonbox' + , 'integerbox' + , 'multenterbox' + , 'enterbox' + , 'exceptionbox' + , 'choicebox' + , 'codebox' + , 'textbox' + , 'diropenbox' + , 'fileopenbox' + , 'filesavebox' + , 'passwordbox' + , 'multpasswordbox' + , 'multchoicebox' + , 'abouteasygui' + , 'eg_version' + , 'egdemo' + , 'EgStore' +] + +import os +import sys +import string +import pickle +import traceback + + +# -------------------------------------------------- +# check python version and take appropriate action +# -------------------------------------------------- +""" +From the python documentation: + +sys.hexversion contains the version number encoded as a single integer. This is +guaranteed to increase with each version, including proper support for non- +production releases. For example, to test that the Python interpreter is at +least version 1.5.2, use: + +if sys.hexversion >= 0x010502F0: + # use some advanced feature + ... +else: + # use an alternative implementation or warn the user + ... +""" + +if sys.hexversion >= 0x020600F0: + runningPython26 = True +else: + runningPython26 = False + +if sys.hexversion >= 0x030000F0: + runningPython3 = True +else: + runningPython3 = False + +# Try to import the Python Image Library. If it doesn't exist, only .gif images are supported. +try: + from PIL import Image as PILImage + from PIL import ImageTk as PILImageTk +except: + pass + +if runningPython3: + from tkinter import * + import tkinter.filedialog as tk_FileDialog + from io import StringIO +else: + from Tkinter import * + import tkFileDialog as tk_FileDialog + from StringIO import StringIO + +# Set up basestring appropriately +if runningPython3: + basestring = str + +def write(*args): + args = [str(arg) for arg in args] + args = " ".join(args) + sys.stdout.write(args) + + +def writeln(*args): + write(*args) + sys.stdout.write("\n") + + +if TkVersion < 8.0: + stars = "*" * 75 + writeln("""\n\n\n""" + stars + """ +You are running Tk version: """ + str(TkVersion) + """ +You must be using Tk version 8.0 or greater to use EasyGui. +Terminating. +""" + stars + """\n\n\n""") + sys.exit(0) + +rootWindowPosition = "+300+200" + +PROPORTIONAL_FONT_FAMILY = ("MS", "Sans", "Serif") +MONOSPACE_FONT_FAMILY = ("Courier") + +PROPORTIONAL_FONT_SIZE = 10 +MONOSPACE_FONT_SIZE = 9 # a little smaller, because it it more legible at a smaller size +TEXT_ENTRY_FONT_SIZE = 12 # a little larger makes it easier to see + + +STANDARD_SELECTION_EVENTS = ["Return", "Button-1", "space"] + +# Initialize some global variables that will be reset later +__choiceboxMultipleSelect = None +__replyButtonText = None +__choiceboxResults = None +__firstWidget = None +__enterboxText = None +__enterboxDefaultText = "" +__multenterboxText = "" +choiceboxChoices = None +choiceboxWidget = None +entryWidget = None +boxRoot = None + +#------------------------------------------------------------------- +# various boxes built on top of the basic buttonbox +#----------------------------------------------------------------------- + +#----------------------------------------------------------------------- +# ynbox +#----------------------------------------------------------------------- +def ynbox(msg="Shall I continue?" + , title=" " + , choices=("[]Yes", "[]No") + , image=None + , default_choice='[]Yes' + , cancel_choice='[]No'): + """ + Display a msgbox with choices of Yes and No. + + The returned value is calculated this way:: + + if the first choice ("Yes") is chosen, or if the dialog is cancelled: + return True + else: + return False + + If invoked without a msg argument, displays a generic request for a confirmation + that the user wishes to continue. So it can be used this way:: + + if ynbox(): + pass # continue + else: + sys.exit(0) # exit the program + + :param msg: the msg to be displayed + :type msg: str + :param str title: the window title + :param list choices: a list or tuple of the choices to be displayed + :param str image: Filename of image to display + :param str default_choice: The choice you want highlighted when the gui appears + :param str cancel_choice: If the user presses the 'X' close, which button should be pressed + + :return: True if 'Yes' or dialog is cancelled, False if 'No' + """ + return boolbox(msg=msg, + title=title, + choices=choices, + image=image, + default_choice=default_choice, + cancel_choice=cancel_choice) + +#----------------------------------------------------------------------- +# ccbox +#----------------------------------------------------------------------- +def ccbox(msg="Shall I continue?" + , title=" " + , choices=("C[o]ntinue", "C[a]ncel") + , image=None + , default_choice='Continue' + , cancel_choice='Cancel'): + """ + Display a msgbox with choices of Continue and Cancel. + + The returned value is calculated this way:: + + if the first choice ("Continue") is chosen, or if the dialog is cancelled: + return True + else: + return False + + If invoked without a msg argument, displays a generic request for a confirmation + that the user wishes to continue. So it can be used this way:: + + if ccbox(): + pass # continue + else: + sys.exit(0) # exit the program + + :param str msg: the msg to be displayed + :param str title: the window title + :param list choices: a list or tuple of the choices to be displayed + :param str image: Filename of image to display + :param str default_choice: The choice you want highlighted when the gui appears + :param str cancel_choice: If the user presses the 'X' close, which button should be pressed + + :return: True if 'Continue' or dialog is cancelled, False if 'Cancel' + """ + return boolbox(msg=msg, + title=title, + choices=choices, + image=image, + default_choice=default_choice, + cancel_choice=cancel_choice) + +#----------------------------------------------------------------------- +# boolbox +#----------------------------------------------------------------------- +def boolbox(msg="Shall I continue?" + , title=" " + , choices=("[Y]es", "[N]o") + , image=None + , default_choice='Yes' + , cancel_choice='No'): + """ + Display a boolean msgbox. + + The returned value is calculated this way:: + + if the first choice is chosen, or if the dialog is cancelled: + returns True + else: + returns False + + :param str msg: the msg to be displayed + :param str title: the window title + :param list choices: a list or tuple of the choices to be displayed + :param str image: Filename of image to display + :param str default_choice: The choice you want highlighted when the gui appears + :param str cancel_choice: If the user presses the 'X' close, which button should be pressed + :return: True if first button pressed or dialog is cancelled, False if second button is pressed + """ + if len(choices) != 2: + raise AssertionError('boolbox takes exactly 2 choices! Consider using indexbox instead') + + reply = buttonbox(msg=msg, + title=title, + choices=choices, + image=image, + default_choice=default_choice, + cancel_choice=cancel_choice) + if reply == choices[0]: + return True + else: + return False + + +#----------------------------------------------------------------------- +# indexbox +#----------------------------------------------------------------------- +def indexbox(msg="Shall I continue?" + , title=" " + , choices=("Yes", "No") + , image=None + , default_choice='Yes' + , cancel_choice='No'): + """ + Display a buttonbox with the specified choices. + + :param str msg: the msg to be displayed + :param str title: the window title + :param list choices: a list or tuple of the choices to be displayed + :param str image: Filename of image to display + :param str default_choice: The choice you want highlighted when the gui appears + :param str cancel_choice: If the user presses the 'X' close, which button should be pressed + :return: the index of the choice selected, starting from 0 + """ + reply = buttonbox(msg=msg, + title=title, + choices=choices, + image=image, + default_choice=default_choice, + cancel_choice=cancel_choice) + if reply is None: + return None + for i, choice in enumerate(choices): + if reply == choice: + return i + msg = "There is a program logic error in the EasyGui code for indexbox.\nreply={0}, choices={1}".format(reply, choices) + raise AssertionError(msg) + + +#----------------------------------------------------------------------- +# msgbox +#----------------------------------------------------------------------- +def msgbox(msg="(Your message goes here)" + , title=" " + , ok_button="OK" + , image=None + , root=None): + """ + Display a message box + + :param str msg: the msg to be displayed + :param str title: the window title + :param str ok_button: text to show in the button + :param str image: Filename of image to display + :param tk_widget root: Top-level Tk widget + :return: the text of the ok_button + """ + if not isinstance(ok_button, basestring): + raise AssertionError("The 'ok_button' argument to msgbox must be a string.") + + return buttonbox(msg=msg, + title=title, + choices=[ok_button], + image=image, + root=root, + default_choice=ok_button, + cancel_choice=ok_button) + + +#------------------------------------------------------------------- +# buttonbox +#------------------------------------------------------------------- +def buttonbox(msg="" + , title=" " + , choices=("Button[1]", "Button[2]", "Button[3]") + , image=None + , root=None + , default_choice=None + , cancel_choice=None): + """ + Display a msg, a title, an image, and a set of buttons. + The buttons are defined by the members of the choices list. + + :param str msg: the msg to be displayed + :param str title: the window title + :param list choices: a list or tuple of the choices to be displayed + :param str image: Filename of image to display + :param str default_choice: The choice you want highlighted when the gui appears + :param str cancel_choice: If the user presses the 'X' close, which button should be pressed + :return: the text of the button that the user selected + """ + global boxRoot, __replyButtonText, buttonsFrame + + # If default is not specified, select the first button. This matches old behavior. + if default_choice is None: + default_choice = choices[0] + + # Initialize __replyButtonText to the first choice. + # This is what will be used if the window is closed by the close button. + __replyButtonText = choices[0] + + if root: + root.withdraw() + boxRoot = Toplevel(master=root) + boxRoot.withdraw() + else: + boxRoot = Tk() + boxRoot.withdraw() + + + boxRoot.title(title) + boxRoot.iconname('Dialog') + boxRoot.geometry(rootWindowPosition) + boxRoot.minsize(400, 100) + + + # ------------- define the messageFrame --------------------------------- + messageFrame = Frame(master=boxRoot) + messageFrame.pack(side=TOP, fill=BOTH) + + # ------------- define the imageFrame --------------------------------- + if image: + tk_Image = None + try: + tk_Image = __load_tk_image(image) + except Exception as inst: + print(inst) + if tk_Image: + imageFrame = Frame(master=boxRoot) + imageFrame.pack(side=TOP, fill=BOTH) + label = Label(imageFrame, image=tk_Image) + label.image = tk_Image # keep a reference! + label.pack(side=TOP, expand=YES, fill=X, padx='1m', pady='1m') + + # ------------- define the buttonsFrame --------------------------------- + buttonsFrame = Frame(master=boxRoot) + buttonsFrame.pack(side=TOP, fill=BOTH) + + # -------------------- place the widgets in the frames ----------------------- + messageWidget = Message(messageFrame, text=msg, width=400) + messageWidget.configure(font=(PROPORTIONAL_FONT_FAMILY, PROPORTIONAL_FONT_SIZE)) + messageWidget.pack(side=TOP, expand=YES, fill=X, padx='3m', pady='3m') + + __put_buttons_in_buttonframe(choices, default_choice, cancel_choice) + + # -------------- the action begins ----------- + boxRoot.deiconify() + boxRoot.mainloop() + boxRoot.destroy() + if root: + root.deiconify() + return __replyButtonText + + +#------------------------------------------------------------------- +# integerbox +#------------------------------------------------------------------- +def integerbox(msg="" + , title=" " + , default="" + , lowerbound=0 + , upperbound=99 + , image=None + , root=None): + """ + Show a box in which a user can enter an integer. + + In addition to arguments for msg and title, this function accepts + integer arguments for "default", "lowerbound", and "upperbound". + + The default argument may be None. + + When the user enters some text, the text is checked to verify that it + can be converted to an integer between the lowerbound and upperbound. + + If it can be, the integer (not the text) is returned. + + If it cannot, then an error msg is displayed, and the integerbox is + redisplayed. + + If the user cancels the operation, None is returned. + + :param str msg: the msg to be displayed + :param str title: the window title + :param str default: The default value to return + :param int lowerbound: The lower-most value allowed + :param int upperbound: The upper-most value allowed + :param str image: Filename of image to display + :param tk_widget root: Top-level Tk widget + :return: the integer value entered by the user + + """ + + + if not msg: + msg = "Enter an integer between {0} and {1}".format(lowerbound, upperbound) + + # Validate the arguments for default, lowerbound and upperbound and convert to integers + exception_string = 'integerbox "{0}" must be an integer. It is >{1}< of type {2}' + if default: + try: + default=int(default) + except ValueError: + raise ValueError(exception_string.format('default', default, type(default))) + try: + lowerbound=int(lowerbound) + except ValueError: + raise ValueError(exception_string.format('lowerbound', lowerbound, type(lowerbound))) + try: + upperbound=int(upperbound) + except ValueError: + raise ValueError(exception_string.format('upperbound', upperbound, type(upperbound))) + + while 1: + reply = enterbox(msg, title, str(default), image=image, root=root) + if reply is None: + return None + try: + reply = int(reply) + except: + msgbox('The value that you entered:\n\t"{}"\nis not an integer.'.format(reply) + , "Error") + continue + if reply < lowerbound: + msgbox('The value that you entered is less than the lower bound of {}.'.format(lowerbound) + , "Error") + continue + if reply > upperbound: + msgbox('The value that you entered is greater than the upper bound of {}.'.format(upperbound) + , "Error") + continue + # reply has passed all validation checks. + # It is an integer between the specified bounds. + return reply + + +#------------------------------------------------------------------- +# multenterbox +#------------------------------------------------------------------- +# TODO RL: Should defaults be list constructors. i think after multiple calls, the value is retained. +# TODO RL: Rename/alias to multienterbox? +# default should be None and then in the logic create an empty list. +def multenterbox(msg="Fill in values for the fields." + , title=" " + , fields=() + , values=()): + r""" + Show screen with multiple data entry fields. + + If there are fewer values than names, the list of values is padded with + empty strings until the number of values is the same as the number of names. + + If there are more values than names, the list of values + is truncated so that there are as many values as names. + + Returns a list of the values of the fields, + or None if the user cancels the operation. + + Here is some example code, that shows how values returned from + multenterbox can be checked for validity before they are accepted:: + + msg = "Enter your personal information" + title = "Credit Card Application" + fieldNames = ["Name","Street Address","City","State","ZipCode"] + fieldValues = [] # we start with blanks for the values + fieldValues = multenterbox(msg,title, fieldNames) + + # make sure that none of the fields was left blank + while 1: + if fieldValues is None: break + errmsg = "" + for i in range(len(fieldNames)): + if fieldValues[i].strip() == "": + errmsg += ('"%s" is a required field.\n\n' % fieldNames[i]) + if errmsg == "": + break # no problems found + fieldValues = multenterbox(errmsg, title, fieldNames, fieldValues) + + writeln("Reply was: %s" % str(fieldValues)) + + :param str msg: the msg to be displayed. + :param str title: the window title + :param list fields: a list of fieldnames. + :param list values: a list of field values + :return: String + """ + return __multfillablebox(msg, title, fields, values, None) + + +#----------------------------------------------------------------------- +# multpasswordbox +#----------------------------------------------------------------------- +def multpasswordbox(msg="Fill in values for the fields." + , title=" " + , fields=tuple() + , values=tuple()): + r""" + Same interface as multenterbox. But in multpassword box, + the last of the fields is assumed to be a password, and + is masked with asterisks. + + :param str msg: the msg to be displayed. + :param str title: the window title + :param list fields: a list of fieldnames. + :param list values: a list of field values + :return: String + + **Example** + + Here is some example code, that shows how values returned from + multpasswordbox can be checked for validity before they are accepted:: + + msg = "Enter logon information" + title = "Demo of multpasswordbox" + fieldNames = ["Server ID", "User ID", "Password"] + fieldValues = [] # we start with blanks for the values + fieldValues = multpasswordbox(msg,title, fieldNames) + + # make sure that none of the fields was left blank + while 1: + if fieldValues is None: break + errmsg = "" + for i in range(len(fieldNames)): + if fieldValues[i].strip() == "": + errmsg = errmsg + ('"%s" is a required field.\n\n' % fieldNames[i]) + if errmsg == "": break # no problems found + fieldValues = multpasswordbox(errmsg, title, fieldNames, fieldValues) + + writeln("Reply was: %s" % str(fieldValues)) + + """ + return __multfillablebox(msg, title, fields, values, "*") + + +def bindArrows(widget): + + widget.bind("", tabRight) + widget.bind("", tabLeft) + + widget.bind("", tabRight) + widget.bind("", tabLeft) + + +def tabRight(event): + boxRoot.event_generate("") + + +def tabLeft(event): + boxRoot.event_generate("") + + +#----------------------------------------------------------------------- +# __multfillablebox +#----------------------------------------------------------------------- +def __multfillablebox(msg="Fill in values for the fields." + , title=" " + , fields=() + , values=() + , mask=None): + global boxRoot, __multenterboxText, __multenterboxDefaultText, cancelButton, entryWidget, okButton + + choices = ["OK", "Cancel"] + if len(fields) == 0: + return None + + fields = list(fields[:]) # convert possible tuples to a list + values = list(values[:]) # convert possible tuples to a list + + #TODO RL: The following seems incorrect when values>fields. Replace below with zip? + if len(values) == len(fields): + pass + elif len(values) > len(fields): + fields = fields[0:len(values)] + else: + while len(values) < len(fields): + values.append("") + + boxRoot = Tk() + + boxRoot.protocol('WM_DELETE_WINDOW', denyWindowManagerClose) + boxRoot.title(title) + boxRoot.iconname('Dialog') + boxRoot.geometry(rootWindowPosition) + boxRoot.bind("", __multenterboxCancel) + + # -------------------- put subframes in the boxRoot -------------------- + messageFrame = Frame(master=boxRoot) + messageFrame.pack(side=TOP, fill=BOTH) + + #-------------------- the msg widget ---------------------------- + messageWidget = Message(messageFrame, width="4.5i", text=msg) + messageWidget.configure(font=(PROPORTIONAL_FONT_FAMILY, PROPORTIONAL_FONT_SIZE)) + messageWidget.pack(side=RIGHT, expand=1, fill=BOTH, padx='3m', pady='3m') + + global entryWidgets + entryWidgets = list() + + lastWidgetIndex = len(fields) - 1 + + for widgetIndex in range(len(fields)): + argFieldName = fields[widgetIndex] + argFieldValue = values[widgetIndex] + entryFrame = Frame(master=boxRoot) + entryFrame.pack(side=TOP, fill=BOTH) + + # --------- entryWidget ---------------------------------------------- + labelWidget = Label(entryFrame, text=argFieldName) + labelWidget.pack(side=LEFT) + + entryWidget = Entry(entryFrame, width=40, highlightthickness=2) + entryWidgets.append(entryWidget) + entryWidget.configure(font=(PROPORTIONAL_FONT_FAMILY, TEXT_ENTRY_FONT_SIZE)) + entryWidget.pack(side=RIGHT, padx="3m") + + bindArrows(entryWidget) + + entryWidget.bind("", __multenterboxGetText) + entryWidget.bind("", __multenterboxCancel) + + # for the last entryWidget, if this is a multpasswordbox, + # show the contents as just asterisks + if widgetIndex == lastWidgetIndex: + if mask: + entryWidgets[widgetIndex].configure(show=mask) + + # put text into the entryWidget + entryWidgets[widgetIndex].insert(0, argFieldValue) + widgetIndex += 1 + + # ------------------ ok button ------------------------------- + buttonsFrame = Frame(master=boxRoot) + buttonsFrame.pack(side=BOTTOM, fill=BOTH) + + okButton = Button(buttonsFrame, takefocus=1, text="OK") + bindArrows(okButton) + okButton.pack(expand=1, side=LEFT, padx='3m', pady='3m', ipadx='2m', ipady='1m') + + # for the commandButton, bind activation events to the activation event handler + commandButton = okButton + handler = __multenterboxGetText + for selectionEvent in STANDARD_SELECTION_EVENTS: + commandButton.bind("<%s>" % selectionEvent, handler) + + + # ------------------ cancel button ------------------------------- + cancelButton = Button(buttonsFrame, takefocus=1, text="Cancel") + bindArrows(cancelButton) + cancelButton.pack(expand=1, side=RIGHT, padx='3m', pady='3m', ipadx='2m', ipady='1m') + + # for the commandButton, bind activation events to the activation event handler + commandButton = cancelButton + handler = __multenterboxCancel + for selectionEvent in STANDARD_SELECTION_EVENTS: + commandButton.bind("<%s>" % selectionEvent, handler) + + + # ------------------- time for action! ----------------- + entryWidgets[0].focus_force() # put the focus on the entryWidget + boxRoot.mainloop() # run it! + + # -------- after the run has completed ---------------------------------- + boxRoot.destroy() # button_click didn't destroy boxRoot, so we do it now + return __multenterboxText + + +#----------------------------------------------------------------------- +# __multenterboxGetText +#----------------------------------------------------------------------- +def __multenterboxGetText(event): + global __multenterboxText + + __multenterboxText = list() + for entryWidget in entryWidgets: + __multenterboxText.append(entryWidget.get()) + boxRoot.quit() + + +def __multenterboxCancel(event): + global __multenterboxText + __multenterboxText = None + boxRoot.quit() + + +#------------------------------------------------------------------- +# enterbox +#------------------------------------------------------------------- +def enterbox(msg="Enter something." + , title=" " + , default="" + , strip=True + , image=None + , root=None): + """ + Show a box in which a user can enter some text. + + You may optionally specify some default text, which will appear in the + enterbox when it is displayed. + + Example:: + + reply = enterbox(....) + if reply: + ... + else: + ... + + :param str msg: the msg to be displayed. + :param str title: the window title + :param str default: value returned if user does not change it + :param bool strip: If True, the return value will have its whitespace stripped before being returned + :return: the text that the user entered, or None if he cancels the operation. + """ + result = __fillablebox(msg, title, default=default, mask=None, image=image, root=root) + if result and strip: + result = result.strip() + return result + + +def passwordbox(msg="Enter your password." + , title=" " + , default="" + , image=None + , root=None): + """ + Show a box in which a user can enter a password. + The text is masked with asterisks, so the password is not displayed. + + :param str msg: the msg to be displayed. + :param str title: the window title + :param str default: value returned if user does not change it + :return: the text that the user entered, or None if he cancels the operation. + """ + return __fillablebox(msg, title, default, mask="*", image=image, root=root) + + +def __load_tk_image(filename): + """ + Load in an image file and return as a tk Image. + + :param filename: image filename to load + :return: tk Image object + """ + + if filename is None: + return None + + if not os.path.isfile(filename): + raise ValueError('Image file {} does not exist.'.format(filename)) + + tk_image = None + + filename = os.path.normpath(filename) + _, ext = os.path.splitext(filename) + + try: + pil_image = PILImage.open(filename) + tk_image = PILImageTk.PhotoImage(pil_image) + except: + try: + tk_image = PhotoImage(file=filename) #Fallback if PIL isn't available + except: + msg = "Cannot load {}. Check to make sure it is an image file.".format(filename) + try: + _ = PILImage + except: + msg += "\nPIL library isn't installed. If it isn't installed, only .gif files can be used." + raise ValueError(msg) + return tk_image + + +def __fillablebox(msg + , title="" + , default="" + , mask=None + , image=None + , root=None): + """ + Show a box in which a user can enter some text. + You may optionally specify some default text, which will appear in the + enterbox when it is displayed. + Returns the text that the user entered, or None if he cancels the operation. + """ + + global boxRoot, __enterboxText, __enterboxDefaultText + global cancelButton, entryWidget, okButton + + if title is None: + title == "" + if default is None: + default = "" + __enterboxDefaultText = default + __enterboxText = __enterboxDefaultText + + if root: + root.withdraw() + boxRoot = Toplevel(master=root) + boxRoot.withdraw() + else: + boxRoot = Tk() + boxRoot.withdraw() + + boxRoot.protocol('WM_DELETE_WINDOW', denyWindowManagerClose) + boxRoot.title(title) + boxRoot.iconname('Dialog') + boxRoot.geometry(rootWindowPosition) + boxRoot.bind("", __enterboxCancel) + + # ------------- define the messageFrame --------------------------------- + messageFrame = Frame(master=boxRoot) + messageFrame.pack(side=TOP, fill=BOTH) + + # ------------- define the imageFrame --------------------------------- + + try: + tk_Image = __load_tk_image(image) + except Exception as inst: + print(inst) + tk_Image = None + if tk_Image: + imageFrame = Frame(master=boxRoot) + imageFrame.pack(side=TOP, fill=BOTH) + label = Label(imageFrame, image=tk_Image) + label.image = tk_Image # keep a reference! + label.pack(side=TOP, expand=YES, fill=X, padx='1m', pady='1m') + + # ------------- define the buttonsFrame --------------------------------- + buttonsFrame = Frame(master=boxRoot) + buttonsFrame.pack(side=TOP, fill=BOTH) + + + # ------------- define the entryFrame --------------------------------- + entryFrame = Frame(master=boxRoot) + entryFrame.pack(side=TOP, fill=BOTH) + + # ------------- define the buttonsFrame --------------------------------- + buttonsFrame = Frame(master=boxRoot) + buttonsFrame.pack(side=TOP, fill=BOTH) + + #-------------------- the msg widget ---------------------------- + messageWidget = Message(messageFrame, width="4.5i", text=msg) + messageWidget.configure(font=(PROPORTIONAL_FONT_FAMILY, PROPORTIONAL_FONT_SIZE)) + messageWidget.pack(side=RIGHT, expand=1, fill=BOTH, padx='3m', pady='3m') + + # --------- entryWidget ---------------------------------------------- + entryWidget = Entry(entryFrame, width=40) + bindArrows(entryWidget) + entryWidget.configure(font=(PROPORTIONAL_FONT_FAMILY, TEXT_ENTRY_FONT_SIZE)) + if mask: + entryWidget.configure(show=mask) + entryWidget.pack(side=LEFT, padx="3m") + entryWidget.bind("", __enterboxGetText) + entryWidget.bind("", __enterboxCancel) + # put text into the entryWidget + entryWidget.insert(0, __enterboxDefaultText) + + # ------------------ ok button ------------------------------- + okButton = Button(buttonsFrame, takefocus=1, text="OK") + bindArrows(okButton) + okButton.pack(expand=1, side=LEFT, padx='3m', pady='3m', ipadx='2m', ipady='1m') + + # for the commandButton, bind activation events to the activation event handler + commandButton = okButton + handler = __enterboxGetText + for selectionEvent in STANDARD_SELECTION_EVENTS: + commandButton.bind("<{}>".format(selectionEvent), handler) + + + # ------------------ cancel button ------------------------------- + cancelButton = Button(buttonsFrame, takefocus=1, text="Cancel") + bindArrows(cancelButton) + cancelButton.pack(expand=1, side=RIGHT, padx='3m', pady='3m', ipadx='2m', ipady='1m') + + # for the commandButton, bind activation events to the activation event handler + commandButton = cancelButton + handler = __enterboxCancel + for selectionEvent in STANDARD_SELECTION_EVENTS: + commandButton.bind("<{}>".format(selectionEvent), handler) + + # ------------------- time for action! ----------------- + entryWidget.focus_force() # put the focus on the entryWidget + boxRoot.deiconify() + boxRoot.mainloop() # run it! + + # -------- after the run has completed ---------------------------------- + if root: + root.deiconify() + boxRoot.destroy() # button_click didn't destroy boxRoot, so we do it now + return __enterboxText + + +def __enterboxGetText(event): + global __enterboxText + + __enterboxText = entryWidget.get() + boxRoot.quit() + + +def __enterboxRestore(event): + global entryWidget + + entryWidget.delete(0, len(entryWidget.get())) + entryWidget.insert(0, __enterboxDefaultText) + + +def __enterboxCancel(event): + global __enterboxText + + __enterboxText = None + boxRoot.quit() + + +def denyWindowManagerClose(): + """ don't allow WindowManager close + """ + x = Tk() + x.withdraw() + x.bell() + x.destroy() + + +#------------------------------------------------------------------- +# multchoicebox +#------------------------------------------------------------------- +def multchoicebox(msg="Pick as many items as you like." + , title=" " + , choices=() + , **kwargs): + """ + Present the user with a list of choices. + allow him to select multiple items and return them in a list. + if the user doesn't choose anything from the list, return the empty list. + return None if he cancelled selection. + + :param str msg: the msg to be displayed + :param str title: the window title + :param list choices: a list or tuple of the choices to be displayed + :return: List containing choice selected or None if cancelled + + """ + if len(choices) == 0: + choices = ["Program logic error - no choices were specified."] + + global __choiceboxMultipleSelect + __choiceboxMultipleSelect = 1 + return __choicebox(msg, title, choices) + + +#----------------------------------------------------------------------- +# choicebox +#----------------------------------------------------------------------- +def choicebox(msg="Pick something." + , title=" " + , choices=()): + """ + Present the user with a list of choices. + return the choice that he selects. + + :param str msg: the msg to be displayed + :param str title: the window title + :param list choices: a list or tuple of the choices to be displayed + :return: List containing choice selected or None if cancelled + """ + if len(choices) == 0: + choices = ["Program logic error - no choices were specified."] + + global __choiceboxMultipleSelect + __choiceboxMultipleSelect = 0 + return __choicebox(msg, title, choices) + + +#----------------------------------------------------------------------- +# __choicebox +#----------------------------------------------------------------------- +def __choicebox(msg + , title + , choices): + """ + internal routine to support choicebox() and multchoicebox() + """ + global boxRoot, __choiceboxResults, choiceboxWidget, defaultText + global choiceboxWidget, choiceboxChoices + #------------------------------------------------------------------- + # If choices is a tuple, we make it a list so we can sort it. + # If choices is already a list, we make a new list, so that when + # we sort the choices, we don't affect the list object that we + # were given. + #------------------------------------------------------------------- + choices = list(choices[:]) + + if len(choices) == 0: + choices = ["Program logic error - no choices were specified."] + defaultButtons = ["OK", "Cancel"] + + choices = [str(c) for c in choices] + + #TODO RL: lines_to_show is set to a min and then set to 20 right after that. Figure out why. + lines_to_show = min(len(choices), 20) + lines_to_show = 20 + + if title is None: + title = "" + + # Initialize __choiceboxResults + # This is the value that will be returned if the user clicks the close icon + __choiceboxResults = None + + boxRoot = Tk() + #boxRoot.protocol('WM_DELETE_WINDOW', denyWindowManagerClose ) #RL: Removed so top-level program can be closed with an 'x' + screen_width = boxRoot.winfo_screenwidth() + screen_height = boxRoot.winfo_screenheight() + root_width = int((screen_width * 0.8)) + root_height = int((screen_height * 0.5)) + root_xpos = int((screen_width * 0.1)) + root_ypos = int((screen_height * 0.05)) + + boxRoot.title(title) + boxRoot.iconname('Dialog') + rootWindowPosition = "+0+0" + boxRoot.geometry(rootWindowPosition) + boxRoot.expand = NO + boxRoot.minsize(root_width, root_height) + rootWindowPosition = '+{0}+{1}'.format(root_xpos, root_ypos) + boxRoot.geometry(rootWindowPosition) + + # ---------------- put the frames in the window ----------------------------------------- + message_and_buttonsFrame = Frame(master=boxRoot) + message_and_buttonsFrame.pack(side=TOP, fill=X, expand=NO) + + messageFrame = Frame(message_and_buttonsFrame) + messageFrame.pack(side=LEFT, fill=X, expand=YES) + + buttonsFrame = Frame(message_and_buttonsFrame) + buttonsFrame.pack(side=RIGHT, expand=NO, pady=0) + + choiceboxFrame = Frame(master=boxRoot) + choiceboxFrame.pack(side=BOTTOM, fill=BOTH, expand=YES) + + # -------------------------- put the widgets in the frames ------------------------------ + + # ---------- put a msg widget in the msg frame------------------- + messageWidget = Message(messageFrame, anchor=NW, text=msg, width=int(root_width * 0.9)) + messageWidget.configure(font=(PROPORTIONAL_FONT_FAMILY, PROPORTIONAL_FONT_SIZE)) + messageWidget.pack(side=LEFT, expand=YES, fill=BOTH, padx='1m', pady='1m') + + # -------- put the choiceboxWidget in the choiceboxFrame --------------------------- + choiceboxWidget = Listbox(choiceboxFrame + , height=lines_to_show + , borderwidth="1m" + , relief="flat" + , bg="white" + ) + + if __choiceboxMultipleSelect: + choiceboxWidget.configure(selectmode=MULTIPLE) + + choiceboxWidget.configure(font=(PROPORTIONAL_FONT_FAMILY, PROPORTIONAL_FONT_SIZE)) + + # add a vertical scrollbar to the frame + rightScrollbar = Scrollbar(choiceboxFrame, orient=VERTICAL, command=choiceboxWidget.yview) + choiceboxWidget.configure(yscrollcommand=rightScrollbar.set) + + # add a horizontal scrollbar to the frame + bottomScrollbar = Scrollbar(choiceboxFrame, orient=HORIZONTAL, command=choiceboxWidget.xview) + choiceboxWidget.configure(xscrollcommand=bottomScrollbar.set) + + # pack the Listbox and the scrollbars. Note that although we must define + # the textArea first, we must pack it last, so that the bottomScrollbar will + # be located properly. + + bottomScrollbar.pack(side=BOTTOM, fill=X) + rightScrollbar.pack(side=RIGHT, fill=Y) + + choiceboxWidget.pack(side=LEFT, padx="1m", pady="1m", expand=YES, fill=BOTH) + + #--------------------------------------------------- + # sort the choices + # eliminate duplicates + # put the choices into the choiceboxWidget + #--------------------------------------------------- + + if runningPython3: + choices.sort(key=str.lower) + else: + choices.sort(lambda x, y: cmp(x.lower(), y.lower())) # case-insensitive sort + + lastInserted = None + choiceboxChoices = list() + for choice in choices: + if choice == lastInserted: + continue + else: + choiceboxWidget.insert(END, choice) + choiceboxChoices.append(choice) + lastInserted = choice + + boxRoot.bind('', KeyboardListener) + + # put the buttons in the buttonsFrame + if len(choices): + okButton = Button(buttonsFrame, takefocus=YES, text="OK", height=1, width=6) + bindArrows(okButton) + okButton.pack(expand=NO, side=TOP, padx='2m', pady='1m', ipady="1m", ipadx="2m") + + # for the commandButton, bind activation events to the activation event handler + commandButton = okButton + handler = __choiceboxGetChoice + for selectionEvent in STANDARD_SELECTION_EVENTS: + commandButton.bind("<%s>" % selectionEvent, handler) + + # now bind the keyboard events + choiceboxWidget.bind("", __choiceboxGetChoice) + choiceboxWidget.bind("", __choiceboxGetChoice) + else: + # now bind the keyboard events + choiceboxWidget.bind("", __choiceboxCancel) + choiceboxWidget.bind("", __choiceboxCancel) + + cancelButton = Button(buttonsFrame, takefocus=YES, text="Cancel", height=1, width=6) + bindArrows(cancelButton) + cancelButton.pack(expand=NO, side=BOTTOM, padx='2m', pady='1m', ipady="1m", ipadx="2m") + + # for the commandButton, bind activation events to the activation event handler + commandButton = cancelButton + handler = __choiceboxCancel + for selectionEvent in STANDARD_SELECTION_EVENTS: + commandButton.bind("<%s>" % selectionEvent, handler) + + # add special buttons for multiple select features + if len(choices) and __choiceboxMultipleSelect: + selectionButtonsFrame = Frame(messageFrame) + selectionButtonsFrame.pack(side=RIGHT, fill=Y, expand=NO) + + selectAllButton = Button(selectionButtonsFrame, text="Select All", height=1, width=6) + bindArrows(selectAllButton) + + selectAllButton.bind("", __choiceboxSelectAll) + selectAllButton.pack(expand=NO, side=TOP, padx='2m', pady='1m', ipady="1m", ipadx="2m") + + clearAllButton = Button(selectionButtonsFrame, text="Clear All", height=1, width=6) + bindArrows(clearAllButton) + clearAllButton.bind("", __choiceboxClearAll) + clearAllButton.pack(expand=NO, side=TOP, padx='2m', pady='1m', ipady="1m", ipadx="2m") + + + # -------------------- bind some keyboard events ---------------------------- + boxRoot.bind("", __choiceboxCancel) + + # --------------------- the action begins ----------------------------------- + # put the focus on the choiceboxWidget, and the select highlight on the first item + choiceboxWidget.select_set(0) + choiceboxWidget.focus_force() + + # --- run it! ----- + boxRoot.mainloop() + try: + boxRoot.destroy() + except: + pass + return __choiceboxResults + + +def __choiceboxGetChoice(event): + global boxRoot, __choiceboxResults, choiceboxWidget + + if __choiceboxMultipleSelect: + __choiceboxResults = [choiceboxWidget.get(index) for index in choiceboxWidget.curselection()] + else: + choice_index = choiceboxWidget.curselection() + __choiceboxResults = choiceboxWidget.get(choice_index) + + boxRoot.quit() + + +def __choiceboxSelectAll(event): + global choiceboxWidget, choiceboxChoices + + choiceboxWidget.selection_set(0, len(choiceboxChoices) - 1) + + +def __choiceboxClearAll(event): + global choiceboxWidget, choiceboxChoices + + choiceboxWidget.selection_clear(0, len(choiceboxChoices) - 1) + + +def __choiceboxCancel(event): + global boxRoot, __choiceboxResults + + __choiceboxResults = None + boxRoot.quit() + + +def KeyboardListener(event): + global choiceboxChoices, choiceboxWidget + key = event.keysym + if len(key) <= 1: + if key in string.printable: + # Find the key in the list. + # before we clear the list, remember the selected member + try: + start_n = int(choiceboxWidget.curselection()[0]) + except IndexError: + start_n = -1 + + ## clear the selection. + choiceboxWidget.selection_clear(0, 'end') + + ## start from previous selection +1 + for n in range(start_n + 1, len(choiceboxChoices)): + item = choiceboxChoices[n] + if item[0].lower() == key.lower(): + choiceboxWidget.selection_set(first=n) + choiceboxWidget.see(n) + return + else: + # has not found it so loop from top + for n, item in enumerate(choiceboxChoices): + if item[0].lower() == key.lower(): + choiceboxWidget.selection_set(first=n) + choiceboxWidget.see(n) + return + + # nothing matched -- we'll look for the next logical choice + for n, item in enumerate(choiceboxChoices): + if item[0].lower() > key.lower(): + if n > 0: + choiceboxWidget.selection_set(first=(n - 1)) + else: + choiceboxWidget.selection_set(first=0) + choiceboxWidget.see(n) + return + + # still no match (nothing was greater than the key) + # we set the selection to the first item in the list + lastIndex = len(choiceboxChoices) - 1 + choiceboxWidget.selection_set(first=lastIndex) + choiceboxWidget.see(lastIndex) + return + + +#----------------------------------------------------------------------- +# exception_format +#----------------------------------------------------------------------- +def exception_format(): + """ + Convert exception info into a string suitable for display. + """ + return "".join(traceback.format_exception( + sys.exc_info()[0] + , sys.exc_info()[1] + , sys.exc_info()[2] + )) + + +#----------------------------------------------------------------------- +# exceptionbox +#----------------------------------------------------------------------- +def exceptionbox(msg=None + ,title=None): + """ + Display a box that gives information about + an exception that has just been raised. + + The caller may optionally pass in a title for the window, or a + msg to accompany the error information. + + Note that you do not need to (and cannot) pass an exception object + as an argument. The latest exception will automatically be used. + + :param str msg: the msg to be displayed + :param str title: the window title + :return: None + + """ + if title is None: + title = "Error Report" + if msg is None: + msg = "An error (exception) has occurred in the program." + + codebox(msg, title, exception_format()) + + +#------------------------------------------------------------------- +# codebox +#------------------------------------------------------------------- + +def codebox(msg="" + , title=" " + , text=""): + """ + Display some text in a monospaced font, with no line wrapping. + This function is suitable for displaying code and text that is + formatted using spaces. + + The text parameter should be a string, or a list or tuple of lines to be + displayed in the textbox. + + :param str msg: the msg to be displayed + :param str title: the window title + :param str text: what to display in the textbox + """ + return textbox(msg, title, text, codebox=1) + + +#------------------------------------------------------------------- +# textbox +#------------------------------------------------------------------- +def textbox(msg="" + , title=" " + , text="" + , codebox=0): + """ + Display some text in a proportional font with line wrapping at word breaks. + This function is suitable for displaying general written text. + + The text parameter should be a string, or a list or tuple of lines to be + displayed in the textbox. + + :param str msg: the msg to be displayed + :param str title: the window title + :param str text: what to display in the textbox + :param str codebox: if 1, act as a codebox + """ + + if msg is None: + msg = "" + if title is None: + title = "" + + global boxRoot, __replyButtonText, __widgetTexts, buttonsFrame + global rootWindowPosition + choices = ["OK"] + __replyButtonText = choices[0] + + boxRoot = Tk() + + boxRoot.protocol('WM_DELETE_WINDOW', denyWindowManagerClose) + + screen_width = boxRoot.winfo_screenwidth() + screen_height = boxRoot.winfo_screenheight() + root_width = int((screen_width * 0.8)) + root_height = int((screen_height * 0.5)) + root_xpos = int((screen_width * 0.1)) + root_ypos = int((screen_height * 0.05)) + + boxRoot.title(title) + boxRoot.iconname('Dialog') + rootWindowPosition = "+0+0" + boxRoot.geometry(rootWindowPosition) + boxRoot.expand = NO + boxRoot.minsize(root_width, root_height) + rootWindowPosition = '+{0}+{1}'.format(root_xpos, root_ypos) + boxRoot.geometry(rootWindowPosition) + + mainframe = Frame(master=boxRoot) + mainframe.pack(side=TOP, fill=BOTH, expand=YES) + + # ---- put frames in the window ----------------------------------- + # we pack the textboxFrame first, so it will expand first + textboxFrame = Frame(mainframe, borderwidth=3) + textboxFrame.pack(side=BOTTOM, fill=BOTH, expand=YES) + + message_and_buttonsFrame = Frame(mainframe) + message_and_buttonsFrame.pack(side=TOP, fill=X, expand=NO) + + messageFrame = Frame(message_and_buttonsFrame) + messageFrame.pack(side=LEFT, fill=X, expand=YES) + + buttonsFrame = Frame(message_and_buttonsFrame) + buttonsFrame.pack(side=RIGHT, expand=NO) + + # -------------------- put widgets in the frames -------------------- + + # put a textArea in the top frame + if codebox: + character_width = int((root_width * 0.6) / MONOSPACE_FONT_SIZE) + textArea = Text(textboxFrame, height=25, width=character_width, padx="2m", pady="1m") + textArea.configure(wrap=NONE) + textArea.configure(font=(MONOSPACE_FONT_FAMILY, MONOSPACE_FONT_SIZE)) + + else: + character_width = int((root_width * 0.6) / MONOSPACE_FONT_SIZE) + textArea = Text( + textboxFrame + , height=25 + , width=character_width + , padx="2m" + , pady="1m" + ) + textArea.configure(wrap=WORD) + textArea.configure(font=(PROPORTIONAL_FONT_FAMILY, PROPORTIONAL_FONT_SIZE)) + + + # some simple keybindings for scrolling + mainframe.bind("", textArea.yview_scroll(1, PAGES)) + mainframe.bind("", textArea.yview_scroll(-1, PAGES)) + + mainframe.bind("", textArea.xview_scroll(1, PAGES)) + mainframe.bind("", textArea.xview_scroll(-1, PAGES)) + + mainframe.bind("", textArea.yview_scroll(1, UNITS)) + mainframe.bind("", textArea.yview_scroll(-1, UNITS)) + + + # add a vertical scrollbar to the frame + rightScrollbar = Scrollbar(textboxFrame, orient=VERTICAL, command=textArea.yview) + textArea.configure(yscrollcommand=rightScrollbar.set) + + # add a horizontal scrollbar to the frame + bottomScrollbar = Scrollbar(textboxFrame, orient=HORIZONTAL, command=textArea.xview) + textArea.configure(xscrollcommand=bottomScrollbar.set) + + # pack the textArea and the scrollbars. Note that although we must define + # the textArea first, we must pack it last, so that the bottomScrollbar will + # be located properly. + + # Note that we need a bottom scrollbar only for code. + # Text will be displayed with wordwrap, so we don't need to have a horizontal + # scroll for it. + if codebox: + bottomScrollbar.pack(side=BOTTOM, fill=X) + rightScrollbar.pack(side=RIGHT, fill=Y) + + textArea.pack(side=LEFT, fill=BOTH, expand=YES) + + + # ---------- put a msg widget in the msg frame------------------- + messageWidget = Message(messageFrame, anchor=NW, text=msg, width=int(root_width * 0.9)) + messageWidget.configure(font=(PROPORTIONAL_FONT_FAMILY, PROPORTIONAL_FONT_SIZE)) + messageWidget.pack(side=LEFT, expand=YES, fill=BOTH, padx='1m', pady='1m') + + # put the buttons in the buttonsFrame + okButton = Button(buttonsFrame, takefocus=YES, text="OK", height=1, width=6) + okButton.pack(expand=NO, side=TOP, padx='2m', pady='1m', ipady="1m", ipadx="2m") + + # for the commandButton, bind activation events to the activation event handler + commandButton = okButton + handler = __textboxOK + for selectionEvent in ["Return", "Button-1", "Escape"]: + commandButton.bind("<%s>" % selectionEvent, handler) + + + # ----------------- the action begins ---------------------------------------- + try: + # load the text into the textArea + if isinstance(text, basestring): + pass + else: + try: + text = "".join(text) # convert a list or a tuple to a string + except: + msgbox("Exception when trying to convert {} to text in textArea".format(type(text))) + sys.exit(16) + textArea.insert('end', text, "normal") + + except: + msgbox("Exception when trying to load the textArea.") + sys.exit(16) + + try: + okButton.focus_force() + except: + msgbox("Exception when trying to put focus on okButton.") + sys.exit(16) + + boxRoot.mainloop() + + # this line MUST go before the line that destroys boxRoot + areaText = textArea.get(0.0, 'end-1c') + boxRoot.destroy() + return areaText # return __replyButtonText + + +#------------------------------------------------------------------- +# __textboxOK +#------------------------------------------------------------------- +def __textboxOK(event): + global boxRoot + boxRoot.quit() + + +#------------------------------------------------------------------- +# diropenbox +#------------------------------------------------------------------- +def diropenbox(msg=None + , title=None + , default=None): + """ + A dialog to get a directory name. + Note that the msg argument, if specified, is ignored. + + Returns the name of a directory, or None if user chose to cancel. + + If the "default" argument specifies a directory name, and that + directory exists, then the dialog box will start with that directory. + + :param str msg: the msg to be displayed + :param str title: the window title + :param str default: starting directory when dialog opens + :return: Normalized path selected by user + """ + title = getFileDialogTitle(msg, title) + localRoot = Tk() + localRoot.withdraw() + if not default: + default = None + f = tk_FileDialog.askdirectory( + parent=localRoot + , title=title + , initialdir=default + , initialfile=None + ) + localRoot.destroy() + if not f: + return None + return os.path.normpath(f) + + +#------------------------------------------------------------------- +# getFileDialogTitle +#------------------------------------------------------------------- +def getFileDialogTitle(msg + ,title): + """ + Create nicely-formatted string based on arguments msg and title + :param msg: the msg to be displayed + :param title: the window title + :return: None + """ + if msg and title: + return "%s - %s" % (title, msg) + if msg and not title: + return str(msg) + if title and not msg: + return str(title) + return None # no message and no title + + +#------------------------------------------------------------------- +# class FileTypeObject for use with fileopenbox +#------------------------------------------------------------------- +class FileTypeObject: + def __init__(self, filemask): + if len(filemask) == 0: + raise AssertionError('Filetype argument is empty.') + + self.masks = list() + + if isinstance(filemask, basestring): # a str or unicode + self.initializeFromString(filemask) + + elif isinstance(filemask, list): + if len(filemask) < 2: + raise AssertionError('Invalid filemask.\n' + + 'List contains less than 2 members: "{}"'.format(filemask)) + else: + self.name = filemask[-1] + self.masks = list(filemask[:-1]) + else: + raise AssertionError('Invalid filemask: "{}"'.format(filemask)) + + def __eq__(self, other): + if self.name == other.name: + return True + return False + + def add(self, other): + for mask in other.masks: + if mask in self.masks: + pass + else: + self.masks.append(mask) + + def toTuple(self): + return self.name, tuple(self.masks) + + def isAll(self): + if self.name == "All files": + return True + return False + + def initializeFromString(self, filemask): + # remove everything except the extension from the filemask + self.ext = os.path.splitext(filemask)[1] + if self.ext == "": + self.ext = ".*" + if self.ext == ".": + self.ext = ".*" + self.name = self.getName() + self.masks = ["*" + self.ext] + + def getName(self): + e = self.ext + file_types = {".*":"All", ".txt":"Text", ".py":"Python", ".pyc":"Python", ".xls":"Excel"} + if e in file_types: + return '{} files'.format(file_types[e]) + if e.startswith("."): + return '{} files'.format(e[1:].upper()) + return '{} files'.format(e.upper()) + + +#------------------------------------------------------------------- +# fileopenbox +#------------------------------------------------------------------- +def fileopenbox(msg=None + , title=None + , default='*' + , filetypes=None + , multiple=False): + """ + A dialog to get a file name. + + **About the "default" argument** + + The "default" argument specifies a filepath that (normally) + contains one or more wildcards. + fileopenbox will display only files that match the default filepath. + If omitted, defaults to "\*" (all files in the current directory). + + WINDOWS EXAMPLE:: + + ...default="c:/myjunk/*.py" + + will open in directory c:\\myjunk\\ and show all Python files. + + WINDOWS EXAMPLE:: + + ...default="c:/myjunk/test*.py" + + will open in directory c:\\myjunk\\ and show all Python files + whose names begin with "test". + + + Note that on Windows, fileopenbox automatically changes the path + separator to the Windows path separator (backslash). + + **About the "filetypes" argument** + + If specified, it should contain a list of items, + where each item is either: + + - a string containing a filemask # e.g. "\*.txt" + - a list of strings, where all of the strings except the last one + are filemasks (each beginning with "\*.", + such as "\*.txt" for text files, "\*.py" for Python files, etc.). + and the last string contains a filetype description + + EXAMPLE:: + + filetypes = ["*.css", ["*.htm", "*.html", "HTML files"] ] + + .. note:: If the filetypes list does not contain ("All files","*"), it will be added. + + If the filetypes list does not contain a filemask that includes + the extension of the "default" argument, it will be added. + For example, if default="\*abc.py" + and no filetypes argument was specified, then + "\*.py" will automatically be added to the filetypes argument. + + :param str msg: the msg to be displayed. + :param str title: the window title + :param str default: filepath with wildcards + :param object filetypes: filemasks that a user can choose, e.g. "\*.txt" + :param bool multiple: If true, more than one file can be selected + :return: the name of a file, or None if user chose to cancel + """ + localRoot = Tk() + localRoot.withdraw() + + initialbase, initialfile, initialdir, filetypes = fileboxSetup(default, filetypes) + + #------------------------------------------------------------ + # if initialfile contains no wildcards; we don't want an + # initial file. It won't be used anyway. + # Also: if initialbase is simply "*", we don't want an + # initialfile; it is not doing any useful work. + #------------------------------------------------------------ + if (initialfile.find("*") < 0) and (initialfile.find("?") < 0): + initialfile = None + elif initialbase == "*": + initialfile = None + + func = tk_FileDialog.askopenfilenames if multiple else tk_FileDialog.askopenfilename + ret_val = func(parent=localRoot + , title=getFileDialogTitle(msg, title) + , initialdir=initialdir + , initialfile=initialfile + , filetypes=filetypes + ) + + if multiple: + f = [os.path.normpath(x) for x in localRoot.tk.splitlist(ret_val)] + else: + f = os.path.normpath(ret_val) + + localRoot.destroy() + + if not f: + return None + return f + + +#------------------------------------------------------------------- +# filesavebox +#------------------------------------------------------------------- +def filesavebox(msg=None + , title=None + , default="" + , filetypes=None): + """ + A file to get the name of a file to save. + Returns the name of a file, or None if user chose to cancel. + + The "default" argument should contain a filename (i.e. the + current name of the file to be saved). It may also be empty, + or contain a filemask that includes wildcards. + + The "filetypes" argument works like the "filetypes" argument to + fileopenbox. + + :param str msg: the msg to be displayed. + :param str title: the window title + :param str default: default filename to return + :param object filetypes: filemasks that a user can choose, e.g. " \*.txt" + :return: the name of a file, or None if user chose to cancel + """ + + localRoot = Tk() + localRoot.withdraw() + + initialbase, initialfile, initialdir, filetypes = fileboxSetup(default, filetypes) + + f = tk_FileDialog.asksaveasfilename(parent=localRoot + , title=getFileDialogTitle(msg, title) + , initialfile=initialfile + , initialdir=initialdir + , filetypes=filetypes + ) + localRoot.destroy() + if not f: + return None + return os.path.normpath(f) + + +#------------------------------------------------------------------- +# +# fileboxSetup +# +#------------------------------------------------------------------- +def fileboxSetup(default + , filetypes): + if not default: + default = os.path.join(".", "*") + initialdir, initialfile = os.path.split(default) + if not initialdir: + initialdir = "." + if not initialfile: + initialfile = "*" + initialbase, initialext = os.path.splitext(initialfile) + initialFileTypeObject = FileTypeObject(initialfile) + + allFileTypeObject = FileTypeObject("*") + ALL_filetypes_was_specified = False + + if not filetypes: + filetypes = list() + filetypeObjects = list() + + for filemask in filetypes: + fto = FileTypeObject(filemask) + + if fto.isAll(): + ALL_filetypes_was_specified = True # remember this + + if fto == initialFileTypeObject: + initialFileTypeObject.add(fto) # add fto to initialFileTypeObject + else: + filetypeObjects.append(fto) + + #------------------------------------------------------------------ + # make sure that the list of filetypes includes the ALL FILES type. + #------------------------------------------------------------------ + if ALL_filetypes_was_specified: + pass + elif allFileTypeObject == initialFileTypeObject: + pass + else: + filetypeObjects.insert(0, allFileTypeObject) + #------------------------------------------------------------------ + # Make sure that the list includes the initialFileTypeObject + # in the position in the list that will make it the default. + # This changed between Python version 2.5 and 2.6 + #------------------------------------------------------------------ + if len(filetypeObjects) == 0: + filetypeObjects.append(initialFileTypeObject) + + if initialFileTypeObject in (filetypeObjects[0], filetypeObjects[-1]): + pass + else: + if runningPython26: + filetypeObjects.append(initialFileTypeObject) + else: + filetypeObjects.insert(0, initialFileTypeObject) + + filetypes = [fto.toTuple() for fto in filetypeObjects] + + return initialbase, initialfile, initialdir, filetypes + + + +#------------------------------------------------------------------- +# utility routines +#------------------------------------------------------------------- +# These routines are used by several other functions in the EasyGui module. + +def uniquify_list_of_strings(input_list): + """ + Ensure that every string within input_list is unique. + :param list input_list: List of strings + :return: New list with unique names as needed. + """ + output_list = list() + for i, item in enumerate(input_list): + tempList = input_list[:i] + input_list[i+1:] + if item not in tempList: + output_list.append(item) + else: + output_list.append('{0}_{1}'.format(item, i)) + return output_list + +import re + +def parse_hotkey(text): + """ + Extract a desired hotkey from the text. The format to enclose the hotkey in square braces + as in Button_[1] which would assign the keyboard key 1 to that button. The one will be included in the + button text. To hide they key, use double square braces as in: Ex[[qq]]it , which would assign + the q key to the Exit button. Special keys such as may also be used: Move [] for a full + list of special keys, see this reference: http://infohost.nmt.edu/tcc/help/pubs/tkinter/web/key-names.html + :param text: + :return: list containing cleaned text, hotkey, and hotkey position within cleaned text. + """ + + ret_val = [text, None, None] #Default return values + if text is None: + return ret_val + + # Single character, remain visible + res = re.search('(?<=\[).(?=\])', text) + if res: + start = res.start(0) + end = res.end(0) + caption = text[:start-1]+text[start:end]+text[end+1:] + ret_val = [caption, text[start:end], start-1] + + # Single character, hide it + res = re.search('(?<=\[\[).(?=\]\])', text) + if res: + start = res.start(0) + end = res.end(0) + caption = text[:start-2]+text[end+2:] + ret_val = [caption, text[start:end], None] + + # a Keysym. Always hide it + res = re.search('(?<=\[\<).+(?=\>\])', text) + if res: + start = res.start(0) + end = res.end(0) + caption = text[:start-2]+text[end+2:] + ret_val = [caption, '<{}>'.format(text[start:end]), None] + + return ret_val + + + +def __buttonEvent(event=None, buttons=None, virtual_event=None): + """ + Handle an event that is generated by a person interacting with a button. It may be a button press + or a key press. + """ + # TODO: Replace globals with tkinter variables + global boxRoot, __replyButtonText, rootWindowPosition + + # Determine window location and save to global + m = re.match("(\d+)x(\d+)([-+]\d+)([-+]\d+)", boxRoot.geometry()) + if not m: + raise ValueError("failed to parse geometry string: {}".format(boxRoot.geometry())) + width, height, xoffset, yoffset = [int(s) for s in m.groups()] + rootWindowPosition = '{0:+g}{1:+g}'.format(xoffset, yoffset) + + # print('{0}:{1}:{2}'.format(event, buttons, virtual_event)) + if virtual_event == 'cancel': + for button_name, button in buttons.items(): + if 'cancel_choice' in button: + __replyButtonText = button['original_text'] + __replyButtonText = None + boxRoot.quit() + return + + if virtual_event == 'select': + text = event.widget.config('text')[-1] + if not isinstance(text, basestring): + text = ' '.join(text) + for button_name, button in buttons.items(): + if button['clean_text'] == text: + __replyButtonText = button['original_text'] + boxRoot.quit() + return + + # Hotkeys + if buttons: + for button_name, button in buttons.items(): + hotkey_pressed = event.keysym + if event.keysym != event.char: # A special character + hotkey_pressed = '<{}>'.format(event.keysym) + if button['hotkey'] == hotkey_pressed: + __replyButtonText = button_name + boxRoot.quit() + return + + print("Event not understood") + + +def __put_buttons_in_buttonframe(choices, default_choice, cancel_choice): + """Put the buttons in the buttons frame + """ + global buttonsFrame, cancel_invoke + + #TODO: I'm using a dict to hold buttons, but this could all be cleaned up if I subclass Button to hold + # all the event bindings, etc + #TODO: Break __buttonEvent out into three: regular keyboard, default select, and cancel select. + unique_choices = uniquify_list_of_strings(choices) + # Create buttons dictionary and Tkinter widgets + buttons = dict() + for button_text, unique_button_text in zip(choices, unique_choices): + this_button = dict() + this_button['original_text'] = button_text + this_button['clean_text'], this_button['hotkey'], hotkey_position = parse_hotkey(button_text) + this_button['widget'] = Button(buttonsFrame, + takefocus=1, + text=this_button['clean_text'], + underline=hotkey_position) + this_button['widget'].pack(expand=YES, side=LEFT, padx='1m', pady='1m', ipadx='2m', ipady='1m') + buttons[unique_button_text] = this_button + # Bind arrows, Enter, Escape + for this_button in buttons.values(): + bindArrows(this_button['widget']) + for selectionEvent in STANDARD_SELECTION_EVENTS: + this_button['widget'].bind("<{}>".format(selectionEvent), + lambda e: __buttonEvent(e, buttons, virtual_event='select'), + add=True) + + # Assign default and cancel buttons + if cancel_choice in buttons: + buttons[cancel_choice]['cancel_choice'] = True + boxRoot.bind_all('', lambda e: __buttonEvent(e, buttons, virtual_event='cancel'), add=True) + boxRoot.protocol('WM_DELETE_WINDOW', lambda: __buttonEvent(None, buttons, virtual_event='cancel')) + if default_choice in buttons: + buttons[default_choice]['default_choice'] = True + buttons[default_choice]['widget'].focus_force() + # Bind hotkeys + for hk in [button['hotkey'] for button in buttons.values() if button['hotkey']]: + boxRoot.bind_all(hk, lambda e: __buttonEvent(e, buttons), add=True) + + return + + +#----------------------------------------------------------------------- +# +# class EgStore +# +#----------------------------------------------------------------------- +class EgStore: + r""" +A class to support persistent storage. + +You can use EgStore to support the storage and retrieval +of user settings for an EasyGui application. + +**Example A: define a class named Settings as a subclass of EgStore** +:: + + class Settings(EgStore): + def __init__(self, filename): # filename is required + #------------------------------------------------- + # Specify default/initial values for variables that + # this particular application wants to remember. + #------------------------------------------------- + self.userId = "" + self.targetServer = "" + + #------------------------------------------------- + # For subclasses of EgStore, these must be + # the last two statements in __init__ + #------------------------------------------------- + self.filename = filename # this is required + self.restore() # restore values from the storage file if possible + +**Example B: create settings, a persistent Settings object** +:: + + settingsFile = "myApp_settings.txt" + settings = Settings(settingsFile) + + user = "obama_barak" + server = "whitehouse1" + settings.userId = user + settings.targetServer = server + settings.store() # persist the settings + + # run code that gets a new value for userId, and persist the settings + user = "biden_joe" + settings.userId = user + settings.store() + +**Example C: recover the Settings instance, change an attribute, and store it again.** +:: + + settings = Settings(settingsFile) + settings.userId = "vanrossum_g" + settings.store() + +""" + + def __init__(self, filename): # obtaining filename is required + self.filename = None + raise NotImplementedError() + + def restore(self): + """ + Set the values of whatever attributes are recoverable + from the pickle file. + + Populate the attributes (the __dict__) of the EgStore object + from the attributes (the __dict__) of the pickled object. + + If the pickled object has attributes that have been initialized + in the EgStore object, then those attributes of the EgStore object + will be replaced by the values of the corresponding attributes + in the pickled object. + + If the pickled object is missing some attributes that have + been initialized in the EgStore object, then those attributes + of the EgStore object will retain the values that they were + initialized with. + + If the pickled object has some attributes that were not + initialized in the EgStore object, then those attributes + will be ignored. + + IN SUMMARY: + + After the recover() operation, the EgStore object will have all, + and only, the attributes that it had when it was initialized. + + Where possible, those attributes will have values recovered + from the pickled object. + """ + if not os.path.exists(self.filename): return self + if not os.path.isfile(self.filename): return self + + try: + with open(self.filename, "rb") as f: + unpickledObject = pickle.load(f) + + for key in list(self.__dict__.keys()): + default = self.__dict__[key] + self.__dict__[key] = unpickledObject.__dict__.get(key, default) + except: + pass + + return self + + def store(self): + """ + Save the attributes of the EgStore object to a pickle file. + Note that if the directory for the pickle file does not already exist, + the store operation will fail. + """ + with open(self.filename, "wb") as f: + pickle.dump(self, f) + + def kill(self): + """ + Delete my persistent file (i.e. pickle file), if it exists. + """ + if os.path.isfile(self.filename): + os.remove(self.filename) + return + + def __str__(self): + """ + return my contents as a string in an easy-to-read format. + """ + # find the length of the longest attribute name + longest_key_length = 0 + keys = list() + for key in self.__dict__.keys(): + keys.append(key) + longest_key_length = max(longest_key_length, len(key)) + + keys.sort() # sort the attribute names + lines = list() + for key in keys: + value = self.__dict__[key] + key = key.ljust(longest_key_length) + lines.append("%s : %s\n" % (key, repr(value))) + return "".join(lines) # return a string showing the attributes + + +#----------------------------------------------------------------------- +# +# test/demo easygui +# +#----------------------------------------------------------------------- + +package_dir = os.path.dirname(os.path.realpath(__file__)) +def egdemo(): + """ + Run the EasyGui demo. + """ + # clear the console + writeln("\n" * 100) + + msg = list() + msg.append("Pick the kind of box that you wish to demo.") + msg.append(" * Python version {}".format(sys.version)) + msg.append(" * EasyGui version {}".format(eg_version)) + msg.append(" * Tk version {}".format(TkVersion)) + intro_message = "\n".join(msg) + + while 1: # do forever + choices = [ + "msgbox", + "buttonbox", + "buttonbox(image) -- a buttonbox that displays an image", + "choicebox", + "multchoicebox", + "textbox", + "ynbox", + "ccbox", + "enterbox", + "enterbox(image) -- an enterbox that displays an image", + "exceptionbox", + "codebox", + "integerbox", + "boolbox", + "indexbox", + "filesavebox", + "fileopenbox", + "passwordbox", + "multenterbox", + "multpasswordbox", + "diropenbox", + "About EasyGui", + " Help" + ] + choice = choicebox(msg=intro_message + , title="EasyGui " + eg_version + , choices=choices) + + if not choice: + return + + reply = choice.split() + + if reply[0] == "msgbox": + reply = msgbox("short msg", "This is a long title") + writeln("Reply was: {!r}".format(reply)) + + elif reply[0] == "About": + reply = abouteasygui() + + elif reply[0] == "Help": + _demo_help() + + elif reply[0] == "buttonbox": + reply = buttonbox(choices=['one', 'two', 'two', 'three'], default_choice='two') + writeln("Reply was: {!r}".format(reply)) + + title = "Demo of Buttonbox with many, many buttons!" + msg = "This buttonbox shows what happens when you specify too many buttons." + reply = buttonbox(msg=msg, title=title, choices=choices, cancel_choice='msgbox') + writeln("Reply was: {!r}".format(reply)) + + elif reply[0] == "buttonbox(image)": + _demo_buttonbox_with_image() + + elif reply[0] == "boolbox": + reply = boolbox() + writeln("Reply was: {!r}".format(reply)) + + elif reply[0] == "enterbox": + image = os.path.join(package_dir, "python_and_check_logo.gif") + message = "Enter the name of your best friend." \ + "\n(Result will be stripped.)" + reply = enterbox(message, "Love!", " Suzy Smith ") + writeln("Reply was: {!r}".format(reply)) + + message = "Enter the name of your best friend." \ + "\n(Result will NOT be stripped.)" + reply = enterbox(message, "Love!", " Suzy Smith ", strip=False) + writeln("Reply was: {!r}".format(reply)) + + reply = enterbox("Enter the name of your worst enemy:", "Hate!") + writeln("Reply was: {!r}".format(reply)) + + elif reply[0] == "enterbox(image)": + image = os.path.join(package_dir, "python_and_check_logo.gif") + message = "What kind of snake is this?" + reply = enterbox(message, "Quiz", image=image) + writeln("Reply was: {!r}".format(reply)) + + elif reply[0] == "exceptionbox": + try: + thisWillCauseADivideByZeroException = 1 / 0 + except: + exceptionbox() + + elif reply[0] == "integerbox": + reply = integerbox( + "Enter a number between 3 and 333" + ,"Demo: integerbox WITH a default value" + ,222, 3, 333) + writeln("Reply was: {!r}".format(reply)) + + reply = integerbox( + "Enter a number between 0 and 99" + ,"Demo: integerbox WITHOUT a default value" + ) + writeln("Reply was: {!r}".format(reply)) + + elif reply[0] == "diropenbox": + _demo_diropenbox() + elif reply[0] == "fileopenbox": + _demo_fileopenbox() + elif reply[0] == "filesavebox": + _demo_filesavebox() + + elif reply[0] == "indexbox": + title = reply[0] + msg = "Demo of " + reply[0] + choices = ["Choice1", "Choice2", "Choice3", "Choice4"] + reply = indexbox(msg, title, choices) + writeln("Reply was: {!r}".format(reply)) + + elif reply[0] == "passwordbox": + reply = passwordbox("Demo of password box WITHOUT default" + + "\n\nEnter your secret password", "Member Logon") + writeln("Reply was: {!s}".format(reply)) + + reply = passwordbox("Demo of password box WITH default" + + "\n\nEnter your secret password", "Member Logon", "alfie") + writeln("Reply was: {!s}".format(reply)) + + elif reply[0] == "multenterbox": + msg = "Enter your personal information" + title = "Credit Card Application" + fieldNames = ["Name", "Street Address", "City", "State", "ZipCode"] + fieldValues = list() # we start with blanks for the values + fieldValues = multenterbox(msg, title, fieldNames) + + # make sure that none of the fields was left blank + while 1: + if fieldValues is None: + break + errs = list() + for n, v in zip(fieldNames, fieldValues): + if v.strip() == "": + errs.append('"{}" is a required field.'.format(n)) + if not len(errs): + break # no problems found + fieldValues = multenterbox("\n".join(errs), title, fieldNames, fieldValues) + + writeln("Reply was: {}".format(fieldValues)) + + elif reply[0] == "multpasswordbox": + msg = "Enter logon information" + title = "Demo of multpasswordbox" + fieldNames = ["Server ID", "User ID", "Password"] + fieldValues = list() # we start with blanks for the values + fieldValues = multpasswordbox(msg, title, fieldNames) + + # make sure that none of the fields was left blank + while 1: + if fieldValues is None: + break + errs = list() + for n, v in zip(fieldNames, fieldValues): + if v.strip() == "": + errs.append('"{}" is a required field.\n\n'.format(n)) + if not len(errs): + break # no problems found + fieldValues = multpasswordbox("".join(errs), title, fieldNames, fieldValues) + + writeln("Reply was: {!s}".format(fieldValues)) + + elif reply[0] == "ynbox": + title = "Demo of ynbox" + msg = "Were you expecting the Spanish Inquisition?" + reply = ynbox(msg, title) + writeln("Reply was: {!r}".format(reply)) + if reply: + msgbox("NOBODY expects the Spanish Inquisition!", "Wrong!") + + elif reply[0] == "ccbox": + msg = "Insert your favorite message here" + title = "Demo of ccbox" + reply = ccbox(msg, title) + writeln("Reply was: {!r}".format(reply)) + + elif reply[0] == "choicebox": + title = "Demo of choicebox" + longchoice = "This is an example of a very long option which you may or may not wish to choose." * 2 + listChoices = ["nnn", "ddd", "eee", "fff", "aaa", longchoice + , "aaa", "bbb", "ccc", "ggg", "hhh", "iii", "jjj", "kkk", "LLL", "mmm", "nnn", "ooo", "ppp", "qqq", + "rrr", "sss", "ttt", "uuu", "vvv"] + + msg = "Pick something. " + ("A wrapable sentence of text ?! " * 30) + "\nA separate line of text." * 6 + reply = choicebox(msg=msg, choices=listChoices) + writeln("Reply was: {!r}".format(reply)) + + msg = "Pick something. " + reply = choicebox(msg=msg, title=title, choices=listChoices) + writeln("Reply was: {!r}".format(reply)) + + msg = "Pick something. " + reply = choicebox(msg="The list of choices is empty!", choices=list()) + writeln("Reply was: {!r}".format(reply)) + + elif reply[0] == "multchoicebox": + listChoices = ["aaa", "bbb", "ccc", "ggg", "hhh", "iii", "jjj", "kkk" + , "LLL", "mmm", "nnn", "ooo", "ppp", "qqq" + , "rrr", "sss", "ttt", "uuu", "vvv"] + + msg = "Pick as many choices as you wish." + reply = multchoicebox(msg, "Demo of multchoicebox", listChoices) + writeln("Reply was: {!r}".format(reply)) + + elif reply[0] == "textbox": + _demo_textbox(reply[0]) + elif reply[0] == "codebox": + _demo_codebox(reply[0]) + + else: + msgbox("Choice\n\n{}\n\nis not recognized".format(choice), "Program Logic Error") + return + + +def _demo_textbox(reply): + text_snippet = (( \ + """It was the best of times, and it was the worst of times. The rich ate cake, and the poor had cake recommended to them, but wished only for enough cash to buy bread. The time was ripe for revolution! """ \ + * 5) + "\n\n") * 10 + title = "Demo of textbox" + msg = "Here is some sample text. " * 16 + reply = textbox(msg, title, text_snippet) + writeln("Reply was: {!s}".format(reply)) + + +def _demo_codebox(reply): + #TODO RL: Turn this sample code into the code in this module, just for fun + code_snippet = ("dafsdfa dasflkj pp[oadsij asdfp;ij asdfpjkop asdfpok asdfpok asdfpok" * 3) + "\n" + \ + """# here is some dummy Python code +for someItem in myListOfStuff: + do something(someItem) + do something() + do something() + if somethingElse(someItem): + doSomethingEvenMoreInteresting() + +""" * 16 + msg = "Here is some sample code. " * 16 + reply = codebox(msg, "Code Sample", code_snippet) + writeln("Reply was: {!r}".format(reply)) + + +def _demo_buttonbox_with_image(): + msg = "Do you like this picture?\nIt is " + choices = ["Yes", "No", "No opinion"] + + for image in [ + os.path.join(package_dir, "python_and_check_logo.gif") + , os.path.join(package_dir, "python_and_check_logo.jpg") + , os.path.join(package_dir, "python_and_check_logo.png") + , os.path.join(package_dir, "zzzzz.gif")]: + reply = buttonbox(msg + image, image=image, choices=choices) + writeln("Reply was: {!r}".format(reply)) + + +def _demo_help(): + savedStdout = sys.stdout # save the sys.stdout file object + sys.stdout = capturedOutput = StringIO() + print(globals()['__doc__']) #help("easygui") + sys.stdout = savedStdout # restore the sys.stdout file object + codebox("EasyGui Help", text=capturedOutput.getvalue()) + + +def _demo_filesavebox(): + filename = "myNewFile.txt" + title = "File SaveAs" + msg = "Save file as:" + + f = filesavebox(msg, title, default=filename) + writeln("You chose to save file: {}".format(f)) + + +def _demo_diropenbox(): + title = "Demo of diropenbox" + msg = "Pick the directory that you wish to open." + d = diropenbox(msg, title) + writeln("You chose directory...: {}".format(d)) + + d = diropenbox(msg, title, default="./") + writeln("You chose directory...: {}".format(d)) + + d = diropenbox(msg, title, default="c:/") + writeln("You chose directory...: {}".format(d)) + + +def _demo_fileopenbox(): + msg = "Python files" + title = "Open files" + default = "*.py" + f = fileopenbox(msg, title, default=default) + writeln("You chose to open file: {}".format(f)) + + default = "./*.gif" + msg = "Some other file types (Multi-select)" + filetypes = ["*.jpg", ["*.zip", "*.tgs", "*.gz", "Archive files"], ["*.htm", "*.html", "HTML files"]] + f = fileopenbox(msg, title, default=default, filetypes=filetypes, multiple=True) + writeln("You chose to open file: %s" % f) + + +EASYGUI_ABOUT_INFORMATION = ''' +======================================================================== +0.97(2014-12-20) +======================================================================== +We are happy to release version 0.97 of easygui. The intent of this release is to address some basic +functionality issues as well as improve easygui in the ways people have asked. + +Robert Lugg (me) was searching for a GUI library for my python work. I saw easygui and liked very much its +paradigm. Stephen Ferg, the creator and developer of easygui, graciously allowed me to start development +back up. With the help of Alexander Zawadzki, Horst Jens, and others I set a goal to release before the +end of 2014. + +We rely on user feedback so please bring up problems, ideas, or just say how you are using easygui. + +BUG FIXES +--------- + * sourceforge #4: easygui docs contain bad references to easygui_pydoc.html + * sourceforge #6: no index.html in docs download file. Updated to sphinx which as autolinking. + * sourceforge #8: unicode issues with file*box. Fixed all that I knew how. + * sourceforge #12: Cannot Exit with 'X'. Now X and escape either return "cancel_button", if set, or None + +ENHANCEMENTS +------------ + * Added ability to specify default_choice and cancel_choice for button widgets (See API docs) + * True and False are returned instead of 1 and 0 for several boxes + * Allow user to map keyboard keys to buttons by enclosing a hotkey in square braces like: "Pick [M]e", which would assign + keyboard key M to that button. Double braces hide that character, and keysyms are allowed: + [[q]]Exit Would show Exit on the button, and the button would be controlled by the q key + []Help Would show Help on the button, and the button would be controlled by the F1 function key + NOTE: We are still working on the exact syntax of these key mappings as Enter, space, and arrows are already being + used. + * Escape and the windows 'X' button always work in buttonboxes. Those return None in that case. + * sourceforge #9: let fileopenbox open multiple files. Added optional argument 'multiple' + * Location of dialogs on screen is preserved. This isn't perfect yet, but now, at least, the dialogs don't + always reset to their default position! + * added some, but not all of the bugs/enhancements developed by Robbie Brook: + http://all-you-need-is-tech.blogspot.com/2013/01/improving-easygui-for-python.html + +KNOWN ISSUES +------------ + * In the documentation, there were previous references to issues when using the IDLE IDE. I haven't + experienced those, but also didn't do anything to fix them, so they may still be there. Please report + any problems and we'll try to address them + * I am fairly new to contributing to open source, so I don't understand packaging, pypi, etc. There + are likely problems as well as better ways to do things. Again, I appreciate any help or guidance. + +Other Changes (that you likely don't care about) +------------------------------------------------ + * Restructured loading of image files to try PIL first throw error if file doesn't exist. + * Converted docs to sphinx with just a bit of doctest. Most content was retained from the old site, so + there might be some redundancies still. Please make any suggested improvements. + * Set up a GitHub repository for development: https://github.com/robertlugg/easygui + +EasyGui is licensed under what is generally known as +the "modified BSD license" (aka "revised BSD", "new BSD", "3-clause BSD"). +This license is GPL-compatible but less restrictive than GPL. + +======================================================================== +0.96(2010-08-29) +======================================================================== +This version fixes some problems with version independence. + +BUG FIXES +------------------------------------------------------ + * A statement with Python 2.x-style exception-handling syntax raised + a syntax error when running under Python 3.x. + Thanks to David Williams for reporting this problem. + + * Under some circumstances, PIL was unable to display non-gif images + that it should have been able to display. + The cause appears to be non-version-independent import syntax. + PIL modules are now imported with a version-independent syntax. + Thanks to Horst Jens for reporting this problem. + +LICENSE CHANGE +------------------------------------------------------ +Starting with this version, EasyGui is licensed under what is generally known as +the "modified BSD license" (aka "revised BSD", "new BSD", "3-clause BSD"). +This license is GPL-compatible but less restrictive than GPL. +Earlier versions were licensed under the Creative Commons Attribution License 2.0. + + +======================================================================== +0.95(2010-06-12) +======================================================================== + +ENHANCEMENTS +------------------------------------------------------ + * Previous versions of EasyGui could display only .gif image files using the + msgbox "image" argument. This version can now display all image-file formats + supported by PIL the Python Imaging Library) if PIL is installed. + If msgbox is asked to open a non-gif image file, it attempts to import + PIL and to use PIL to convert the image file to a displayable format. + If PIL cannot be imported (probably because PIL is not installed) + EasyGui displays an error message saying that PIL must be installed in order + to display the image file. + + Note that + http://www.pythonware.com/products/pil/ + says that PIL doesn't yet support Python 3.x. + + +======================================================================== +0.94(2010-06-06) +======================================================================== + +ENHANCEMENTS +------------------------------------------------------ + * The codebox and textbox functions now return the contents of the box, rather + than simply the name of the button ("Yes"). This makes it possible to use + codebox and textbox as data-entry widgets. A big "thank you!" to Dominic + Comtois for requesting this feature, patiently explaining his requirement, + and helping to discover the tkinter techniques to implement it. + + NOTE THAT in theory this change breaks backward compatibility. But because + (in previous versions of EasyGui) the value returned by codebox and textbox + was meaningless, no application should have been checking it. So in actual + practice, this change should not break backward compatibility. + + * Added support for SPACEBAR to command buttons. Now, when keyboard + focus is on a command button, a press of the SPACEBAR will act like + a press of the ENTER key; it will activate the command button. + + * Added support for keyboard navigation with the arrow keys (up,down,left,right) + to the fields and buttons in enterbox, multenterbox and multpasswordbox, + and to the buttons in choicebox and all buttonboxes. + + * added highlightthickness=2 to entry fields in multenterbox and + multpasswordbox. Now it is easier to tell which entry field has + keyboard focus. + + +BUG FIXES +------------------------------------------------------ + * In EgStore, the pickle file is now opened with "rb" and "wb" rather than + with "r" and "w". This change is necessary for compatibility with Python 3+. + Thanks to Marshall Mattingly for reporting this problem and providing the fix. + + * In integerbox, the actual argument names did not match the names described + in the docstring. Thanks to Daniel Zingaro of at University of Toronto for + reporting this problem. + + * In integerbox, the "argLowerBound" and "argUpperBound" arguments have been + renamed to "lowerbound" and "upperbound" and the docstring has been corrected. + + NOTE THAT THIS CHANGE TO THE ARGUMENT-NAMES BREAKS BACKWARD COMPATIBILITY. + If argLowerBound or argUpperBound are used, an AssertionError with an + explanatory error message is raised. + + * In choicebox, the signature to choicebox incorrectly showed choicebox as + accepting a "buttons" argument. The signature has been fixed. + + +======================================================================== +0.93(2009-07-07) +======================================================================== + +ENHANCEMENTS +------------------------------------------------------ + + * Added exceptionbox to display stack trace of exceptions + + * modified names of some font-related constants to make it + easier to customize them + + +======================================================================== +0.92(2009-06-22) +======================================================================== + +ENHANCEMENTS +------------------------------------------------------ + + * Added EgStore class to to provide basic easy-to-use persistence. + +BUG FIXES +------------------------------------------------------ + + * Fixed a bug that was preventing Linux users from copying text out of + a textbox and a codebox. This was not a problem for Windows users. + +''' + + +def abouteasygui(): + """ + shows the easygui revision history + """ + codebox("About EasyGui\n{}".format(eg_version), "EasyGui", EASYGUI_ABOUT_INFORMATION) + return None + + +if __name__ == '__main__': + egdemo() + diff --git a/llnet.py b/llnet.py new file mode 100644 index 0000000..a3f3cf9 --- /dev/null +++ b/llnet.py @@ -0,0 +1,652 @@ +import os +import sys +import time +import cPickle +import numpy +import h5py +import scipy.io +import matplotlib.pyplot as plt +import theano +import theano.tensor as T +import PIL.Image +import shutil +import Data_process2 + +from easygui import * + +from theano.tensor.shared_randomstreams import RandomStreams +from logistic_sgd import LogisticRegression, load_data, load_data_overlapped, load_data_overlapped_strides +from mlp import HiddenLayer +from dA import dA +from utils import tile_raster_images +from Data_process2 import reconstruct_from_patches_with_strides_2d +from sklearn.feature_extraction import image as im +from scipy import misc, ndimage +from skimage import color, data, restoration +import nlinalg + +##################################################################################################################### +# # +# Training Code # +# # +##################################################################################################################### + +####################################### +# Hyperparameters / Options # +####################################### + +# Training Dataset +# tr_dataset = 'dataset/20151026_lowlightnoisy_17x17.mat' + +# Hyperparameters +patch_size = (17,17) +prod = patch_size[0]*patch_size[1] +hp_hlsize = [1000,1000,1000,1000,1000] +hp_corruption_levels = [0.1, 0.1, 0.1, 0.1, 0.1] +hp_pretraining_epochs = 3 +hp_batchsize = 10 #llnet1: 50 + +####################################### +# Class Construction # +####################################### + +class SdA(object): + + def __init__( + self, + numpy_rng, + theano_rng=None, + n_ins=prod, + hidden_layers_sizes=[500,500], + n_outs=prod, + corruption_levels=[0.1, 0.1] + ): + + self.sigmoid_layers = [] + self.dA_layers = [] + self.params = [] + self.n_layers = len(hidden_layers_sizes) + + assert self.n_layers > 0 + + if not theano_rng: + theano_rng = RandomStreams(numpy_rng.randint(2 ** 30)) + + self.x = T.matrix('x') + self.y = T.matrix('y') + + for i in xrange(self.n_layers): + if i == 0: + input_size = n_ins + else: + input_size = hidden_layers_sizes[i - 1] + + if i == 0: + layer_input = self.x + else: + layer_input = self.sigmoid_layers[-1].output + + sigmoid_layer = HiddenLayer(rng=numpy_rng, + input=layer_input, + n_in=input_size, + n_out=hidden_layers_sizes[i], + activation=T.nnet.sigmoid) + + self.sigmoid_layers.append(sigmoid_layer) + self.params.extend(sigmoid_layer.params) + + dA_layer = dA(numpy_rng=numpy_rng, + theano_rng=theano_rng, + input=layer_input, + n_visible=input_size, + n_hidden=hidden_layers_sizes[i], + W=sigmoid_layer.W, + bhid=sigmoid_layer.b) + self.dA_layers.append(dA_layer) + + self.logLayer = LogisticRegression( + input=self.sigmoid_layers[-1].output, + n_in=hidden_layers_sizes[-1], + n_out=n_outs + ) + + self.params.extend(self.logLayer.params) + self.finetune_cost = self.logLayer.image_norm(self.y, obj=self) + self.errors = self.logLayer.image_norm(self.y, obj=self) + + def pretraining_functions(self, train_set_x, batch_size): + + index = T.lscalar('index') + corruption_level = T.scalar('corruption') + learning_rate = T.scalar('lr') + batch_begin = index * batch_size + batch_end = batch_begin + batch_size + + pretrain_fns = [] + for dA in self.dA_layers: + cost, updates = dA.get_cost_updates(corruption_level, + learning_rate) + + fn = theano.function( + inputs=[ + index, + theano.Param(corruption_level, default=0.2), + theano.Param(learning_rate, default=0.1) + ], + outputs=cost, + updates=updates, + givens={ + self.x: train_set_x[batch_begin: batch_end] + } + ) + pretrain_fns.append(fn) + + return pretrain_fns + + def build_finetune_functions(self, train_set,valid_set,test_set,batch_size, learning_rate): + + (train_set_x, train_set_y) = train_set + (valid_set_x, valid_set_y) = valid_set + (test_set_x, test_set_y) = test_set + + n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] + n_valid_batches /= batch_size + n_test_batches = test_set_x.get_value(borrow=True).shape[0] + n_test_batches /= batch_size + + index = T.lscalar('index') + + gparams = T.grad(self.finetune_cost, self.params) + + updates = [ + (param, param - gparam * learning_rate) + for param, gparam in zip(self.params, gparams) + ] + + train_fn = theano.function( + inputs=[index], + outputs=self.finetune_cost, + updates=updates, + givens={ + self.x: train_set_x[ + index * batch_size: (index + 1) * batch_size + ], + self.y: train_set_y[ + index * batch_size: (index + 1) * batch_size + ] + }, + name='train' + ) + + test_score_i = theano.function( + [index], + self.errors, + givens={ + self.x: test_set_x[ + index * batch_size: (index + 1) * batch_size + ], + self.y: test_set_y[ + index * batch_size: (index + 1) * batch_size + ] + }, + name='test' + ) + + valid_score_i = theano.function( + [index], + self.errors, + givens={ + self.x: valid_set_x[ + index * batch_size: (index + 1) * batch_size + ], + self.y: valid_set_y[ + index * batch_size: (index + 1) * batch_size + ] + }, + name='valid' + ) + + def valid_score(): + return [valid_score_i(i) for i in xrange(n_valid_batches)] + + def test_score(): + return [test_score_i(i) for i in xrange(n_test_batches)] + + return train_fn, valid_score, test_score + +####################################### +# SDA Training # +####################################### + +def test_SdA(finetune_lr=0.1, pretraining_epochs=hp_pretraining_epochs, + pretrain_lr=0.1, training_epochs=100000, batch_size=hp_batchsize, patch_size = patch_size): + + datasets = load_data(tr_dataset) + train_set = datasets[0] + valid_set = datasets[1] + test_set = datasets[2] + train_set_x, train_set_y = datasets[0] + valid_set_x, valid_set_y = datasets[1] + test_set_x, test_set_y = datasets[2] + datasets = [] + + print '... plotting clean images' + image = PIL.Image.fromarray(tile_raster_images( + X=test_set_y.get_value(), + img_shape=patch_size, tile_shape=(50, 40), + tile_spacing=(0, 0),scale_rows_to_unit_interval=False)) + image.save('outputs/LLnet_clean.png') + + print '... plotting noisy images' + image = PIL.Image.fromarray(tile_raster_images( + X=test_set_x.get_value(), + img_shape=patch_size, tile_shape=(50, 40), + tile_spacing=(0, 0),scale_rows_to_unit_interval=False)) + image.save('outputs/LLnet_noisy.png') + + n_train_samples = train_set_x.get_value(borrow=True).shape[0] + n_train_batches = n_train_samples/batch_size + + numpy_rng = numpy.random.RandomState(89677) + print '... building the model' + + sda = SdA( + numpy_rng=numpy_rng, + n_ins=patch_size[0]*patch_size[1], + hidden_layers_sizes= hp_hlsize, + n_outs=patch_size[0]*patch_size[1] + ) + + print '... compiling functions' + + pretraining_fns = sda.pretraining_functions(train_set_x=train_set_y, + batch_size=batch_size) + + print '... pre-training the model' + start_time = time.clock() + for i in xrange(sda.n_layers): + if i <= sda.n_layers/2: + + if i == (sda.n_layers - 1): + currentlr = pretrain_lr; + else: + currentlr = pretrain_lr*0.1 + + for epoch in xrange(pretraining_epochs): + c = [] + for batch_index in xrange(n_train_batches): + current_c = pretraining_fns[i](index=batch_index, + corruption=hp_corruption_levels[i], + lr=currentlr) + if (batch_index % (n_train_batches/100 + 1) == 0): + print ' ... Layer %i Epoch %i Progress %i/%i, Cost: %.4f, AvgCost: %.4f' %(i, epoch, batch_index, n_train_batches, current_c, numpy.mean(c)) + c.append(current_c) + print 'Pre-trained layer %i, epoch %d, cost ' % (i, epoch), + print numpy.mean(c) + + print ' model checkpoint for current epoch...' + f = file('outputs/model_checkpoint.obj', 'wb') + cPickle.dump(sda,f, protocol=cPickle.HIGHEST_PROTOCOL) + f.close() + + end_time = time.clock() + + print ('... pretrained bottom half of the SdA in %.2fm' % ((end_time - start_time) / 60.)) + + + layer_all = sda.n_layers + 1 #Number of hidden layers + 1 + print layer_all + + for i in xrange(layer_all/2 - 1): + + #Reverse map 2 to 5 + layer = i+2 + layer_applied = layer_all - layer + 1 + print '... applying weights from SdA layer', layer, 'to SdA layer', (layer_applied) + ww, bb, bbp = [sda.dA_layers[layer-1].W.get_value(), sda.dA_layers[layer-1].b.get_value(), sda.dA_layers[layer-1].b_prime.get_value()] + sda.dA_layers[layer_applied-1].W.set_value(ww.T) + sda.dA_layers[layer_applied-1].b.set_value(bbp) + sda.dA_layers[layer_applied-1].b_prime.set_value(bb) + + #Reverse map 1 to loglayer + layer = 1 + print '... applying weights from SdA layer', layer, 'to loglayer layer' + ww, bb, bbp = [sda.dA_layers[layer-1].W.get_value(), sda.dA_layers[layer-1].b.get_value(), sda.dA_layers[layer-1].b_prime.get_value()] + sda.logLayer.W.set_value(ww.T) + sda.logLayer.b.set_value(bbp) + + + '''#Set sigmoid layer weights equal to dA weights + for i in xrange(sda.n_layers): + sda.sigmoid_layers[i].W.set_value(sda.dA_layers[i].W.get_value()) + sda.sigmoid_layers[i].b.set_value(sda.dA_layers[i].b.get_value())''' + + + print '... compiling functions' + train_fn, validate_model, test_model = sda.build_finetune_functions( + train_set = train_set,valid_set = valid_set,test_set=test_set, + batch_size=batch_size, + learning_rate=finetune_lr + ) + + reconstructed = theano.function([], + sda.logLayer.y_pred,givens={ + sda.x: test_set_x},on_unused_input='ignore') + + w1 = theano.function([], + nlinalg.trace(T.dot(sda.sigmoid_layers[0].W.T,sda.sigmoid_layers[0].W)),givens={ + sda.x: test_set_x},on_unused_input='ignore') + + w2 = theano.function([], + nlinalg.trace(T.dot(sda.sigmoid_layers[1].W.T,sda.sigmoid_layers[1].W)),givens={ + sda.x: test_set_x},on_unused_input='ignore') + + w3 = theano.function([], + nlinalg.trace(T.dot(sda.sigmoid_layers[2].W.T,sda.sigmoid_layers[2].W)),givens={ + sda.x: test_set_x},on_unused_input='ignore') + + w4 = theano.function([], + nlinalg.trace(T.dot(sda.sigmoid_layers[3].W.T,sda.sigmoid_layers[3].W)),givens={ + sda.x: test_set_x},on_unused_input='ignore') + + w5 = theano.function([], + nlinalg.trace(T.dot(sda.sigmoid_layers[4].W.T,sda.sigmoid_layers[4].W)),givens={ + sda.x: test_set_x},on_unused_input='ignore') + + wl = theano.function([], + nlinalg.trace(T.dot(sda.logLayer.W.T,sda.logLayer.W)),givens={ + sda.x: test_set_x},on_unused_input='ignore') + + '''print ' loading previous model...' + f = file('outputs/model_bestpsnr.obj', 'rb') + sda = cPickle.load(f) + f.close()''' + + print '... finetuning the model' + patience = 100000 * n_train_batches + patience_increase = 2. + improvement_threshold = 1 + validation_frequency = min(n_train_batches, patience / 2) + best_validation_loss = numpy.inf + test_score = 0. + start_time = time.clock() + done_looping = False + epoch = 0 + + plot_valid_error = [] + ww1 = [] + ww2 = [] + ww3 = [] + ww4 = [] + ww5 = [] + wwl = [] + psnrs = [] + best_psnr = [] + + while (epoch < training_epochs) and (not done_looping): + epoch = epoch + 1 + + if 1 == 0: ########################################################################## Switch for on-the-fly training data generation + + if epoch % 50 == 0: + print '... calling matlab function!' + call(["/usr/local/MATLAB/R2015a/bin/matlab","-nodesktop","-r",'end2end_datagen_256; exit']) + print '... data regeneration complete, loading new data' + datasets = load_data('dataset/llnet_17x17_OTF.mat') + train_set = datasets[0] + valid_set = datasets[1] + test_set = datasets[2] + train_set_x, train_set_y = datasets[0] + valid_set_x, valid_set_y = datasets[1] + test_set_x, test_set_y = datasets[2] + datasets = [] + + reconstructed = theano.function([], + sda.logLayer.y_pred,givens={ + sda.x: test_set_x},on_unused_input='warn') + + print '... plotting clean images' + image = PIL.Image.fromarray(tile_raster_images( + X=test_set_y.get_value(), + img_shape=patch_size, tile_shape=(50, 40), + tile_spacing=(0, 0),scale_rows_to_unit_interval=False)) + image.save('outputs/LLnet_clean.png') + + print '... plotting noisy images' + image = PIL.Image.fromarray(tile_raster_images( + X=test_set_x.get_value(), + img_shape=patch_size, tile_shape=(50, 40), + tile_spacing=(0, 0),scale_rows_to_unit_interval=False)) + image.save('outputs/LLnet_noisy.png') + + if 1 == 1: ########################################################################## Switch for training rate schedule change + + if epoch % 200 == 0: + tempval = finetune_lr * 0.1 + print '... switching learning rate to %.4f, recompiling function'%(tempval) + train_fn, validate_model, test_model = sda.build_finetune_functions( + train_set = train_set,valid_set = valid_set,test_set=test_set, + batch_size=batch_size, + learning_rate=tempval + ) + + for minibatch_index in xrange(n_train_batches): + minibatch_avg_cost = train_fn(minibatch_index) + if (minibatch_index % (n_train_batches/100 + 1) == 0): + print ' ... FT E%i, %i/%i/%i, aCost: %.4f' %(epoch, minibatch_index, n_train_batches, hp_batchsize, minibatch_avg_cost) + iter = (epoch - 1) * n_train_batches + minibatch_index + + if (iter + 1) % validation_frequency == 0: + validation_losses = validate_model() + this_validation_loss = numpy.mean(validation_losses) + print('epoch %i, minibatch %i/%i, validation loss %f (best: %f)' % + (epoch, minibatch_index + 1, n_train_batches, + this_validation_loss, best_validation_loss)) + + plot_valid_error.append(this_validation_loss) + + # Training monitoring tools ----------------------------------------- + + ww1.append(w1()) + ww2.append(w2()) + ww3.append(w3()) + ww4.append(w4()) + ww5.append(w5()) + wwl.append(wl()) + + psnr = 10*numpy.log10(255**2 / numpy.mean(numpy.sqrt(numpy.sum(((test_set_y.get_value() - reconstructed())*255)**2,axis=1,keepdims=True)))) + psnrs.append( psnr ) + + if psnr >= numpy.max(psnrs): + print ' saving trained model based on highest psnr...' + f = file('outputs/model_bestpsnr.obj', 'wb') + cPickle.dump(sda,f, protocol=cPickle.HIGHEST_PROTOCOL) + f.close() + print ' plotting reconstructed images based on highest psnr...' + image = PIL.Image.fromarray(tile_raster_images( + X=reconstructed(), + img_shape=patch_size, tile_shape=(50,40), + tile_spacing=(0, 0),scale_rows_to_unit_interval=False)) + image.save('outputs/LLnet_reconstructed_bestpsnr.png') + + plt.clf() + plt.suptitle('Epoch %d'%(epoch)) + plt.subplot(121); plt.plot(plot_valid_error,'-xb'); plt.title('Validation Error, best %.4f'%(numpy.min(plot_valid_error))) + plt.subplot(122); plt.plot(psnrs,'-xb'); plt.title('PSNR, best %.4f dB'%(numpy.max(psnrs))); + if len(psnrs)>2: + plt.xlabel('Rate: %.4f dB/step'%(psnrs[-1] - psnrs[-2])) + plt.savefig('outputs/validation_error.png') + + + plt.clf() + plt.suptitle('Weight Norms, epoch %d'%(epoch)) + plt.subplot(231); plt.plot(ww1,'-xr'); plt.axis('tight'); plt.title('Layer1') + plt.subplot(232); plt.plot(ww2,'-xc'); plt.axis('tight'); plt.title('Layer2') + plt.subplot(233); plt.plot(ww3,'-xy'); plt.axis('tight'); plt.title('Layer3') + plt.subplot(234); plt.plot(ww4,'-xg'); plt.axis('tight'); plt.title('Layer4') + plt.subplot(235); plt.plot(ww5,'-xb'); plt.axis('tight'); plt.title('Layer5') + plt.subplot(236); plt.plot(wwl,'-xm'); plt.axis('tight'); plt.title('Sigmoid Layer') + plt.savefig('outputs/weightnorms.png') + + # Training monitoring tools ----------------------------------------- + + if this_validation_loss < best_validation_loss: + if ( + this_validation_loss < best_validation_loss * + improvement_threshold + ): + patience = max(patience, iter * patience_increase) + + best_validation_loss = this_validation_loss + best_iter = iter + test_losses = test_model() + test_score = numpy.mean(test_losses) + print((' epoch %i, minibatch %i/%i, test loss of ' + 'best model %f') % + (epoch, minibatch_index + 1, n_train_batches, + test_score)) + + print ' saving trained model based on lowest validation error...' + f = file('outputs/model.obj', 'wb') + cPickle.dump(sda,f, protocol=cPickle.HIGHEST_PROTOCOL) + f.close() + + print ' plotting reconstructed images...' + image = PIL.Image.fromarray(tile_raster_images( + X=reconstructed(), + img_shape=patch_size, tile_shape=(50,40), + tile_spacing=(0, 0),scale_rows_to_unit_interval=False)) + image.save('outputs/LLnet_reconstructed.png') + + print ' plotting complete. Training next epoch...' + + if patience <= iter: + done_looping = True + break + + end_time = time.clock() + print( + ( + 'Optimization complete with best validation loss of %f, ' + 'on iteration %i, ' + 'with test performance %f' + ) + % (best_validation_loss, best_iter + 1, test_score) + ) + print >> sys.stderr, ('The training code for file ' + + os.path.split(__file__)[1] + + ' ran for %.2fm' % ((end_time - start_time) / 60.)) + +##################################################################################################################### +# # +# Inference Code # +# # +##################################################################################################################### + +###################################################### +# Overlapping Patches Denoising (With Strides) # (Default) +###################################################### + +def denoise_overlapped_strides(strides=(3,3)): #1 2 4 11 + + #print '=== OVERLAPPING PATCHES',strides,'STRIDES ===============================' + + testdata = misc.imread(te_noisy_image,flatten=True) + fname=te_noisy_image.rsplit('/',1)[-1][:-4] + #scipy.misc.imsave('outputs/LLnet_inference_'+fname+'_test.png',testdata) + shutil.copyfile(te_noisy_image, 'outputs/ori_'+fname+'.png') + + test_set_x, te_h,te_w = load_data_overlapped_strides(te_dataset = te_noisy_image, patch_size = patch_size, strides=strides) + im_ = test_set_x.get_value() + im_noisy = im_.reshape((im_).shape[0], *patch_size) + rec_n = im.reconstruct_from_patches_2d(im_noisy, (te_h,te_w)) + + reconstructed = theano.function([], + sda.logLayer.y_pred,givens={ + sda.x: test_set_x},on_unused_input='warn') + result = reconstructed() + + im_recon = result.reshape((result).shape[0], *patch_size) + rec_r = reconstruct_from_patches_with_strides_2d(im_recon, (te_h,te_w), strides=strides) + + scipy.misc.imsave('outputs/LLnet_inference_'+fname+'_out.png',rec_r) + +# print sda.sigmoid_layers[0].W.get_value().shape +# print sda.sigmoid_layers[1].W.get_value().shape +# print sda.sigmoid_layers[2].W.get_value().shape +# print sda.sigmoid_layers[3].W.get_value().shape +# print sda.sigmoid_layers[4].W.get_value().shape +# print sda.sigmoid_layers[5].W.get_value().shape +# print sda.sigmoid_layers[6].W.get_value().shape + + filters = sda.sigmoid_layers[0].W.get_value() + print filters.shape + image = PIL.Image.fromarray(tile_raster_images( + X=filters.T, + img_shape=(17, 17), tile_shape=(4, 20), + tile_spacing=(1, 1),scale_rows_to_unit_interval=True)) + image.save('outputs/LLnet_filters.png') + +##################################################################################################################### +# # +# Terminal Commands # +# # +##################################################################################################################### + +if __name__ == '__main__': + + print(chr(27) + "[2J") + + # Command line interface -------------------- + if len(sys.argv) > 1: + if len(sys.argv[1])>0: + if sys.argv[1]=='train': + tr_dataset = str(sys.argv[2]) + test_SdA() + exit() + if sys.argv[1]=='test': + print '... Runnning algorithm!' + te_noisy_image = str(sys.argv[2]) + model_to_load = str(sys.argv[3]) + f = file(model_to_load, 'rb') + sda = cPickle.load(f) + f.close() + denoise_overlapped_strides(); + print 'Completed:', te_noisy_image + exit() + # ------------------------------------------- + + msg = "You are currently running the image enhancement program, LLNet, developed by Akintayo, Lore, and Sarkar. What would you like to do?" + choices = ["Train Model","Enhance Single/Multiple Images","Exit Program"] + reply = buttonbox(msg, title="Welcome to LLNet!", choices=choices) + if reply == "Exit Program": + exit() + + if reply == "Train Model": + + if ccbox('You are currently training a new model. The model file might be overwritten. Continue?','Information')==True: + tr_dataset = fileopenbox(title='Select training data.',default='*',filetypes=["*.mat"]) + test_SdA() + else: + msgbox("Program terminated. Goodbye!") + exit() + + if reply == "Enhance Single/Multiple Images": + + # Present model to load + model_to_load = fileopenbox(title='Select your model to load.',default='*',filetypes=["*.obj"]) + f = file(model_to_load, 'rb') + sda = cPickle.load(f) + f.close() + + # Load the test image + te_noisy_image_list = fileopenbox(title='Select an image to enhance. Multiple images are allowed; hold SHIFT and click to select.',default='*',filetypes=["*.png", ["*.jpg", "*.jpeg", "JPEG Files"] , '*.bmp' , '*.gif' ],multiple=True) + print te_noisy_image_list + print '... Runnning algorithm!' + + for f in te_noisy_image_list: + + te_noisy_image = f + denoise_overlapped_strides(); + print 'Completed:', f + diff --git a/logistic_sgd.py b/logistic_sgd.py new file mode 100644 index 0000000..91c5b11 --- /dev/null +++ b/logistic_sgd.py @@ -0,0 +1,142 @@ +import cPickle +import gzip +import os +import sys +import time +import h5py +import numpy +import theano +import theano.tensor as T +import nlinalg + +from Data_process2 import overlapping_patches, overlapping_patches_strides + +################################################# +# Logistic Regression Class # +################################################# + +class LogisticRegression(object): + + def __init__(self, input, n_in, n_out,W=None, b=None): + + if not W : + self.W = theano.shared( + value=numpy.zeros( + (n_in, n_out), + dtype=theano.config.floatX + ), + name='W', + borrow=True + ) + + if not b: + self.b = theano.shared( + value=numpy.zeros( + (n_out,), + dtype=theano.config.floatX + ), + name='b', + borrow=True + ) + + self.p_y_given_x = T.nnet.sigmoid(T.dot(input, self.W) + self.b) + self.y_pred = self.p_y_given_x + self.params = [self.W, self.b] + + def image_norm(self, y, obj): + + y_diff = (y - self.y_pred) + l2norm = (T.sqrt((y_diff**2).sum(axis=1,keepdims=False))**2) + lambda_reg = 0.00001 + weights = 0 + for i in xrange(obj.n_layers): + weight = (T.sqrt((obj.dA_layers[i].W ** 2).sum())**2) + #weight = (nlinalg.trace(T.dot(obj.dA_layers[i].W.T, obj.dA_layers[i].W)))**2 #Frobenius norm + weights = weights + weight + regterm = T.sum(weights,keepdims=False) + + return T.mean(l2norm) + lambda_reg/2 *regterm + + def image_norm_noreg(self, y): + + y_diff = (y - self.y_pred) + l2norm = (T.sqrt((y_diff**2).sum(axis=1,keepdims=False))**2) + + return T.mean(l2norm) + +################################################# +# Loading Data for Training # +################################################# + +def load_data(dataset): + + print '... loading h5py mat data' + + f = h5py.File(dataset) + + train_set_x = numpy.transpose(f['train_set_x']) + valid_set_x = numpy.transpose(f['valid_set_x']) + test_set_x = numpy.transpose(f['test_set_x']) + train_set_y = numpy.transpose(f['train_set_y']) + valid_set_y = numpy.transpose(f['valid_set_y']) + test_set_y = numpy.transpose(f['test_set_y']) + + train_set = train_set_x, train_set_y + valid_set = valid_set_x, valid_set_y + test_set = test_set_x, test_set_y + + def shared_dataset(data_xy, borrow=True): + + data_x, data_y = data_xy + shared_x = theano.shared(numpy.asarray(data_x, + dtype=theano.config.floatX), + borrow=borrow) + shared_y = theano.shared(numpy.asarray(data_y, + dtype=theano.config.floatX), + borrow=borrow) + + return shared_x, shared_y #T.cast(shared_y, 'int32') + + test_set_x, test_set_y = shared_dataset(test_set) + valid_set_x, valid_set_y = shared_dataset(valid_set) + train_set_x, train_set_y = shared_dataset(train_set) + + rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y), + (test_set_x, test_set_y)] + return rval + +################################################# +# Loading Data for Testing with Full Overlap # +################################################# + +def load_data_overlapped(te_dataset, patch_size): + + test_set_, te_h, te_w = overlapping_patches(path=te_dataset, patch_size = patch_size) + + def shared_dataset(data_x, borrow=True): + shared_data = theano.shared(numpy.asarray(data_x, + dtype=theano.config.floatX), + borrow=borrow) + return shared_data + + test_set_ = shared_dataset(test_set_) + rval = test_set_ + return rval, te_h, te_w + +################################################# +# Loading Data for Testing with Overlap Strides # +################################################# + +def load_data_overlapped_strides(te_dataset, patch_size, strides): + + test_set_, te_h, te_w = overlapping_patches_strides(path=te_dataset, patch_size = patch_size, strides=strides) + + def shared_dataset(data_x, borrow=True): + shared_data = theano.shared(numpy.asarray(data_x, + dtype=theano.config.floatX), + borrow=borrow) + return shared_data + + test_set_ = shared_dataset(test_set_) + rval = test_set_ + return rval, te_h, te_w diff --git a/mlp.py b/mlp.py new file mode 100644 index 0000000..6701378 --- /dev/null +++ b/mlp.py @@ -0,0 +1,404 @@ +""" +This tutorial introduces the multilayer perceptron using Theano. + + A multilayer perceptron is a logistic regressor where +instead of feeding the input to the logistic regression you insert a +intermediate layer, called the hidden layer, that has a nonlinear +activation function (usually tanh or sigmoid) . One can use many such +hidden layers making the architecture deep. The tutorial will also tackle +the problem of MNIST digit classification. + +.. math:: + + f(x) = G( b^{(2)} + W^{(2)}( s( b^{(1)} + W^{(1)} x))), + +References: + + - textbooks: "Pattern Recognition and Machine Learning" - + Christopher M. Bishop, section 5 + +""" +__docformat__ = 'restructedtext en' + + +import os +import sys +import time + +import numpy + +import theano +import theano.tensor as T + + +from logistic_sgd import LogisticRegression, load_data + + +# start-snippet-1 +class HiddenLayer(object): + def __init__(self, rng, input, n_in, n_out, W=None, b=None, + activation=T.tanh): + """ + Typical hidden layer of a MLP: units are fully-connected and have + sigmoidal activation function. Weight matrix W is of shape (n_in,n_out) + and the bias vector b is of shape (n_out,). + + NOTE : The nonlinearity used here is tanh + + Hidden unit activation is given by: tanh(dot(input,W) + b) + + :type rng: numpy.random.RandomState + :param rng: a random number generator used to initialize weights + + :type input: theano.tensor.dmatrix + :param input: a symbolic tensor of shape (n_examples, n_in) + + :type n_in: int + :param n_in: dimensionality of input + + :type n_out: int + :param n_out: number of hidden units + + :type activation: theano.Op or function + :param activation: Non linearity to be applied in the hidden + layer + """ + self.input = input + # end-snippet-1 + + # `W` is initialized with `W_values` which is uniformely sampled + # from sqrt(-6./(n_in+n_hidden)) and sqrt(6./(n_in+n_hidden)) + # for tanh activation function + # the output of uniform if converted using asarray to dtype + # theano.config.floatX so that the code is runable on GPU + # Note : optimal initialization of weights is dependent on the + # activation function used (among other things). + # For example, results presented in [Xavier10] suggest that you + # should use 4 times larger initial weights for sigmoid + # compared to tanh + # We have no info for other function, so we use the same as + # tanh. + if W is None: + W_values = numpy.asarray( + rng.uniform( + low=-numpy.sqrt(6. / (n_in + n_out)), + high=numpy.sqrt(6. / (n_in + n_out)), + size=(n_in, n_out) + ), + dtype=theano.config.floatX + ) + if activation == theano.tensor.nnet.sigmoid: + W_values *= 4 + + W = theano.shared(value=W_values, name='W', borrow=True) + + if b is None: + b_values = numpy.zeros((n_out,), dtype=theano.config.floatX) + b = theano.shared(value=b_values, name='b', borrow=True) + + self.W = W + self.b = b + + lin_output = T.dot(input, self.W) + self.b + self.output = ( + lin_output if activation is None + else activation(lin_output) + ) + # parameters of the model + self.params = [self.W, self.b] + + +# start-snippet-2 +class MLP(object): + """Multi-Layer Perceptron Class + + A multilayer perceptron is a feedforward artificial neural network model + that has one layer or more of hidden units and nonlinear activations. + Intermediate layers usually have as activation function tanh or the + sigmoid function (defined here by a ``HiddenLayer`` class) while the + top layer is a softamx layer (defined here by a ``LogisticRegression`` + class). + """ + + def __init__(self, rng, input, n_in, n_hidden, n_out): + """Initialize the parameters for the multilayer perceptron + + :type rng: numpy.random.RandomState + :param rng: a random number generator used to initialize weights + + :type input: theano.tensor.TensorType + :param input: symbolic variable that describes the input of the + architecture (one minibatch) + + :type n_in: int + :param n_in: number of input units, the dimension of the space in + which the datapoints lie + + :type n_hidden: int + :param n_hidden: number of hidden units + + :type n_out: int + :param n_out: number of output units, the dimension of the space in + which the labels lie + + """ + + # Since we are dealing with a one hidden layer MLP, this will translate + # into a HiddenLayer with a tanh activation function connected to the + # LogisticRegression layer; the activation function can be replaced by + # sigmoid or any other nonlinear function + self.hiddenLayer = HiddenLayer( + rng=rng, + input=input, + n_in=n_in, + n_out=n_hidden, + activation=T.tanh + ) + + # The logistic regression layer gets as input the hidden units + # of the hidden layer + self.logRegressionLayer = LogisticRegression( + input=self.hiddenLayer.output, + n_in=n_hidden, + n_out=n_out + ) + # end-snippet-2 start-snippet-3 + # L1 norm ; one regularization option is to enforce L1 norm to + # be small + self.L1 = ( + abs(self.hiddenLayer.W).sum() + + abs(self.logRegressionLayer.W).sum() + ) + + # square of L2 norm ; one regularization option is to enforce + # square of L2 norm to be small + self.L2_sqr = ( + (self.hiddenLayer.W ** 2).sum() + + (self.logRegressionLayer.W ** 2).sum() + ) + + # negative log likelihood of the MLP is given by the negative + # log likelihood of the output of the model, computed in the + # logistic regression layer + self.negative_log_likelihood = ( + self.logRegressionLayer.negative_log_likelihood + ) + # same holds for the function computing the number of errors + self.errors = self.logRegressionLayer.errors + + # the parameters of the model are the parameters of the two layer it is + # made out of + self.params = self.hiddenLayer.params + self.logRegressionLayer.params + # end-snippet-3 + + +def test_mlp(learning_rate=0.01, L1_reg=0.00, L2_reg=0.0001, n_epochs=1000, + dataset='mnist.pkl.gz', batch_size=20, n_hidden=500): + """ + Demonstrate stochastic gradient descent optimization for a multilayer + perceptron + + This is demonstrated on MNIST. + + :type learning_rate: float + :param learning_rate: learning rate used (factor for the stochastic + gradient + + :type L1_reg: float + :param L1_reg: L1-norm's weight when added to the cost (see + regularization) + + :type L2_reg: float + :param L2_reg: L2-norm's weight when added to the cost (see + regularization) + + :type n_epochs: int + :param n_epochs: maximal number of epochs to run the optimizer + + :type dataset: string + :param dataset: the path of the MNIST dataset file from + http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz + + + """ + datasets = load_data(dataset) + + train_set_x, train_set_y = datasets[0] + valid_set_x, valid_set_y = datasets[1] + test_set_x, test_set_y = datasets[2] + + # compute number of minibatches for training, validation and testing + n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size + n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] / batch_size + n_test_batches = test_set_x.get_value(borrow=True).shape[0] / batch_size + + ###################### + # BUILD ACTUAL MODEL # + ###################### + print '... building the model' + + # allocate symbolic variables for the data + index = T.lscalar() # index to a [mini]batch + x = T.matrix('x') # the data is presented as rasterized images + y = T.ivector('y') # the labels are presented as 1D vector of + # [int] labels + + rng = numpy.random.RandomState(1234) + + # construct the MLP class + classifier = MLP( + rng=rng, + input=x, + n_in=28 * 28, + n_hidden=n_hidden, + n_out=10 + ) + + # start-snippet-4 + # the cost we minimize during training is the negative log likelihood of + # the model plus the regularization terms (L1 and L2); cost is expressed + # here symbolically + cost = ( + classifier.negative_log_likelihood(y) + + L1_reg * classifier.L1 + + L2_reg * classifier.L2_sqr + ) + # end-snippet-4 + + # compiling a Theano function that computes the mistakes that are made + # by the model on a minibatch + test_model = theano.function( + inputs=[index], + outputs=classifier.errors(y), + givens={ + x: test_set_x[index * batch_size:(index + 1) * batch_size], + y: test_set_y[index * batch_size:(index + 1) * batch_size] + } + ) + + validate_model = theano.function( + inputs=[index], + outputs=classifier.errors(y), + givens={ + x: valid_set_x[index * batch_size:(index + 1) * batch_size], + y: valid_set_y[index * batch_size:(index + 1) * batch_size] + } + ) + + # start-snippet-5 + # compute the gradient of cost with respect to theta (sotred in params) + # the resulting gradients will be stored in a list gparams + gparams = [T.grad(cost, param) for param in classifier.params] + + # specify how to update the parameters of the model as a list of + # (variable, update expression) pairs + + # given two list the zip A = [a1, a2, a3, a4] and B = [b1, b2, b3, b4] of + # same length, zip generates a list C of same size, where each element + # is a pair formed from the two lists : + # C = [(a1, b1), (a2, b2), (a3, b3), (a4, b4)] + updates = [ + (param, param - learning_rate * gparam) + for param, gparam in zip(classifier.params, gparams) + ] + + # compiling a Theano function `train_model` that returns the cost, but + # in the same time updates the parameter of the model based on the rules + # defined in `updates` + train_model = theano.function( + inputs=[index], + outputs=cost, + updates=updates, + givens={ + x: train_set_x[index * batch_size: (index + 1) * batch_size], + y: train_set_y[index * batch_size: (index + 1) * batch_size] + } + ) + # end-snippet-5 + + ############### + # TRAIN MODEL # + ############### + print '... training' + + # early-stopping parameters + patience = 10000 # look as this many examples regardless + patience_increase = 2 # wait this much longer when a new best is + # found + improvement_threshold = 0.995 # a relative improvement of this much is + # considered significant + validation_frequency = min(n_train_batches, patience / 2) + # go through this many + # minibatche before checking the network + # on the validation set; in this case we + # check every epoch + + best_validation_loss = numpy.inf + best_iter = 0 + test_score = 0. + start_time = time.clock() + + epoch = 0 + done_looping = False + + while (epoch < n_epochs) and (not done_looping): + epoch = epoch + 1 + for minibatch_index in xrange(n_train_batches): + + minibatch_avg_cost = train_model(minibatch_index) + # iteration number + iter = (epoch - 1) * n_train_batches + minibatch_index + + if (iter + 1) % validation_frequency == 0: + # compute zero-one loss on validation set + validation_losses = [validate_model(i) for i + in xrange(n_valid_batches)] + this_validation_loss = numpy.mean(validation_losses) + + print( + 'epoch %i, minibatch %i/%i, validation error %f %%' % + ( + epoch, + minibatch_index + 1, + n_train_batches, + this_validation_loss * 100. + ) + ) + + # if we got the best validation score until now + if this_validation_loss < best_validation_loss: + #improve patience if loss improvement is good enough + if ( + this_validation_loss < best_validation_loss * + improvement_threshold + ): + patience = max(patience, iter * patience_increase) + + best_validation_loss = this_validation_loss + best_iter = iter + + # test it on the test set + test_losses = [test_model(i) for i + in xrange(n_test_batches)] + test_score = numpy.mean(test_losses) + + print((' epoch %i, minibatch %i/%i, test error of ' + 'best model %f %%') % + (epoch, minibatch_index + 1, n_train_batches, + test_score * 100.)) + + if patience <= iter: + done_looping = True + break + + end_time = time.clock() + print(('Optimization complete. Best validation score of %f %% ' + 'obtained at iteration %i, with test performance %f %%') % + (best_validation_loss * 100., best_iter + 1, test_score * 100.)) + print >> sys.stderr, ('The code for file ' + + os.path.split(__file__)[1] + + ' ran for %.2fm' % ((end_time - start_time) / 60.)) + + +if __name__ == '__main__': + test_mlp() diff --git a/nlinalg.py b/nlinalg.py new file mode 100644 index 0000000..f783c61 --- /dev/null +++ b/nlinalg.py @@ -0,0 +1,701 @@ +import logging +import theano + +logger = logging.getLogger(__name__) +import numpy + +from theano.gof import Op, Apply + +from theano.tensor import as_tensor_variable, dot, DimShuffle, Dot +from theano.tensor.blas import Dot22 +from theano.tensor.opt import (register_stabilize, + register_specialize, register_canonicalize) +from theano.gof import local_optimizer +from theano.gof.opt import Optimizer +from theano.gradient import DisconnectedType +from theano.tensor import basic as tensor + + +class MatrixPinv(Op): + """Computes the pseudo-inverse of a matrix :math:`A`. + + The pseudo-inverse of a matrix A, denoted :math:`A^+`, is + defined as: "the matrix that 'solves' [the least-squares problem] + :math:`Ax = b`," i.e., if :math:`\\bar{x}` is said solution, then + :math:`A^+` is that matrix such that :math:`\\bar{x} = A^+b`. + + Note that :math:`Ax=AA^+b`, so :math:`AA^+` is close to the identity matrix. + This method is not faster then `matrix_inverse`. Its strength comes from + that it works for non-square matrices. + If you have a square matrix though, `matrix_inverse` can be both more + exact and faster to compute. Also this op does not get optimized into a + solve op. + """ + + __props__ = () + + def __init__(self): + pass + + def make_node(self, x): + x = as_tensor_variable(x) + assert x.ndim == 2 + return Apply(self, [x], [x.type()]) + + def perform(self, node, (x,), (z, )): + z[0] = numpy.linalg.pinv(x).astype(x.dtype) + +pinv = MatrixPinv() + + +class MatrixInverse(Op): + """Computes the inverse of a matrix :math:`A`. + + Given a square matrix :math:`A`, ``matrix_inverse`` returns a square + matrix :math:`A_{inv}` such that the dot product :math:`A \cdot A_{inv}` + and :math:`A_{inv} \cdot A` equals the identity matrix :math:`I`. + + :note: When possible, the call to this op will be optimized to the call + of ``solve``. + """ + + __props__ = () + + def __init__(self): + pass + + def make_node(self, x): + x = as_tensor_variable(x) + assert x.ndim == 2 + return Apply(self, [x], [x.type()]) + + def perform(self, node, (x,), (z, )): + z[0] = numpy.linalg.inv(x).astype(x.dtype) + + def grad(self, inputs, g_outputs): + r"""The gradient function should return + + .. math:: V\frac{\partial X^{-1}}{\partial X}, + + where :math:`V` corresponds to ``g_outputs`` and :math:`X` to + ``inputs``. Using the `matrix cookbook + `_, + once can deduce that the relation corresponds to + + .. math:: (X^{-1} \cdot V^{T} \cdot X^{-1})^T. + + """ + x, = inputs + xi = self(x) + gz, = g_outputs + #TT.dot(gz.T,xi) + return [-matrix_dot(xi, gz.T, xi).T] + + def R_op(self, inputs, eval_points): + r"""The gradient function should return + + .. math:: \frac{\partial X^{-1}}{\partial X}V, + + where :math:`V` corresponds to ``g_outputs`` and :math:`X` to + ``inputs``. Using the `matrix cookbook + `_, + once can deduce that the relation corresponds to + + .. math:: X^{-1} \cdot V \cdot X^{-1}. + + """ + x, = inputs + xi = self(x) + ev, = eval_points + if ev is None: + return [None] + return [-matrix_dot(xi, ev, xi)] + + def infer_shape(self, node, shapes): + return shapes + +matrix_inverse = MatrixInverse() + + +def matrix_dot(*args): + """ Shorthand for product between several dots + + Given :math:`N` matrices :math:`A_0, A_1, .., A_N`, ``matrix_dot`` will + generate the matrix product between all in the given order, namely + :math:`A_0 \cdot A_1 \cdot A_2 \cdot .. \cdot A_N`. + """ + rval = args[0] + for a in args[1:]: + rval = theano.tensor.dot(rval, a) + return rval + + +class AllocDiag(Op): + """ + Allocates a square matrix with the given vector as its diagonal. + """ + def __eq__(self, other): + return type(self) == type(other) + + def __hash__(self): + return hash(type(self)) + + def make_node(self, _x): + x = as_tensor_variable(_x) + if x.type.ndim != 1: + raise TypeError('AllocDiag only works on vectors', _x) + return Apply(self, [x], [theano.tensor.matrix(dtype=x.type.dtype)]) + + def grad(self, inputs, g_outputs): + return [extract_diag(g_outputs[0])] + + def perform(self, node, (x,), (z,)): + if x.ndim != 1: + raise TypeError(x) + z[0] = numpy.diag(x) + + def infer_shape(self, node, shapes): + x_s, = shapes + return [(x_s[0], x_s[0])] + +alloc_diag = AllocDiag() + + +class ExtractDiag(Op): + """ Return the diagonal of a matrix. + + :note: work on the GPU. + """ + def __init__(self, view=False): + self.view = view + if self.view: + self.view_map = {0: [0]} + + def __eq__(self, other): + return type(self) == type(other) and self.view == other.view + + def __hash__(self): + return hash(type(self)) ^ hash(self.view) + + def make_node(self, _x): + if not isinstance(_x, theano.Variable): + x = as_tensor_variable(_x) + else: + x = _x + + if x.type.ndim != 2: + raise TypeError('ExtractDiag only works on matrices', _x) + return Apply(self, [x], [x.type.__class__(broadcastable=(False,), + dtype=x.type.dtype)()]) + + def perform(self, node, ins, outs): + """ For some reason numpy.diag(x) is really slow, so we + implemented our own. """ + x, = ins + z, = outs + # zero-dimensional matrices ... + if x.shape[0] == 0 or x.shape[1] == 0: + z[0] = node.outputs[0].type.value_zeros((0,)) + return + + if x.shape[0] < x.shape[1]: + rval = x[:, 0] + else: + rval = x[0] + + rval.strides = (x.strides[0] + x.strides[1],) + if self.view: + z[0] = rval + else: + z[0] = rval.copy() + + def __str__(self): + return 'ExtractDiag{view=%s}' % self.view + + def grad(self, inputs, g_outputs): + x = theano.tensor.zeros_like(inputs[0]) + xdiag = alloc_diag(g_outputs[0]) + return [theano.tensor.set_subtensor( + x[:xdiag.shape[0], :xdiag.shape[1]], + xdiag)] + + def infer_shape(self, node, shapes): + x_s, = shapes + shp = theano.tensor.min(node.inputs[0].shape) + return [(shp,)] + +extract_diag = ExtractDiag() +#TODO: optimization to insert ExtractDiag with view=True + + +def diag(x): + """ + Numpy-compatibility method + If `x` is a matrix, return its diagonal. + If `x` is a vector return a matrix with it as its diagonal. + + * This method does not support the `k` argument that numpy supports. + """ + xx = as_tensor_variable(x) + if xx.type.ndim == 1: + return alloc_diag(xx) + elif xx.type.ndim == 2: + return extract_diag(xx) + else: + raise TypeError('diag requires vector or matrix argument', x) + + +def trace(X): + """ + Returns the sum of diagonal elements of matrix X. + + :note: work on GPU since 0.6rc4. + """ + return extract_diag(X).sum() + + +class Det(Op): + """Matrix determinant + Input should be a square matrix + """ + def make_node(self, x): + x = as_tensor_variable(x) + assert x.ndim == 2 + o = theano.tensor.scalar(dtype=x.dtype) + return Apply(self, [x], [o]) + + def perform(self, node, (x,), (z, )): + try: + z[0] = numpy.asarray(numpy.linalg.det(x), dtype=x.dtype) + except Exception: + print 'Failed to compute determinant', x + raise + + def grad(self, inputs, g_outputs): + gz, = g_outputs + x, = inputs + return [gz * self(x) * matrix_inverse(x).T] + + def infer_shape(self, node, shapes): + return [()] + + def __str__(self): + return "Det" +det = Det() + + +class Eig(Op): + """Compute the eigenvalues and right eigenvectors of a square array. + + """ + _numop = staticmethod(numpy.linalg.eig) + __props__ = () + + def make_node(self, x): + x = as_tensor_variable(x) + assert x.ndim == 2 + w = theano.tensor.vector(dtype=x.dtype) + v = theano.tensor.matrix(dtype=x.dtype) + return Apply(self, [x], [w, v]) + + def perform(self, node, (x,), (w, v)): + w[0], v[0] = [z.astype(x.dtype) for z in self._numop(x)] + + def infer_shape(self, node, shapes): + n = shapes[0][0] + return [(n,), (n, n)] + +eig = Eig() + + +class Eigh(Eig): + """ + Return the eigenvalues and eigenvectors of a Hermitian or symmetric matrix. + + """ + _numop = staticmethod(numpy.linalg.eigh) + __props__ = ('UPLO',) + + def __init__(self, UPLO='L'): + assert UPLO in ['L', 'U'] + self.UPLO = UPLO + + def make_node(self, x): + x = as_tensor_variable(x) + assert x.ndim == 2 + # Numpy's linalg.eigh may return either double or single + # presision eigenvalues depending on installed version of + # LAPACK. Rather than trying to reproduce the (rather + # involved) logic, we just probe linalg.eigh with a trivial + # input. + w_dtype = self._numop([[numpy.dtype(x.dtype).type()]])[0].dtype.name + w = theano.tensor.vector(dtype=w_dtype) + v = theano.tensor.matrix(dtype=x.dtype) + return Apply(self, [x], [w, v]) + + def perform(self, node, (x,), (w, v)): + w[0], v[0] = self._numop(x, self.UPLO) + + def grad(self, inputs, g_outputs): + r"""The gradient function should return + + .. math:: \sum_n\left(W_n\frac{\partial\,w_n} + {\partial a_{ij}} + + \sum_k V_{nk}\frac{\partial\,v_{nk}} + {\partial a_{ij}}\right), + + where [:math:`W`, :math:`V`] corresponds to ``g_outputs``, + :math:`a` to ``inputs``, and :math:`(w, v)=\mbox{eig}(a)`. + + Analytic formulae for eigensystem gradients are well-known in + perturbation theory: + + .. math:: \frac{\partial\,w_n} + {\partial a_{ij}} = v_{in}\,v_{jn} + + + .. math:: \frac{\partial\,v_{kn}} + {\partial a_{ij}} = + \sum_{m\ne n}\frac{v_{km}v_{jn}}{w_n-w_m} + """ + x, = inputs + w, v = self(x) + # Replace gradients wrt disconnected variables with + # zeros. This is a work-around for issue #1063. + gw, gv = _zero_disconnected([w, v], g_outputs) + return [EighGrad(self.UPLO)(x, w, v, gw, gv)] + + +def _zero_disconnected(outputs, grads): + l = [] + for o, g in zip(outputs, grads): + if isinstance(g.type, DisconnectedType): + l.append(o.zeros_like()) + else: + l.append(g) + return l + + +class EighGrad(Op): + """Gradient of an eigensystem of a Hermitian matrix. + + """ + __props__ = ('UPLO',) + + def __init__(self, UPLO='L'): + assert UPLO in ['L', 'U'] + self.UPLO = UPLO + if UPLO == 'L': + self.tri0 = numpy.tril + self.tri1 = lambda a: numpy.triu(a, 1) + else: + self.tri0 = numpy.triu + self.tri1 = lambda a: numpy.tril(a, -1) + + def make_node(self, x, w, v, gw, gv): + x, w, v, gw, gv = map(as_tensor_variable, (x, w, v, gw, gv)) + assert x.ndim == 2 + assert w.ndim == 1 + assert v.ndim == 2 + assert gw.ndim == 1 + assert gv.ndim == 2 + out_dtype = theano.scalar.upcast(x.dtype, w.dtype, v.dtype, + gw.dtype, gv.dtype) + out = theano.tensor.matrix(dtype=out_dtype) + return Apply(self, [x, w, v, gw, gv], [out]) + + def perform(self, node, inputs, outputs): + """ + Implements the "reverse-mode" gradient for the eigensystem of + a square matrix. + """ + x, w, v, W, V = inputs + N = x.shape[0] + outer = numpy.outer + + G = lambda n: sum(v[:, m] * V.T[n].dot(v[:, m]) / (w[n] - w[m]) + for m in xrange(N) if m != n) + g = sum(outer(v[:, n], v[:, n] * W[n] + G(n)) + for n in xrange(N)) + + # Numpy's eigh(a, 'L') (eigh(a, 'U')) is a function of tril(a) + # (triu(a)) only. This means that partial derivative of + # eigh(a, 'L') (eigh(a, 'U')) with respect to a[i,j] is zero + # for i < j (i > j). At the same time, non-zero components of + # the gradient must account for the fact that variation of the + # opposite triangle contributes to variation of two elements + # of Hermitian (symmetric) matrix. The following line + # implements the necessary logic. + out = self.tri0(g) + self.tri1(g).T + + # The call to self.tri0 in perform upcast from float32 to + # float64 or from int* to int64 in numpy 1.6.1 but not in + # 1.6.2. We do not want version dependent dtype in Theano. + # We think it should be the same as the output. + outputs[0][0] = numpy.asarray(out, dtype=node.outputs[0].dtype) + + def infer_shape(self, node, shapes): + return [shapes[0]] + + +def eigh(a, UPLO='L'): + return Eigh(UPLO)(a) + + +class QRFull(Op): + """ + Full QR Decomposition. + Computes the QR decomposition of a matrix. + Factor the matrix a as qr, where q is orthonormal + and r is upper-triangular. + """ + _numop = staticmethod(numpy.linalg.qr) + __props__ = ('mode',) + + def __init__(self, mode): + self.mode = mode + + def make_node(self, x): + x = as_tensor_variable(x) + assert x.ndim == 2, "The input of qr function should be a matrix." + q = theano.tensor.matrix(dtype=x.dtype) + r = theano.tensor.matrix(dtype=x.dtype) + return Apply(self, [x], [q, r]) + + def perform(self, node, (x,), (q, r)): + assert x.ndim == 2, "The input of qr function should be a matrix." + + q[0], r[0] = self._numop(x, + self.mode) + + +class QRIncomplete(Op): + """ + Incomplete QR Decomposition. + Computes the QR decomposition of a matrix. + Factor the matrix a as qr and return a single matrix. + """ + _numop = staticmethod(numpy.linalg.qr) + __props__ = ('mode',) + + def __init__(self, mode): + self.mode = mode + + def make_node(self, x): + x = as_tensor_variable(x) + assert x.ndim == 2, "The input of qr function should be a matrix." + q = theano.tensor.matrix(dtype=x.dtype) + return Apply(self, [x], [q]) + + def perform(self, node, (x,), (q,)): + assert x.ndim == 2, "The input of qr function should be a matrix." + q[0] = self._numop(x, + self.mode) + + +def qr(a, mode="full"): + """ + Computes the QR decomposition of a matrix. + Factor the matrix a as qr, where q + is orthonormal and r is upper-triangular. + + :type a: + array_like, shape (M, N) + :param a: + Matrix to be factored. + + :type mode: + one of 'reduced', 'complete', 'r', 'raw', 'full' and + 'economic', optional + :keyword mode: + If K = min(M, N), then + + 'reduced' + returns q, r with dimensions (M, K), (K, N) + + 'complete' + returns q, r with dimensions (M, M), (M, N) + + 'r' + returns r only with dimensions (K, N) + + 'raw' + returns h, tau with dimensions (N, M), (K,) + + 'full' + alias of 'reduced', deprecated (default) + + 'economic' + returns h from 'raw', deprecated. The options 'reduced', + + 'complete', and 'raw' are new in numpy 1.8, see the notes for more + information. The default is 'reduced' and to maintain backward + compatibility with earlier versions of numpy both it and the old + default 'full' can be omitted. Note that array h returned in 'raw' + mode is transposed for calling Fortran. The 'economic' mode is + deprecated. The modes 'full' and 'economic' may be passed using only + the first letter for backwards compatibility, but all others + must be spelled out. + + Default mode is 'full' which is also default for numpy 1.6.1. + + :note: Default mode was left to full as full and reduced are + both doing the same thing in the new numpy version but only + full works on the old previous numpy version. + + :rtype q: + matrix of float or complex, optional + :return q: + A matrix with orthonormal columns. When mode = 'complete' the + result is an orthogonal/unitary matrix depending on whether or + not a is real/complex. The determinant may be either +/- 1 in + that case. + + :rtype r: + matrix of float or complex, optional + :return r: + The upper-triangular matrix. + """ + x = [[2, 1], [3, 4]] + if isinstance(numpy.linalg.qr(x,mode), tuple): + return QRFull(mode)(a) + else: + return QRIncomplete(mode)(a) + + +class SVD(Op): + + # See doc in the docstring of the function just after this class. + _numop = staticmethod(numpy.linalg.svd) + __props__ = ('full_matrices', 'compute_uv') + + def __init__(self, full_matrices=True, compute_uv=True): + """ + full_matrices : bool, optional + If True (default), u and v have the shapes (M, M) and (N, N), + respectively. + Otherwise, the shapes are (M, K) and (K, N), respectively, + where K = min(M, N). + compute_uv : bool, optional + Whether or not to compute u and v in addition to s. + True by default. + """ + self.full_matrices = full_matrices + self.compute_uv = compute_uv + + def make_node(self, x): + x = as_tensor_variable(x) + assert x.ndim == 2, "The input of svd function should be a matrix." + w = theano.tensor.matrix(dtype=x.dtype) + u = theano.tensor.vector(dtype=x.dtype) + v = theano.tensor.matrix(dtype=x.dtype) + return Apply(self, [x], [w, u, v]) + + def perform(self, node, (x,), (w, u, v)): + assert x.ndim == 2, "The input of svd function should be a matrix." + w[0], u[0], v[0] = self._numop(x, + self.full_matrices, + self.compute_uv) + + +def svd(a, full_matrices=1, compute_uv=1): + """ + This function performs the SVD on CPU. + + :type full_matrices: bool, optional + :param full_matrices: + If True (default), u and v have the shapes (M, M) and (N, N), + respectively. + Otherwise, the shapes are (M, K) and (K, N), respectively, + where K = min(M, N). + :type compute_uv: bool, optional + :param compute_uv: + Whether or not to compute u and v in addition to s. + True by default. + + :returns: U, V and D matrices. + """ + return SVD(full_matrices, compute_uv)(a) + + +def test_matrix_inverse_solve(): + if not imported_scipy: + raise SkipTest("Scipy needed for the Solve op.") + A = theano.tensor.dmatrix('A') + b = theano.tensor.dmatrix('b') + node = matrix_inverse(A).dot(b).owner + [out] = inv_as_solve.transform(node) + assert isinstance(out.owner.op, Solve) + + +class lstsq(Op): + def __eq__(self, other): + return type(self) == type(other) + + def __hash__(self): + return hash(type(self)) + + def __str__(self): + return self.__class__.__name__ + + def make_node(self, x, y, rcond): + x = theano.tensor.as_tensor_variable(x) + y = theano.tensor.as_tensor_variable(y) + rcond = theano.tensor.as_tensor_variable(rcond) + return theano.Apply(self, [x, y, rcond], + [theano.tensor.matrix(), theano.tensor.dvector(), + theano.tensor.lscalar(), theano.tensor.dvector()]) + + def perform(self, node, inputs, outputs): + x = inputs[0] + y = inputs[1] + rcond = inputs[2] + zz = numpy.linalg.lstsq(inputs[0], inputs[1], inputs[2]) + outputs[0][0] = zz[0] + outputs[1][0] = zz[1] + outputs[2][0] = numpy.array(zz[2]) + outputs[3][0] = zz[3] + + +def matrix_power(M, n): + result = 1 + for i in xrange(n): + result = theano.dot(result, M) + return result + + +def norm(x,ord): + x = as_tensor_variable(x) + ndim = x.ndim + if ndim == 0: + raise ValueError("'axis' entry is out of bounds.") + elif ndim == 1: + if ord is None: + return tensor.sum(x**2)**0.5 + elif ord == 'inf': + return tensor.max(abs(x)) + elif ord == '-inf': + return tensor.min(abs(x)) + elif ord == 0: + return x[x.nonzero()].shape[0] + else: + try: + z = tensor.sum(abs(x**ord))**(1./ord) + except TypeError: + raise ValueError("Invalid norm order for vectors.") + return z + elif ndim == 2: + if ord is None or ord == 'fro': + return tensor.sum(abs(x**2))**(0.5) + elif ord == 'inf': + return tensor.max(tensor.sum(abs(x), 1)) + elif ord == '-inf': + return tensor.min(tensor.sum(abs(x), 1)) + elif ord == 1: + return tensor.max(tensor.sum(abs(x), 0)) + elif ord == -1: + return tensor.min(tensor.sum(abs(x),0)) + else: + raise ValueError(0) + elif ndim > 2: + raise NotImplementedError("We don't support norm witn ndim > 2") diff --git a/rbm.py b/rbm.py new file mode 100644 index 0000000..2c821fc --- /dev/null +++ b/rbm.py @@ -0,0 +1,540 @@ +"""This tutorial introduces restricted boltzmann machines (RBM) using Theano. + +Boltzmann Machines (BMs) are a particular form of energy-based model which +contain hidden variables. Restricted Boltzmann Machines further restrict BMs +to those without visible-visible and hidden-hidden connections. +""" +import time + +try: + import PIL.Image as Image +except ImportError: + import Image + +import numpy + +import theano +import theano.tensor as T +import os + +from theano.tensor.shared_randomstreams import RandomStreams + +from utils import tile_raster_images +from logistic_sgd import load_data + + +# start-snippet-1 +class RBM(object): + """Restricted Boltzmann Machine (RBM) """ + def __init__( + self, + input=None, + n_visible=784, + n_hidden=500, + W=None, + hbias=None, + vbias=None, + numpy_rng=None, + theano_rng=None + ): + """ + RBM constructor. Defines the parameters of the model along with + basic operations for inferring hidden from visible (and vice-versa), + as well as for performing CD updates. + + :param input: None for standalone RBMs or symbolic variable if RBM is + part of a larger graph. + + :param n_visible: number of visible units + + :param n_hidden: number of hidden units + + :param W: None for standalone RBMs or symbolic variable pointing to a + shared weight matrix in case RBM is part of a DBN network; in a DBN, + the weights are shared between RBMs and layers of a MLP + + :param hbias: None for standalone RBMs or symbolic variable pointing + to a shared hidden units bias vector in case RBM is part of a + different network + + :param vbias: None for standalone RBMs or a symbolic variable + pointing to a shared visible units bias + """ + + self.n_visible = n_visible + self.n_hidden = n_hidden + + if numpy_rng is None: + # create a number generator + numpy_rng = numpy.random.RandomState(1234) + + if theano_rng is None: + theano_rng = RandomStreams(numpy_rng.randint(2 ** 30)) + + if W is None: + # W is initialized with `initial_W` which is uniformely + # sampled from -4*sqrt(6./(n_visible+n_hidden)) and + # 4*sqrt(6./(n_hidden+n_visible)) the output of uniform if + # converted using asarray to dtype theano.config.floatX so + # that the code is runable on GPU + initial_W = numpy.asarray( + numpy_rng.uniform( + low=-4 * numpy.sqrt(6. / (n_hidden + n_visible)), + high=4 * numpy.sqrt(6. / (n_hidden + n_visible)), + size=(n_visible, n_hidden) + ), + dtype=theano.config.floatX + ) + # theano shared variables for weights and biases + W = theano.shared(value=initial_W, name='W', borrow=True) + + if hbias is None: + # create shared variable for hidden units bias + hbias = theano.shared( + value=numpy.zeros( + n_hidden, + dtype=theano.config.floatX + ), + name='hbias', + borrow=True + ) + + if vbias is None: + # create shared variable for visible units bias + vbias = theano.shared( + value=numpy.zeros( + n_visible, + dtype=theano.config.floatX + ), + name='vbias', + borrow=True + ) + + # initialize input layer for standalone RBM or layer0 of DBN + self.input = input + if not input: + self.input = T.matrix('input') + + self.W = W + self.hbias = hbias + self.vbias = vbias + self.theano_rng = theano_rng + # **** WARNING: It is not a good idea to put things in this list + # other than shared variables created in this function. + self.params = [self.W, self.hbias, self.vbias] + # end-snippet-1 + + def free_energy(self, v_sample): + ''' Function to compute the free energy ''' + wx_b = T.dot(v_sample, self.W) + self.hbias + vbias_term = T.dot(v_sample, self.vbias) + hidden_term = T.sum(T.log(1 + T.exp(wx_b)), axis=1) + return -hidden_term - vbias_term + + def propup(self, vis): + '''This function propagates the visible units activation upwards to + the hidden units + + Note that we return also the pre-sigmoid activation of the + layer. As it will turn out later, due to how Theano deals with + optimizations, this symbolic variable will be needed to write + down a more stable computational graph (see details in the + reconstruction cost function) + + ''' + pre_sigmoid_activation = T.dot(vis, self.W) + self.hbias + return [pre_sigmoid_activation, T.nnet.sigmoid(pre_sigmoid_activation)] + + def sample_h_given_v(self, v0_sample): + ''' This function infers state of hidden units given visible units ''' + # compute the activation of the hidden units given a sample of + # the visibles + pre_sigmoid_h1, h1_mean = self.propup(v0_sample) + # get a sample of the hiddens given their activation + # Note that theano_rng.binomial returns a symbolic sample of dtype + # int64 by default. If we want to keep our computations in floatX + # for the GPU we need to specify to return the dtype floatX + h1_sample = self.theano_rng.binomial(size=h1_mean.shape, + n=1, p=h1_mean, + dtype=theano.config.floatX) + return [pre_sigmoid_h1, h1_mean, h1_sample] + + def propdown(self, hid): + '''This function propagates the hidden units activation downwards to + the visible units + + Note that we return also the pre_sigmoid_activation of the + layer. As it will turn out later, due to how Theano deals with + optimizations, this symbolic variable will be needed to write + down a more stable computational graph (see details in the + reconstruction cost function) + + ''' + pre_sigmoid_activation = T.dot(hid, self.W.T) + self.vbias + return [pre_sigmoid_activation, T.nnet.sigmoid(pre_sigmoid_activation)] + + def sample_v_given_h(self, h0_sample): + ''' This function infers state of visible units given hidden units ''' + # compute the activation of the visible given the hidden sample + pre_sigmoid_v1, v1_mean = self.propdown(h0_sample) + # get a sample of the visible given their activation + # Note that theano_rng.binomial returns a symbolic sample of dtype + # int64 by default. If we want to keep our computations in floatX + # for the GPU we need to specify to return the dtype floatX + v1_sample = self.theano_rng.binomial(size=v1_mean.shape, + n=1, p=v1_mean, + dtype=theano.config.floatX) + return [pre_sigmoid_v1, v1_mean, v1_sample] + + def gibbs_hvh(self, h0_sample): + ''' This function implements one step of Gibbs sampling, + starting from the hidden state''' + pre_sigmoid_v1, v1_mean, v1_sample = self.sample_v_given_h(h0_sample) + pre_sigmoid_h1, h1_mean, h1_sample = self.sample_h_given_v(v1_sample) + return [pre_sigmoid_v1, v1_mean, v1_sample, + pre_sigmoid_h1, h1_mean, h1_sample] + + def gibbs_vhv(self, v0_sample): + ''' This function implements one step of Gibbs sampling, + starting from the visible state''' + pre_sigmoid_h1, h1_mean, h1_sample = self.sample_h_given_v(v0_sample) + pre_sigmoid_v1, v1_mean, v1_sample = self.sample_v_given_h(h1_sample) + return [pre_sigmoid_h1, h1_mean, h1_sample, + pre_sigmoid_v1, v1_mean, v1_sample] + + # start-snippet-2 + def get_cost_updates(self, lr=0.1, persistent=None, k=1): + """This functions implements one step of CD-k or PCD-k + + :param lr: learning rate used to train the RBM + + :param persistent: None for CD. For PCD, shared variable + containing old state of Gibbs chain. This must be a shared + variable of size (batch size, number of hidden units). + + :param k: number of Gibbs steps to do in CD-k/PCD-k + + Returns a proxy for the cost and the updates dictionary. The + dictionary contains the update rules for weights and biases but + also an update of the shared variable used to store the persistent + chain, if one is used. + + """ + + # compute positive phase + pre_sigmoid_ph, ph_mean, ph_sample = self.sample_h_given_v(self.input) + + # decide how to initialize persistent chain: + # for CD, we use the newly generate hidden sample + # for PCD, we initialize from the old state of the chain + if persistent is None: + chain_start = ph_sample + else: + chain_start = persistent + # end-snippet-2 + # perform actual negative phase + # in order to implement CD-k/PCD-k we need to scan over the + # function that implements one gibbs step k times. + # Read Theano tutorial on scan for more information : + # http://deeplearning.net/software/theano/library/scan.html + # the scan will return the entire Gibbs chain + ( + [ + pre_sigmoid_nvs, + nv_means, + nv_samples, + pre_sigmoid_nhs, + nh_means, + nh_samples + ], + updates + ) = theano.scan( + self.gibbs_hvh, + # the None are place holders, saying that + # chain_start is the initial state corresponding to the + # 6th output + outputs_info=[None, None, None, None, None, chain_start], + n_steps=k + ) + # start-snippet-3 + # determine gradients on RBM parameters + # note that we only need the sample at the end of the chain + chain_end = nv_samples[-1] + + cost = T.mean(self.free_energy(self.input)) - T.mean( + self.free_energy(chain_end)) + # We must not compute the gradient through the gibbs sampling + gparams = T.grad(cost, self.params, consider_constant=[chain_end]) + # end-snippet-3 start-snippet-4 + # constructs the update dictionary + for gparam, param in zip(gparams, self.params): + # make sure that the learning rate is of the right dtype + updates[param] = param - gparam * T.cast( + lr, + dtype=theano.config.floatX + ) + if persistent: + # Note that this works only if persistent is a shared variable + updates[persistent] = nh_samples[-1] + # pseudo-likelihood is a better proxy for PCD + monitoring_cost = self.get_pseudo_likelihood_cost(updates) + else: + # reconstruction cross-entropy is a better proxy for CD + monitoring_cost = self.get_reconstruction_cost(updates, + pre_sigmoid_nvs[-1]) + + return monitoring_cost, updates + # end-snippet-4 + + def get_pseudo_likelihood_cost(self, updates): + """Stochastic approximation to the pseudo-likelihood""" + + # index of bit i in expression p(x_i | x_{\i}) + bit_i_idx = theano.shared(value=0, name='bit_i_idx') + + # binarize the input image by rounding to nearest integer + xi = T.round(self.input) + + # calculate free energy for the given bit configuration + fe_xi = self.free_energy(xi) + + # flip bit x_i of matrix xi and preserve all other bits x_{\i} + # Equivalent to xi[:,bit_i_idx] = 1-xi[:, bit_i_idx], but assigns + # the result to xi_flip, instead of working in place on xi. + xi_flip = T.set_subtensor(xi[:, bit_i_idx], 1 - xi[:, bit_i_idx]) + + # calculate free energy with bit flipped + fe_xi_flip = self.free_energy(xi_flip) + + # equivalent to e^(-FE(x_i)) / (e^(-FE(x_i)) + e^(-FE(x_{\i}))) + cost = T.mean(self.n_visible * T.log(T.nnet.sigmoid(fe_xi_flip - + fe_xi))) + + # increment bit_i_idx % number as part of updates + updates[bit_i_idx] = (bit_i_idx + 1) % self.n_visible + + return cost + + def get_reconstruction_cost(self, updates, pre_sigmoid_nv): + """Approximation to the reconstruction error + + Note that this function requires the pre-sigmoid activation as + input. To understand why this is so you need to understand a + bit about how Theano works. Whenever you compile a Theano + function, the computational graph that you pass as input gets + optimized for speed and stability. This is done by changing + several parts of the subgraphs with others. One such + optimization expresses terms of the form log(sigmoid(x)) in + terms of softplus. We need this optimization for the + cross-entropy since sigmoid of numbers larger than 30. (or + even less then that) turn to 1. and numbers smaller than + -30. turn to 0 which in terms will force theano to compute + log(0) and therefore we will get either -inf or NaN as + cost. If the value is expressed in terms of softplus we do not + get this undesirable behaviour. This optimization usually + works fine, but here we have a special case. The sigmoid is + applied inside the scan op, while the log is + outside. Therefore Theano will only see log(scan(..)) instead + of log(sigmoid(..)) and will not apply the wanted + optimization. We can not go and replace the sigmoid in scan + with something else also, because this only needs to be done + on the last step. Therefore the easiest and more efficient way + is to get also the pre-sigmoid activation as an output of + scan, and apply both the log and sigmoid outside scan such + that Theano can catch and optimize the expression. + + """ + + cross_entropy = T.mean( + T.sum( + self.input * T.log(T.nnet.sigmoid(pre_sigmoid_nv)) + + (1 - self.input) * T.log(1 - T.nnet.sigmoid(pre_sigmoid_nv)), + axis=1 + ) + ) + + return cross_entropy + + +def test_rbm(learning_rate=0.1, training_epochs=15, + dataset='mnist.pkl.gz', batch_size=20, + n_chains=20, n_samples=10, output_folder='rbm_plots', + n_hidden=500): + """ + Demonstrate how to train and afterwards sample from it using Theano. + + This is demonstrated on MNIST. + + :param learning_rate: learning rate used for training the RBM + + :param training_epochs: number of epochs used for training + + :param dataset: path the the pickled dataset + + :param batch_size: size of a batch used to train the RBM + + :param n_chains: number of parallel Gibbs chains to be used for sampling + + :param n_samples: number of samples to plot for each chain + + """ + datasets = load_data(dataset) + + train_set_x, train_set_y = datasets[0] + test_set_x, test_set_y = datasets[2] + + # compute number of minibatches for training, validation and testing + n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size + + # allocate symbolic variables for the data + index = T.lscalar() # index to a [mini]batch + x = T.matrix('x') # the data is presented as rasterized images + + rng = numpy.random.RandomState(123) + theano_rng = RandomStreams(rng.randint(2 ** 30)) + + # initialize storage for the persistent chain (state = hidden + # layer of chain) + persistent_chain = theano.shared(numpy.zeros((batch_size, n_hidden), + dtype=theano.config.floatX), + borrow=True) + + # construct the RBM class + rbm = RBM(input=x, n_visible=28 * 28, + n_hidden=n_hidden, numpy_rng=rng, theano_rng=theano_rng) + + # get the cost and the gradient corresponding to one step of CD-15 + cost, updates = rbm.get_cost_updates(lr=learning_rate, + persistent=persistent_chain, k=15) + + ################################# + # Training the RBM # + ################################# + if not os.path.isdir(output_folder): + os.makedirs(output_folder) + os.chdir(output_folder) + + # start-snippet-5 + # it is ok for a theano function to have no output + # the purpose of train_rbm is solely to update the RBM parameters + train_rbm = theano.function( + [index], + cost, + updates=updates, + givens={ + x: train_set_x[index * batch_size: (index + 1) * batch_size] + }, + name='train_rbm' + ) + + plotting_time = 0. + start_time = time.clock() + + # go through training epochs + for epoch in xrange(training_epochs): + + # go through the training set + mean_cost = [] + for batch_index in xrange(n_train_batches): + mean_cost += [train_rbm(batch_index)] + + print 'Training epoch %d, cost is ' % epoch, numpy.mean(mean_cost) + + # Plot filters after each training epoch + plotting_start = time.clock() + # Construct image from the weight matrix + image = Image.fromarray( + tile_raster_images( + X=rbm.W.get_value(borrow=True).T, + img_shape=(28, 28), + tile_shape=(10, 10), + tile_spacing=(1, 1) + ) + ) + image.save('filters_at_epoch_%i.png' % epoch) + plotting_stop = time.clock() + plotting_time += (plotting_stop - plotting_start) + + end_time = time.clock() + + pretraining_time = (end_time - start_time) - plotting_time + + print ('Training took %f minutes' % (pretraining_time / 60.)) + # end-snippet-5 start-snippet-6 + ################################# + # Sampling from the RBM # + ################################# + # find out the number of test samples + number_of_test_samples = test_set_x.get_value(borrow=True).shape[0] + + # pick random test examples, with which to initialize the persistent chain + test_idx = rng.randint(number_of_test_samples - n_chains) + persistent_vis_chain = theano.shared( + numpy.asarray( + test_set_x.get_value(borrow=True)[test_idx:test_idx + n_chains], + dtype=theano.config.floatX + ) + ) + # end-snippet-6 start-snippet-7 + plot_every = 1000 + # define one step of Gibbs sampling (mf = mean-field) define a + # function that does `plot_every` steps before returning the + # sample for plotting + ( + [ + presig_hids, + hid_mfs, + hid_samples, + presig_vis, + vis_mfs, + vis_samples + ], + updates + ) = theano.scan( + rbm.gibbs_vhv, + outputs_info=[None, None, None, None, None, persistent_vis_chain], + n_steps=plot_every + ) + + # add to updates the shared variable that takes care of our persistent + # chain :. + updates.update({persistent_vis_chain: vis_samples[-1]}) + # construct the function that implements our persistent chain. + # we generate the "mean field" activations for plotting and the actual + # samples for reinitializing the state of our persistent chain + sample_fn = theano.function( + [], + [ + vis_mfs[-1], + vis_samples[-1] + ], + updates=updates, + name='sample_fn' + ) + + # create a space to store the image for plotting ( we need to leave + # room for the tile_spacing as well) + image_data = numpy.zeros( + (29 * n_samples + 1, 29 * n_chains - 1), + dtype='uint8' + ) + for idx in xrange(n_samples): + # generate `plot_every` intermediate samples that we discard, + # because successive samples in the chain are too correlated + vis_mf, vis_sample = sample_fn() + print ' ... plotting sample ', idx + image_data[29 * idx:29 * idx + 28, :] = tile_raster_images( + X=vis_mf, + img_shape=(28, 28), + tile_shape=(1, n_chains), + tile_spacing=(1, 1) + ) + + # construct image + image = Image.fromarray(image_data) + image.save('samples.png') + # end-snippet-7 + os.chdir('../') + +if __name__ == '__main__': + test_rbm() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..1951c86 --- /dev/null +++ b/utils.py @@ -0,0 +1,140 @@ +""" This file contains different utility functions that are not connected +in anyway to the networks presented in the tutorials, but rather help in +processing the outputs into a more understandable way. + +For example ``tile_raster_images`` helps in generating a easy to grasp +image from a set of samples or weights. +""" + + +import numpy + + +def scale_to_unit_interval(ndar, eps=1e-8): + """ Scales all values in the ndarray ndar to be between 0 and 1 """ + ndar = ndar.copy() + ndar -= ndar.min() + ndar *= 1.0 / (ndar.max() + eps) + return ndar + + +################## KIN MODIFIED THIS PART, DO NOT USE!!!!!! TRANSPOSED TILE OUTPUT #### +def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0), + scale_rows_to_unit_interval=True, + output_pixel_vals=True): + """ + Transform an array with one flattened image per row, into an array in + which images are reshaped and layed out like tiles on a floor. + + This function is useful for visualizing datasets whose rows are images, + and also columns of matrices for transforming those rows + (such as the first layer of a neural net). + + :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can + be 2-D ndarrays or None; + :param X: a 2-D array in which every row is a flattened image. + + :type img_shape: tuple; (height, width) + :param img_shape: the original shape of each image + + :type tile_shape: tuple; (rows, cols) + :param tile_shape: the number of images to tile (rows, cols) + + :param output_pixel_vals: if output should be pixel values (i.e. int8 + values) or floats + + :param scale_rows_to_unit_interval: if the values need to be scaled before + being plotted to [0,1] or not + + + :returns: array suitable for viewing as an image. + (See:`Image.fromarray`.) + :rtype: a 2-d array with same dtype as X. + + """ + + assert len(img_shape) == 2 + assert len(tile_shape) == 2 + assert len(tile_spacing) == 2 + + # The expression below can be re-written in a more C style as + # follows : + # + # out_shape = [0,0] + # out_shape[0] = (img_shape[0]+tile_spacing[0])*tile_shape[0] - + # tile_spacing[0] + # out_shape[1] = (img_shape[1]+tile_spacing[1])*tile_shape[1] - + # tile_spacing[1] + out_shape = [ + (ishp + tsp) * tshp - tsp + for ishp, tshp, tsp in zip(img_shape, tile_shape, tile_spacing) + ] + + if isinstance(X, tuple): + assert len(X) == 4 + # Create an output numpy ndarray to store the image + if output_pixel_vals: + out_array = numpy.zeros((out_shape[0], out_shape[1], 4), + dtype='uint8') + else: + out_array = numpy.zeros((out_shape[0], out_shape[1], 4), + dtype=X.dtype) + + #colors default to 0, alpha defaults to 1 (opaque) + if output_pixel_vals: + channel_defaults = [0, 0, 0, 255] + else: + channel_defaults = [0., 0., 0., 1.] + + for i in xrange(4): + if X[i] is None: + # if channel is None, fill it with zeros of the correct + # dtype + dt = out_array.dtype + if output_pixel_vals: + dt = 'uint8' + out_array[:, :, i] = numpy.zeros( + out_shape, + dtype=dt + ) + channel_defaults[i] + else: + # use a recurrent call to compute the channel and store it + # in the output + out_array[:, :, i] = tile_raster_images( + X[i], img_shape, tile_shape, tile_spacing, + scale_rows_to_unit_interval, output_pixel_vals) + return out_array + + else: + # if we are dealing with only one channel + H, W = img_shape + Hs, Ws = tile_spacing + + # generate a matrix to store the output + dt = X.dtype + if output_pixel_vals: + dt = 'uint8' + out_array = numpy.zeros(out_shape, dtype=dt) + + for tile_row in xrange(tile_shape[0]): + for tile_col in xrange(tile_shape[1]): + if tile_row * tile_shape[1] + tile_col < X.shape[0]: + this_x = X[tile_row * tile_shape[1] + tile_col] + if scale_rows_to_unit_interval: + # if we should scale values to be between 0 and 1 + # do this by calling the `scale_to_unit_interval` + # function + this_img = scale_to_unit_interval( + this_x.reshape(img_shape).T) + else: + this_img = this_x.reshape(img_shape).T + # add the slice to the corresponding position in the + # output array + c = 1 + if output_pixel_vals: + c = 255 + out_array[ + tile_row * (H + Hs): tile_row * (H + Hs) + H, + tile_col * (W + Ws): tile_col * (W + Ws) + W + ] = this_img * c + return out_array