-
Notifications
You must be signed in to change notification settings - Fork 1
/
nn.py
422 lines (327 loc) · 15 KB
/
nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
import math
from functools import partial
from collections import namedtuple
from packaging import version
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from misc import default
from misc import divisible_by
from misc import exists
from misc import cast_tuple
from misc import once
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
print_once = once(print)
class Attend(nn.Module):
def __init__(self, dropout = 0., flash = False):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# determine efficient attention configs for cuda and cpu
self.cpu_config = AttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = AttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = AttentionConfig(False, True, True)
def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
# Check if there is a compatible device for flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.
)
return out
def forward(self, q, k, v):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
if self.flash:
return self.flash_attn(q, k, v)
scale = q.shape[-1] ** -0.5
# similarity
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
# attention
attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)
# aggregate values
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
return out
def Upsample(dim, dim_out = None):
return nn.Sequential(nn.Upsample(scale_factor=2, mode = 'nearest'), nn.Conv2d(dim, default(dim_out, dim), 3, padding=1))
def Downsample(dim, dim_out = None):
return nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2), nn.Conv2d(dim * 4, default(dim_out, dim), 1))
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class RandomOrLearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def __init__(self, dim, is_random = False):
super().__init__()
assert divisible_by(dim, 2)
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class LinearAttention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1),
RMSNorm(dim)
)
def forward(self, x):
b, c, h, w = x.shape
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q.softmax(dim = -2)
k = k.softmax(dim = -1)
q = q * self.scale
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
return self.to_out(out)
class Attention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32, flash = False):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.norm = RMSNorm(dim)
self.attend = Attend(flash = flash)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv)
out = self.attend(q, k, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
class HyperNet(nn.Module):
def __init__(self, layer_num, hidden_dim, latent_dim):
super().__init__()
assert layer_num >= 1
self.param_projs = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(layer_num)])
self.scales = nn.Parameter(torch.randn(layer_num), requires_grad=True)
self.eye = nn.Parameter(torch.eye(latent_dim), requires_grad=False)
self.in_proj = nn.Linear(hidden_dim, latent_dim)
self.out_proj = nn.Linear(latent_dim, hidden_dim)
def get_weights(self, var):
var = var.permute(0, 2, 3, 1)
params = [mat(var) for mat in self.param_projs]
low_ranks = [torch.matmul(vec.unsqueeze(-1), vec.unsqueeze(-2)) for vec in params]
abs_scales = torch.log(1.0 + torch.exp(self.scales))
eye = self.eye.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand_as(low_ranks[0])
full_ranks = [mat + eye * val for mat, val in zip(low_ranks, abs_scales)]
return full_ranks
def forward(self, x, weights):
repr = x.permute(0, 2, 3, 1)
repr = self.in_proj(repr).unsqueeze(-1)
for mat in weights:
repr = torch.matmul(mat, repr)
repr = self.out_proj(repr.squeeze(-1))
return repr.permute(0, 3, 1, 2)
class UNet(nn.Module):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
self_condition=False,
latent_dim=None,
hypernet_layers=None,
resnet_block_groups=8,
learned_variance=False,
learned_sinusoidal_cond=False,
random_fourier_features=False,
learned_sinusoidal_dim=16,
attn_dim_head=32,
attn_heads=4,
full_attn=(False, False, False, True),
flash_attn=False,
):
super().__init__()
# determine dimensions
self.channels = channels
self.self_condition = self_condition
input_channels = channels * (2 if self_condition else 1)
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
if hypernet_layers is not None:
assert (latent_dim is not None)
self.in_hypernet = HyperNet(hypernet_layers, init_dim, latent_dim)
else:
self.in_hypernet = None
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
# time embeddings
time_dim = dim * 4
self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
if self.random_or_learned_sinusoidal_cond:
sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
fourier_dim = learned_sinusoidal_dim + 1
else:
sinu_pos_emb = SinusoidalPosEmb(dim)
fourier_dim = dim
self.time_mlp = nn.Sequential(sinu_pos_emb, nn.Linear(fourier_dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim))
# attention
num_stages = len(dim_mults)
full_attn = cast_tuple(full_attn, num_stages)
attn_heads = cast_tuple(attn_heads, num_stages)
attn_dim_head = cast_tuple(attn_dim_head, num_stages)
assert len(full_attn) == len(dim_mults)
FullAttention = partial(Attention, flash=flash_attn)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)):
is_last = ind >= (num_resolutions - 1)
attn_klass = FullAttention if layer_full_attn else LinearAttention
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1])
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))):
is_last = ind == (len(in_out) - 1)
attn_klass = FullAttention if layer_full_attn else LinearAttention
self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))
default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
if hypernet_layers is not None:
assert (latent_dim is not None)
self.out_hypernet = HyperNet(hypernet_layers, dim, latent_dim)
else:
self.out_hypernet = None
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
@property
def downsample_factor(self):
return 2 ** (len(self.downs) - 1)
def forward(self, x, time, x_self_cond=None, latent_var=None):
assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet'
if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim=1)
x = self.init_conv(x)
if self.in_hypernet is not None:
in_params = self.in_hypernet.get_weights(latent_var)
x = x + self.in_hypernet(x, in_params)
r = x.clone()
t = self.time_mlp(time)
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x) + x
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x) + x
x = self.mid_block2(x, t)
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim = 1)
x = block2(x, t)
x = attn(x) + x
x = upsample(x)
x = torch.cat((x, r), dim = 1)
x = self.final_res_block(x, t)
if self.out_hypernet is not None:
out_params = self.out_hypernet.get_weights(latent_var)
x = x + self.out_hypernet(x, out_params)
return self.final_conv(x)