##### I/ Imports

import numpy as np
import dedalus.public as d3
import scipy.integrate as integrate
import matplotlib.pyplot as plt

########################################################################################################################################################

##### II/ Parameters

x_max = +1 # right edge of the box
x_min = -1 # left edge of the box

w1 = 0.25 # half-width of the summit
w2 = 0.25 # width of the slopes

Nx = 128

dealias = 3/2

########################################################################################################################################################

##### III/ Box setup

#### 1) Coordinates, Distributor, Bases, and Vectors

### a) Coordinate
coord = d3.Coordinate('x')

### b) Distributor
dist = d3.Distributor(coord, dtype = np.complex128)

### c) Basis
basis = d3.Legendre(coord, size = Nx, bounds = (x_min, x_max), dealias = dealias)

### d) Coordinates (again)
x = dist.local_grid(basis)


#### 2) Fields

### a) Some smooting functions to define the sides of the montain

def smoothing_function_in(r):
    a = -(w1 + w2)
    b = -w1
    func = lambda s: np.exp(-0.1 / (np.power(s - a, 1) * np.power(b - s, 1)))
    normalisation_factor = 1 / integrate.quad(func, a, b)[0]
    result = integrate.quad(func, a, r)[0] * normalisation_factor
    return result

def smoothing_function_out(r):
    a = w1
    b = w1 + w2
    func = lambda s: np.exp(-0.1 / (np.power(s - a, 1) * np.power(b - s, 1)))
    normalisation_factor = 1 / integrate.quad(func, a, b)[0]
    result = 1 - integrate.quad(func, a, r)[0] * normalisation_factor
    return result

smoothing_function_in_vec = np.vectorize(smoothing_function_in)
smoothing_function_out_vec = np.vectorize(smoothing_function_out)

### b) define 5 zones: the summit, the two slopes, and the two fields
F1 = (x_min <= x) * (x <= -(w1 + w2))
S1 = (-(w1 + w2) < x) * (x <= -w1)
S = (-w1 < x) * (x <= w1)
S2 = (w1 < x) * (x <= (w1 + w2))
F2 = ((w1 + w2) <= x) * (x <= x_max)

### c) function f
f = dist.Field(name = 'f', bases = basis)

f['g'][F1] = 0
f['g'][S1] = smoothing_function_in_vec(x[S1])
f['g'][S] = 1
f['g'][S2] = smoothing_function_out_vec(x[S2])
f['g'][F2] = 0

### d) derivative of f
df_dx = d3.Differentiate(f, coord).evaluate()
f.change_scales(1)
df_dx.change_scales(1)

########################################################################################################################################################

##### III/ Plot the results

### 1) Grid space
plt.plot(x, f['g'], 'g', label = "f['g']")
plt.plot(x, df_dx['g'], 'b', label = "df_dx['g']")

plt.xlim(x_min - 0.1, x_max + 0.1)
plt.ylim(-1.5, +1.5)
plt.legend()
plt.show()

### 2) Coefficient space
plt.plot(range(Nx), np.abs(f['c']) / np.max(np.abs(f['c'])), 'g', label = "f['c']")
plt.yscale('log')
plt.legend()
plt.show()