@@ -203,7 +203,8 @@ def get_params(
203203
204204def get_kwargs (
205205 plot_kwargs : dict ,
206- ytick_color
206+ ytick_color ,
207+ is_paired : bool = False
207208 ):
208209 """
209210 Extracts the kwargs from the `plot_kwargs` object for use in the plotter function.
@@ -214,6 +215,8 @@ def get_kwargs(
214215 Kwargs passed to the plot function.
215216 ytick_color : str or color list
216217 Color of the yticks.
218+ is_paired : bool, optional
219+ A boolean flag to determine if the plot is for paired data. Default is False.
217220 """
218221 from .misc_tools import merge_two_dicts
219222
@@ -334,7 +337,7 @@ def get_kwargs(
334337 default_group_summaries_kwargs = {
335338 "zorder" : 3 ,
336339 "lw" : 2 ,
337- "alpha" : 1 ,
340+ "alpha" : 1 if not is_paired else 0.6 ,
338341 'gap_width_percent' : 1.5 ,
339342 'offset' : 0.1 ,
340343 'color' : None
@@ -513,7 +516,7 @@ def get_color_palette(
513516 idx : list ,
514517 all_plot_groups : list ,
515518 delta2 : bool ,
516- sankey : bool
519+ proportional : bool
517520 ):
518521 """
519522 Create the color palette to be used in the plotter function.
@@ -534,9 +537,11 @@ def get_color_palette(
534537 A list of all the group names.
535538 delta2 : bool
536539 A boolean flag to determine if the plot will have a delta-delta effect size.
537- sankey : bool
538- A boolean flag to determine if the plot is for a Sankey diagram .
540+ proportional : bool
541+ A boolean flag to determine if the plot is for a proportional plot .
539542 """
543+ sankey = True if proportional and show_pairs else False
544+
540545 # Create color palette that will be shared across subplots.
541546 color_col = plot_kwargs ["color_col" ]
542547 if color_col is None :
@@ -548,7 +553,13 @@ def get_color_palette(
548553 color_groups = pd .unique (plot_data [color_col ])
549554 bootstraps_color_by_group = False
550555 if show_pairs :
551- bootstraps_color_by_group = False
556+ if plot_kwargs ["custom_palette" ] is not None :
557+ if delta2 or sankey :
558+ bootstraps_color_by_group = False
559+ else :
560+ bootstraps_color_by_group = True
561+ else :
562+ bootstraps_color_by_group = False
552563
553564 # Handle the color palette.
554565 filled = True
@@ -599,6 +610,17 @@ def get_color_palette(
599610 groups_in_palette = {
600611 k : custom_pal [k ] for k in color_groups
601612 }
613+ elif proportional and not sankey : # barplots (unpaired proportional data)
614+ keys = list (custom_pal .keys ())
615+ if all (k in keys for k in [1 , 0 ]) and len (keys ) == 2 :
616+ groups_in_palette = {
617+ k : custom_pal [k ] for k in [1 , 0 ]
618+ }
619+ bootstraps_color_by_group = False
620+ else :
621+ groups_in_palette = {
622+ k : custom_pal [k ] for k in all_plot_groups if k in color_groups
623+ }
602624 elif sankey :
603625 groups_in_palette = {
604626 k : custom_pal [k ] for k in [1 , 0 ]
@@ -915,7 +937,7 @@ def initialize_fig(
915937 raw_label = plot_kwargs ["raw_label" ]
916938 if raw_label is None :
917939 if proportional :
918- raw_label = "Proportion of Success " if effect_size_type != "cohens_h" else "Value"
940+ raw_label = "Proportion of success " if effect_size_type != "cohens_h" else "Value"
919941 else :
920942 raw_label = yvar
921943
@@ -929,16 +951,16 @@ def initialize_fig(
929951
930952 # Set contrast axes y-label.
931953 contrast_label_dict = {
932- "mean_diff" : "Mean Difference " ,
933- "median_diff" : "Median Difference " ,
954+ "mean_diff" : "Mean difference " ,
955+ "median_diff" : "Median difference " ,
934956 "cohens_d" : "Cohen's d" ,
935957 "hedges_g" : "Hedges' g" ,
936958 "cliffs_delta" : "Cliff's delta" ,
937959 "cohens_h" : "Cohen's h" ,
938960 }
939961
940962 if proportional and effect_size_type != "cohens_h" :
941- default_contrast_label = "Proportion Difference "
963+ default_contrast_label = "Proportion difference "
942964 else :
943965 default_contrast_label = contrast_label_dict [effect_size_type ]
944966
@@ -1856,13 +1878,15 @@ def color_picker(color_type: str,
18561878 elements : list ,
18571879 color_col : str ,
18581880 show_pairs : bool ,
1859- color_palette : dict ) -> list :
1881+ color_palette : dict ,
1882+ bootstraps_color_by_group : bool ) -> list :
18601883 num_of_elements = len (elements )
18611884 colors = (
18621885 [kwargs .pop ('color' )] * num_of_elements
18631886 if kwargs .get ('color' , None ) is not None
18641887 else ['black' ] * num_of_elements
1865- if color_col is not None or show_pairs
1888+ # if color_col is not None or show_pairs
1889+ if color_col is not None or not bootstraps_color_by_group
18661890 else list (color_palette .values ())
18671891 )
18681892 if color_type in ['contrast' , 'summary' , 'delta_text' ]:
@@ -1877,7 +1901,7 @@ def color_picker(color_type: str,
18771901 return final_colors
18781902
18791903
1880- def prepare_bars_for_plot (bar_type , bar_kwargs , horizontal , plot_palette_raw , color_col , show_pairs ,
1904+ def prepare_bars_for_plot (bar_type , bar_kwargs , horizontal , plot_palette_raw , color_col , show_pairs , bootstraps_color_by_group ,
18811905 plot_data = None , xvar = None , yvar = None , # Raw data
18821906 results = None , ticks_to_plot = None , extra_delta = None , # Contrast data
18831907 reference_band = None , summary_axes = None , ci_type = None # Summary data
@@ -1951,7 +1975,8 @@ def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, co
19511975 elements = ticks_to_plot if bar_type == 'contrast' else ticks ,
19521976 color_col = color_col ,
19531977 show_pairs = show_pairs ,
1954- color_palette = plot_palette_raw
1978+ color_palette = plot_palette_raw ,
1979+ bootstraps_color_by_group = bootstraps_color_by_group
19551980 )
19561981 if bar_type == 'contrast' and extra_delta is not None :
19571982 colors .append ('black' )
0 commit comments