Coverage for gwcelery/tasks/p_astro.py: 51%
68 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-14 05:52 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-14 05:52 +0000
1"""Computation of ``p_astro`` by source category and utilities
2related to ``p_astro.json`` source classification files.
3See Kapadia et al (2019), :doi:`10.1088/1361-6382/ab5f2d`, for details.
4"""
5import io
6import json
8from celery.utils.log import get_task_logger
10try:
11 from ligo.p_astro import computation as pastrocomp
12except ImportError: # p_astro older than lscsoft/p_astro!42
13 from ligo import p_astro_computation as pastrocomp
15import numpy as np
16from matplotlib import pyplot as plt
18from .. import app
19from ..util import PromiseProxy, closing_figures, read_json
20from . import gracedb, igwn_alert
22MEAN_VALUES_DICT = PromiseProxy(
23 read_json, ('ligo.data', 'H1L1V1-mean_counts-1126051217-61603201.json'))
25THRESHOLDS_DICT = PromiseProxy(
26 read_json, ('ligo.data', 'H1L1V1-pipeline-far_snr-thresholds.json'))
28P_ASTRO_LIVETIME = PromiseProxy(
29 read_json, ('ligo.data', 'p_astro_livetime.json'))
32log = get_task_logger(__name__)
35@app.task(shared=False)
36def compute_p_astro(snr, far, mass1, mass2, pipeline, instruments):
37 """
38 Task to compute `p_astro` by source category.
40 Parameters
41 ----------
42 snr : float
43 event's SNR
44 far : float
45 event's cfar
46 mass1 : float
47 event's mass1
48 mass2 : float
49 event's mass2
50 instruments : set
51 set of instruments that detected the event
53 Returns
54 -------
55 p_astros : str
56 JSON dump of the p_astro by source category
58 Example
59 -------
60 >>> p_astros = json.loads(compute_p_astro(files))
61 >>> p_astros
62 {'BNS': 0.999, 'BBH': 0.0, 'NSBH': 0.0, 'Terrestrial': 0.001}
64 """
65 # Ensure SNR does not increase indefinitely beyond limiting FAR
66 # for MBTA and PyCBC events
67 snr_choice = pastrocomp.choose_snr(far,
68 snr,
69 pipeline,
70 instruments,
71 THRESHOLDS_DICT)
73 # Define constants to compute bayesfactors
74 snr_star = 8.5
75 far_star = 1 / (30 * 86400)
77 # Compute astrophysical bayesfactor for
78 # GraceDB event
79 fground = 3 * snr_star**3 / (snr_choice**4)
80 bground = far / far_star
81 astro_bayesfac = fground / bground
83 # Update terrestrial count based on far threshold
84 lam_0 = far_star * P_ASTRO_LIVETIME['p_astro_livetime']
85 mean_values_dict = dict(MEAN_VALUES_DICT)
86 mean_values_dict["counts_Terrestrial"] = lam_0
88 # Compute categorical p_astro values
89 p_astro_values = \
90 pastrocomp.evaluate_p_astro_from_bayesfac(astro_bayesfac,
91 mean_values_dict,
92 mass1,
93 mass2)
94 # Dump mean values in json file
95 return json.dumps(p_astro_values)
98def _format_prob(prob):
99 if prob >= 1:
100 return '100%'
101 elif prob <= 0:
102 return '0%'
103 elif prob > 0.99:
104 return '>99%'
105 elif prob < 0.01:
106 return '<1%'
107 else:
108 return '{}%'.format(int(np.round(100 * prob)))
111@app.task(shared=False)
112@closing_figures()
113def plot(contents):
114 """Make a visualization of the source classification.
116 Parameters
117 ----------
118 contents : str, bytes
119 The contents of the ``p_astro.json`` file.
121 Returns
122 -------
123 png : bytes
124 The contents of a PNG file.
126 Notes
127 -----
128 The unusually small size of the plot (2.5 x 2 inches) is optimized for
129 viewing in GraceDB's image display widget.
131 Examples
132 --------
133 .. plot::
134 :include-source:
136 >>> from gwcelery.tasks import p_astro
137 >>> contents = '''
138 ... {"Terrestrial": 0.001, "BNS": 0.65, "NSBH": 0.20,
139 ... "BBH": 0.059}
140 ... '''
141 >>> p_astro.plot(contents)
143 """
144 # Explicitly use a non-interactive Matplotlib backend.
145 plt.switch_backend('agg')
147 classification = json.loads(contents)
148 outfile = io.BytesIO()
150 probs, names = zip(
151 *sorted(zip(classification.values(), classification.keys())))
153 with plt.style.context('seaborn-v0_8-white'):
154 fig, ax = plt.subplots(figsize=(2.5, 2))
155 ax.barh(names, probs)
156 for i, prob in enumerate(probs):
157 ax.annotate(_format_prob(prob), (0, i), (4, 0),
158 textcoords='offset points', ha='left', va='center')
159 ax.set_xlim(0, 1)
160 ax.set_xticks([])
161 ax.tick_params(left=False)
162 for side in ['top', 'bottom', 'right']:
163 ax.spines[side].set_visible(False)
164 fig.tight_layout()
165 fig.savefig(outfile, format='png')
166 return outfile.getvalue()
169@igwn_alert.handler('superevent',
170 'mdc_superevent',
171 shared=False)
172def handle(alert):
173 """LVAlert handler to plot and upload a visualization of every
174 ``*.p_astro.json`` file that is added to a superevent.
175 """
176 if alert['alert_type'] != 'log':
177 return
179 graceid = alert['uid']
180 filename = alert['data'].get('filename')
181 p_astro_filenames = {f'{pipeline}.p_astro.json' for pipeline in
182 ['cwb', 'gstlal', 'mbta', 'pycbc', 'spiir',
183 'RapidPE_RIFT']}
185 if filename in p_astro_filenames:
186 (
187 gracedb.download.s(filename, graceid)
188 |
189 plot.s()
190 |
191 gracedb.upload.s(
192 filename.replace('.json', '.png'), graceid,
193 message=(
194 'Source classification visualization from '
195 '<a href="/api/superevents/{graceid}/files/{filename}">'
196 '{filename}</a>').format(
197 graceid=graceid, filename=filename),
198 tags=['em_follow', 'p_astro', 'public']
199 )
200 ).delay()