pokemon-polynomial-2-gbar / scheduler /scheduler_config.py
AltLuv's picture
End of training
8e422c7
raw
history blame
1.79 kB
import jax.numpy as jnp
import jax
import torch
from dataclasses import dataclass
import sympy
import sympy as sp
from sympy import Matrix, Symbol
import math
from sde_redefined_param import SDEDimension
@dataclass
class SDEPolynomialConfig:
name = "Custom"
initial_variable_value = 0
max_variable_value = 1# math.inf
min_sample_value = 1e-6
variable = Symbol('t', nonnegative=True, real=True, domain=sympy.Interval(initial_variable_value, max_variable_value, left_open=False, right_open=False))
drift_dimension = SDEDimension.SCALAR
diffusion_dimension = SDEDimension.SCALAR
diffusion_matrix_dimension = SDEDimension.SCALAR
drift_degree = 20
diffusion_degree = 20
drift_parameters = Matrix([sympy.symbols(f"f:{drift_degree}", real=True, nonzero=True)])
diffusion_parameters = Matrix([sympy.symbols(f"l:{diffusion_degree}", real=True, nonzero=True)])
@property
def drift(self):
transformed_variable = self.variable
return -sympy.Abs(sum(sympy.HadamardProduct(Matrix([[transformed_variable**i for i in range(1,self.drift_degree+1)]]), self.drift_parameters).doit()))
@property
def diffusion(self):
return self.variable**(sum(sympy.HadamardProduct(Matrix([[self.variable**i for i in range(0,self.diffusion_degree)]]),self.diffusion_parameters.applyfunc(lambda x: x**2)).doit()))
# TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t))
diffusion_matrix = 1
module = 'jax'
drift_integral_form=True
diffusion_integral_form=True
diffusion_integral_decomposition = 'cholesky' # ldl
target = "epsilon" # x0
non_symbolic_parameters = {'drift': torch.ones(drift_degree), 'diffusion': torch.ones(diffusion_degree)}