class Rope2D:
"""Helper class to apply RoPE2D as well as interpolate on the fly."""
def __init__(self, dim, use_cls_token=False):
self.dim = dim
self.use_cls_token = use_cls_token
self.grid_size = None
self.freq = None
def init_tensors(self):
self.rope = RotaryEmbedding(self.dim // 2)
def update_grid(self, device, grid_h, grid_w):
if self.grid_size != (grid_h, grid_w):
self.grid_size = (grid_h, grid_w)
self.rope = self.rope.to(device)
if self.use_cls_token:
# +1 to leave space for the cls token to be (0, 0)
grid_y_range = torch.arange(grid_h, device=device) + 1
grid_x_range = torch.arange(grid_w, device=device) + 1
else:
grid_y_range = torch.arange(grid_h, device=device)
grid_x_range = torch.arange(grid_w, device=device)
freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
if self.use_cls_token:
freq = torch.cat(
[torch.zeros(1, freq.shape[-1], device=device), freq], dim=0
)
self.freq = freq[None, ...]
self.freq = self.freq.to(device)
def __call__(self, q, k):
# batch, heads, seq, dim = q.shape
q = apply_rotary_emb(self.freq[:, None, :, :], q)
k = apply_rotary_emb(self.freq[:, None, :, :], k)
return q, k