import numpy as np

from refl1d.names import Parameter, SLD, Slab, Experiment, FitProblem, load4
from refl1d.flayer import FunctionalProfile
from molgroups import mol
from molgroups import components as cmp

## === Probes/data files ===
probe_d2o = load4('ch061_d2o_ph7.refl', back_reflectivity=True, name='D2O')
probe_h2o = load4('ch060_h2o_ph7.refl', back_reflectivity=True, name='H2O')

# Probe parameters
# Backgrounds (one for each model)
probe_d2o.background.range(-1e-7, 1e-5)
probe_h2o.background.range(-1e-7, 1e-5)

# Intensity (common to all models)
probe_d2o.intensity.range(0.95,1.05)
probe_h2o.intensity = probe_d2o.intensity

# Sample broadening (common to all models)
probe_d2o.sample_broadening.range(-0.003, 0.02)
probe_h2o.sample_broadening = probe_d2o.sample_broadening

# Theta offset (common to all models)
probe_d2o.theta_offset.range(-0.02, 0.02)
probe_h2o.theta_offset = probe_d2o.theta_offset

## === Structural parameters ===

vf_bilayer = Parameter(name='volume fraction bilayer', value=0.9).range(0.0, 1.0)
l_lipid1 = Parameter(name='inner acyl chain thickness', value=10.0).range(8, 30)
l_lipid2 = Parameter(name='outer acyl chain thickness', value=10.0).range(8, 18)
l_submembrane = Parameter(name='submembrane thickness', value=10.0).range(0, 50)
sigma = Parameter(name='bilayer roughness', value=5).range(0.5, 9)
global_rough = Parameter(name ='substrate roughness', value=5).range(2, 9)
tiox_rough = Parameter(name='titanium oxide roughness', value=4).range(2, 9)
d_oxide = Parameter(name='silicon oxide layer thickness', value=10).range(5, 30)
d_tiox =  Parameter(name='titanium oxide layer thickness', value=110).range(100, 200)

# Define size of Z space = DIMENSION * STEPSIZE
DIMENSION = 300 # number of steps
STEPSIZE = 0.5 # step size

## === Materials ===

# Material definitions
d2o = SLD(name='d2o', rho=6.3000, irho=0.0000)
h2o = SLD(name='h2o', rho=-0.56, irho=0.0000)
tiox = SLD(name='tiox', rho=2.1630, irho=0.0000)
siox = SLD(name='siox', rho=4.1000, irho=0.0000)
silicon = SLD(name='silicon', rho=2.0690, irho=0.0000)

# Material SLD parameters
d2o.rho.range(5.3000, 6.36)
h2o.rho.range(-0.56, 0.6)
tiox.rho.range(1.2, 3.2)
siox.rho.range(2.8, 4.8)

## === Molecular groups ===
DOPC = cmp.Lipid(name='DOPC', headgroup=cmp.pc, tails=2 * [cmp.oleoyl], methyls=[cmp.methyl])
blm = mol.ssBLM(lipids=[DOPC],
                lipid_nf=[1.0])

# == Sample stack helper functions ==
def bilayer(z, sigma, bulknsld, substrate_rough, rho_substrate, l_lipid1, l_lipid2, l_submembrane, vf_bilayer):
    """ Generic tethered bilayer """

    # Scale all SLDs from Refl1D units (1e-6 Ang^-2) to molgroups units (Ang^-2)
    bulknsld = bulknsld * 1e-6
    rho_substrate = rho_substrate * 1e-6

    blm.fnSet(sigma=sigma, bulknsld=bulknsld, global_rough=substrate_rough, rho_substrate=rho_substrate,
              l_lipid1=l_lipid1, l_lipid2=l_lipid2, l_submembrane=l_submembrane, l_siox=0.0,
              vf_bilayer=vf_bilayer, radius_defect=1e8)

    # Calculate scattering properties of volume occupied by bilayer
    normarea, area, nsl = blm.fnWriteProfile(z)
    normarea = blm.normarea

    # for statistical analysis of molgroups
    problem.moldat, problem.results = write_groups([blm], ['bilayer'])
    
    # Return nSLD profile in Refl1D units
    return apply_bulknsld(z, bulknsld, normarea, area, nsl) * 1e6

def write_groups(groups, labels):
    """Return dictionaries with combined output of fnWriteGroup2Dict and fnWriteResults2Dict
    
        Inputs:
        groups: list of Molgroups objects to process
        labels: list (same length as groups) of labels"""
    
    moldict = {}
    resdict = {}
    for lbl, gp in zip(labels, groups):
        moldict = {**moldict, **gp.fnWriteGroup2Dict({}, lbl, np.arange(DIMENSION) * STEPSIZE)}
        resdict = {**resdict, **gp.fnWriteResults2Dict({}, lbl)}
        
    return moldict, resdict

def apply_bulknsld(z, bulknsld, normarea, area, nsl):
    """Given area and nSL profiles, fill in the remaining volume with bulk material"""
    
    # Fill in the remaining volume with buffer of appropriate nSLD
    nsld = nsl / (normarea * np.gradient(z)) + (1.0 - area / normarea) * bulknsld

    # Return nSLD profile in Refl1D units
    return nsld

def make_samples(func, substrate, contrasts, **kwargs):
    """Create samples from combining a substrate stack with a molgroups layer
    
        Inputs:
        func: function used to define FunctionalProfile object. Must have form func(z, bulknsld, *args)
        substrate: Refl1D Stack or Layer object representing the substrate
        contrasts: list of buffer materials, e.g. [d2o, h2o]. One sample will be created for each contrast
        **kwargs: keyword arguments. Must have one keyword argument for each arg in func(..., *args), but
                  not one for bulknsld"""
    samples = []

    for contrast in contrasts:
        mollayer = FunctionalProfile(DIMENSION * STEPSIZE, 0, profile=func, bulknsld=contrast.rho, **kwargs)
        layer_contrast = Slab(material=contrast, thickness=0.0000, interface=5.0000)
        samples.append(substrate | mollayer | layer_contrast)

    return samples

## == Sample layer stack ==

layer_silicon = Slab(material=silicon, thickness=0.0000, interface=global_rough)
layer_siox = Slab(material=siox, thickness=d_oxide, interface=global_rough)
layer_tiox = Slab(material=tiox, thickness=d_tiox - 0.5 * blm.substrate.length, interface=0.00)

# Use the bilayer definition function to generate the bilayer SLD profile, passing in the relevant parameters.
substrate = layer_silicon  | layer_siox | layer_tiox
sample_d2o, sample_h2o = make_samples(bilayer, substrate, [d2o, h2o], sigma=sigma, 
                             substrate_rough=tiox_rough, rho_substrate=tiox.rho,
                             l_lipid1=l_lipid1, l_lipid2=l_lipid2, l_submembrane=l_submembrane,
                             vf_bilayer=vf_bilayer)

## === Critical edge sampling ===
probe_d2o.critical_edge(substrate=silicon, surface=d2o)

## === Problem definition ===
## step = True corresponds to a calculation of the reflectivity from an actual profile
## with microslabbed interfaces.  When step = False, the Nevot-Croce
## approximation is used to account for roughness.  This approximation speeds up
## the calculation tremendously, and is reasonably accurate as long as the
## roughness is much less than the layer thickness
step = False

model_d2o = Experiment(sample=sample_d2o, probe=probe_d2o, dz=STEPSIZE, step_interfaces = step)
model_h2o = Experiment(sample=sample_h2o, probe=probe_h2o, dz=STEPSIZE, step_interfaces = step)

problem = FitProblem([model_d2o, model_h2o])

problem.name = "tiox_dopc_d2o_h2o"

## === Custom plotting code ===

def custom_plot():

    import plotly.graph_objs as go
    from refl1d.webview.server.colors import COLORS

    moldat = problem.moldat

    def hex_to_rgb(hex_string):
        r_hex = hex_string[1:3]
        g_hex = hex_string[3:5]
        b_hex = hex_string[5:7]
        return int(r_hex, 16), int(g_hex, 16), int(b_hex, 16)

    n_lipids = 1
    group_names = {'TiO2 substrate': ['bilayer.substrate'],
               #'silicon oxide': ['bilayer.siox'],
               'inner headgroups': [f'bilayer.headgroup1_{i}' for i in range(1, n_lipids + 1)],
               'inner acyl chains': [f'bilayer.methylene1_{i}' for i in range(1, n_lipids + 1)] + [f'bilayer.methyl1_{i}' for i in range(1, n_lipids + 1)],
               'outer acyl chains': [f'bilayer.methylene2_{i}' for i in range(1, n_lipids + 1)] + [f'bilayer.methyl2_{i}' for i in range(1, n_lipids + 1)],
               'outer headgroups': [f'bilayer.headgroup2_{i}' for i in range(1, n_lipids + 1)],
              }
    
    normarea = moldat['bilayer.normarea']['area']

    fig = go.Figure()
    traces = []
    MOD_COLORS = COLORS[1:]
    color_idx = 1
    sumarea = 0
    for lbl, item in group_names.items():
        area = 0
        for gp in item:
            if gp in moldat.keys():
                zaxis = moldat[gp]['zaxis']
                area += np.maximum(0, moldat[gp]['area'])
            else:
                print(f'Warning: {gp} not found')

        color = MOD_COLORS[color_idx % len(MOD_COLORS)]
        plotly_color = ','.join(map(str, hex_to_rgb(color)))
        traces.append(go.Scatter(x=zaxis,
                                 y=area / normarea,
                                 mode='lines',
                                 name=lbl,
                                 line=dict(color=color)))
        traces.append(go.Scatter(x=zaxis,
                                 y=area / normarea,
                                 mode='lines',
                                 line=dict(width=0),
                                 fill='tozeroy',
                                 fillcolor=f'rgba({plotly_color},0.3)',
                                 showlegend=False
                                 ))
        color_idx += 1
        sumarea += area

    color = COLORS[0]
    plotly_color = ','.join(map(str, hex_to_rgb(color)))
    
    traces.append(go.Scatter(x=zaxis,
                                y=sumarea / normarea,
                                mode='lines',
                                name='buffer',
                                line=dict(color=color)))
    traces.append(go.Scatter(x=zaxis,
                                y=sumarea / normarea,
                                mode='lines',
                                line=dict(width=0),
                                fill='tonexty',
                                fillcolor=f'rgba({plotly_color},0.3)',
                                showlegend=False
                                ))    
    traces.append(go.Scatter(x=zaxis,
                                y=[1.0] * len(zaxis),
                                mode='lines',
                                line=dict(color=color, width=0),
                                showlegend=False))

    
    fig.add_traces(traces[::-1])

    fig.update_layout(
        title='Component Volume Occupancy',
        template = 'plotly_white',
        xaxis_title=dict(text='z (Ang)'),
        yaxis_title=dict(text='volume occupancy')
    )

    return fig

setattr(problem, 'custom_plot', custom_plot)

if __name__ == '__main__':

    problem.chisq_str()
    problem.custom_plot().show()

