Skip to content

Read const_component_dim during conversion of nnet2 to nnet3 model #3071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 4, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions egs/wsj/s5/steps/nnet3/convert_nnet2_to_nnet3.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class Nnet3Model(object):
def __init__(self):
self.input_dim = -1
self.output_dim = -1
self.ivector_dim = -1
self.counts = defaultdict(int)
self.num_components = 0
self.components_read = 0
Expand All @@ -118,7 +119,10 @@ def add_component(self, component, pairs):
Component = namedtuple("Component", "ident component pairs")

if "<InputDim>" in pairs and self.input_dim == -1:
self.input_dim = pairs["<InputDim>"]
self.input_dim = int(pairs["<InputDim>"])

if "<ConstComponentDim>" in pairs and self.ivector_dim == -1:
self.ivector_dim = int(pairs["<ConstComponentDim>"])

# remove nnet2 specific tokens and catch descriptors
if component == "<PnormComponent>" and "<P>" in pairs:
Expand Down Expand Up @@ -159,13 +163,18 @@ def write_config(self, filename):
config_string=config_string))

f.write("\n# Component nodes\n")
f.write("input-node name=input dim={0}\n".format(self.input_dim))
if self.ivector_dim != -1:
f.write("input-node name=input dim={0}\n".format(self.input_dim-self.ivector_dim))
f.write("input-node name=ivector dim={0}\n".format(self.ivector_dim))
else:
f.write("input-node name=input dim={0}\n".format(self.input_dim))
previous_component = "input"
for component in self.components:
if component.ident == "splice":
# Create splice string for the next node
previous_component = make_splice_string(previous_component,
component.pairs["<Context>"])
component.pairs["<Context>"],
component.pairs["<ConstComponentDim>"])
continue
f.write("component-node name={name} component={name} "
"input={inp}\n".format(name=component.ident,
Expand Down Expand Up @@ -264,7 +273,7 @@ def parse_component(line, line_buffer):
pairs = {}

if component in SPLICE_COMPONENTS:
pairs = parse_splice_component(component, line, line_buffer)
line, pairs = parse_splice_component(component, line, line_buffer)
elif component in AFFINE_COMPONENTS:
pairs = parse_affine_component(component, line, line_buffer)
elif component == "<FixedScaleComponent>":
Expand Down Expand Up @@ -335,7 +344,13 @@ def parse_splice_component(component, line, line_buffer):
line = consume_token("<Context>", line)
context = line.strip()[1:-1].split()

return {"<InputDim>" : input_dim, "<Context>" : context}
const_component_dim = 0
line = next(line_buffer) # Context vector adds newline
line = consume_token("<ConstComponentDim>", line)
const_component_dim = int(line.strip().split()[0])

return line, {"<InputDim>" : input_dim, "<Context>" : context,
"<ConstComponentDim>" : const_component_dim}

def parse_end_of_component(component, line, line_buffer):
# Keeps reading until it hits the end tag for component
Expand Down Expand Up @@ -422,14 +437,16 @@ def consume_token(token, line):

return line.partition(token)[2]

def make_splice_string(nodename, context):
def make_splice_string(nodename, context, const_component_dim=0):
"""Generates splice string from a list of context.

E.g. make_splice_string("renorm4", [-4, 4])
returns "Append(Offset(renorm4, -4), Offset(renorm4, 4))"
"""
assert type(context) == list, "context argument must be a list"
string = ["Offset({0}, {1})".format(nodename, i) for i in context]
if const_component_dim > 0:
string.append("ReplaceIndex(ivector, t, 0)")
string = "Append(" + ", ".join(string) + ")"
return string

Expand Down