CipherInputStream.java

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.crypto.stream;

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.util.Properties;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.ShortBufferException;

import org.apache.commons.crypto.cipher.Cipher;
import org.apache.commons.crypto.cipher.CipherTransformation;
import org.apache.commons.crypto.stream.input.ChannelInput;
import org.apache.commons.crypto.stream.input.Input;
import org.apache.commons.crypto.stream.input.StreamInput;
import org.apache.commons.crypto.utils.Utils;

/**
 * CipherInputStream reads input data and decrypts data in stream manner. It supports
 * any mode of operations such as AES CBC/CTR/GCM mode in concept.It is not thread-safe.
 *
 */

public class CipherInputStream extends InputStream implements
    ReadableByteChannel {
  private final byte[] oneByteBuf = new byte[1];

  protected final Cipher cipher;
  protected final int bufferSize;

  protected final byte[] key;
  protected final byte[] initIV;
  protected byte[] iv;

  protected boolean closed;
  protected boolean finalDone = false;

  protected Input input;

  /**
   * Input data buffer. The data starts at inBuffer.position() and ends at
   * to inBuffer.limit().
   */
  protected ByteBuffer inBuffer;

  /**
   * The decrypted data buffer. The data starts at outBuffer.position() and
   * ends at outBuffer.limit();
   */
  protected ByteBuffer outBuffer;

  public CipherInputStream(CipherTransformation transformation,
                           Properties props, InputStream in, byte[] key, byte[] iv)
      throws IOException {
    this(in, Utils.getCipherInstance(transformation, props),
        Utils.getBufferSize(props), key, iv);
  }

  public CipherInputStream(CipherTransformation transformation,
                           Properties props, ReadableByteChannel in, byte[] key, byte[] iv)
      throws IOException {
    this(in, Utils.getCipherInstance(transformation, props),
        Utils.getBufferSize(props), key, iv);
  }

  public CipherInputStream(InputStream in, Cipher cipher, int bufferSize,
                           byte[] key, byte[] iv) throws IOException {
    this(new StreamInput(in, bufferSize), cipher, bufferSize, key, iv);
  }

  public CipherInputStream(ReadableByteChannel in, Cipher cipher,
                           int bufferSize, byte[] key, byte[] iv) throws IOException {
    this(new ChannelInput(in), cipher, bufferSize, key, iv);
  }

  public CipherInputStream(
      Input input,
      Cipher cipher,
      int bufferSize,
      byte[] key,
      byte[] iv) throws IOException {
    this.input = input;
    this.cipher = cipher;
    this.bufferSize = Utils.checkBufferSize(cipher, bufferSize);
    this.key = key.clone();
    this.initIV = iv.clone();
    this.iv = iv.clone();

    inBuffer = ByteBuffer.allocateDirect(this.bufferSize);
    outBuffer = ByteBuffer.allocateDirect(this.bufferSize +
        cipher.getTransformation().getAlgorithmBlockSize());
    outBuffer.limit(0);

    initCipher();
  }

  @Override
  public int read() throws IOException {
    int n;
    while ((n = read(oneByteBuf, 0, 1)) == 0) ;
    return (n == -1) ? -1 : oneByteBuf[0] & 0xff;
  }

  /**
   * Decryption is buffer based.
   * If there is data in {@link #outBuffer}, then read it out of this buffer.
   * If there is no data in {@link #outBuffer}, then read more from the
   * underlying stream and do the decryption.
   * @param b the buffer into which the decrypted data is read.
   * @param off the buffer offset.
   * @param len the maximum number of decrypted data bytes to read.
   * @return int the total number of decrypted data bytes read into the buffer.
   * @throws IOException
   */
  @Override
  public int read(byte[] b, int off, int len) throws IOException {
    checkStream();
    if (b == null) {
      throw new NullPointerException();
    } else if (off < 0 || len < 0 || len > b.length - off) {
      throw new IndexOutOfBoundsException();
    } else if (len == 0) {
      return 0;
    }

    int remaining = outBuffer.remaining();
    if (remaining > 0) {
      // Satisfy the read with the existing data
      int n = Math.min(len, remaining);
      outBuffer.get(b, off, n);
      return n;
    } else {
      // No data in the out buffer, try read new data and decrypt it
      int nd = decryptMore();
      if(nd <= 0)
        return nd;

      int n = Math.min(len, outBuffer.remaining());
      outBuffer.get(b, off, n);
      return n;
    }
  }

  @Override
  public long skip(long n) throws IOException {
    Utils.checkArgument(n >= 0, "Negative skip length.");
    checkStream();

    if (n == 0) {
      return 0;
    }

    long remaining = n;
    int nd;

    while (remaining > 0) {
      if(remaining <= outBuffer.remaining()) {
        // Skip in the remaining buffer
        int pos = outBuffer.position() + (int) remaining;
        outBuffer.position(pos);

        remaining = 0;
        break;
      } else {
        remaining -= outBuffer.remaining();
        outBuffer.clear();
      }

      nd = decryptMore();
      if (nd < 0) {
        break;
      }
    }

    return n - remaining;
  }

  @Override
  public int available() throws IOException {
    checkStream();

    return input.available() + outBuffer.remaining();
  }

  @Override
  public void close() throws IOException {
    if (closed) {
      return;
    }

    input.close();
    freeBuffers();
    cipher.close();
    super.close();
    closed = true;
  }

  @Override
  public void mark(int readlimit) {
  }

  @Override
  public void reset() throws IOException {
    throw new IOException("Mark/reset not supported");
  }

  @Override
  public boolean markSupported() {
    return false;
  }

  @Override
  public boolean isOpen() {
    return !closed;
  }

  @Override
  public int read(ByteBuffer dst) throws IOException {
    checkStream();
    int remaining = outBuffer.remaining();
    if (remaining <= 0) {
      // Decrypt more data
      int nd = decryptMore();
      if(nd < 0) {
        return -1;
      }
    }

    // Copy decrypted data from outBuffer to dst
    remaining = outBuffer.remaining();
    final int toRead = dst.remaining();
    if (toRead <= remaining) {
      final int limit = outBuffer.limit();
      outBuffer.limit(outBuffer.position() + toRead);
      dst.put(outBuffer);
      outBuffer.limit(limit);
      return toRead;
    } else {
      dst.put(outBuffer);
      return remaining;
    }
  }

  /**
   * Get the buffer size
   */
  protected int getBufferSize() {
    return bufferSize;
  }

  /**
   * Get the key
   */
  protected byte[] getKey() {
    return key;
  }

  /**
   * Get the initialization vector
   */
  protected byte[] getInitIV() {
    return initIV;
  }

  /**
   * Get the internal Cipher
   */
  protected Cipher getCipher() {
    return cipher;
  }

  /** Initialize the cipher. */
  protected void initCipher()
      throws IOException {
    try {
      cipher.init(Cipher.DECRYPT_MODE, key, iv);
    } catch (InvalidKeyException e) {
      throw new IOException(e);
    } catch(InvalidAlgorithmParameterException e) {
      throw new IOException(e);
    }
  }

  /**
   * Decrypt more data by reading the under layer stream. The decrypted data will
   * be put in the output buffer. If the end of the under stream reached, we will
   * do final of the cipher to finish all the decrypting of data.
   *
   * @return The number of decrypted data. -1 if end of the decrypted stream
   */
  protected int decryptMore() throws IOException {
    if(finalDone) {
      return -1;
    }

    int n = input.read(inBuffer);
    if (n < 0) {
      // The stream is end, finalize the cipher stream
      decryptFinal();

      // Satisfy the read with the remaining
      int remaining = outBuffer.remaining();
      if (remaining > 0) {
        return remaining;
      }

      // End of the stream
      return -1;
    } else if(n == 0) {
      // No data is read, but the stream is not end yet
      return 0;
    } else {
      decrypt();
      return outBuffer.remaining();
    }
  }

  /**
   * Do the decryption using inBuffer as input and outBuffer as output.
   * Upon return, inBuffer is cleared; the decrypted data starts at
   * outBuffer.position() and ends at outBuffer.limit();
   */
  protected void decrypt() throws IOException {
    // Prepare the input buffer and clear the out buffer
    inBuffer.flip();
    outBuffer.clear();

    try {
      cipher.update(inBuffer, outBuffer);
    } catch (ShortBufferException e) {
      throw new IOException(e);
    }

    // Clear the input buffer and prepare out buffer
    inBuffer.clear();
    outBuffer.flip();
  }

  /**
   * Do final of the cipher to end the decrypting stream
   */
  protected void decryptFinal() throws IOException {
    // Prepare the input buffer and clear the out buffer
    inBuffer.flip();
    outBuffer.clear();

    try {
      cipher.doFinal(inBuffer, outBuffer);
      finalDone = true;
    } catch (ShortBufferException e) {
      throw new IOException(e);
    } catch (IllegalBlockSizeException e) {
      throw new IOException(e);
    } catch( BadPaddingException e) {
      throw new IOException(e);
    }

    // Clear the input buffer and prepare out buffer
    inBuffer.clear();
    outBuffer.flip();
  }

  protected void checkStream() throws IOException {
    if (closed) {
      throw new IOException("Stream closed");
    }
  }

  /** Forcibly free the direct buffers. */
  protected void freeBuffers() {
    Utils.freeDirectBuffer(inBuffer);
    Utils.freeDirectBuffer(outBuffer);
  }
}