-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
117 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Copyright (C) 2024 Hao Zhang<[email protected]> | ||
# | ||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU General Public License | ||
# along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
# | ||
|
||
# This file implements KAN (https://arxiv.org/pdf/2404.19756) | ||
|
||
import torch | ||
|
||
|
||
class KanUnit(torch.nn.Module): | ||
|
||
def __init__(self, dim_input, dim_output, low, high, count, dropout=0): | ||
super().__init__() | ||
self.dim_input = dim_input | ||
self.dim_output = dim_output | ||
self.register_buffer( | ||
'low', | ||
torch.tensor([low for _ in range(dim_input)], dtype=torch.float32), | ||
persistent=True, | ||
) | ||
self.register_buffer( | ||
'high', | ||
torch.tensor([high for _ in range(dim_input)], dtype=torch.float32), | ||
persistent=True, | ||
) | ||
self.count = count | ||
self.dropout = dropout | ||
|
||
self.parameter = torch.nn.Parameter(torch.randn([dim_input, dim_output, count]) * 1e-4) | ||
|
||
def expand_from(self, model): | ||
assert self.dim_input == model.dim_input | ||
assert self.dim_output == model.dim_output | ||
self.low.copy_(model.low) | ||
self.high.copy_(model.high) | ||
x = (torch.arange(self.count) / (self.count - 1)).unsqueeze(-1).expand(-1, self.dim_input) | ||
w1 = x * (model.count - 1) | ||
w2 = w1 + 1 | ||
w1 = w1.clip(0, model.count - 2).to(dtype=torch.int64) | ||
w2 = w2.clip(1, model.count - 1).to(dtype=torch.int64) | ||
x1 = w1 / (model.count - 1) | ||
x2 = w2 / (model.count - 1) | ||
w1_expand = w1.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.dim_output, -1) | ||
w2_expand = w2.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.dim_output, -1) | ||
y1 = torch.gather(model.parameter.unsqueeze(0).expand(self.count, -1, -1, -1), -1, w1_expand).squeeze(-1) | ||
y2 = torch.gather(model.parameter.unsqueeze(0).expand(self.count, -1, -1, -1), -1, w2_expand).squeeze(-1) | ||
y = (y2 - y1) * ((x - x1) / (x2 - x1)).unsqueeze(-1) + y1 | ||
self.parameter.data.copy_(y.transpose(0, 1).transpose(1, 2)) | ||
|
||
def forward(self, x, update: float = 0): | ||
if update != 0: | ||
low = x.min(dim=0).values | ||
high = x.max(dim=0).values | ||
self.low.copy_(low * update + self.low * (1 - update)).detach_().requires_grad_(False) | ||
self.high.copy_(high * update + self.high * (1 - update)).detach_().requires_grad_(False) | ||
|
||
assert x.ndim == 2 | ||
assert self.dim_input == x.size(1) | ||
batch_size = x.size(0) | ||
# x: batch, dim_input | ||
w1 = (x - self.low) / ((self.high - self.low) / (self.count - 1)) | ||
w2 = w1 + 1 | ||
if self.training: | ||
w1[torch.rand_like(w1) < self.dropout] -= 1 | ||
w2[torch.rand_like(w1) < self.dropout] += 1 | ||
w1 = w1.clip(0, self.count - 2).to(dtype=torch.int64) | ||
w2 = w2.clip(1, self.count - 1).to(dtype=torch.int64) | ||
x1 = self.low + w1 * ((self.high - self.low) / (self.count - 1)) | ||
x2 = self.low + w2 * ((self.high - self.low) / (self.count - 1)) | ||
w1_expand = w1.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.dim_output, -1) | ||
w2_expand = w2.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.dim_output, -1) | ||
y1 = torch.gather(self.parameter.unsqueeze(0).expand(batch_size, -1, -1, -1), -1, w1_expand).squeeze(-1) | ||
y2 = torch.gather(self.parameter.unsqueeze(0).expand(batch_size, -1, -1, -1), -1, w2_expand).squeeze(-1) | ||
y = (y2 - y1) * ((x - x1) / (x2 - x1)).unsqueeze(-1) + y1 | ||
y = y.mean(dim=-2) | ||
assert y.ndim == 2 | ||
assert batch_size == y.size(0) | ||
assert self.dim_output == y.size(1) | ||
return y | ||
|
||
|
||
class Kan(torch.nn.Module): | ||
|
||
def __init__(self, dims, low, high, count, dropout=0): | ||
super().__init__() | ||
self.layers = torch.nn.ModuleList( | ||
KanUnit( | ||
dims[i], | ||
dims[i + 1], | ||
low, | ||
high, | ||
count, | ||
dropout=dropout, | ||
) for i in range(len(dims) - 1)) | ||
|
||
def expand_from(self, model): | ||
for new_layer, old_layer in zip(self.layers, model.layers): | ||
new_layer.expand_from(old_layer) | ||
|
||
def forward(self, x, update: float = 0): | ||
for layer in self.layers: | ||
x = layer(x, update=update) | ||
return x |