Skip to content

Commit 454c4dc

Browse files
authored
Merge pull request #212 from sangyu/feat-multicontrast_whorlmap_forest_plot_integration
Feature: Multicontrast object and new whorlmap
2 parents 5d83bb1 + 4e1c212 commit 454c4dc

File tree

309 files changed

+3389
-199
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

309 files changed

+3389
-199
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# DABEST-Python
22

3-
43
<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->
54

65
[![minimal Python

dabest/_delta_objects.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,17 +435,19 @@ def __init__(self, effectsizedataframe, permutation_count,
435435
self.__control_N,
436436
self.__test_var,
437437
self.__test_N)
438+
439+
self.__bootstraps_variance = ci2g.calculate_bootstraps_var(self.__bootstraps)
438440

439441
# Compute the weighted average mean differences of the bootstrap data
440442
# using the pooled group variances of the raw data as the inverse of
441443
# weights
442444
self.__bootstraps_weighted_delta = ci2g.calculate_weighted_delta(
443-
self.__group_var,
445+
self.__bootstraps_variance,
444446
self.__bootstraps)
445447

446448
# Compute the weighted average mean difference based on the raw data
447449
self.__difference = es.weighted_delta(np.array(self.__effsizedf["difference"]),
448-
self.__group_var)
450+
self.__bootstraps_variance)
449451

450452
sorted_weighted_deltas = npsort(self.__bootstraps_weighted_delta)
451453

@@ -753,6 +755,14 @@ def group_var(self):
753755
in order.
754756
'''
755757
return self.__group_var
758+
759+
@property
760+
def bootstraps_var(self):
761+
'''
762+
Return the variances of each bootstrapped mean difference distribution
763+
in order.
764+
'''
765+
return self.__bootstraps_variance
756766

757767

758768
@property

dabest/_effsize_objects.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,7 @@ def plot(
11431143
face_color=None,
11441144

11451145
raw_desat=0.5, # swarm_desat=0.5, OLD # bar_desat=0.5, OLD
1146-
contrast_desat=1, # halfviolin_desat=1, OLD
1146+
contrast_desat=1.0, # halfviolin_desat=1, OLD
11471147

11481148
raw_alpha=None, # NEW
11491149
contrast_alpha=0.8, # halfviolin_alpha=0.8, OLD
@@ -1478,7 +1478,8 @@ def plot(
14781478

14791479
if raw_alpha is None:
14801480
raw_alpha = (0.4 if self.is_proportional and self.is_paired
1481-
else 0.5 if self.is_paired
1481+
else 0.5 if self.is_paired and (color_col is not None or self.__delta2)
1482+
else 0.2 if self.is_paired and color_col is None
14821483
else 1.0
14831484
)
14841485

dabest/_modidx.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
'dabest/_stats_tools/confint_2group_diff.py'),
2828
'dabest._stats_tools.confint_2group_diff.bootstrap_indices': ( 'API/confint_2group_diff.html#bootstrap_indices',
2929
'dabest/_stats_tools/confint_2group_diff.py'),
30+
'dabest._stats_tools.confint_2group_diff.calculate_bootstraps_var': ( 'API/confint_2group_diff.html#calculate_bootstraps_var',
31+
'dabest/_stats_tools/confint_2group_diff.py'),
3032
'dabest._stats_tools.confint_2group_diff.calculate_group_var': ( 'API/confint_2group_diff.html#calculate_group_var',
3133
'dabest/_stats_tools/confint_2group_diff.py'),
3234
'dabest._stats_tools.confint_2group_diff.calculate_weighted_delta': ( 'API/confint_2group_diff.html#calculate_weighted_delta',
@@ -107,6 +109,36 @@
107109
'dabest/misc_tools.py'),
108110
'dabest.misc_tools.show_legend': ('API/misc_tools.html#show_legend', 'dabest/misc_tools.py'),
109111
'dabest.misc_tools.unpack_and_add': ('API/misc_tools.html#unpack_and_add', 'dabest/misc_tools.py')},
112+
'dabest.multi': { 'dabest.multi.MultiContrast': ('API/multi.html#multicontrast', 'dabest/multi.py'),
113+
'dabest.multi.MultiContrast.__init__': ('API/multi.html#multicontrast.__init__', 'dabest/multi.py'),
114+
'dabest.multi.MultiContrast.__repr__': ('API/multi.html#multicontrast.__repr__', 'dabest/multi.py'),
115+
'dabest.multi.MultiContrast._extract_data': ('API/multi.html#multicontrast._extract_data', 'dabest/multi.py'),
116+
'dabest.multi.MultiContrast._extract_single_contrast': ( 'API/multi.html#multicontrast._extract_single_contrast',
117+
'dabest/multi.py'),
118+
'dabest.multi.MultiContrast._validate_and_parse_structure': ( 'API/multi.html#multicontrast._validate_and_parse_structure',
119+
'dabest/multi.py'),
120+
'dabest.multi.MultiContrast._validate_ci_type': ( 'API/multi.html#multicontrast._validate_ci_type',
121+
'dabest/multi.py'),
122+
'dabest.multi.MultiContrast._validate_contrast_consistency': ( 'API/multi.html#multicontrast._validate_contrast_consistency',
123+
'dabest/multi.py'),
124+
'dabest.multi.MultiContrast._validate_effect_size': ( 'API/multi.html#multicontrast._validate_effect_size',
125+
'dabest/multi.py'),
126+
'dabest.multi.MultiContrast._validate_effect_size_compatibility': ( 'API/multi.html#multicontrast._validate_effect_size_compatibility',
127+
'dabest/multi.py'),
128+
'dabest.multi.MultiContrast._validate_individual_dabest_obj': ( 'API/multi.html#multicontrast._validate_individual_dabest_obj',
129+
'dabest/multi.py'),
130+
'dabest.multi.MultiContrast.bootstraps': ('API/multi.html#multicontrast.bootstraps', 'dabest/multi.py'),
131+
'dabest.multi.MultiContrast.confidence_intervals': ( 'API/multi.html#multicontrast.confidence_intervals',
132+
'dabest/multi.py'),
133+
'dabest.multi.MultiContrast.effect_sizes': ('API/multi.html#multicontrast.effect_sizes', 'dabest/multi.py'),
134+
'dabest.multi.MultiContrast.forest_plot': ('API/multi.html#multicontrast.forest_plot', 'dabest/multi.py'),
135+
'dabest.multi.MultiContrast.get_bootstrap_by_position': ( 'API/multi.html#multicontrast.get_bootstrap_by_position',
136+
'dabest/multi.py'),
137+
'dabest.multi.MultiContrast.whorlmap': ('API/multi.html#multicontrast.whorlmap', 'dabest/multi.py'),
138+
'dabest.multi._sample_bootstrap': ('API/multi.html#_sample_bootstrap', 'dabest/multi.py'),
139+
'dabest.multi._spiralize': ('API/multi.html#_spiralize', 'dabest/multi.py'),
140+
'dabest.multi.combine': ('API/multi.html#combine', 'dabest/multi.py'),
141+
'dabest.multi.whorlmap': ('API/multi.html#whorlmap', 'dabest/multi.py')},
110142
'dabest.plot_tools': { 'dabest.plot_tools.SwarmPlot': ('API/plot_tools.html#swarmplot', 'dabest/plot_tools.py'),
111143
'dabest.plot_tools.SwarmPlot.__init__': ( 'API/plot_tools.html#swarmplot.__init__',
112144
'dabest/plot_tools.py'),

dabest/_stats_tools/confint_2group_diff.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
__all__ = ['create_jackknife_indexes', 'create_repeated_indexes', 'compute_meandiff_jackknife', 'bootstrap_indices',
77
'compute_bootstrapped_diff', 'delta2_bootstrap_loop', 'compute_delta2_bootstrapped_diff',
88
'compute_meandiff_bias_correction', 'compute_interval_limits', 'calculate_group_var',
9-
'calculate_weighted_delta']
9+
'calculate_bootstraps_var', 'calculate_weighted_delta']
1010

1111
# %% ../../nbs/API/confint_2group_diff.ipynb 4
1212
import numpy as np
@@ -319,15 +319,23 @@ def calculate_group_var(control_var, control_N, test_var, test_N):
319319

320320
return pooled_var
321321

322+
def calculate_bootstraps_var(bootstraps):
322323

323-
def calculate_weighted_delta(group_var, differences):
324+
bootstraps_var_list = [np.var(x, ddof=1) for x in bootstraps]
325+
bootstraps_var_array = np.array(bootstraps_var_list)
326+
327+
return bootstraps_var_array
328+
329+
330+
331+
def calculate_weighted_delta(bootstrap_dist_var, differences):
324332
"""
325333
Compute the weighted deltas.
326334
"""
327335

328-
weight = 1 / group_var
336+
weight = np.true_divide(1, bootstrap_dist_var)
329337
denom = np.sum(weight)
330338
num = 0.0
331339
for i in range(len(weight)):
332340
num += weight[i] * differences[i]
333-
return num / denom
341+
return np.true_divide(num, denom)

dabest/_stats_tools/effsize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,11 +392,11 @@ def _compute_hedges_correction_factor(n1,
392392

393393
# %% ../../nbs/API/effsize.ipynb 13
394394
@njit(cache=True)
395-
def weighted_delta(difference, group_var):
395+
def weighted_delta(difference, bootstrap_dist_var):
396396
'''
397397
Compute the weighted deltas where the weight is the inverse of the
398398
pooled group difference.
399399
'''
400400

401-
weight = np.true_divide(1, group_var)
401+
weight = np.true_divide(1, bootstrap_dist_var)
402402
return np.sum(difference*weight)/np.sum(weight)

dabest/misc_tools.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ def get_params(
203203

204204
def get_kwargs(
205205
plot_kwargs: dict,
206-
ytick_color
206+
ytick_color,
207+
is_paired: bool = False
207208
):
208209
"""
209210
Extracts the kwargs from the `plot_kwargs` object for use in the plotter function.
@@ -214,6 +215,8 @@ def get_kwargs(
214215
Kwargs passed to the plot function.
215216
ytick_color : str or color list
216217
Color of the yticks.
218+
is_paired : bool, optional
219+
A boolean flag to determine if the plot is for paired data. Default is False.
217220
"""
218221
from .misc_tools import merge_two_dicts
219222

@@ -334,7 +337,7 @@ def get_kwargs(
334337
default_group_summaries_kwargs = {
335338
"zorder": 3,
336339
"lw": 2,
337-
"alpha": 1,
340+
"alpha": 1 if not is_paired else 0.6,
338341
'gap_width_percent': 1.5,
339342
'offset': 0.1,
340343
'color': None
@@ -513,7 +516,7 @@ def get_color_palette(
513516
idx: list,
514517
all_plot_groups: list,
515518
delta2: bool,
516-
sankey: bool
519+
proportional: bool
517520
):
518521
"""
519522
Create the color palette to be used in the plotter function.
@@ -534,9 +537,11 @@ def get_color_palette(
534537
A list of all the group names.
535538
delta2 : bool
536539
A boolean flag to determine if the plot will have a delta-delta effect size.
537-
sankey : bool
538-
A boolean flag to determine if the plot is for a Sankey diagram.
540+
proportional : bool
541+
A boolean flag to determine if the plot is for a proportional plot.
539542
"""
543+
sankey = True if proportional and show_pairs else False
544+
540545
# Create color palette that will be shared across subplots.
541546
color_col = plot_kwargs["color_col"]
542547
if color_col is None:
@@ -548,7 +553,13 @@ def get_color_palette(
548553
color_groups = pd.unique(plot_data[color_col])
549554
bootstraps_color_by_group = False
550555
if show_pairs:
551-
bootstraps_color_by_group = False
556+
if plot_kwargs["custom_palette"] is not None:
557+
if delta2 or sankey:
558+
bootstraps_color_by_group = False
559+
else:
560+
bootstraps_color_by_group = True
561+
else:
562+
bootstraps_color_by_group = False
552563

553564
# Handle the color palette.
554565
filled = True
@@ -599,6 +610,17 @@ def get_color_palette(
599610
groups_in_palette = {
600611
k: custom_pal[k] for k in color_groups
601612
}
613+
elif proportional and not sankey: # barplots (unpaired proportional data)
614+
keys = list(custom_pal.keys())
615+
if all(k in keys for k in [1, 0]) and len(keys) == 2:
616+
groups_in_palette = {
617+
k: custom_pal[k] for k in [1, 0]
618+
}
619+
bootstraps_color_by_group = False
620+
else:
621+
groups_in_palette = {
622+
k: custom_pal[k] for k in all_plot_groups if k in color_groups
623+
}
602624
elif sankey:
603625
groups_in_palette = {
604626
k: custom_pal[k] for k in [1, 0]
@@ -915,7 +937,7 @@ def initialize_fig(
915937
raw_label = plot_kwargs["raw_label"]
916938
if raw_label is None:
917939
if proportional:
918-
raw_label = "Proportion of Success" if effect_size_type != "cohens_h" else "Value"
940+
raw_label = "Proportion of success" if effect_size_type != "cohens_h" else "Value"
919941
else:
920942
raw_label = yvar
921943

@@ -929,16 +951,16 @@ def initialize_fig(
929951

930952
# Set contrast axes y-label.
931953
contrast_label_dict = {
932-
"mean_diff": "Mean Difference",
933-
"median_diff": "Median Difference",
954+
"mean_diff": "Mean difference",
955+
"median_diff": "Median difference",
934956
"cohens_d": "Cohen's d",
935957
"hedges_g": "Hedges' g",
936958
"cliffs_delta": "Cliff's delta",
937959
"cohens_h": "Cohen's h",
938960
}
939961

940962
if proportional and effect_size_type != "cohens_h":
941-
default_contrast_label = "Proportion Difference"
963+
default_contrast_label = "Proportion difference"
942964
else:
943965
default_contrast_label = contrast_label_dict[effect_size_type]
944966

@@ -1856,13 +1878,15 @@ def color_picker(color_type: str,
18561878
elements: list,
18571879
color_col: str,
18581880
show_pairs: bool,
1859-
color_palette: dict) -> list:
1881+
color_palette: dict,
1882+
bootstraps_color_by_group: bool) -> list:
18601883
num_of_elements = len(elements)
18611884
colors = (
18621885
[kwargs.pop('color')] * num_of_elements
18631886
if kwargs.get('color', None) is not None
18641887
else ['black'] * num_of_elements
1865-
if color_col is not None or show_pairs
1888+
# if color_col is not None or show_pairs
1889+
if color_col is not None or not bootstraps_color_by_group
18661890
else list(color_palette.values())
18671891
)
18681892
if color_type in ['contrast', 'summary', 'delta_text']:
@@ -1877,7 +1901,7 @@ def color_picker(color_type: str,
18771901
return final_colors
18781902

18791903

1880-
def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, color_col, show_pairs,
1904+
def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, color_col, show_pairs, bootstraps_color_by_group,
18811905
plot_data = None, xvar = None, yvar = None, # Raw data
18821906
results = None, ticks_to_plot = None, extra_delta = None, # Contrast data
18831907
reference_band = None, summary_axes = None, ci_type = None # Summary data
@@ -1951,7 +1975,8 @@ def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, co
19511975
elements = ticks_to_plot if bar_type=='contrast' else ticks,
19521976
color_col = color_col,
19531977
show_pairs = show_pairs,
1954-
color_palette = plot_palette_raw
1978+
color_palette = plot_palette_raw,
1979+
bootstraps_color_by_group = bootstraps_color_by_group
19551980
)
19561981
if bar_type == 'contrast' and extra_delta is not None:
19571982
colors.append('black')

0 commit comments

Comments
 (0)