Exemple #1
0
  decodeSequence(inputSeq) {
    // Encode the inputs state vectors.
    let statesValue = this.encoderModel.predict(inputSeq);

    // Generate empty target sequence of length 1.
    let targetSeq = tf.buffer([1, 1, this.numDecoderTokens]);
    // Populate the first character of the target sequence with the start
    // character.
    targetSeq.set(1, 0, 0, this.targetTokenIndex['\t']);

    // Sample loop for a batch of sequences.
    // (to simplify, here we assume that a batch of size 1).
    let stopCondition = false;
    let decodedSentence = '';
    while (!stopCondition) {
      const predictOutputs =
          this.decoderModel.predict([targetSeq.toTensor()].concat(statesValue));
      const outputTokens = predictOutputs[0];
      const h = predictOutputs[1];
      const c = predictOutputs[2];

      // Sample a token.
      // We know that outputTokens.shape is [1, 1, n], so no need for slicing.
      const logits = outputTokens.reshape([outputTokens.shape[2]]);
      const sampledTokenIndex = logits.argMax().dataSync()[0];
      const sampledChar = this.reverseTargetCharIndex[sampledTokenIndex];
      decodedSentence += sampledChar;

      // Exit condition: either hit max length or find stop character.
      if (sampledChar === '\n' ||
          decodedSentence.length > this.maxDecoderSeqLength) {
        stopCondition = true;
      }

      // Update the target sequence (of length 1).
      targetSeq = tf.buffer([1, 1, this.numDecoderTokens]);
      targetSeq.set(1, 0, 0, sampledTokenIndex);

      // Update states.
      statesValue = [h, c];
    }

    return decodedSentence;
  }
Exemple #2
0
  /**
   * Encode a string (e.g., a sentence) as a Tensor3D that can be fed directly
   * into the TensorFlow.js model.
   */
  encodeString(str) {
    const strLen = str.length;
    const encoded =
        tf.buffer([1, this.maxEncoderSeqLength, this.numEncoderTokens]);
    for (let i = 0; i < strLen; ++i) {
      if (i >= this.maxEncoderSeqLength) {
        console.error(
            'Input sentence exceeds maximum encoder sequence length: ' +
            this.maxEncoderSeqLength);
      }

      const tokenIndex = this.inputTokenIndex[str[i]];
      if (tokenIndex == null) {
        console.error(
            'Character not found in input token index: "' + tokenIndex + '"');
      }
      encoded.set(1, 0, i, tokenIndex);
    }
    return encoded.toTensor();
  }