-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlambda_function.py
More file actions
289 lines (231 loc) · 9.95 KB
/
lambda_function.py
File metadata and controls
289 lines (231 loc) · 9.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
''' Implements a lambda function that works as an API Gateway endpoint.
The function will receive a JSON object (event) with the following format:
{
'function': 'spam_or_ham',
'messages': [
'message 1',
'message 2',
'message 3',]
}
The function will return a JSON object with the following format:
{
'status_code': 200,
'function': 'spam_or_ham',
'responses': {
'message 1': 'ham',
'message 2': 'spam',
'message 3': 'ham'
},
'errors': []
}
NOTES:
1. The function takes in the original text messages as a list of strings.
2. The responses are stored in a dictionary, with the original text message as
the key and the predicted classes as the value.
'''
# To run locally, install the required libraries
# !python -m pip install -r Resources/lambda_requirements.txt
# Import the required libraries
import nltk # For text processing
import pickle # For loading the vectorizer
import keras # For deep learning model (Still requires TensorFlow backend)
import json # For JSON encoding/decoding
import argparse # For parsing input arguments
from numpy import argmax # For finding the index of the maximum value
# Import the NLTK libraries
'''
The following libraries are needed for processing the text messages:
1. punkt: This package is used to split the text into individual words or sentences.
2. stopwords: Stop words are common words that are removed from text processing tasks to improve the
accuracy and efficiency of the results. Examples: 'the', 'is', 'and', 'a', etc.
3. wordnet: This package is used for lemmatization, which is the process of converting a word to its
base form. Example: The lemma of the word 'running' is 'run'.
'''
# NOTE: When deployed to AWS, these libraries should already be available in an AWS Lambda Layer.
# The following line sets the path to their location in the AWS Lambda Layer.
nltk.data.path.append('./opt/nltk_data')
nltk.download('punkt_tab')
nltk.download('stopwords')
nltk.download('wordnet')
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
lemmatizer = WordNetLemmatizer()
# Import the custom tokenizer - Used for encoding the messages with the same vectorizer as the trained model
from Resources.CustomTokenizer import Custom_Tokenizer
tokenizer_Instance = Custom_Tokenizer()
MODEL_NAME = 'Resources/Sequential_0.99875_10_64_1_64_relu_softmax.keras'
VECTORIZER_NAME = 'Resources/Sequential_0.99875_10_64_1_64_relu_softmax.pkl'
STATUS_CODE = 'status_code'
RESPONSES = 'responses'
ERRORS = 'errors'
STATUS_CODE_SUCCESS = 200
STATUS_CODE_BAD_REQUEST = 400
SPAM_OR_HAM = 'spam_or_ham'
SUPPORTED_FUNCTIONS = [SPAM_OR_HAM]
MESSAGES = 'messages'
FUNCTION = 'function'
SPAM_OR_HAM_FIELDS = [FUNCTION, MESSAGES]
EMPTY_STRING = ''
EMPTY_DICT = {}
EMPTY_LIST = []
def lambda_handler(event, context=None):
""" Main Lambda function
Args:
event (dict): A dictionary containing the input data and parameters
context (object): An object containing information about the invocation,
function, and execution environment
Returns:
dict: A JSON object containing the status code with function name and
response data for successful queries or error information if the
request fails.
"""
try:
# Validate that the event has the required fields and contains at least one message
validation_errors = event_is_valid(event)
if validation_errors is not None:
return validation_errors
# Load the model from the file
model = keras.models.load_model(MODEL_NAME)
# Get the messages from the event
messages = event[MESSAGES]
# Load the vectorizer from the pickle-encoded file
with open(VECTORIZER_NAME, 'rb') as file:
vectorizer = pickle.load(file)
vectorizer.tokenizer = tokenizer_Instance.custom_tokenizer
# Use the vectorizer to encode the messages
# Convert the messages into TF-IDF vector using the same vectorizer as the trained model
encoded_messages = vectorizer.transform(messages).toarray()
# Check all messages to see if they are spam or ham
return spam_or_ham(model, messages, encoded_messages)
except Exception as e:
return {
STATUS_CODE: STATUS_CODE_BAD_REQUEST,
FUNCTION: SPAM_OR_HAM,
RESPONSES: EMPTY_DICT,
ERRORS: [str(e)]
}
def event_is_valid(event):
""" Validate the event object
Args:
event (dict): A dictionary containing the input data and parameters
Returns:
dict: A JSON object containing the status code and error information if the
request fails. Otherwise, None is returned.
"""
error_response = {
STATUS_CODE: STATUS_CODE_BAD_REQUEST,
FUNCTION: EMPTY_STRING,
RESPONSES: EMPTY_DICT,
ERRORS: EMPTY_LIST
}
# Validate that the event was provided, is a dictionary, and is not empty
if event is None:
error_response[ERRORS].append('Event object is missing')
return error_response
if not isinstance(event, dict):
error_response[ERRORS].append('Event object is not a dictionary')
return error_response
if len(event) == 0:
error_response[ERRORS].append('Event object is empty')
return error_response
# Validate the function field exists, is a string, and contains a valid function name
if FUNCTION not in event:
error_response[ERRORS].append('Function field is required')
return error_response
if not isinstance(event[FUNCTION], str):
error_response[ERRORS].append('Function field must be a string')
return error_response
function_name = event[FUNCTION]
if function_name not in SUPPORTED_FUNCTIONS:
error_response[ERRORS].append(f'Invalid function name: {function_name}')
return error_response
# Validate the remaining event fields based on the selected function
if function_name == SPAM_OR_HAM:
error_response[FUNCTION] = SPAM_OR_HAM
# Validate that the event contains the required fields for the selected function,
# that the fields are of the correct type, and that they are not empty.
# This includes verifying that no additional fields are present
field_errors = False
# Validate the Messages field
if MESSAGES not in event:
field_errors = True
error_response[ERRORS].append('Messages field is required')
else:
if not isinstance(event[MESSAGES], list):
field_errors = True
error_response[ERRORS].append('Messages field must be a list')
elif len(event[MESSAGES]) == 0:
field_errors = True
error_response[ERRORS].append('Messages field must contain at least one message')
else:
messages = event[MESSAGES]
for message in messages:
if not isinstance(message, str):
field_errors = True
error_response[ERRORS].append('Messages field contains an invalid message; all messages must be strings')
break
# Verify that no additional fields are present
for key in event:
if key not in SPAM_OR_HAM_FIELDS:
field_errors = True
error_response[ERRORS].append(f'Invalid field in event: {key}')
# If any field validations have failed, return the error response
if field_errors:
return error_response
return None
def spam_or_ham(model, messages, encoded_messages):
""" Determine if the messages are spam or ham
Args:
model (keras.models.Sequential): A trained deep learning model
messages (list): A list of messages to classify
encoded_messages (numpy.ndarray): A 2D array of encoded messages
Returns:
dict: A JSON object containing the status code with function name and
response data for successful queries or error information if the
request fails.
"""
# Make predictions using the model
raw_predictions = model.predict(encoded_messages) # % probability for each class, per message
predictions = argmax(raw_predictions, axis=1)
# Combine the messages and the resulting predictions, converting the predictions
# into the final predicted classes (0 = ham, 1 = spam).
responses = {}
for i in range(len(messages)):
if predictions[i] == 0:
responses[messages[i]] = 'ham'
else:
responses[messages[i]] = 'spam'
# Return the responses as a JSON object
return {
STATUS_CODE: STATUS_CODE_SUCCESS,
FUNCTION: SPAM_OR_HAM,
RESPONSES: responses,
ERRORS: []
}
def main():
""" Main function for running the AWS Lambda function locally """
parser = argparse.ArgumentParser(description='Run the AWS Lambda function locally')
parser.add_argument('--event', type=str, help='JSON string representing the event object')
args = parser.parse_args()
if args.event:
# Parse the JSON string from command line argument
try:
event = json.loads(args.event)
except Exception as e:
return {
STATUS_CODE: STATUS_CODE_BAD_REQUEST,
FUNCTION: EMPTY_STRING,
RESPONSES: EMPTY_DICT,
ERRORS: [f'Unable to parse event object: {str(e)}']
}
return lambda_handler(event)
else:
return {
STATUS_CODE: STATUS_CODE_BAD_REQUEST,
FUNCTION: EMPTY_STRING,
RESPONSES: EMPTY_DICT,
ERRORS: [f'No JSON-encoded event object provided. Use the --event argument to provide an event object']
}
if __name__ == '__main__':
result = main()
print(result)