-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot.py
1682 lines (1383 loc) · 55.1 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Wrapper functions with boilerplate code for making plots the way I like them
"""
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
from builtins import zip
from builtins import map
from builtins import range
from past.utils import old_div
import matplotlib
import matplotlib.patheffects as pe
import numpy as np, warnings
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
import scipy.stats
from . import misc
import my
import pandas
import itertools
import collections
def alpha_blend_with_mask(rgb0, rgb1, alpha0, mask0):
"""Alpha-blend two RGB images, masking out one image.
rgb0 : first image, to be masked
Must be 3-dimensional, and rgb0.shape[-1] must be 3 or 4
If rgb0.shape[-1] == 4, the 4th channel will be dropped
rgb1 : second image, wil not be masked
Must be 3-dimensional, and rgb1.shape[-1] must be 3 or 4
If rgb1.shape[-1] == 4, the 4th channel will be dropped
Then, must have same shape as rgb0
alpha0 : the alpha to apply to rgb0. (1 - alpha) will be applied to
mask0 : True where to ignore rgb0
Must have dimension 2 or 3
If 2-dimensional, will be replicated along the channel dimension
Then, must have same shape as rgb0
Returns : array of same shape as rgb0 and rgb1
Where mask0 is True, the result is the same as rgb1
Where mask1 is False, the result is rgb0 * alpha0 + rgb1 * (1 - alpha0)
"""
# Replicate mask along color channel if necessary
if mask0.ndim == 2:
mask0 = np.stack([mask0] * 3, axis=-1)
# Check 3-dimensional
assert mask0.ndim == 3
assert rgb0.ndim == 3
assert rgb1.ndim == 3
# Drop alpha if present
if rgb0.shape[-1] == 4:
rgb0 = rgb0[:, :, :3]
if rgb1.shape[-1] == 4:
rgb1 = rgb1[:, :, :3]
if mask0.shape[-1] == 4:
mask0 = mask0[:, :, :3]
# Error check
assert rgb0.shape == rgb1.shape
assert mask0.shape == rgb0.shape
# Blend
blended = alpha0 * rgb0 + (1 - alpha0) * rgb1
# Flatten to apply mask
blended_flat = blended.flatten()
mask_flat = mask0.flatten()
replace_with = rgb1.flatten()
# Masked replace
blended_flat[mask_flat] = replace_with[mask_flat]
# Reshape to original
replaced_blended = blended_flat.reshape(blended.shape)
# Return
return replaced_blended
def custom_RdBu_r():
"""Custom RdBu_r colormap with true white at center"""
# Copied from matplotlib source: lib/matplotlib/_cm.py
# And adjusted to go to true white at center
_RdBu_data = (
(0.40392156862745099, 0.0 , 0.12156862745098039),
(0.69803921568627447, 0.09411764705882353, 0.16862745098039217),
(0.83921568627450982, 0.37647058823529411, 0.30196078431372547),
(0.95686274509803926, 0.6470588235294118 , 0.50980392156862742),
(0.99215686274509807, 0.85882352941176465, 0.7803921568627451 ),
(1,1,1),#(0.96862745098039216, 0.96862745098039216, 0.96862745098039216),
(0.81960784313725488, 0.89803921568627454, 0.94117647058823528),
(0.5725490196078431 , 0.77254901960784317, 0.87058823529411766),
(0.2627450980392157 , 0.57647058823529407, 0.76470588235294112),
(0.12941176470588237, 0.4 , 0.67450980392156867),
(0.0196078431372549 , 0.18823529411764706, 0.38039215686274508)
)
# Copied from matplotlib source: lib/matplotlib/cm.py
myrdbu = matplotlib.colors.LinearSegmentedColormap.from_list(
'myrdbu', _RdBu_data[::-1], matplotlib.rcParams['image.lut'])
# Return
return myrdbu
def smooth_and_plot_versus_depth(
data,
colname,
ax=None,
NS_sigma=40,
RS_sigma=20,
n_depth_bins=101,
depth_min=0,
depth_max=1600,
datapoint_plot_kwargs=None,
smoothed_plot_kwargs=None,
plot_layer_boundaries=True,
layer_boundaries_ylim=None,
):
"""Plot individual datapoints and smoothed versus depth.
data : DataFrame
Must have columns "Z_corrected", "NS", and `colname`, which become
x- and y- coordinates.
colname : string
Name of column containing data
ax : Axis, or None
if None, creates ax
NS_sigma, RS_sigma : float
The standard deviation of the smoothing kernel to apply to each
depth_min, depth_max, n_depth_bins : float, float, int
The x-coordinates at which the smoothed results are evaluated
datapoint_plot_kwargs : dict
Plot kwargs for individual data points.
Defaults:
'marker': 'o', 'ls': 'none', 'ms': 1.5, 'mew': 0, 'alpha': .25,
smoothed_plot_kwargs : dict
Plot kwargs for smoothed line.
Defaults: 'lw': 1.5, 'path_effects': path_effects
plot_layer_boundaries: bool
If True, plot layer boundaries
layer_boundaries_ylim : tuple of length 2, or None
If not None, layer boundaries are plotted to these ylim
If None, ax.get_ylim() is used after plotting everything else
Returns: ax
"""
## Set up defaults
# Bins at which to evaluate smoothed
depth_bins = np.linspace(depth_min, depth_max, n_depth_bins)
# datapoint_plot_kwargs
default_datapoint_plot_kwargs = {
'marker': 'o', 'ls': 'none', 'ms': 1, 'mew': 1,
'alpha': .3, 'mfc': 'none',
}
if datapoint_plot_kwargs is not None:
default_datapoint_plot_kwargs.update(datapoint_plot_kwargs)
use_datapoint_plot_kwargs = default_datapoint_plot_kwargs
# smoothed_plot_kwargs
path_effects = [pe.Stroke(linewidth=3, foreground='k'), pe.Normal()]
default_smoothed_plot_kwargs = {
'lw': 1.5,
'path_effects': path_effects,
}
if smoothed_plot_kwargs is not None:
default_smoothed_plot_kwargs.update(smoothed_plot_kwargs)
use_smoothed_plot_kwargs = default_smoothed_plot_kwargs
## Plot versus depth
if ax is None:
f, ax = plt.subplots()
# Iterate over NS
for NS, sub_data in data.groupby('NS'):
if NS:
color = 'b'
sigma = NS_sigma
else:
color = 'r'
sigma = RS_sigma
# Get the data to smooth
to_smooth = sub_data.set_index('Z_corrected')[colname]
# Smooth
smoothed = my.misc.gaussian_sum_smooth_pandas(
to_smooth, depth_bins, sigma=sigma)
# Plot the individual data points
ax.plot(
to_smooth.index,
to_smooth.values,
color=color,
zorder=0,
**use_datapoint_plot_kwargs,
)
# Plot the smoothed
ax.plot(
smoothed, color=color,
**use_smoothed_plot_kwargs)
## Pretty
my.plot.despine(ax)
ax.set_xticks((0, 500, 1000, 1500))
ax.set_xlim((0, 1500))
ax.set_xticklabels(('0.0', '0.5', '1.0', '1.5'))
ax.set_xlabel('depth in cortex (mm)')
## Add layer boundaries
if plot_layer_boundaries:
# ylim for the boundaries
if layer_boundaries_ylim is None:
layer_boundaries_ylim = ax.get_ylim()
# Layer boundaries
layer_boundaries = [128, 419, 626, 1006, 1366]
layer_names = ['L1', 'L2/3', 'L4', 'L5', 'L6', 'L6b']
# Centers of layers (for naming)
layer_depth_bins = np.concatenate(
[[-50], layer_boundaries, [1500]]).astype(np.float)
layer_centers = (layer_depth_bins[:-1] + layer_depth_bins[1:]) / 2.0
# Adjust position of L2/3 and L6 slightly
layer_centers[1] = layer_centers[1] - 50
layer_centers[2] = layer_centers[2] + 10
layer_centers[3] = layer_centers[3] + 25
layer_centers[-2] = layer_centers[-2] + 50
# Plot each (but not top of L1 or bottom of L6)
for lb in layer_boundaries[1:-1]:
ax.plot(
[lb, lb], layer_boundaries_ylim,
color='gray', lw=.8, zorder=-1)
# Set the boundaries tight
ax.set_ylim(layer_boundaries_ylim)
# Warn
if data[colname].max() > layer_boundaries_ylim[1]:
print(
"warning: max datapoint {} ".format(data[colname].max()) +
"greater than layer_boundaries_ylim[1]")
if data[colname].min() < layer_boundaries_ylim[0]:
print(
"warning: min datapoint {} ".format(data[colname].min()) +
"less than layer_boundaries_ylim[0]")
# Label the layer names
# x in data, y in figure
blended_transform = matplotlib.transforms.blended_transform_factory(
ax.transData, ax.figure.transFigure)
# Name each (but not L1 or L6b)
zobj = zip(layer_names[1:-1], layer_centers[1:-1])
for layer_name, layer_center in zobj:
ax.text(
layer_center, .98, layer_name,
ha='center', va='center', size=12, transform=blended_transform)
## Return ax
return ax
def plot_by_depth_and_layer(df, column, combine_layer_5=True, aggregate='median',
ax=None, ylim=None, agg_plot_kwargs=None, point_alpha=.5, point_ms=3,
layer_label_offset=-.1, agg_plot_meth='rectangle'):
"""Plot values by depth and layer
df : DataFrame
Should have columns 'Z_corrected', 'layer', 'NS', and `column`
column : name of column in `df` to plot
combine_layer_5 : whether to combine 5a and 5b
aggregate : None, 'mean', or 'median'
ax : where to plot
ylim : desired ylim (affects layer name position)
agg_plot_kwargs : how to plot aggregated
"""
# Set agg_plot_kwargs
default_agg_plot_kwargs = {'marker': '_', 'ls': 'none', 'ms': 16,
'mew': 4, 'alpha': .5}
if agg_plot_kwargs is not None:
default_agg_plot_kwargs.update(agg_plot_kwargs)
agg_plot_kwargs = default_agg_plot_kwargs
# Layer boundaries
layer_boundaries = [128, 419, 626, 1006, 1366]
layer_names = ['L1', 'L2/3', 'L4', 'L5', 'L6', 'L6b']
layer_depth_bins = np.concatenate([[-50], layer_boundaries, [1500]]).astype(np.float)
layer_centers = (layer_depth_bins[:-1] + layer_depth_bins[1:]) / 2.0
# Make a copy
df = df.copy()
# Optionally combine layers 5a and 5b
if combine_layer_5:
# Combine layers 5a and 5b
df['layer'] = df['layer'].astype(str)
df.loc[df['layer'].isin(['5a', '5b']), 'layer'] = '5'
# Optionally create figure
if ax is None:
f, ax = plt.subplots(figsize=(4.5, 3.5))
# Plot datapoints for NS and RS separately
NS_l = [False, True]
for NS, sub_df in df.groupby('NS'):
# Color by NS
color = 'b' if NS else 'r'
# Plot raw data
ax.plot(
sub_df.loc[:, 'Z_corrected'].values,
sub_df.loc[:, column].values,
color=color, marker='o', mfc='white',
ls='none', alpha=point_alpha, ms=point_ms, clip_on=False,
)
# Keep track of this
if ylim is None:
ylim = ax.get_ylim()
# Plot aggregates of NS and RS separately
if aggregate is not None:
for NS, sub_df in df.groupby('NS'):
# Color by NS
color = 'b' if NS else 'r'
# Aggregate over bins
gobj = sub_df.groupby('layer')[column]
counts_by_bin = gobj.size()
# Aggregate
if aggregate == 'mean':
agg_by_bin = gobj.mean()
elif aggregate == 'median':
agg_by_bin = gobj.median()
else:
raise ValueError("unrecognized aggregated method: {}".format(aggregate))
# Block out aggregates with too few data points
agg_by_bin[counts_by_bin <= 3] = np.nan
# Reindex to ensure this matches layer_centers
# TODO: Make this match the way it was aggregated
agg_by_bin = agg_by_bin.reindex(['1', '2/3', '4', '5', '6', '6b'])
assert len(agg_by_bin) == len(layer_centers)
if agg_plot_meth == 'markers':
# Plot aggregates as individual markers
ax.plot(
layer_centers,
agg_by_bin.values,
color=color,
**agg_plot_kwargs
)
elif agg_plot_meth == 'rectangle':
# Plot aggregates as a rectangle
for n_layer, layer in enumerate(['2/3', '4', '5', '6']):
lo_depth = layer_depth_bins[n_layer + 1]
hi_depth = layer_depth_bins[n_layer + 2]
value = agg_by_bin.loc[layer]
#~ ax.plot([lo_depth, hi_depth], [value, value],
#~ color='k', ls='-', lw=2.5)
#~ ax.plot([lo_depth, hi_depth], [value, value],
#~ color=color, ls='--', lw=2.5)
# zorder brings the patch on top of the datapoints
patch = plt.Rectangle(
(lo_depth + .1 * (hi_depth - lo_depth), value),
width=((hi_depth-lo_depth) * .8),
height=(.03 * np.diff(ylim)),
ec='k', fc=color, alpha=.5, lw=1.5, zorder=20)
ax.add_patch(patch)
# Plot layer boundaries, skipping L1 and L6b
for lb in layer_boundaries[1:-1]:
ax.plot([lb, lb], [ylim[0], ylim[1]], color='gray', ls='-', lw=1)
# Name the layers
text_ypos = ylim[1] + layer_label_offset * (ylim[1] - ylim[0])
for layer_name, layer_center in zip(layer_names, layer_centers):
if layer_name in ['L1', 'L6b']:
continue
ax.text(layer_center, text_ypos, layer_name[1:], ha='center', va='bottom',
color='k')
# Reset the ylim
ax.set_ylim(ylim)
# xticks
ax.set_xticks((200, 600, 1000, 1400))
ax.set_xticklabels([])
ax.set_xlim((100, 1500))
my.plot.despine(ax)
ax.set_xlabel('depth in cortex')
return ax
def connected_pairs(v1, v2, p=None, signif=None, shapes=None, colors=None,
labels=None, ax=None):
"""Plot columns of (v1, v2) as connected pairs"""
import my.stats
if ax is None:
f, ax = plt.subplots()
# Arrayify
v1 = np.asarray(v1)
v2 = np.asarray(v2)
if signif is None:
signif = np.zeros_like(v1)
else:
signif = np.asarray(signif)
# Defaults
if shapes is None:
shapes = ['o'] * v1.shape[0]
if colors is None:
colors = ['k'] * v1.shape[0]
if labels is None:
labels = ['' * v1.shape[1]]
# Store location of each pair
xvals = []
xvalcenters = []
# Iterate over columns
for n, (col1, col2, signifcol, label) in enumerate(zip(v1.T, v2.T, signif.T, labels)):
# Where to plot this pair
x1 = n * 2
x2 = n * 2 + 1
xvals += [x1, x2]
xvalcenters.append(np.mean([x1, x2]))
# Iterate over specific pairs
for val1, val2, sigval, shape, color in zip(col1, col2, signifcol, shapes, colors):
lw = 2 if sigval else 0.5
ax.plot([x1, x2], [val1, val2], marker=shape, color=color,
ls='-', mec=color, mfc='none', lw=lw)
# Plot the median
median1 = np.median(col1[~np.isnan(col1)])
median2 = np.median(col2[~np.isnan(col2)])
ax.plot([x1, x2], [median1, median2], marker='o', color='k', ls='-',
mec=color, mfc='none', lw=4)
# Sigtest on pop
utest_res = my.stats.r_utest(col1[~np.isnan(col1)], col2[~np.isnan(col2)],
paired='TRUE', fix_float=1e6)
if utest_res['p'] < 0.05:
ax.text(np.mean([x1, x2]), 1.0, '*', va='top', ha='center')
# Label center of each pair
ax.set_xlim([xvals[0]-1, xvals[-1] + 1])
if labels:
ax.set_xticks(xvalcenters)
ax.set_xticklabels(labels)
return ax, xvals
def radar_by_stim(evoked_resp, ax=None, label_stim=True):
"""Given a df of spikes by stim, plot radar
evoked_resp should have arrays of counts indexed by all the stimulus
names
"""
from ns5_process import LBPB
if ax is None:
f, ax = plt.subplots(figsize=(3, 3), subplot_kw={'polar': True})
# Heights of the bars
evoked_resp = evoked_resp.ix[LBPB.mixed_stimnames]
barmeans = evoked_resp.apply(np.mean)
barstderrs = evoked_resp.apply(misc.sem)
# Set up the radar
radar_dists = [[barmeans[sname+block]
for sname in ['ri_hi', 'le_hi', 'le_lo', 'ri_lo']]
for block in ['_lc', '_pc']]
# make it circular
circle_meansLB = np.array(radar_dists[0] + [radar_dists[0][0]])
circle_meansPB = np.array(radar_dists[1] + [radar_dists[1][0]])
circle_errsLB = np.array([barstderrs[sname+'_lc'] for sname in
['ri_hi', 'le_hi', 'le_lo', 'ri_lo', 'ri_hi']])
circle_errsPB = np.array([barstderrs[sname+'_pc'] for sname in
['ri_hi', 'le_hi', 'le_lo', 'ri_lo', 'ri_hi']])
# x-values (really theta values)
xts = np.array([45, 135, 225, 315, 405])*np.pi/180.0
# Plot LB means and errs
#ax.errorbar(xts, circle_meansLB, circle_errsLB, color='b')
ax.plot(xts, circle_meansLB, color='b')
ax.fill_between(x=xts, y1=circle_meansLB-circle_errsLB,
y2=circle_meansLB+circle_errsLB, color='b', alpha=.5)
# Plot PB means and errs
ax.plot(xts, circle_meansPB, color='r')
ax.fill_between(x=xts, y1=circle_meansPB-circle_errsPB,
y2=circle_meansPB+circle_errsPB, color='r', alpha=.5)
# Tick labels
xtls = ['right\nhigh', 'left\nhigh', 'left\nlow', 'right\nlow']
ax.set_xticks(xts)
ax.set_xticklabels([]) # if xtls, will overlap
ax.set_yticks(ax.get_ylim()[1:])
ax.set_yticks([])
# manual tick
if label_stim:
for xt, xtl in zip(xts, xtls):
ax.text(xt, ax.get_ylim()[1]*1.25, xtl, size='large',
ha='center', va='center')
# pretty and save
#f.tight_layout()
return ax
def despine(ax, detick=True, which_ticks='both', which=('right', 'top')):
"""Remove the top and right axes from the plot
which_ticks : can be 'major', 'minor', or 'both
"""
for w in which:
ax.spines[w].set_visible(False)
if detick:
ax.tick_params(which=which_ticks, **{w:False})
return ax
def font_embed():
"""Produce files that can be usefully imported into AI"""
# For PDF imports:
# Not sure what this does
matplotlib.rcParams['ps.useafm'] = True
# Makes it so that the text is editable
matplotlib.rcParams['pdf.fonttype'] = 42
# For SVG imports:
# AI can edit the text but can't import the font itself
#matplotlib.rcParams['svg.fonttype'] = 'svgfont'
# seems to work better
matplotlib.rcParams['svg.fonttype'] = 'none'
# This explicitly sets the font.family as Arial
# Otherwise it is the default (sans-serif) which was interpreted as
# Arial by AI, but as something else by Inkscape.
matplotlib.rcParams['font.family'] = 'Arial'
def manuscript_defaults():
"""For putting into a word document.
Typical figure is approx 3"x3" panels. Apply a 50% scaling.
I think these defaults should be 14pt, actually.
"""
matplotlib.rcParams['font.sans-serif'] = 'Arial'
matplotlib.rcParams['axes.labelsize'] = 14
matplotlib.rcParams['axes.titlesize'] = 14
matplotlib.rcParams['xtick.labelsize'] = 14
matplotlib.rcParams['ytick.labelsize'] = 14
matplotlib.rcParams['font.size'] = 14 # ax.text objects
matplotlib.rcParams['legend.fontsize'] = 14
def poster_defaults():
"""For a poster
Title: 80pt
Section headers: 60pt
Body text: 40pt
Axis labels, tick marks, subplot titles: 32pt
Typical panel size: 6"
So it's easiest to just use manuscript_defaults() and double
the size.
"""
matplotlib.rcParams['font.sans-serif'] = 'Arial'
matplotlib.rcParams['axes.labelsize'] = 14
matplotlib.rcParams['axes.titlesize'] = 14
matplotlib.rcParams['xtick.labelsize'] = 14
matplotlib.rcParams['ytick.labelsize'] = 14
matplotlib.rcParams['font.size'] = 14 # ax.text objects
matplotlib.rcParams['legend.fontsize'] = 14
def presentation_defaults():
"""For importing into presentation.
Typical figure is 11" wide and 7" tall. No scaling should be necessary.
Typically presentation figures have more whitespace and fewer panels
than manuscript figures.
Actually I think the font size should not be below 18, unless really
necessary.
"""
matplotlib.rcParams['font.sans-serif'] = 'Arial'
matplotlib.rcParams['axes.labelsize'] = 18
matplotlib.rcParams['axes.titlesize'] = 18
matplotlib.rcParams['xtick.labelsize'] = 18
matplotlib.rcParams['ytick.labelsize'] = 18
matplotlib.rcParams['font.size'] = 18 # ax.text objects
matplotlib.rcParams['legend.fontsize'] = 18
def figure_1x1_small():
"""Smaller f, ax for single panel with a nearly square axis
"""
f, ax = plt.subplots(figsize=(2.2, 2))
# left = .3 is for the case of yticklabels with two signif digits
f.subplots_adjust(bottom=.28, left=.3, right=.95, top=.95)
return f, ax
def figure_1x1_square():
"""Standard size f, ax for single panel with a square axis
Room for xlabel, ylabel, and title in 16pt font
"""
f, ax = plt.subplots(figsize=(3, 3))
f.subplots_adjust(bottom=.23, left=.26, right=.9, top=.87)
return f, ax
def figure_1x1_standard():
"""Standard size f, ax for single panel with a slightly rectangular axis
Room for xlabel, ylabel, and title in 16pt font
"""
f, ax = plt.subplots(figsize=(3, 2.5))
f.subplots_adjust(bottom=.24, left=.26, right=.93, top=.89)
return f, ax
def figure_1x2_standard(**kwargs):
"""Standard size f, ax for single panel with a slightly rectangular axis
Room for xlabel, ylabel, and title in 16pt font
"""
f, axa = plt.subplots(1, 2, figsize=(6, 2.5), **kwargs)
f.subplots_adjust(left=.15, right=.9, wspace=.2, bottom=.22, top=.85)
return f, axa
def figure_1x2_small(**kwargs):
f, axa = plt.subplots(1, 2, figsize=(4, 2), **kwargs)
f.subplots_adjust(left=.2, right=.975, wspace=.3, bottom=.225, top=.8)
return f, axa
def rescue_tick(ax=None, f=None, x=3, y=3):
# Determine what axes to process
if ax is not None:
ax_l = [ax]
elif f is not None:
ax_l = f.axes
else:
raise ValueError("either ax or f must not be None")
# Iterate over axes to process
for ax in ax_l:
if x is not None:
ax.xaxis.set_major_locator(plt.MaxNLocator(x))
if y is not None:
ax.yaxis.set_major_locator(plt.MaxNLocator(y))
def crucifix(x, y, xerr=None, yerr=None, relative_CIs=False, p=None,
ax=None, factor=None, below=None, above=None, null=None,
data_range=None, axtype=None, zero_substitute=1e-6,
suppress_null_error_bars=False):
"""Crucifix plot y vs x around the unity line
x, y : array-like, length N, paired data
xerr, yerr : array-like, Nx2, confidence intervals around x and y
relative_CIs : if True, then add x to xerr (and ditto yerr)
p : array-like, length N, p-values for each point
ax : graphical object
factor : multiply x, y, and errors by this value
below : dict of point specs for points significantly below the line
above : dict of point specs for points significantly above the line
null : dict of point specs for points nonsignificant
data_range : re-adjust the data limits to this
axtype : if 'symlog' then set axes to symlog
"""
# Set up point specs
if below is None:
below = {'color': 'b', 'marker': '.', 'ls': '-', 'alpha': 1.0,
'mec': 'b', 'mfc': 'b'}
if above is None:
above = {'color': 'r', 'marker': '.', 'ls': '-', 'alpha': 1.0,
'mec': 'r', 'mfc': 'r'}
if null is None:
null = {'color': 'gray', 'marker': '.', 'ls': '-', 'alpha': 0.5,
'mec': 'gray', 'mfc': 'gray'}
# Defaults for data range
if data_range is None:
data_range = [None, None]
else:
data_range = list(data_range)
# Convert to array and optionally multiply
if factor is None:
factor = 1
x = np.asarray(x) * factor
y = np.asarray(y) * factor
# p-values
if p is not None:
p = np.asarray(p)
# Same with errors but optionally also reshape and recenter
if xerr is not None:
xerr = np.asarray(xerr) * factor
if xerr.ndim == 1:
xerr = np.array([-xerr, xerr]).T
if relative_CIs:
xerr += x[:, None]
if yerr is not None:
yerr = np.asarray(yerr) * factor
if yerr.ndim == 1:
yerr = np.array([-yerr, yerr]).T
if relative_CIs:
yerr += y[:, None]
# Create figure handles
if ax is None:
f = plt.figure()
ax = f.add_subplot(111)
# Plot each point
min_value, max_value = [], []
for n, (xval, yval) in enumerate(zip(x, y)):
# Get p-value and error bars for this point
pval = 1.0 if p is None else p[n]
xerrval = xerr[n] if xerr is not None else None
yerrval = yerr[n] if yerr is not None else None
# Replace neginfs
if xerrval is not None:
xerrval[xerrval == 0] = zero_substitute
if yerrval is not None:
yerrval[yerrval == 0] = zero_substitute
#~ if xval < .32:
#~ 1/0
# What color
if pval < .05 and yval < xval:
pkwargs = below
elif pval < .05 and yval > xval:
pkwargs = above
else:
pkwargs = null
lkwargs = pkwargs.copy()
lkwargs.pop('marker')
# Now actually plot the point
ax.plot([xval], [yval], **pkwargs)
# plot error bars, keep track of data range
if xerrval is not None and not (suppress_null_error_bars and pkwargs is null):
ax.plot(xerrval, [yval, yval], **lkwargs)
max_value += list(xerrval)
else:
max_value.append(xval)
# same for y
if yerrval is not None and not (suppress_null_error_bars and pkwargs is null):
ax.plot([xval, xval], yerrval, **lkwargs)
max_value += list(yerrval)
else:
max_value.append(xval)
# Plot the unity line
if data_range[0] is None:
data_range[0] = np.min(max_value)
if data_range[1] is None:
data_range[1] = np.max(max_value)
ax.plot(data_range, data_range, 'k:')
ax.set_xlim(data_range)
ax.set_ylim(data_range)
# symlog
if axtype:
ax.set_xscale(axtype)
ax.set_yscale(axtype)
ax.axis('scaled')
return ax
def scatter_with_trend(x, y, xname='X', yname='Y', ax=None,
legend_font_size='medium', **kwargs):
"""Scatter plot `y` vs `x`, also linear regression line
Kwargs sent to the point plotting
"""
if 'marker' not in kwargs:
kwargs['marker'] = '.'
if 'ls' not in kwargs:
kwargs['ls'] = ''
if 'color' not in kwargs:
kwargs['color'] = 'g'
x = np.asarray(x)
y = np.asarray(y)
dropna = np.isnan(x) | np.isnan(y)
x = x[~dropna]
y = y[~dropna]
if ax is None:
f = plt.figure()
ax = f.add_subplot(111)
ax.plot(x, y, **kwargs)
m, b, rval, pval, stderr = \
scipy.stats.stats.linregress(x.flatten(), y.flatten())
trend_line_label = 'r=%0.3f p=%0.3f' % (rval, pval)
ax.plot([x.min(), x.max()], m * np.array([x.min(), x.max()]) + b, 'k:',
label=trend_line_label)
ax.legend(loc='best', prop={'size':legend_font_size})
ax.set_xlabel(xname)
ax.set_ylabel(yname)
return ax
def vert_bar(bar_lengths, bar_labels=None, bar_positions=None, ax=None,
bar_errs=None, bar_colors=None, bar_hatches=None, tick_labels_rotation=90,
plot_bar_ends='ks', bar_width=.8, mpl_ebar=False,
yerr_is_absolute=True):
"""Vertical bar plot with nicer defaults
bar_lengths : heights of the bars, length N
bar_labels : text labels
bar_positions : x coordinates of the bar centers. Default is range(N)
ax : axis to plot in
bar_errs : error bars. Will be cast to array
If 1d, then these are drawn +/-
If 2d, then (UNLIKE MATPLOTLIB) they are interpreted as the exact
locations of the endpoints. Transposed as necessary. If mpl_ebar=True,
then it is passed directly to `errorbar`, and it needs to be 2xN and
the bars are drawn at -row0 and +row1.
bar_colors : colors of bars. If longer than N, then the first N are taken
bar_hatches : set the hatches like this. length N
plot_bar_ends : if not None, then this is plotted at the tops of the bars
bar_width : passed as width to ax.bar
mpl_ebar : controls behavior of errorbars
yerr_is_absolute : if not mpl_ebar, and you are independently specifying
the locations of each end exactly, set this to True
Does nothing if yerr is 1d
"""
# Default bar positions
if bar_positions is None:
bar_positions = list(range(len(bar_lengths)))
bar_centers = bar_positions
# Arrayify bar lengths
bar_lengths = np.asarray(bar_lengths)
N = len(bar_lengths)
# Default bar colors
if bar_colors is not None:
bar_colors = np.asarray(bar_colors)
if len(bar_colors) > N:
bar_color = bar_colors[:N]
# Deal with errorbars (if specified, and not mpl_ebar behavior)
if bar_errs is not None and not mpl_ebar:
bar_errs = np.asarray(bar_errs)
# Transpose as necessary
if bar_errs.ndim == 2 and bar_errs.shape[0] != 2:
if bar_errs.shape[1] == 2:
bar_errs = bar_errs.T
else:
raise ValueError("weird shape for bar_errs: %r" % bar_errs)
if bar_errs.ndim == 2 and yerr_is_absolute:
# Put into MPL syntax: -row0, +row1
assert bar_errs.shape[1] == N
bar_errs = np.array([
bar_lengths - bar_errs[0],
bar_errs[1] - bar_lengths])
# Create axis objects
if ax is None:
f, ax = plt.subplots()
# Make the bar plot
ax.bar(left=bar_centers, bottom=0, width=bar_width, height=bar_lengths,
align='center', yerr=bar_errs, capsize=0,
ecolor='k', color=bar_colors, orientation='vertical')
# Hatch it
if bar_hatches is not None:
for p, hatch in zip(ax.patches, bar_hatches): p.set_hatch(hatch)
# Plot squares on the bar tops
if plot_bar_ends:
ax.plot(bar_centers, bar_lengths, plot_bar_ends)
# Labels
ax.set_xticks(bar_centers)
ax.set_xlim(bar_centers[0] - bar_width, bar_centers[-1] + bar_width)
if bar_labels:
ax.set_xticklabels(bar_labels, rotation=tick_labels_rotation)
return ax
def horiz_bar(bar_lengths, bar_labels=None, bar_positions=None, ax=None,
bar_errs=None, bar_colors=None, bar_hatches=None):
"""Horizontal bar plot"""
# Defaults
if bar_positions is None:
bar_positions = list(range(len(bar_lengths)))
bar_centers = bar_positions
if ax is None:
f, ax = plt.subplots()
# Make the bar plot
ax.bar(left=0, bottom=bar_centers, width=bar_lengths, height=.8,
align='center', xerr=bar_errs, capsize=0,
ecolor='k', color=bar_colors, orientation='horizontal')
# Hatch it
if bar_hatches is not None:
for p, hatch in zip(ax.patches, bar_hatches): p.set_hatch(hatch)
# Plot squares on the bar tops
ax.plot(bar_lengths, bar_centers, 'ks')
# Labels
ax.set_yticks(bar_centers)
ax.set_yticklabels(bar_labels)
return ax
def auto_subplot(n, return_fig=True, squeeze=False, **kwargs):
"""Return nx and ny for n subplots total"""
nx = int(np.floor(np.sqrt(n)))
ny = int(np.ceil(old_div(n, float(nx))))
if return_fig:
return plt.subplots(nx, ny, squeeze=squeeze, **kwargs)
else:
return nx, ny
def imshow(C, x=None, y=None, ax=None,
extent=None, xd_range=None, yd_range=None,
cmap=plt.cm.RdBu_r, origin='upper', interpolation='nearest', aspect='auto',
axis_call='tight', clim=None, center_clim=False,
skip_coerce=False, **kwargs):
"""Wrapper around imshow with better defaults.
Plots "right-side up" with the first pixel C[0, 0] in the upper left,
not the lower left. So it's like an image or a matrix, not like
a graph. This done by setting the `origin` to 'upper', and by
appropriately altering `extent` to account for this flip.
C must be regularly-spaced. See this example for how to handle irregular:
http://stackoverflow.com/questions/14120222/matplotlib-imshow-with-irregular-spaced-data-points
C - Two-dimensional array of data
If C has shape (m, n), then the image produced will have m rows
and n columns.
x - array of values corresonding to the x-coordinates of the columns
Only use this if you want numerical values on the columns.