File size: 1,788 Bytes
2819523 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import jax.numpy as jnp
import jax
import torch
from dataclasses import dataclass
import sympy
import sympy as sp
from sympy import Matrix, Symbol
import math
from sde_redefined_param import SDEDimension
@dataclass
class SDEPolynomialConfig:
name = "Custom"
initial_variable_value = 0
max_variable_value = 1# math.inf
min_sample_value = 1e-6
variable = Symbol('t', nonnegative=True, real=True, domain=sympy.Interval(initial_variable_value, max_variable_value, left_open=False, right_open=False))
drift_dimension = SDEDimension.SCALAR
diffusion_dimension = SDEDimension.SCALAR
diffusion_matrix_dimension = SDEDimension.SCALAR
drift_degree = 20
diffusion_degree = 20
drift_parameters = Matrix([sympy.symbols(f"f:{drift_degree}", real=True, nonzero=True)])
diffusion_parameters = Matrix([sympy.symbols(f"l:{diffusion_degree}", real=True, nonzero=True)])
@property
def drift(self):
transformed_variable = self.variable
return -sympy.Abs(sum(sympy.HadamardProduct(Matrix([[transformed_variable**i for i in range(1,self.drift_degree+1)]]), self.drift_parameters).doit()))
@property
def diffusion(self):
return self.variable**(sum(sympy.HadamardProduct(Matrix([[self.variable**i for i in range(0,self.diffusion_degree)]]),self.diffusion_parameters.applyfunc(lambda x: x**2)).doit()))
# TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t))
diffusion_matrix = 1
module = 'jax'
drift_integral_form=True
diffusion_integral_form=True
diffusion_integral_decomposition = 'cholesky' # ldl
target = "epsilon" # x0
non_symbolic_parameters = {'drift': torch.ones(drift_degree), 'diffusion': torch.ones(diffusion_degree)}
|