Time-series data such as those in the stock market is usually dependent on the previous n historical data points.
Recurrent Neural Network (RNN) is applied to sequence data to capture the intra-sequence dependence. Long-short term memoy network (LSTM) is a variant of RNN, capturing long-term impact on short-term signal behaviour.
Key difference between a simple RNN and LSTM:
- Simple RNN: later output nodes of the network are less sensitive to the input at time t=1: gradient vanishes.
- LSTM: Preseves gradient by implementing “forget” and “input” gates.
LSTM holds the following components in each layer:
- Inputs: Previous ouptut ($ h_{t-1} $) and current input ($ x_{t} $)
- Forget gate:
- System: $ \sigma $ decides whehter to throw away information from the current cell state $ C_t $
- Ouput: $f_t =\sigma( W_f \times [h_{t-1}, x_t] + b_f) $ A number between 0 and 1 for each number in cell state $ C_{t-1} $.
- Input gate:
- System 1: $ \sigma $ decides which values will be updated
- Output: $ i_t = \sigma(W_i \times [h_{t-1}, x_t]+b_i) $
- System 2: $ tanh $ creates a vector of new candidate values
- Output: $ \tilde{C_t} = tanh(W_C \times [h_{t-1}, x_t] + b_C) $
- Output: $ i_t \times \tilde{C_t} $
- System 1: $ \sigma $ decides which values will be updated
- Summation of Forget and Input gates:
- $ C_t = f_t \times C_{t-1} + i_t \times \tilde{C_t} $
- Final Process:
- System 1: $ \sigma $ decides what parts of the cell state $ C_t $ will be outputed
- Output: $ o_t = \sigma(W_o \times [h_{t-1}, x_t] + b_o) $
- System 2: $ tanh $ to generate a value between -1 and 1, multiply by $ o_t $
- Output: $ h_t = o_t \times tanh(C_t) $
- System 1: $ \sigma $ decides what parts of the cell state $ C_t $ will be outputed
Further reading:
- http://colah.github.io/posts/2015-08-Understanding-LSTMs/
- https://www.superdatascience.com/the-ultimate-guide-to-recurrent-neural-networks-rnn/
Here are two implementations of codes using Tensorflow and Pytorch:
Tensorflow is a bit more convolved but you can play around with the architecture:
https://www.datacamp.com/community/tutorials/lstm-python-stock-market
Pytorch has a built-in LSTM model:
https://github.com/jessicayung/blog-code-snippets/blob/master/lstm-pytorch/lstm-baseline.py