Skip to content

Commit

Permalink
Add KAN model
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Jul 7, 2024
1 parent 51b3082 commit 0abd882
Showing 1 changed file with 117 additions and 0 deletions.
117 changes: 117 additions & 0 deletions tetraku/tetraku/networks/kan/kan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) 2024 Hao Zhang<[email protected]>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#

# This file implements KAN (https://arxiv.org/pdf/2404.19756)

import torch


class KanUnit(torch.nn.Module):

def __init__(self, dim_input, dim_output, low, high, count, dropout=0):
super().__init__()
self.dim_input = dim_input
self.dim_output = dim_output
self.register_buffer(
'low',
torch.tensor([low for _ in range(dim_input)], dtype=torch.float32),
persistent=True,
)
self.register_buffer(
'high',
torch.tensor([high for _ in range(dim_input)], dtype=torch.float32),
persistent=True,
)
self.count = count
self.dropout = dropout

self.parameter = torch.nn.Parameter(torch.randn([dim_input, dim_output, count]) * 1e-4)

def expand_from(self, model):
assert self.dim_input == model.dim_input
assert self.dim_output == model.dim_output
self.low.copy_(model.low)
self.high.copy_(model.high)
x = (torch.arange(self.count) / (self.count - 1)).unsqueeze(-1).expand(-1, self.dim_input)
w1 = x * (model.count - 1)
w2 = w1 + 1
w1 = w1.clip(0, model.count - 2).to(dtype=torch.int64)
w2 = w2.clip(1, model.count - 1).to(dtype=torch.int64)
x1 = w1 / (model.count - 1)
x2 = w2 / (model.count - 1)
w1_expand = w1.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.dim_output, -1)
w2_expand = w2.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.dim_output, -1)
y1 = torch.gather(model.parameter.unsqueeze(0).expand(self.count, -1, -1, -1), -1, w1_expand).squeeze(-1)
y2 = torch.gather(model.parameter.unsqueeze(0).expand(self.count, -1, -1, -1), -1, w2_expand).squeeze(-1)
y = (y2 - y1) * ((x - x1) / (x2 - x1)).unsqueeze(-1) + y1
self.parameter.data.copy_(y.transpose(0, 1).transpose(1, 2))

def forward(self, x, update: float = 0):
if update != 0:
low = x.min(dim=0).values
high = x.max(dim=0).values
self.low.copy_(low * update + self.low * (1 - update)).detach_().requires_grad_(False)
self.high.copy_(high * update + self.high * (1 - update)).detach_().requires_grad_(False)

assert x.ndim == 2
assert self.dim_input == x.size(1)
batch_size = x.size(0)
# x: batch, dim_input
w1 = (x - self.low) / ((self.high - self.low) / (self.count - 1))
w2 = w1 + 1
if self.training:
w1[torch.rand_like(w1) < self.dropout] -= 1
w2[torch.rand_like(w1) < self.dropout] += 1
w1 = w1.clip(0, self.count - 2).to(dtype=torch.int64)
w2 = w2.clip(1, self.count - 1).to(dtype=torch.int64)
x1 = self.low + w1 * ((self.high - self.low) / (self.count - 1))
x2 = self.low + w2 * ((self.high - self.low) / (self.count - 1))
w1_expand = w1.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.dim_output, -1)
w2_expand = w2.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.dim_output, -1)
y1 = torch.gather(self.parameter.unsqueeze(0).expand(batch_size, -1, -1, -1), -1, w1_expand).squeeze(-1)
y2 = torch.gather(self.parameter.unsqueeze(0).expand(batch_size, -1, -1, -1), -1, w2_expand).squeeze(-1)
y = (y2 - y1) * ((x - x1) / (x2 - x1)).unsqueeze(-1) + y1
y = y.mean(dim=-2)
assert y.ndim == 2
assert batch_size == y.size(0)
assert self.dim_output == y.size(1)
return y


class Kan(torch.nn.Module):

def __init__(self, dims, low, high, count, dropout=0):
super().__init__()
self.layers = torch.nn.ModuleList(
KanUnit(
dims[i],
dims[i + 1],
low,
high,
count,
dropout=dropout,
) for i in range(len(dims) - 1))

def expand_from(self, model):
for new_layer, old_layer in zip(self.layers, model.layers):
new_layer.expand_from(old_layer)

def forward(self, x, update: float = 0):
for layer in self.layers:
x = layer(x, update=update)
return x

0 comments on commit 0abd882

Please sign in to comment.