Coverage for gwcelery/tasks/gwskynet.py: 97%

78 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-11-14 05:52 +0000

1"""GWSkyNet annotation with GWSkyNet model""" 

2import json 

3import re 

4from functools import cache 

5 

6import numpy as np 

7 

8from .. import app 

9from ..util.tempfile import NamedTemporaryFile 

10from . import gracedb, igwn_alert, superevents 

11 

12manual_pref_event_change_regexp = re.compile( 

13 app.conf['views_manual_preferred_event_log_message'].replace('.', '\\.') 

14 .replace('{}', '.+') 

15) 

16 

17 

18@cache 

19def GWSkyNet_model(): 

20 # FIXME Remove import from function scope once importing GWSkyNet is not a 

21 # slow operation 

22 from GWSkyNet import GWSkyNet 

23 

24 return GWSkyNet.load_GWSkyNet_model() 

25 

26 

27# FIXME: run GWSkyNet on general-purpose workers 

28# once https://git.ligo.org/manleong.chan/gwskynet/-/issues/6 is fixed. 

29@app.task(queue='openmp', shared=False) 

30def gwskynet_annotation(input_list, SNRs, superevent_id): 

31 """Perform the series of tasks necessary for GWSkyNet to 

32 

33 Parameters 

34 ---------- 

35 input_list : list 

36 The output of _download_and_keep_file_name that includes the 

37 downloaded the skymap and the versioned file name of the skymap. 

38 This list is in the form [skymap, skymap_filename]. 

39 snr : numpy array of floats 

40 detector SNRs. 

41 superevent_id : str 

42 superevent uid 

43 skymap_filename : str 

44 versioned filename for skymap 

45 """ 

46 # FIXME Remove import from function scope once importing GWSkyNet is not a 

47 # slow operation 

48 from GWSkyNet import GWSkyNet 

49 

50 filecontents, skymap_filename = input_list 

51 with NamedTemporaryFile(content=filecontents) as fitsfile: 

52 GWSkyNet_input = GWSkyNet.prepare_data(fitsfile.name) 

53 # One of the inputs from BAYESTAR to GWSkyNet is the list of instruments, 

54 # i.e., metadata['instruments'], which is converted to a binary array with 

55 # three elements, i.e. GWSkyNet_input[2], for H1, L1 and V1. 

56 # GWSkyNet 2.4.0 uses this array to indicate detector with SNR >= 4.5 

57 GWSkyNet_input[2][0] = np.where(SNRs >= app.conf['gwskynet_snr_threshold'], 

58 1, 0) 

59 gwskynet_score = GWSkyNet.predict(GWSkyNet_model(), GWSkyNet_input) 

60 FAP, FNP = GWSkyNet.get_rates(gwskynet_score) 

61 fap = FAP[0] 

62 fnp = FNP[0] 

63 gs = gwskynet_score[0] 

64 gwskynet_output = {'superevent_id': superevent_id, 

65 'file': skymap_filename, 

66 'GWSkyNet_score': gs, 

67 'GWSkyNet_FAP': fap, 

68 'GWSkyNet_FNP': fnp} 

69 return json.dumps(gwskynet_output) 

70 

71 

72def get_cbc_event_snr(event): 

73 """Get detector SNRs from the LVAlert packet. 

74 

75 Parameters 

76 ---------- 

77 event : dict 

78 Event dictionary (e.g., the return value from 

79 :meth:`gwcelery.tasks.gracedb.get_event`, or 

80 ``preferred_event_data`` in igwn-alert packet.) 

81 

82 Returns 

83 ------- 

84 snr : numpy array of floats 

85 detector SNRs. 

86 

87 """ 

88 # GWSkyNet 2.4.0 uses this SNR array to modify one of the inputs, so 

89 # snr needs to be formatted such that index 0, 1 and 2 points to H1, 

90 # L1 and V1 respectively 

91 snr = np.zeros(3) 

92 attribs = event['extra_attributes']['SingleInspiral'] 

93 for det in attribs: 

94 if det['ifo'] == 'H1': 

95 snr[0] = det['snr'] 

96 if det['ifo'] == 'L1': 

97 snr[1] = det['snr'] 

98 if det['ifo'] == 'V1': 

99 snr[2] = det['snr'] 

100 return snr 

101 

102 

103@gracedb.task(shared=False) 

104def _download_and_return_file_name(filename, graceid): 

105 """Wrapper around gracedb.download that returns the file name.""" 

106 filecontents = gracedb.download(filename, graceid) 

107 return [filecontents, filename] 

108 

109 

110@gracedb.task(shared=False) 

111def _unpack_gwskynet_annotation_and_upload(gwskynet_output, graceid): 

112 filename = 'gwskynet.json' 

113 gwskynet_output_dict = json.loads(gwskynet_output) 

114 message = ('GWSkyNet annotation from <a href=' 

115 '"/api/events/{graceid}/files/' 

116 '{skymap_filename}">' 

117 '{skymap_filename}</a>.' 

118 ' GWSkyNet score: {cs},' 

119 ' GWSkyNet FAP: {GWSkyNet_FAP},' 

120 ' GWSkyNet FNP: {GWSkyNet_FNP}.').format( 

121 graceid=graceid, 

122 skymap_filename=gwskynet_output_dict['file'], 

123 cs=np.round(gwskynet_output_dict['GWSkyNet_score'], 3), 

124 GWSkyNet_FAP=np.round(gwskynet_output_dict['GWSkyNet_FAP'], 

125 3), 

126 GWSkyNet_FNP=np.round(gwskynet_output_dict['GWSkyNet_FNP'], 

127 3) 

128 ) 

129 return gracedb.upload(gwskynet_output, filename, graceid, message=message, 

130 tags=['em_follow', 'public']) 

131 

132 

133def _should_annotate(preferred_event, new_label, new_log_comment, labels, 

134 alert_type): 

135 # First check if the event passes all of GWSkyNet's annotation criteria 

136 SNRs = get_cbc_event_snr(preferred_event) 

137 

138 if not (preferred_event['search'].lower() == 'allsky' and 

139 preferred_event['far'] <= app.conf['gwskynet_upper_far_threshold'] 

140 and (SNRs >= app.conf['gwskynet_snr_threshold']).sum() >= 2 and 

141 np.sqrt(sum(SNRs**2)) >= 

142 app.conf['gwskynet_network_snr_threshold']): 

143 return False 

144 

145 annotate = False 

146 # Check if the GWSkyNet should annotate in response to this IGWN-Alert 

147 if alert_type == 'label_added': 

148 if superevents.should_publish(preferred_event, significant=False) is \ 

149 False and new_label == 'SKYMAP_READY': 

150 # if the superevent is with FAR higher than the preliminary alert 

151 # threshold, GWSkyNet will anotate the superevent directly. 

152 annotate = True 

153 elif new_label == 'GCN_PRELIM_SENT' or \ 

154 new_label == 'LOW_SIGNIF_PRELIM_SENT': 

155 # if the FAR is lower than the preliminary alert threshold then 

156 # GWSkyNet annotates the superevent if the preliminary alert has 

157 # been sent. 

158 annotate = True 

159 elif 'GCN_PRELIM_SENT' not in labels and 'LOW_SIGNIF_PRELIM_SENT' not in \ 

160 labels: 

161 # GWSkyNet annotations not applied until after initial prelim sent when 

162 # FAR passes alert threshold 

163 pass 

164 elif new_log_comment.startswith('Localization copied from '): 

165 # GWSkyNet will also annotate the superevent if the sky map 

166 # has been changed (i.e. a sky map from a new g-event has been copied) 

167 annotate = True 

168 elif manual_pref_event_change_regexp.match(new_log_comment): 

169 # Need to check for a different log comment if the preferred event has 

170 # been changed manually 

171 annotate = True 

172 

173 return annotate 

174 

175 

176@igwn_alert.handler('superevent', 

177 shared=False) 

178def handle_cbc_superevent(alert): 

179 """"Annotate the CBC preferred events of superevents using GWSkyNet 

180 """ 

181 if alert['object']['preferred_event_data']['group'] != 'CBC': 

182 return 

183 

184 if alert['alert_type'] != 'label_added' and \ 

185 alert['alert_type'] != 'log': 

186 return 

187 

188 superevent_id = alert['uid'] 

189 preferred_event = alert['object']['preferred_event_data'] 

190 new_label = alert['data'].get('name', '') 

191 new_log_comment = alert['data'].get('comment', '') 

192 labels = alert['object'].get('labels', []) 

193 SNRs = get_cbc_event_snr(preferred_event) 

194 

195 if _should_annotate(preferred_event, new_label, new_log_comment, labels, 

196 alert['alert_type']): 

197 ( 

198 gracedb.get_latest_file.s(superevent_id, 

199 'bayestar.multiorder.fits') 

200 | 

201 _download_and_return_file_name.s(superevent_id) 

202 | 

203 gwskynet_annotation.s(SNRs, superevent_id) 

204 | 

205 _unpack_gwskynet_annotation_and_upload.s(superevent_id) 

206 ).apply_async()