# Copyright 2024 Lnyan (https://github.com/lkwq007). All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from functools import partial import numpy as np import jax import jax.numpy as jnp from jax import Array as Tensor import flax from flax import nnx import flax.linen def fake_init(key, feature_shape, param_dtype): return jax.ShapeDtypeStruct(feature_shape, param_dtype) def wrap_LayerNorm(dim, *, eps=1e-5, elementwise_affine=True, bias=True, rngs:nnx.Rngs): return nnx.LayerNorm(dim, epsilon=eps, use_bias=elementwise_affine and bias, use_scale=elementwise_affine, bias_init=fake_init, scale_init=fake_init, rngs=rngs) def wrap_Linear(dim, inner_dim, *, bias=True, rngs:nnx.Rngs): return nnx.Linear(dim, inner_dim, use_bias=bias, kernel_init=fake_init, bias_init=fake_init, rngs=rngs) def wrap_GroupNorm(num_groups, num_channels, *, eps=1e-5, affine=True, rngs:nnx.Rngs): return nnx.GroupNorm(num_channels, num_groups=num_groups, epsilon=eps, use_bias=affine, use_scale=affine, bias_init=fake_init, scale_init=fake_init, rngs=rngs) def wrap_Conv(in_channels, out_channels, kernel_size, *, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', rngs:nnx.Rngs, conv_dim:int): if isinstance(kernel_size, int): kernel_tuple = (kernel_size,) * conv_dim else: # elif isinstance(kernel_size, tuple): assert len(kernel_size) == conv_dim kernel_tuple = kernel_size return nnx.Conv(in_channels, out_channels, kernel_tuple, strides=stride, padding=padding, use_bias=bias, kernel_init=fake_init, bias_init=fake_init, rngs=rngs) # return nnx.Conv(in_channels, out_channels, kernel_tuple, stride=stride, padding=padding, dilation=dilation, feature_group_count=groups, use_bias=bias, rngs=rngs) class nn_GELU(nnx.Module): def __init__(self, approximate="none") -> None: self.approximate=approximate=="tanh" def __call__(self, x): return nnx.gelu(x, approximate=self.approximate) class nn_SiLU(nnx.Module): def __init__(self) -> None: pass def __call__(self, x): return nnx.silu(x) class nn_AvgPool(nnx.Module): def __init__(self, window_shape, strides=None, padding="VALID") -> None: self.window_shape=window_shape self.strides=strides self.padding=padding def __call__(self, x): return flax.linen.avg_pool(x, window_shape=self.window_shape, strides=self.strides, padding=self.padding) # a wrapper class class TorchWrapper: def __init__(self, rngs: nnx.Rngs, dtype=jnp.float32): self.rngs = rngs self.dtype = dtype def declare_with_rng(self, *args): ret=list(map(lambda f: partial(f, dtype=self.dtype, rngs=self.rngs), args)) return ret if len(ret)>1 else ret[0] def conv_nd(self, dims, *args, **kwargs): return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=dims) def avg_pool(self, *args, **kwargs): return nn_AvgPool(*args, **kwargs) def linear(self, *args, **kwargs): return self.Linear(*args, **kwargs) def SiLU(self): return nn_SiLU() def GELU(self, approximate="none"): return nn_GELU(approximate) def Identity(self): return lambda x: x def LayerNorm(self, dim, eps=1e-5, elementwise_affine=True, bias=True): return wrap_LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, bias=bias, rngs=self.rngs) def GroupNorm(self, *args, **kwargs): return wrap_GroupNorm(*args,**kwargs, rngs=self.rngs) def Linear(self, *args, **kwargs): return wrap_Linear(*args, **kwargs, rngs=self.rngs) def Parameter(self, value): return nnx.Param(value) def Dropout(self, p): return nnx.Dropout(rate=p, rngs=self.rngs) def Sequential(self, *args): return nnx.Sequential(*args) def Conv1d(self, *args, **kwargs): return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=1) def Conv2d(self, *args, **kwargs): return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=2) def Conv3d(self, *args, **kwargs): return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=3) def ModuleList(self, lst=None): if lst is None: return [] return list(lst) def Module(self,*args,**kwargs): return nnx.Dict()