Diffusion-based poem generator using latent diffusion
This project implements a diffusion-based text generator to create poems using latent diffusion. Poems were encoded into latent space using the BART encoder-decoder, with its parameters frozen during training to accelerate the process. However, the encoded latent space was length-dependent. To address this, a Perceiver sampler was used to convert the encoded latent space into a fixed 32x64 representation, which was then reshaped into 32x16x4 and passed through the forward diffusion process and later for denoising using the NVIDIA Sana model. The model’s parameters were reduced to speed up training. Diffusion models were chosen due to their ability to understand patterns effectively, as discussed inGENERALIZATION IN DIFFUSION MODELS ARISES FROM GEOMETRY-ADAPTIVE HARMONIC REPRESENTATIONS
Bart
Perceiver sampler
Nvida Sana
Overall architecture
Dataset
The dataset used was taken from Kaggle which contains two folders, both containing subfolders of poems. These poems are categorized by the form (e.g. haiku, sonnet, etc.) or topic (love, nature, joy, peace, etc.).
Step 1: Finetuning language encoder
For encoding and decoding, BART was used with its parameters frozen during training to focus on learning the encoding process (using sequences of 64 tokens). The encoded latent space was then resampled using the Perceiver sampler, which mapped the samples to a 32x64 representation. This was later used to train the diffusion transformer.
Step 2: Forward diffusion process
Random noise was then added to the latents
Step 3: Encoding prompts
To encode prompts SmolLm2-360M was used
Step 4: Training the NVIDIA Sana model
Over here, NVIDIA’s Sana model was used to predict noise and denoise it. The number of parameters was reduced to ease up training in diffusion.py
config = SanaTransformer2DModel.load_config("Efficient-Large-Model/Sana_600M_1024px_diffusers", subfolder="transformer")
config["num_layers"] = 12
config["num_attention_heads"] = 12
config["attention_head_dim"] = 64
config["cross_attention_dim"] = 768
config["num_cross_attention_heads"] = 12
config["cross_attention_head_dim"] = 64
config["caption_channels"] = 960
transformer = SanaTransformer2DModel.from_config(config)
Example
prompt - romantic output Relative to be checked, and asexualerdi nightnightnightnightfalled; I’m notingeded, and live-standstandingly, listened to listen to listen listen to everything else whileoberoberrounddoorkoombombombroundingly, while rubbing stubbordozieBee
References
Pre-training languge encoder and reconstruction networks process -> Latent Diffusion for Language Generation
Forward-diffusion and noise predictions -> Training a Latent Diffusion Model From Scratch
NVIDIA’S SANA -> Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers
Perceiver -> Perceiver: General Perception with Iterative Attention
Future plan
- For the future, I will be increasing the token size for the language encoder (currently 64) to 416
- Train the transformer for a larger number of epochs
- Try incorporating Conditional Diffusion Models with Classifier-Free Gibbs-like Guidance to generate higher quality and diverse samples
Steps to train
Training the language encoder
run python3 bart_latent_model.py
Training the diffusion model
run python3 diffusion.py