Reference implementation of a deep RNN that captures dependencies with a non-diagonal linear state-space model (SSM) over our implementation of generalized orders of magnitude (GOOMs), allowing recurrent states to fluctuate freely over a greater dynamic range of real values than previously possible, enabling computation of non-diagonal recurrences in parallel, via a prefix scan, without any form of stabilization. Installing Clone this repository. Install the Python dependencies in requirements.txt . There is no third step. Instantiating the RNN The following code instantiates a small RNN for generative language modeling tasks with GPT-2's vocabulary: import torch import tiktoken import goom_ssm_rnn DEVICE = 'cuda' # change as needed # Get GPT-2 encoder: enc = tiktoken . get_encoding ( 'gpt2' ) # Instantiate an RNN for natural language generation: model = goom_ssm_rnn . GenerativeRNN ( vocab_sz = enc . n_vocab , d_emb = 768 , n_hid = 24 , d_hid = 32 , n_res = 24 ) # Move model to cuda device: model . to ( device = DEVICE ) # You must provide your own training code. Use of Complex-Typed GOOMs Recurrent layers in the model capture sequential dependencies with a non-diagonal linear SSM, executed via a parallel prefix scan, over GOOMs, implemented as torch.complex64 tensors (i.e., with torch.float32 real and imaginary components). As we explain in our paper, the use of complex-typed GOOMs makes it possible for each layer to compute non-diagonal recurrent states in parallel without requiring any form of stabilization. Otherwise, the rest of the model operates conventionally, over torch.float32 tensors, optionally autocasting to torch.float16, if you specify it. As we explain in our paper, each recurrent layer scales complex-typed GOOMs before exponentiating them to torch.float32 real tensors, because the GOOM magnitudes can be outside bounds representable by torch.float32. Convenience Methods Besides the standard PyTorch forward() method, the model provides three addition...
First seen: 2025-10-22 17:25
Last seen: 2025-10-22 17:25