Skip to content

Commit c014093

Browse files
committed
Trimmed data load function in forest plot
1 parent 95d56c3 commit c014093

File tree

2 files changed

+78
-100
lines changed

2 files changed

+78
-100
lines changed

dabest/forest_plot.py

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,44 @@ def load_plot_data(
4545
"""
4646
# Effect size and contrast types
4747
effect_attr = "hedges_g" if effect_size == 'delta_g' else effect_size
48+
contrast_attr = {"delta2": "delta_delta", "mini_meta": "mini_meta"}.get(contrast_type)
4849

49-
if contrast_type == "delta":
50-
if idx is not None:
51-
bootstraps, differences, bcalows, bcahighs = [], [], [], []
52-
for current_idx, index_group in enumerate(idx):
53-
current_contrast = data[current_idx]
54-
if len(index_group)>0:
55-
for index in index_group:
56-
current_plot_data = getattr(current_contrast, effect_attr)
57-
bootstraps.append(current_plot_data.results.bootstraps[index])
58-
differences.append(current_plot_data.results.difference[index])
59-
bcalows.append(current_plot_data.results.get(ci_type+'_low')[index])
60-
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[index])
61-
else:
50+
# Testing
51+
if idx is not None:
52+
bootstraps, differences, bcalows, bcahighs = [], [], [], []
53+
for current_idx, index_group in enumerate(idx):
54+
current_contrast = data[current_idx]
55+
if len(index_group)>0:
56+
for index in index_group:
57+
current_plot_data = getattr(current_contrast, effect_attr)
58+
if contrast_type == 'delta2':
59+
if index == 2:
60+
current_plot_data = getattr(current_plot_data, contrast_attr)
61+
bootstrap_name, index_val = "bootstraps_delta_delta", 0
62+
elif index == 0 or index == 1:
63+
bootstrap_name, index_val = "bootstraps", index
64+
else:
65+
raise ValueError("The selected indices must be 0, 1, or 2.")
66+
elif contrast_type == "mini_meta":
67+
num_of_groups = len(getattr(current_contrast, effect_attr).results)
68+
if index == num_of_groups:
69+
current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)
70+
bootstrap_name, index_val = "bootstraps_weighted_delta", 0
71+
elif index < num_of_groups:
72+
bootstrap_name, index_val = "bootstraps", index
73+
else:
74+
msg1 = "There are only {} groups (starting from zero) in this dabest object. ".format(num_of_groups)
75+
msg2 = "The idx given is {}.".format(index)
76+
raise ValueError(msg1+msg2)
77+
else: # contrast_type == 'delta'
78+
bootstrap_name, index_val = "bootstraps", index
79+
80+
bootstraps.append(getattr(current_plot_data.results, bootstrap_name)[index_val])
81+
differences.append(current_plot_data.results.difference[index_val])
82+
bcalows.append(current_plot_data.results.get(ci_type+'_low')[index_val])
83+
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[index_val])
84+
else:
85+
if contrast_type == 'delta':
6286
contrast_plot_data = [getattr(contrast, effect_attr) for contrast in data]
6387
bootstraps_nested = [result.results.bootstraps.to_list() for result in contrast_plot_data]
6488
differences_nested = [result.results.difference.to_list() for result in contrast_plot_data]
@@ -69,41 +93,8 @@ def load_plot_data(
6993
differences = [element for innerList in differences_nested for element in innerList]
7094
bcalows = [element for innerList in bcalows_nested for element in innerList]
7195
bcahighs = [element for innerList in bcahighs_nested for element in innerList]
72-
else:
73-
contrast_attr = {"delta2": "delta_delta", "mini_meta": "mini_meta"}.get(contrast_type)
74-
if idx is not None:
75-
bootstraps, differences, bcalows, bcahighs = [], [], [], []
76-
for current_idx, index_group in enumerate(idx):
77-
current_contrast = data[current_idx]
78-
if len(index_group)>0:
79-
for index in index_group:
80-
if contrast_type == 'delta2':
81-
if index == 2:
82-
current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)
83-
bootstrap_name, index_val = "bootstraps_delta_delta", 0
84-
elif index == 0 or index == 1:
85-
current_plot_data = getattr(current_contrast, effect_attr)
86-
bootstrap_name, index_val = "bootstraps", index
87-
else:
88-
raise ValueError("The selected indices must be 0, 1, or 2.")
89-
else:
90-
num_of_groups = len(getattr(current_contrast, effect_attr).results)
91-
if index == num_of_groups:
92-
current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)
93-
bootstrap_name, index_val = "bootstraps_weighted_delta", 0
94-
elif index < num_of_groups:
95-
current_plot_data = getattr(current_contrast, effect_attr)
96-
bootstrap_name, index_val = "bootstraps", index
97-
else:
98-
msg1 = "There are only {} groups (starting from zero) in this dabest object. ".format(num_of_groups)
99-
msg2 = "The idx given is {}.".format(index)
100-
raise ValueError(msg1+msg2)
101-
102-
bootstraps.append(getattr(current_plot_data.results, bootstrap_name)[index_val])
103-
differences.append(current_plot_data.results.difference[index_val])
104-
bcalows.append(current_plot_data.results.get(ci_type+'_low')[index_val])
105-
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[index_val])
106-
else:
96+
97+
else: # contrast_type == 'delta2' or 'mini_meta'
10798
contrast_plot_data = [getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in data]
10899
attribute_suffix = "weighted_delta" if contrast_type == "mini_meta" else "delta_delta"
109100

@@ -121,7 +112,6 @@ def check_for_errors(**kwargs):
121112
raise ValueError("The `data` argument must be a non-empty list of dabest objects.")
122113

123114
## Check if all contrasts are delta-delta or all are mini-meta
124-
125115
contrast_type = ("delta2" if data[0].delta2
126116
else "mini_meta" if data[0].is_mini_meta
127117
else "delta"
@@ -399,7 +389,6 @@ def get_kwargs(
399389
else:
400390
summary_bars_kwargs = merge_two_dicts(default_summary_bars_kwargs, summary_bars_kwargs)
401391

402-
403392
return (violin_kwargs, zeroline_kwargs, marker_kwargs, errorbar_kwargs,
404393
delta_text_kwargs, contrast_bars_kwargs, summary_bars_kwargs)
405394

nbs/API/forest_plot.ipynb

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,44 @@
105105
" \"\"\"\n",
106106
" # Effect size and contrast types\n",
107107
" effect_attr = \"hedges_g\" if effect_size == 'delta_g' else effect_size\n",
108+
" contrast_attr = {\"delta2\": \"delta_delta\", \"mini_meta\": \"mini_meta\"}.get(contrast_type)\n",
108109
"\n",
109-
" if contrast_type == \"delta\":\n",
110-
" if idx is not None:\n",
111-
" bootstraps, differences, bcalows, bcahighs = [], [], [], []\n",
112-
" for current_idx, index_group in enumerate(idx):\n",
113-
" current_contrast = data[current_idx]\n",
114-
" if len(index_group)>0:\n",
115-
" for index in index_group:\n",
116-
" current_plot_data = getattr(current_contrast, effect_attr)\n",
117-
" bootstraps.append(current_plot_data.results.bootstraps[index])\n",
118-
" differences.append(current_plot_data.results.difference[index])\n",
119-
" bcalows.append(current_plot_data.results.get(ci_type+'_low')[index])\n",
120-
" bcahighs.append(current_plot_data.results.get(ci_type+'_high')[index])\n",
121-
" else:\n",
110+
" # Testing\n",
111+
" if idx is not None:\n",
112+
" bootstraps, differences, bcalows, bcahighs = [], [], [], []\n",
113+
" for current_idx, index_group in enumerate(idx):\n",
114+
" current_contrast = data[current_idx]\n",
115+
" if len(index_group)>0:\n",
116+
" for index in index_group:\n",
117+
" current_plot_data = getattr(current_contrast, effect_attr)\n",
118+
" if contrast_type == 'delta2':\n",
119+
" if index == 2:\n",
120+
" current_plot_data = getattr(current_plot_data, contrast_attr)\n",
121+
" bootstrap_name, index_val = \"bootstraps_delta_delta\", 0\n",
122+
" elif index == 0 or index == 1:\n",
123+
" bootstrap_name, index_val = \"bootstraps\", index\n",
124+
" else:\n",
125+
" raise ValueError(\"The selected indices must be 0, 1, or 2.\")\n",
126+
" elif contrast_type == \"mini_meta\":\n",
127+
" num_of_groups = len(getattr(current_contrast, effect_attr).results)\n",
128+
" if index == num_of_groups:\n",
129+
" current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)\n",
130+
" bootstrap_name, index_val = \"bootstraps_weighted_delta\", 0\n",
131+
" elif index < num_of_groups:\n",
132+
" bootstrap_name, index_val = \"bootstraps\", index\n",
133+
" else:\n",
134+
" msg1 = \"There are only {} groups (starting from zero) in this dabest object. \".format(num_of_groups)\n",
135+
" msg2 = \"The idx given is {}.\".format(index)\n",
136+
" raise ValueError(msg1+msg2) \n",
137+
" else: # contrast_type == 'delta'\n",
138+
" bootstrap_name, index_val = \"bootstraps\", index \n",
139+
"\n",
140+
" bootstraps.append(getattr(current_plot_data.results, bootstrap_name)[index_val])\n",
141+
" differences.append(current_plot_data.results.difference[index_val])\n",
142+
" bcalows.append(current_plot_data.results.get(ci_type+'_low')[index_val])\n",
143+
" bcahighs.append(current_plot_data.results.get(ci_type+'_high')[index_val]) \n",
144+
" else:\n",
145+
" if contrast_type == 'delta':\n",
122146
" contrast_plot_data = [getattr(contrast, effect_attr) for contrast in data]\n",
123147
" bootstraps_nested = [result.results.bootstraps.to_list() for result in contrast_plot_data]\n",
124148
" differences_nested = [result.results.difference.to_list() for result in contrast_plot_data]\n",
@@ -129,41 +153,8 @@
129153
" differences = [element for innerList in differences_nested for element in innerList]\n",
130154
" bcalows = [element for innerList in bcalows_nested for element in innerList]\n",
131155
" bcahighs = [element for innerList in bcahighs_nested for element in innerList]\n",
132-
" else:\n",
133-
" contrast_attr = {\"delta2\": \"delta_delta\", \"mini_meta\": \"mini_meta\"}.get(contrast_type)\n",
134-
" if idx is not None:\n",
135-
" bootstraps, differences, bcalows, bcahighs = [], [], [], []\n",
136-
" for current_idx, index_group in enumerate(idx):\n",
137-
" current_contrast = data[current_idx]\n",
138-
" if len(index_group)>0:\n",
139-
" for index in index_group:\n",
140-
" if contrast_type == 'delta2':\n",
141-
" if index == 2:\n",
142-
" current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)\n",
143-
" bootstrap_name, index_val = \"bootstraps_delta_delta\", 0\n",
144-
" elif index == 0 or index == 1:\n",
145-
" current_plot_data = getattr(current_contrast, effect_attr)\n",
146-
" bootstrap_name, index_val = \"bootstraps\", index\n",
147-
" else:\n",
148-
" raise ValueError(\"The selected indices must be 0, 1, or 2.\")\n",
149-
" else:\n",
150-
" num_of_groups = len(getattr(current_contrast, effect_attr).results)\n",
151-
" if index == num_of_groups:\n",
152-
" current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)\n",
153-
" bootstrap_name, index_val = \"bootstraps_weighted_delta\", 0\n",
154-
" elif index < num_of_groups:\n",
155-
" current_plot_data = getattr(current_contrast, effect_attr)\n",
156-
" bootstrap_name, index_val = \"bootstraps\", index\n",
157-
" else:\n",
158-
" msg1 = \"There are only {} groups (starting from zero) in this dabest object. \".format(num_of_groups)\n",
159-
" msg2 = \"The idx given is {}.\".format(index)\n",
160-
" raise ValueError(msg1+msg2)\n",
161-
"\n",
162-
" bootstraps.append(getattr(current_plot_data.results, bootstrap_name)[index_val])\n",
163-
" differences.append(current_plot_data.results.difference[index_val])\n",
164-
" bcalows.append(current_plot_data.results.get(ci_type+'_low')[index_val])\n",
165-
" bcahighs.append(current_plot_data.results.get(ci_type+'_high')[index_val]) \n",
166-
" else:\n",
156+
"\n",
157+
" else: # contrast_type == 'delta2' or 'mini_meta'\n",
167158
" contrast_plot_data = [getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in data]\n",
168159
" attribute_suffix = \"weighted_delta\" if contrast_type == \"mini_meta\" else \"delta_delta\"\n",
169160
"\n",
@@ -181,7 +172,6 @@
181172
" raise ValueError(\"The `data` argument must be a non-empty list of dabest objects.\")\n",
182173
" \n",
183174
" ## Check if all contrasts are delta-delta or all are mini-meta\n",
184-
"\n",
185175
" contrast_type = (\"delta2\" if data[0].delta2 \n",
186176
" else \"mini_meta\" if data[0].is_mini_meta\n",
187177
" else \"delta\"\n",
@@ -459,7 +449,6 @@
459449
" else:\n",
460450
" summary_bars_kwargs = merge_two_dicts(default_summary_bars_kwargs, summary_bars_kwargs)\n",
461451
"\n",
462-
"\n",
463452
" return (violin_kwargs, zeroline_kwargs, marker_kwargs, errorbar_kwargs, \n",
464453
" delta_text_kwargs, contrast_bars_kwargs, summary_bars_kwargs)\n",
465454
"\n",

0 commit comments

Comments
 (0)