Skip to content

Quick Plots

codechembook.quickPlots

plotFit(fit, resample=10, residual=False, components=False, confidence=0, xlabel=None, ylabel=None, template='simple_white', output='png')

Plot the result of a 1d fit using lmfit

Parameters:

Name Type Description Default
fit lmfit result object

Results from a lmfit fit.

required
resample int

Increase the density of model points on the x axis by resample times to smooth. (default: 10)

10
residual bool

Show the residual. (default: False)

False
components bool

Show the individual components of the model. (default: False)

False
confidence int

Show the confidence-sigma confidence interval of the fit. (default: 0)

0
xlabel str

x-axis title (default: None (independent variable name))

None
ylabel str

y-axis title (default: None (blank))

None
template str

Plotly template for formatting Figure. (default: 'simple_white')

'simple_white'
output str

Method to show the plot. (default: "png", options: valid Plotly output types, None for no output)

'png'

Returns:

Name Type Description
Figure

Plotly Figure object containing the figure.

Source code in src/codechembook/quickPlots.py
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
def plotFit(fit, 
            resample = 10, 
            residual = False, 
            components = False, 
            confidence = 0, 
            xlabel = None, 
            ylabel = None, 
            template = 'simple_white',
            output = 'png'):
    """
    Plot the result of a 1d fit using lmfit

    Args:
        fit (lmfit result object): Results from a lmfit fit.
        resample (int, optional): Increase the density of model points on the x axis by `resample` times to smooth. (default: 10)
        residual (bool, optional): Show the residual. (default: False)
        components (bool, optional): Show the individual components of the model. (default: False)
        confidence (int, optional): Show the `confidence`-sigma confidence interval of the fit. (default: 0)
        xlabel (str, optional): x-axis title (default: None (independent variable name))
        ylabel (str, optional): y-axis title (default: None (blank))
        template (str, optional): Plotly template for formatting Figure. (default: 'simple_white')
        output (str, optional): Method to show the plot. (default: "png", options: valid Plotly output types, None for no output)

    Returns:
        Figure: Plotly Figure object containing the figure.
    """

    # Just making some variables for convenience
    # First figure out what the independent variable name(s) is(are)
    independent_vars = fit.model.independent_vars

    # The x data has to be the same for all the independent variables, so
    # so get it from the first one in the list for safety
    xdata = fit.userkws[independent_vars[0]]
    ydata = fit.data

    # Resampling the fit so that it looks smooth to the eye
    smoothx = np.linspace(xdata.min(), xdata.max(), len(xdata)*resample)

    # Need to handle the fact that there may be multiple names for the 
    # independent variable
    kwargs = {}
    for independent_var in independent_vars:
        kwargs[independent_var] = smoothx
    smoothy = fit.eval(**kwargs)

    # If we are plotting the residual, then we need two subplots
    if residual: # will work as long as this is not False, 0, empty list, etc
        row_heights = [0.8, 0.2] # this is the default
        if residual == 'scaled':
            row_heights = [np.max(ydata) - np.min(ydata) , np.max(fit.residual) - np.min(fit.residual)]

        fig = make_subplots(rows = 2, 
                            cols = 1, 
                            shared_xaxes = True, 
                            row_heights = row_heights,
                            vertical_spacing = 0.05)
    else:
        fig = make_subplots()

    # If we are plotting the confidence interval, then plot +/- N * 1-sigma 
    # and fill between the two curves
    if confidence != 0 and type(confidence) == int:
        fig.add_scatter(x = smoothx, 
                        y = smoothy + confidence * fit.eval_uncertainty(**kwargs), 
                        mode = 'lines',
                        line = {'color': 'lightpink', 'width': 0},
                        row = 1, col = 1)
        fig.add_scatter(x = smoothx, 
                        y = smoothy - confidence * fit.eval_uncertainty(**kwargs), 
                        mode = 'lines',
                        line = {'color': 'lightpink', 'width': 0},
                        row = 1, col = 1,
                        fill = 'tonexty')

    # If we are plotting the individual components, go ahead and plot them first
    if components == True:

        # Generate the components resampled to the smooth x array
        comps = fit.eval_components(**kwargs)
        # Loop through the components and plot each one
        for comp in comps:
            fig.add_scatter(x = smoothx, 
                            y = comps[comp], 
                            line = {'dash': 'dot', 'color':'grey'},
                            row = 1, col = 1) 

    # Plot the raw data
    fig.add_scatter(x = xdata, 
                    y = ydata, 
                    mode = 'markers', 
                    name = 'Data', 
                    legendrank = 1, 
                    marker = {'color': 'blue', 'size': 8},
                    line = {'color': 'blue', 'width' : 8},
                    row = 1, col = 1)

    # Plot the fit curve
    fig.add_scatter(x = smoothx, 
                    y = smoothy, 
                    mode = 'lines', 
                    name = 'Best Fit', 
                    legendrank = 2, 
                    line = {'color': 'red'},
                    row = 1, col = 1)

    # If we are doing residuals, plot the residual
    if residual:
        fig.add_scatter(x = xdata, 
                        y = -1*fit.residual, # we need to multiply this by -1, to get the 'expected' behavior of data - fit. 
                        mode = 'markers+lines', 
                        name = 'Residual', 
                        line = {'color': 'black', 'width':1},
                        marker = {'color': 'black', 'size':2},
                        showlegend = False,
                        row = 2, col = 1)

        # Optionally plot the confidence interval of the residual
        if confidence != 0 and type(confidence) == int:

            fig.add_scatter(x = smoothx, 
                            y = confidence * fit.eval_uncertainty(**kwargs), 
                            mode = 'lines',
                            line = {'color': 'gray', 'width': 0},
                            row = 2, col = 1)
            fig.add_scatter(x = smoothx, 
                            y = -1 * confidence * fit.eval_uncertainty(**kwargs), 
                            mode = 'lines',
                            line = {'color': 'gray', 'width': 0},
                            row = 2, col = 1,
                            fill = 'tonexty')
        # Limit the ticks on the Residual axis so that it is readable
        residual_lim = np.max(np.abs(fit.residual)) * 1.05
        fig.update_yaxes(title = 'Residual', 
                         range = [-residual_lim, residual_lim], 
                         nticks = 3, zeroline = True, row = 2)

        #fig.update_yaxes(title = 'Residual', row = 2)



    # Update the layout
    fig.update_layout(template = template, showlegend = False)

    # Flag the user if the fit did not finish successfully
    fig_full = fig.full_figure_for_development()
    if fit.ier not in (1, 2, 3, 4):
        print(fit.result.lmdif_message)

        # find min max of x and y and average instead
        fig.add_annotation(x = (np.max(xdata) - np.min(xdata))/2,
                           y = (np.max(ydata) - np.min(ydata))/2,
                           text = 'Fit not converged.\nCheck command line for info.')

    # If the user supplied an x axis label, add it
    if xlabel is None: # we can default to the name in the model
        xlabel = fit.model.independent_vars[0]
    fig.update_xaxes(title = xlabel, row = 2 if residual else 1)    

    # If the user supplied a y axis label, add it
    if ylabel is None:
        print('Please enter a string for the y label.')
    fig.update_yaxes(title = ylabel, row = 1)

    # Plot the figure to the specified output
    process_output(fig, output) # check to see how we should be outputting this plot

    return fig

process_output(plot, output)

Handler for different outputs that could be chosen by the user as implemented in Plotly.

Parameters:

Name Type Description Default
plot Figure

Plotly Figure option to plot.

required
output str

Type of output requested by user, among those allowed by Plotly.

required
Source code in src/codechembook/quickPlots.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def process_output(plot, output):
    """
    Handler for different outputs that could be chosen by the user as implemented in Plotly.

    Args:
        plot (Figure): Plotly Figure option to plot.
        output (str):  Type of output requested by user, among those allowed by Plotly.
    """
    # Plot the figure to the specified output
    if output in pio.renderers.keys():
        plot.show(output)
    elif output == "default":
        plot.show()
    elif output is None:
        pass # no need to do anything
    else:
        print("Enter 'png' to plot in Spyder or 'browser' for the browser.")
        print("Use 'None' to show nothing and return the figure object.")

quickBin(x, limits=None, nbins=None, width=None)

Accepts a collection of numbers that can be coerced into a numpy array, and bins these numbers. If none of keyword arguments are specified, this results in a Freeman-Diaconis binning.

Parameters:

Name Type Description Default
x ndarray or list

Data set bin.

required
limits list of numeric

Upper and lower limits of data to bin. (default: min and max of x)

None
nbins int

Number of bins. (default: None (automatically determine))

None
width float

Width of the bins. (default: None (automatically determine))

None

Returns:

Type Description

tuple[ndarray, ndarray]: bin_centers (ndarray) and bin_counts (ndarray).

Source code in src/codechembook/quickPlots.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def quickBin(x, limits = None, nbins = None, width = None):
    '''
    Accepts a collection of numbers that can be coerced into a numpy array, and bins these numbers.
    If none of keyword arguments are specified, this results in a Freeman-Diaconis binning.

    Args:
        x (ndarray or list): Data set bin.
        limits (list of numeric, optional): Upper and lower limits of data to bin. (default: min and max of x)
        nbins (int, optional): Number of bins. (default: None (automatically determine))
        width (float, optional): Width of the bins. (default: None (automatically determine))

    Returns:
        tuple[ndarray, ndarray]: bin_centers (ndarray) and bin_counts (ndarray).
    '''
    try:
        x = np.array(x)
    except:
        raise("the data need to be in a form that can be converted to a numpy array")
    # we need to start by finding the limits and the bin width

    # we can start by getting the iqr, which might prove useful for formatting as well
    q75, q25 = np.percentile(x, [75,25]) # find the places for the inner quartile
    iqr = q75 - q25 # calculate the inner quartile range


    # first thing: make sure we have a range to work with...
    if limits == None: # then set the limis as the min and max of x
        limits = [min(x), max(x)]

    if nbins != None and width != None:
        raise("Specify either the number of bins, or the bin width, but not both.")

    # check to see if the width of the bins was specified...
    if width == None and nbins == None: # then use the Freedman-Diaconis method to calculate bins
        width = 2*iqr*len(x)**(-1/3)

    if nbins != None and width == None: # use the number of bins to determine the width
        width = abs(limits[1] - limits[0]) / int(nbins)

    # the only other option is that width was directly specified.... 
    # so now we are ready to go...

    # Define the bin edges using numpy's arange function
    bin_edges = np.arange(limits[0], limits[1] + width, width)

    # Use numpy's histogram function to bin the data, using the bin edges we have calculated
    bin_counts, _ = np.histogram(x, bins=bin_edges)

    # Calculate the bin centers by averaging each pair of consecutive edges
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2


    return [bin_centers, bin_counts]

quickGrid(x=None, labels=None, template='simple_white', output='png')

Plots correlation between a series of arrays with Plotly.

Work in progress. To do: place label in the diagonals add fitting check to make sure all arrays are the same length

Parameters:

Name Type Description Default
x list of ndarrays or numeric

Data to plot.

None
labels list of str

Labels for the arrays. (default: None)

None
template str

Named plotly template. (default: "simple_white")

'simple_white'
output str

Output format. (default: "png")

'png'

Returns:

Name Type Description
Figure

Plotly Figure object plotting correlations between arrays.

Source code in src/codechembook/quickPlots.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def quickGrid(x = None, labels = None, template = "simple_white", output = "png"):
    '''
    Plots correlation between a series of arrays with Plotly.

    Work in progress.  To do:
        place label in the diagonals
        add fitting
        check to make sure all arrays are the same length

    Args:
        x (list of ndarrays or numeric): Data to plot.
        labels (list of str, optional): Labels for the arrays. (default: None)
        template (str, optional): Named plotly template. (default: "simple_white")
        output (str, optional): Output format. (default: "png")

    Returns:
        Figure: Plotly Figure object plotting correlations between arrays.
    '''
    # first make sure that we have lists of lists... 
    # so this first section makes sure that, if we get a single list, we put it in a list
    if type(x[0]) != np.ndarray and type(x[0]) != list: # then x is not an array or list
        xplot = [x]
    else:
        try: 
            xplot = x # if already an array of arrays, then just keep it
        except:
            raise "You need to supply a list or ndarray of floats or ints"

    narrays = len(x)
    gplot = make_subplots(cols = narrays, rows = narrays) # make a square plot

    for j, x1 in enumerate(x): # go through each y array
        for i, x2 in enumerate(x): # go through each x array
            if i == j:
                pass
            else:
                gplot.add_scatter(x = x1, y = x2, 
                                  showlegend=False, 
                                  row = i+1, col = j+1)
                try:
                    ylabel = labels[j]
                except:
                    ylabel = f"y-series {j}"
                try:
                    xlabel = labels[i]
                except:
                    xlabel = f"x-series {i}"
                gplot.update_xaxes(title = xlabel, row = i+1, col = j+1)
                gplot.update_yaxes(title = ylabel, row = i+1, col = j+1)

    gplot.update_layout(template = template)

    process_output(gplot, output)

    return gplot

quickHist(x, xlabel=None, ylabel=None, limits=None, nbins=None, width=None, mode='counts', orientation='vertical', template='simple_white', output='png')

Plot a histogram of 1D data.

Parameters:

Name Type Description Default
x list or ndarray

Collection of numbers to be histogrammed.

required
xlabel str

Title for the x-axis. (default: None (use variable name))

None
ylabel str

Title for the y-axis. (default: None (use variable name))

None
limits list of numeric

Upper and lower limits of data to bin. (default: min and max of x)

None
nbins int

Number of bins. (default: None (automatically determine))

None
width float

Width of the bins. (default: None (automatically determine))

None
mode str

Y-axis is "counts" or "frequency" (default: "counts")

'counts'
orientation str

Vertical or horizontal orientation. (default: "vertical")

'vertical'
template str

Plotly template to use. (default: "simple_white")

'simple_white'
output str

Method to show the plot. (default: "png", options: valid Plotly output types, None for no output)

'png'

Returns:

Name Type Description
Figure

Plotly Figure object containing the histogram.

Source code in src/codechembook/quickPlots.py
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
def quickHist(x, 
              xlabel = None, ylabel = None, 
              limits = None, nbins = None, width = None, 
              mode = "counts",
              orientation = "vertical", # can also be "horizontal"
              template = "simple_white",
              output = "png"):
    """
    Plot a histogram of 1D data.

    Args:
        x (list or ndarray): Collection of numbers to be histogrammed.
        xlabel (str, optional): Title for the x-axis. (default: None (use variable name))
        ylabel (str, optional): Title for the y-axis. (default: None (use variable name))
        limits (list of numeric, optional): Upper and lower limits of data to bin. (default: min and max of x)
        nbins (int, optional): Number of bins. (default: None (automatically determine))
        width (float, optional): Width of the bins. (default: None (automatically determine))
        mode (str, optional): Y-axis is "counts" or "frequency" (default: "counts")
        orientation (str, optional): Vertical or horizontal orientation. (default: "vertical")
        template (str, optional): Plotly template to use. (default: "simple_white")
        output (str, optional): Method to show the plot. (default: "png", options: valid Plotly output types, None for no output)

    Returns:
        Figure: Plotly Figure object containing the histogram.
    """
    # if the user did not supply axis names, then we can just use the variable names
    if xlabel is None:
        xlabel = _get_name_two_calls_ago(x)


    # we will want the iqr for calculating the buffer space on the plot
    q75, q25 = np.percentile(x, [75,25]) # find the places for the inner quartile
    iqr = q75 - q25 # calculate the inner quartile range

    #default to plotting counts, I guess
    bar_centers, bar_lengths = quickBin(x, limits = limits, nbins = nbins, width = width)

    if "counts" in mode:
        # for the ylabel, we can use the mode, if no label was yet supplied. 
        if ylabel is None:
            ylabel = "counts"

    # adjust if we need to change the counts to frequency
    if "freq" in mode: # then we are doing frequency
        bar_lengths = bar_lengths / np.sum(x)

        # for the ylabel, we can use the mode, if no label was yet supplied. 
        if ylabel is None:
            ylabel = "frequency"

    # work out a buffer for the bars on either side
    # calculate the width of bars
    bar_separation = bar_centers[1] - bar_centers[0] # quickBin should always have adjacent bars
    # calculate a buffer based in iqr
    iqr_buffer = 0.05*iqr
    # take whatever is larger
    buffer = max([bar_separation, iqr_buffer])

    # now we can plot a bar chart that looks like a histogram...
    hist = make_subplots()

    if "v" in orientation:
        hist.add_bar(x = bar_centers, y = bar_lengths)
    elif "h" in orientation:
        hist.add_bar(x = bar_lengths, y = bar_centers)


    hist.update_traces(marker = dict(line = dict(width = 1, color = "black")))

    hist.update_xaxes(title = xlabel, range = [min(bar_centers) - buffer, max(bar_centers) + buffer])
    hist.update_yaxes(title = ylabel, range = [0, max(bar_lengths)*1.02])
    hist.update_layout(bargap = 0, template = template)

    process_output(hist, output)

    return hist

quickScatter(x=None, y=None, xlabel=None, ylabel=None, name=None, template='simple_white', mode=None, output='png')

Quickly plot (x,y) data as a scatter plot.

Users can supply single arrays or lists for x- and y-values, a single set of x-values and multiple sets of y-values, or multiple sets of x-values and y-values. If multiple sets of y-values, name and mode can be lists of len(y).

Parameters:

Name Type Description Default
x ndarray or list of ndarray

x-values to plot.

None
y ndarray or list of ndarray

y-values to plot.

None
xlabel str or list of str

x-axis title. (default: None (use variable name))

None
ylabel str or list of str

y-axis title. (default: None (use variable name))

None
name str or list of str

Names of traces. (default: '')

None
template str

Plotly template for formatting Figure. (default: 'simple_white')

'simple_white'
mode str or list of str

Trace appearance for Plotly Scatter object.

None
output str

Method to show the plot. (default: "png")

'png'

Returns:

Name Type Description
Figure

the figure object created.

Source code in src/codechembook/quickPlots.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def quickScatter(x = None, y = None, xlabel = None, ylabel = None, name = None, template = "simple_white", mode = None, output = "png"):
    """
    Quickly plot (x,y) data as a scatter plot.

    Users can supply single arrays or lists for x- and y-values, a single set of x-values and multiple sets of
    y-values, or multiple sets of x-values and y-values.  If multiple sets of y-values, name and mode can be
    lists of len(y).

    Args:
        x (ndarray or list of ndarray, optional): x-values to plot.
        y (ndarray or list of ndarray, optional): y-values to plot.
        xlabel (str or list of str, optional): x-axis title. (default: None (use variable name))
        ylabel (str or list of str, optional): y-axis title. (default: None (use variable name))
        name (str or list of str, optional): Names of traces. (default: '')
        template (str, optional): Plotly template for formatting Figure. (default: 'simple_white')
        mode (str or list of str, optional): Trace appearance for Plotly Scatter object.
        output (str, optional): Method to show the plot. (default: "png")

    Returns:
        Figure: the figure object created.
    """
    # if the user did not supply axis names, then we can just use the variable names
    if xlabel is None:
        xlabel = _get_name_two_calls_ago(x)
    if ylabel is None:
        ylabel = _get_name_two_calls_ago(y)

    # figure out what the user gave us for x and y
    if isinstance(x, list) or isinstance(x, np.ndarray):
        if all(isinstance(xi, list) or isinstance(xi, np.ndarray) for xi in x):
            xplot = x
        elif all(isinstance(xi, float) or isinstance(xi, int) for xi in x):
            xplot = [x]
        else:
            print("You need to supply a list or array of floats or ints for x")
            return

    if isinstance(y, list) or isinstance(y, np.ndarray):
        if all(isinstance(yi, list) or isinstance(yi, np.ndarray) for yi in y):
            yplot = y
        elif all(isinstance(yi, float) or isinstance(yi, int) for yi in y):
            yplot = [y]
        else:
            print("You need to supply a list or array of floats or ints for y")
            return

    # figure out what the user gave us for name
    if name is None:
        name = ['' for y in yplot]
    elif isinstance(name, str):
        if len(yplot) > 1:
            tempname = name
            name = [f'{tempname}{i}' for i, y in enumerate(yplot)]
        if len(yplot) == 1:
            name = [name]
    elif isinstance(name, list):
        if len(name) != len(yplot):
            print("your list of names is a different length than your list of y values")
            return
    else:
        print("there was a problem with your name argument")
        return

    # ensure that we have the same number of items for x and y
    if len(xplot) == 1:
        xplot = [xplot[0]] * len(yplot)
    elif len(xplot) != len(yplot):
        raise "your x values should be a list or array of length equal to y values, or a list or array of 1"

    # start the plotting
    qplot = make_subplots()

    # loop through the provided lists and add them to the plot
    for xi,yi,ni in zip(xplot, yplot, name):
        if len(xi) != len(yi):
            raise "you do not have the same number of x and y points!"
        if mode is None:
            points = go.Scatter(x = xi, y = yi, name = ni)
        elif "lines" in mode or "markers" in mode:
            points = go.Scatter(x = xi, y = yi, mode = mode, name = ni)
        else:
            raise "please enter either 'lines', 'markers', 'lines+markers', or None for mode"
        qplot.add_trace(points)

    # set up the axes
    qplot.update_xaxes(title = str(xlabel)) # cast as string to handle numeric values if passed
    qplot.update_yaxes(title = str(ylabel))

    # confirm that the specified template is one that we have
    if template not in pio.templates.keys():
        print('Invalid template specified, defaulting to simple_white.')
        template = 'simple_white'
    qplot.update_layout(template = template)

    process_output(qplot, output) # check to see how we should be outputting this plot

    return qplot

quickSubs(childPlots=None, layoutfig=None, nrows=None, ncols=None, output='png')

Plot multiple existing Plotly figures in a single Figure as subplots.

Each figure object supplied is added as as a subplot increasing across, then down, the grid.

Parameters:

Name Type Description Default
childPlots list of Figure

Preexisting Plotly Figure objects to be added. (default: None)

None
layoutfig Figure

Figure object that specifies formatting of new figure. (default: None)

None
nrows int

Number of rows of the subplot. (default: None (automatically determine))

None
ncols int

Number of columns of the subplot. (default: None (automatically determine))

None
output str

Output format. (default: "png")

'png'

Returns:

Name Type Description
Figure

New Plotly Figure object containing subplots of all the supplied child plots.

Source code in src/codechembook/quickPlots.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def quickSubs(childPlots = None, 
              layoutfig = None, nrows = None, ncols = None,
              output = "png"):
    '''
    Plot multiple existing Plotly figures in a single Figure as subplots.

    Each figure object supplied is added as as a subplot increasing across, then down, the grid.

    Args:
        childPlots (list of Figure): Preexisting Plotly Figure objects to be added. (default: None)
        layoutfig (Figure, optional): Figure object that specifies formatting of new figure. (default: None)
        nrows (int, optional): Number of rows of the subplot. (default: None (automatically determine))
        ncols (int, optional): Number of columns of the subplot. (default: None (automatically determine))
        output (str, optional): Output format. (default: "png")

    Returns:
        Figure: New Plotly Figure object containing subplots of all the supplied child plots.
    '''
    if nrows == None and ncols == None: # we have specified nothing about the grid to use
        ncols = math.ceil(len(childPlots)**0.5)
        nrows = math.ceil(len(childPlots)/ncols)
    elif nrows == None: # we have only specified the number of columns to use
        nrows = math.ceil(len(childPlots)/ncols)
    elif ncols == None: # we have only specified the number of rows to use
        ncols = math.ceil(len(childPlots)/nrows)

    newfig = make_subplots(rows = nrows, cols = ncols)
    newfigdict = json.loads(newfig.to_json()) # add stuff to this one. <-- need to do this, because we will use the 
    # print(newfigdict)
    # print('end of first newfigdict \n')
    #print(nrows, ncols)

    #figdict = {"data":[], "layout":{}}

    for i, cp in enumerate(childPlots):

        if i == 0: # do not with to append the number
            label = ''
        else:
            label = i+1

        # specify which row and column we are working on
        row = int(i/ncols)+1
        col = int(i%ncols)+1

        # now parse the figure...
        oldfigdict = json.loads(cp.to_json()) 
        for entry in oldfigdict["data"]: # get the indiviual dictionaries in the data list
            entry["xaxis"] = f"x{label}"
            entry["yaxis"] = f"y{label}"
            newfigdict["data"].append(entry) # add modified version to the new figure
        # print(oldfigdict)
        # print('\n')
        # print(i, '\nbefore')
        # print(oldfigdict['layout']["xaxis"])       
        # oldfigdict["layout"][f"xaxis{label}"] = oldfigdict["layout"]["xaxis"] #rename x-axis key
        # oldfigdict["layout"][f"yaxis{label}"] = oldfigdict["layout"]["yaxis"] #rename y-axis key

        # oldfigdict["layout"][f"xaxis{label}"]["anchor"] = f"y{label}"
        # oldfigdict["layout"][f"yaxis{label}"]["anchor"] = f"x{label}"

        temp_x_domain = newfigdict["layout"][f"xaxis{label}"]["domain"]
        temp_y_domain = newfigdict["layout"][f"yaxis{label}"]["domain"]

        newfigdict["layout"][f"xaxis{label}"] = oldfigdict["layout"][f"xaxis"]
        newfigdict["layout"][f"yaxis{label}"] = oldfigdict["layout"][f"yaxis"]
        newfigdict["layout"][f"xaxis{label}"]['domain'] = temp_x_domain
        newfigdict["layout"][f"yaxis{label}"]['domain'] = temp_y_domain
        newfigdict["layout"][f"xaxis{label}"]["anchor"] = f"y{label}" # the anchor for x is relative to y-position
        newfigdict["layout"][f"yaxis{label}"]["anchor"] = f"x{label}" # the anchor for y is relative to x-position
        # newfigdict["layout"][f"xaxis{label}"] = oldfigdict["layout"][f"xaxis{label}"]
        # newfigdict["layout"][f"yaxis{label}"] = oldfigdict["layout"][f"yaxis{label}"]
        # print(i, '\nafter')
        # print(oldfigdict['layout'][f"xaxis{label}"])
    # set up the layout....
    if layoutfig == None:
        layoutfig = childPlots[0]
    layoutfigdict = json.loads(layoutfig.to_json())
    for key in layoutfigdict["layout"]:
        if "axis" not in key: #make sure we are not editing axes, only everything else. 
            newfigdict["layout"][key] = layoutfigdict["layout"][key]

    newfigjson = json.dumps(newfigdict)
    # print(newfigdict)
    newfig = pio.from_json(newfigjson)

    process_output(newfig, output)

    return newfig