# RNN implementation in Tensorflow based on https://github.com/tensorflow/models/blob/0d9a3abdca7be4a855dc769d6d441a5bfcb77c6d/tutorials/rnn/ptb/ptb_word_lm.py
# built with tensorflow 0.12.0

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import collections
import argparse
import time
import os
import shutil
import datetime
import json

parser = argparse.ArgumentParser()
parser.add_argument("-input_file", default="")
parser.add_argument("-output_dir", default=None)
parser.add_argument("-log", action="store_true")

# sample
parser.add_argument("-sample", default="")
parser.add_argument("-sample_length", type=int, default=200)
parser.add_argument("-sample_minutes", type=int, default=0)
parser.add_argument("-primesymbols", default="")
parser.add_argument("-temperature", type=float, default=1.0)
parser.add_argument("-max", action="store_true", help="always choose the most probable character")

# train
parser.add_argument("-seed", type=int, default=0)
parser.add_argument("-batch_count", type=int, default=64)
parser.add_argument("-seq_length", type=int, default=63) # slightly different from batch_count so that the dimensions are not easily confused
parser.add_argument("-num_layers", type=int, default=1)
parser.add_argument("-rnn_size", type=int, default=128)
parser.add_argument("-dropout", type=float, default=0.0)
parser.add_argument("-max_epochs", type=int, default=50)
parser.add_argument("-learning_rate", type=float, default=2e-3)
parser.add_argument("-learning_rate_decay", type=float, default=0.97)
parser.add_argument("-learning_rate_decay_after", type=int, default=10, help="number of epochs before learning rate decays")
parser.add_argument("-train_frac", type=float, default=0.8, help="fraction of data used for training")
parser.add_argument("-val_frac", type=float, default=0.1, help="fraction of data used for validation")
parser.add_argument("-train_minutes", type=int, default=0)
a = parser.parse_args()

saved_args = ["batch_count", "seq_length", "num_layers", "rnn_size", "train_frac", "val_frac", "input_file"]

Dataset = collections.namedtuple("Dataset", "vocab_size, symbol_to_id, id_to_symbol, data")
Examples = collections.namedtuple("Examples", "vocab_size, steps, inputs, targets")
Model = collections.namedtuple("Model", "examples, loss, learning_rate, train_op, raw_outputs, outputs, initial_state, final_state")

INIT_SCALE = 0.08
GRAD_CLIP = 5

SYMBOL_MODE = 0
CHAR_MODE = 1
MODE = SYMBOL_MODE

START_TIME = time.time()


def load_dataset(input_file):
    if input_file == "":
        raise Exception("no input_file specified")

    if MODE == CHAR_MODE:
        raw = open(input_file).read()
        all_symbols = list(sorted(set(raw)))
    else:
        raw = np.load(input_file)
        all_symbols = np.unique(raw)

    symbol_to_id = {s: i for i, s in enumerate(all_symbols)}
    id_to_symbol = all_symbols

    data = np.zeros(len(raw), dtype=np.int32)
    for i, symbol in enumerate(raw):
        data[i] = symbol_to_id[symbol]

    return Dataset(
        vocab_size=len(all_symbols),
        symbol_to_id=symbol_to_id,
        id_to_symbol=id_to_symbol,
        data=data,
    )


def create_examples(name, dataset):
    batch_length = len(dataset.data) // a.batch_count

    data_tensor = tf.convert_to_tensor(dataset.data, name="data_tensor", dtype=tf.int32)
    # [batch_count, batch_length]
    batched_data = tf.reshape(data_tensor[:a.batch_count * batch_length], [a.batch_count, batch_length])

    # leave at least one item at the end for targets
    epoch_steps = (batch_length - 1) // a.seq_length

    if epoch_steps <= 0:
        raise Exception("epoch_steps == 0, decrease batch_count or seq_length")

    train_steps = int(epoch_steps * a.train_frac)
    val_steps = int(epoch_steps * a.val_frac)
    test_steps = epoch_steps - train_steps - val_steps

    if test_steps < 0:
        raise Exception("test_steps == 0, train_frac and val_frac may be too high")

    if name == "train" or name == "train_eval":
        steps = train_steps
        base = 0
    elif name == "val":
        steps = val_steps
        base = train_steps * a.seq_length
    elif name == "test":
        steps = test_steps
        base = (train_steps + val_steps) * a.seq_length
    else:
        raise Exception("invalid name")

    i = tf.train.range_input_producer(steps, shuffle=False).dequeue()
    inputs = tf.strided_slice(batched_data, [0, base + i * a.seq_length], [a.batch_count, base + (i + 1) * a.seq_length], [1, 1])
    inputs.set_shape([a.batch_count, a.seq_length])
    targets = tf.strided_slice(batched_data, [0, base + i * a.seq_length + 1], [a.batch_count, base + (i + 1) * a.seq_length + 1], [1, 1])
    targets.set_shape([a.batch_count, a.seq_length])

    return Examples(
        vocab_size=dataset.vocab_size,
        steps=steps,
        inputs=inputs,
        targets=targets,
    )


def create_model(training, examples):
    batch_count, seq_length = [dim.value for dim in examples.inputs.get_shape()]

    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(a.rnn_size, forget_bias=1.0, state_is_tuple=True)
    if training and a.dropout > 0:
        lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=(1 - a.dropout))
    cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * a.num_layers, state_is_tuple=True)

    initial_state = cell.zero_state(batch_count, tf.float32)

    embedding = tf.get_variable("embedding", [examples.vocab_size, a.rnn_size], dtype=tf.float32)
    # [batch_count, seq_length, rnn_size]
    inputs = tf.nn.embedding_lookup(embedding, examples.inputs)

    if training and a.dropout > 0:
        inputs = tf.nn.dropout(inputs, 1 - a.dropout)

    # [seq_length, batch_count, rnn_size]
    inputs = tf.unstack(inputs, num=seq_length, axis=1)
    # [seq_length, batch_count, rnn_size], [batch_count, rnn_size]
    cell_outputs, final_state = tf.nn.rnn(cell, inputs, initial_state=initial_state)

    # [batch_count, rnn_size * seq_length] => [batch_count * seq_length, rnn_size]
    cell_output = tf.reshape(tf.concat_v2(cell_outputs, 1), [-1, a.rnn_size])

    output_w = tf.get_variable("output_w", [a.rnn_size, examples.vocab_size], dtype=tf.float32)
    output_b = tf.get_variable("output_b", [examples.vocab_size], dtype=tf.float32)
    # [batch_count * seq_length, vocab_size]
    raw_outputs = tf.matmul(cell_output, output_w) + output_b # give out pre-softmax outputs so temperature can be applied
    outputs = tf.nn.softmax(raw_outputs)
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(raw_outputs, tf.reshape(examples.targets, [-1])))

    model = Model(
        examples=examples,
        loss=loss,
        initial_state=initial_state,
        final_state=final_state,
        raw_outputs=raw_outputs,
        outputs=outputs,
        learning_rate=None,
        train_op=None,
    )

    if not training:
        return model

    global_step = tf.contrib.framework.get_or_create_global_step()
    decay_step = tf.reduce_max([global_step - examples.steps * (a.learning_rate_decay_after - 1), 0])
    learning_rate = tf.train.exponential_decay(a.learning_rate, decay_step, examples.steps, a.learning_rate_decay, staircase=True)
    tvars = tf.trainable_variables()
    # grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), conf.max_grad_norm)
    grads = tf.gradients(loss, tvars)
    grads = [tf.clip_by_value(g, -GRAD_CLIP, GRAD_CLIP) for g in grads]
    optimizer = tf.train.AdamOptimizer(learning_rate)
    train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step)

    return model._replace(
        train_op=train_op,
        learning_rate=learning_rate,
    )


def run_epoch(epoch, session, model, verbose=False):
    start_time = time.time()
    state = session.run(model.initial_state)
    ex = model.examples

    fetches = {
        "loss": model.loss,
        "final_state": model.final_state,
    }

    if model.train_op is not None:
        fetches["train_op"] = model.train_op

    losses = 0.0
    for step in range(ex.steps):
        feed_dict = feed_dict_from_state(model.initial_state, state)
        result = session.run(fetches, feed_dict)
        state = result["final_state"]

        losses += result["loss"]
        loss = losses/(step+1)

        if verbose and step % (ex.steps // 10) == 10:
            elapsed = time.time() - start_time
            rate = (step + 1) * a.seq_length * a.batch_count / elapsed
            if rate > 1000:
                rate_str = "%dk" % (rate / 1000)
            else:
                rate_str = "%d" % (rate)
            print("epoch %.3f  loss %.3f  speed %s samples/sec" % (epoch + step / ex.steps, loss, rate_str))

    return loss


def feed_dict_from_state(initial_state, state):
    feed_dict = {}
    for layer, tensors in enumerate(initial_state):
        for i, tensor in enumerate(tensors):
            feed_dict[tensor] = state[layer][i]
    return feed_dict


def run_sample(session, model, dataset):
    start_time = time.time()
    state = session.run(model.initial_state)
    ex = model.examples

    start = time.time()

    samples = []
    if len(a.primesymbols) == 0:
        current = np.random.randint(dataset.vocab_size)
    else:
        primesymbols = [int(e) for e in a.primesymbols.split(", ")]
        for symbol in primesymbols[:-1]:
            feed_dict = feed_dict_from_state(model.initial_state, state)
            current = dataset.symbol_to_id[symbol]
            feed_dict[model.examples.inputs] = np.array([[current]])
            state = session.run(model.final_state, feed_dict)
            samples.append(current)
        current = dataset.symbol_to_id[primesymbols[-1]]
    samples.append(current)

    while len(samples) < a.sample_length:
        feed_dict = feed_dict_from_state(model.initial_state, state)
        feed_dict[model.examples.inputs] = np.array([[current]])

        raw_outputs, state = session.run([model.raw_outputs, model.final_state], feed_dict)
        scaled_output = np.exp(np.array(raw_outputs[0], dtype=np.float64) / a.temperature)
        probabilities = scaled_output / np.sum(scaled_output)
        if a.max:
            current = np.argmax(probabilities)
        else:
            current = np.random.choice(dataset.vocab_size, p=probabilities)
        samples.append(current)

        if a.sample_minutes > 0 and time.time() - START_TIME > a.sample_minutes * 60:
            print("sample time expired")
            break

    converted_samples = [dataset.id_to_symbol[id] for id in samples]
    if MODE == CHAR_MODE:
        sampled = "".join(converted_samples)
    else:
        sampled = np.array(converted_samples, dtype=np.int64)
    elapsed = time.time() - start
    print("\n== sampled (%d samples/sec) ==\n%s\n==================\n" % (len(samples)/elapsed, sampled))
    return sampled


saved_properties = {}
def save_property(k, v):
    if isinstance(v, float):
        print("%s = %0.3f" % (k, v))
    else:
        print("%s = %s" % (k, v))

    if a.output_dir is not None:
        def default(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            if isinstance(obj, np.float):
                return float(obj)
            raise TypeError
        saved_properties[k] = v
        with open(os.path.join(a.output_dir, "properties.json"), "w") as f:
            f.write(json.dumps(saved_properties, sort_keys=True, indent=4, default=default))


def main():
    tf.set_random_seed(a.seed)
    np.random.seed(a.seed)

    # if we are using a checkpoint, load args from there
    if a.sample != "":
        properties = json.loads(open(os.path.join(a.sample, "properties.json")).read())
        for k in saved_args:
            v = properties[k]
            print("loaded saved property", k, "=", v)
            setattr(a, k, v)

    if not a.input_file.endswith(".npy"):
        global MODE
        MODE = CHAR_MODE

    for k, v in a._get_kwargs():
        save_property(k, v)

    dataset = load_dataset(a.input_file)
    save_property("data_length", len(dataset.data))
    save_property("vocab_size", dataset.vocab_size)

    initializer = tf.random_uniform_initializer(-INIT_SCALE, INIT_SCALE)

    models = {}
    for name in ["train", "val", "test", "train_eval"]:
        training = name == "train"
        with tf.name_scope(name):
            with tf.name_scope("examples"):
                examples = create_examples(name, dataset)

            with tf.variable_scope("model", reuse=not training, initializer=initializer):
                models[name] = create_model(training=training, examples=examples)

            if name == "train":
                tf.summary.scalar("Training Loss", models[name].loss)
                tf.summary.scalar("Learning Rate", models[name].learning_rate)
            if name == "val":
                tf.summary.scalar("Validation Loss", models[name].loss)

    with tf.name_scope("sample"):
        with tf.name_scope("examples"):
            examples = Examples(
                vocab_size=dataset.vocab_size,
                steps=None,
                inputs=tf.placeholder(tf.int32, shape=[1, 1]),
                targets=tf.zeros([1, 1], dtype=tf.int32)
            )

        with tf.variable_scope("model", reuse=True, initializer=initializer):
            models["sample"] = create_model(training=False, examples=examples)

    parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    saver = tf.train.Saver()

    # sampling
    if a.sample != "":
        sess = tf.Session()

        checkpoint = tf.train.latest_checkpoint(a.sample)
        saver.restore(sess, checkpoint)

        sampled = run_sample(sess, models["sample"], dataset)
        if a.output_dir is not None:
            if MODE == CHAR_MODE:
                with open(os.path.join(a.output_dir, "sample.txt"), "w") as f:
                    f.write(sampled)
            else:
                np.save(os.path.join(a.output_dir, "sample.npy"), sampled)
        return

    # training
    logdir = None
    if a.log:
        logdir = a.output_dir
    sv = tf.train.Supervisor(logdir=logdir, saver=None)
    with sv.managed_session() as sess:
        save_property("parameter_count", sess.run(parameter_count))

        for i in range(a.max_epochs):
            epoch = i + 1
            print("training for epoch %d learning rate %.3f" % (epoch, sess.run(models["train"].learning_rate)))

            for kind in ["train", "val"]:
                loss = run_epoch(epoch, sess, models[kind], verbose=(kind == "train"))
                save_property(kind + "_loss", loss)

            # save model
            if a.output_dir is not None:
                suffix = ("%08d-%.2f") % (sess.run(sv.global_step), saved_properties.get("val_loss", float("NaN")))
                saver_filename = "model-" + suffix
                saver.save(sess, os.path.join(a.output_dir, saver_filename), write_meta_graph=False)
                print("saved_path", saver_filename)

            save_property("epoch", epoch)
            run_sample(sess, models["sample"], dataset)

            if a.train_minutes > 0 and time.time() - START_TIME > a.train_minutes * 60:
                print("training time expired")
                break

        for kind in ["test", "train_eval"]:
            loss = run_epoch(epoch, sess, models[kind])
            save_property(kind + "_loss", loss)

main()
