Skip to content

Fixed scipy dependency and ported code to tensorflow 2 #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import argparse, os, json
import h5py
import numpy as np
from scipy.misc import imread, imresize

# from scipy.misc import imread, imresize
from cv2 import imread, resize as imresize
import torch
import torchvision

Expand Down Expand Up @@ -86,8 +86,8 @@ def main(args):
i0 = 0
cur_batch = []
for i, (path, idx) in enumerate(input_paths):
img = imread(path, mode='RGB')
img = imresize(img, img_size, interp='bicubic')
img = imread(path)
img = imresize(img, img_size)
img = img.transpose(2, 0, 1)[None]
cur_batch.append(img)
if len(cur_batch) == args.batch_size:
Expand Down
28 changes: 14 additions & 14 deletions mac_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
3. The Write Unit integrates the retrieved information to the previous hidden memory state,
given the value of the control state, to perform the current reasoning operation.
'''
class MACCell(tf.nn.rnn_cell.RNNCell):
class MACCell(tf.compat.v1.nn.rnn_cell.RNNCell):

'''Initialize the MAC cell.
(Note that in the current version the cell is stateful --
Expand Down Expand Up @@ -133,7 +133,7 @@ def output_size(self):
def control(self, controlInput, inWords, outWords, questionLengths,
control, contControl = None, name = "", reuse = None):

with tf.variable_scope("control" + name, reuse = reuse):
with tf.compat.v1.variable_scope("control" + name, reuse = reuse):
dim = config.ctrlDim

## Step 1: compute "continuous" control state given previous control and question.
Expand Down Expand Up @@ -207,14 +207,14 @@ def control(self, controlInput, inWords, outWords, questionLengths,
[batchSize, memDim]
'''
def read(self, knowledgeBase, memory, control, name = "", reuse = None):
with tf.variable_scope("read" + name, reuse = reuse):
with tf.compat.v1.variable_scope("read" + name, reuse = reuse):
dim = config.memDim

## memory dropout
if config.memoryVariationalDropout:
memory = ops.applyVarDpMask(memory, self.memDpMask, self.dropouts["memory"])
else:
memory = tf.nn.dropout(memory, self.dropouts["memory"])
memory = tf.compat.v1.nn.dropout(memory, self.dropouts["memory"])

## Step 1: knowledge base / memory interactions
# parameters for knowledge base and memory projection
Expand Down Expand Up @@ -303,7 +303,7 @@ def read(self, knowledgeBase, memory, control, name = "", reuse = None):
[batchSize, memDim]
'''
def write(self, memory, info, control, contControl = None, name = "", reuse = None):
with tf.variable_scope("write" + name, reuse = reuse):
with tf.compat.v1.variable_scope("write" + name, reuse = reuse):

# optionally project info
if config.writeInfoProj:
Expand Down Expand Up @@ -374,8 +374,8 @@ def write(self, memory, info, control, contControl = None, name = "", reuse = No

return newMemory

def memAutoEnc(newMemory, info, control, name = "", reuse = None):
with tf.variable_scope("memAutoEnc" + name, reuse = reuse):
def memAutoEnc(self, newMemory, info, control, name = "", reuse = None):
with tf.compat.v1.variable_scope("memAutoEnc" + name, reuse = reuse):
# inputs to auto encoder
features = info if config.autoEncMemInputs == "INFO" else newMemory
features = ops.linear(features, config.memDim, config.ctrlDim,
Expand Down Expand Up @@ -419,7 +419,7 @@ def memAutoEnc(newMemory, info, control, name = "", reuse = None):
'''
def __call__(self, inputs, state, scope = None):
scope = scope or type(self).__name__
with tf.variable_scope(scope, reuse = self.reuse): # as tfscope
with tf.compat.v1.variable_scope(scope, reuse = self.reuse): # as tfscope
control = state.control
memory = state.memory

Expand Down Expand Up @@ -460,7 +460,7 @@ def __call__(self, inputs, state, scope = None):

if config.writeDropout < 1.0:
# write unit
info = tf.nn.dropout(info, self.dropouts["write"])
info = tf.compat.v1.nn.dropout(info, self.dropouts["write"])

newMemory = self.write(memory, info, newControl, self.contControl, name = cellName, reuse = cellReuse)

Expand Down Expand Up @@ -495,9 +495,9 @@ def __call__(self, inputs, state, scope = None):
'''
def initState(self, name, dim, initType, batchSize):
if initType == "PRM":
prm = tf.get_variable(name, shape = (dim, ),
prm = tf.compat.v1.get_variable(name, shape = (dim, ),
initializer = tf.random_normal_initializer())
initState = tf.tile(tf.expand_dims(prm, axis = 0), [batchSize, 1])
initState = tf.compat.v1.tile(tf.expand_dims(prm, axis = 0), [batchSize, 1])
elif initType == "ZERO":
initState = tf.zeros((batchSize, dim), dtype = tf.float32)
else: # "Q"
Expand All @@ -516,8 +516,8 @@ def initState(self, name, dim, initType, batchSize):

Returns the updated word sequence and lengths.
'''
def addNullWord(words, lengths):
nullWord = tf.get_variable("zeroWord", shape = (1 , config.ctrlDim), initializer = tf.random_normal_initializer())
def addNullWord(self, words, lengths):
nullWord = tf.compat.v1.get_variable("zeroWord", shape = (1 , config.ctrlDim), initializer = tf.random_normal_initializer())
nullWord = tf.tile(tf.expand_dims(nullWord, axis = 0), [self.batchSize, 1, 1])
words = tf.concat([nullWord, words], axis = 1)
lengths += 1
Expand Down Expand Up @@ -582,7 +582,7 @@ def zero_state(self, batchSize, dtype = tf.float32):

# if config.controlCoverage:
# self.coverage = tf.zeros((batchSize, tf.shape(words)[1]), dtype = tf.float32)
# self.coverageBias = tf.get_variable("coverageBias", shape = (),
# self.coverageBias = tf.compat.v1.get_variable("coverageBias", shape = (),
# initializer = config.controlCoverageBias)

## initialize memory variational dropout mask
Expand Down
20 changes: 11 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from model import MACnet
from collections import defaultdict

tf.compat.v1.disable_eager_execution()

############################################# loggers #############################################

# Writes log header to file
Expand Down Expand Up @@ -151,7 +153,7 @@ def writePreds(preprocessor, evalRes, extraEvalRes):
############################################# session #############################################
# Initializes TF session. Sets GPU memory configuration.
def setSession():
sessionConfig = tf.ConfigProto(allow_soft_placement = True, log_device_placement = False)
sessionConfig = tf.compat.v1.ConfigProto(allow_soft_placement = True, log_device_placement = False)
if config.allowGrowth:
sessionConfig.gpu_options.allow_growth = True
if config.maxMemory < 1.0:
Expand All @@ -161,17 +163,17 @@ def setSession():
############################################## savers #############################################
# Initializes savers (standard, optional exponential-moving-average and optional for subset of variables)
def setSavers(model):
saver = tf.train.Saver(max_to_keep = config.weightsToKeep)
saver = tf.compat.v1.train.Saver(max_to_keep = config.weightsToKeep)

subsetSaver = None
if config.saveSubset:
isRelevant = lambda var: any(s in var.name for s in config.varSubset)
relevantVars = [var for var in tf.global_variables() if isRelevant(var)]
subsetSaver = tf.train.Saver(relevantVars, max_to_keep = config.weightsToKeep, allow_empty = True)
relevantVars = [var for var in tf.compat.v1.global_variables() if isRelevant(var)]
subsetSaver = tf.compat.v1.train.Saver(relevantVars, max_to_keep = config.weightsToKeep, allow_empty = True)

emaSaver = None
if config.useEMA:
emaSaver = tf.train.Saver(model.emaDict, max_to_keep = config.weightsToKeep)
emaSaver = tf.compat.v1.train.Saver(model.emaDict, max_to_keep = config.weightsToKeep)

return {
"saver": saver,
Expand Down Expand Up @@ -657,7 +659,7 @@ def main():
config.gpusNum = len(config.gpus.split(","))
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus

tf.logging.set_verbosity(tf.logging.ERROR)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

# process data
print(bold("Preprocess data..."))
Expand All @@ -673,7 +675,7 @@ def main():
print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue")))

# initializer
init = tf.global_variables_initializer()
init = tf.compat.v1.global_variables_initializer()

# savers
savers = setSavers(model)
Expand All @@ -682,7 +684,7 @@ def main():
# sessionConfig
sessionConfig = setSession()

with tf.Session(config = sessionConfig) as sess:
with tf.compat.v1.Session(config = sessionConfig) as sess:

# ensure no more ops are added after model is built
sess.graph.finalize()
Expand Down Expand Up @@ -711,7 +713,7 @@ def main():
# save weights
saver.save(sess, config.weightsFile(epoch))
if config.saveSubset:
subsetSaver.save(sess, config.subsetWeightsFile(epoch))
config.saveSubset.save(sess, config.subsetWeightsFile(epoch))

# load EMA weights
if config.useEMA:
Expand Down
18 changes: 9 additions & 9 deletions mi_gru_cell.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tensorflow as tf
import numpy as np

class MiGRUCell(tf.nn.rnn_cell.RNNCell):
class MiGRUCell(tf.compat.v1.nn.rnn_cell.RNNCell):
def __init__(self, num_units, input_size = None, activation = tf.tanh, reuse = None):
self.numUnits = num_units
self.activation = activation
Expand All @@ -16,19 +16,19 @@ def output_size(self):
return self.numUnits

def mulWeights(self, inp, inDim, outDim, name = ""):
with tf.variable_scope("weights" + name):
W = tf.get_variable("weights", shape = (inDim, outDim),
initializer = tf.contrib.layers.xavier_initializer())
with tf.compat.v1.variable_scope("weights" + name):
W = tf.compat.v1.get_variable("weights", shape = (inDim, outDim),
initializer = tf.compat.v1.keras.initializers.glorot_normal())

output = tf.matmul(inp, W)
return output

def addBiases(self, inp1, inp2, dim, bInitial = 0, name = ""):
with tf.variable_scope("additiveBiases" + name):
b = tf.get_variable("biases", shape = (dim,),
with tf.compat.v1.variable_scope("additiveBiases" + name):
b = tf.compat.v1.get_variable("biases", shape = (dim,),
initializer = tf.zeros_initializer()) + bInitial
with tf.variable_scope("multiplicativeBias" + name):
beta = tf.get_variable("biases", shape = (3 * dim,),
with tf.compat.v1.variable_scope("multiplicativeBias" + name):
beta = tf.compat.v1.get_variable("biases", shape = (3 * dim,),
initializer = tf.ones_initializer())

Wx, Uh, inter = tf.split(beta * tf.concat([inp1, inp2, inp1 * inp2], axis = 1),
Expand All @@ -38,7 +38,7 @@ def addBiases(self, inp1, inp2, dim, bInitial = 0, name = ""):

def __call__(self, inputs, state, scope = None):
scope = scope or type(self).__name__
with tf.variable_scope(scope, reuse = self.reuse):
with tf.compat.v1.variable_scope(scope, reuse = self.reuse):
inputSize = int(inputs.shape[1])

Wxr = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxr")
Expand Down
24 changes: 12 additions & 12 deletions mi_lstm_cell.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tensorflow as tf
import numpy as np

class MiLSTMCell(tf.nn.rnn_cell.RNNCell):
class MiLSTMCell(tf.compat.v1.nn.rnn_cell.RNNCell):
def __init__(self, num_units, forget_bias = 1.0, input_size = None,
state_is_tuple = True, activation = tf.tanh, reuse = None):
self.numUnits = num_units
Expand All @@ -11,25 +11,25 @@ def __init__(self, num_units, forget_bias = 1.0, input_size = None,

@property
def state_size(self):
return tf.nn.rnn_cell.LSTMStateTuple(self.numUnits, self.numUnits)
return tf.compat.v1.nn.rnn_cell.LSTMStateTuple(self.numUnits, self.numUnits)

@property
def output_size(self):
return self.numUnits

def mulWeights(self, inp, inDim, outDim, name = ""):
with tf.variable_scope("weights" + name):
W = tf.get_variable("weights", shape = (inDim, outDim),
initializer = tf.contrib.layers.xavier_initializer())
with tf.compat.v1.variable_scope("weights" + name):
W = tf.compat.v1.get_variable("weights", shape = (inDim, outDim),
initializer = tf.compat.v1.keras.initializers.glorot_normal())
output = tf.matmul(inp, W)
return output

def addBiases(self, inp1, inp2, dim, name = ""):
with tf.variable_scope("additiveBiases" + name):
b = tf.get_variable("biases", shape = (dim,),
with tf.compat.v1.variable_scope("additiveBiases" + name):
b = tf.compat.v1.get_variable("biases", shape = (dim,),
initializer = tf.zeros_initializer())
with tf.variable_scope("multiplicativeBias" + name):
beta = tf.get_variable("biases", shape = (3 * dim,),
with tf.compat.v1.variable_scope("multiplicativeBias" + name):
beta = tf.compat.v1.get_variable("biases", shape = (3 * dim,),
initializer = tf.ones_initializer())

Wx, Uh, inter = tf.split(beta * tf.concat([inp1, inp2, inp1 * inp2], axis = 1),
Expand All @@ -39,7 +39,7 @@ def addBiases(self, inp1, inp2, dim, name = ""):

def __call__(self, inputs, state, scope = None):
scope = scope or type(self).__name__
with tf.variable_scope(scope, reuse = self.reuse):
with tf.compat.v1.variable_scope(scope, reuse = self.reuse):
c, h = state
inputSize = int(inputs.shape[1])

Expand Down Expand Up @@ -68,10 +68,10 @@ def __call__(self, inputs, state, scope = None):
self.activation(j))
newH = self.activation(newC) * tf.nn.sigmoid(o)

newState = tf.nn.rnn_cell.LSTMStateTuple(newC, newH)
newState = tf.compat.v1.nn.rnn_cell.LSTMStateTuple(newC, newH)
return newH, newState

def zero_state(self, batchSize, dtype = tf.float32):
return tf.nn.rnn_cell.LSTMStateTuple(tf.zeros((batchSize, self.numUnits), dtype = dtype),
return tf.compat.v1.nn.rnn_cell.LSTMStateTuple(tf.zeros((batchSize, self.numUnits), dtype = dtype),
tf.zeros((batchSize, self.numUnits), dtype = dtype))

Loading