9
9
from requests_aws4auth import AWS4Auth
10
10
11
11
from filter_builder import *
12
+ from dashboard import *
12
13
13
14
#Global variables for prod
14
15
region = environ ['MY_AWS_REGION' ]
15
16
aos_host = environ ['OS_ENDPOINT' ]
16
17
sagemaker_endpoint = environ ['SAGEMAKER_ENDPOINT' ]
17
18
os_secret_id = environ ['OS_SECRET_ID' ]
18
19
model_name = environ ['MODEL_NAME' ]
20
+ search_index_name = environ ['NEW_INDEX_NAME' ]
19
21
20
22
def get_awsauth_from_secret (region , secret_id ):
21
23
"""
@@ -57,7 +59,7 @@ def invoke_sagemaker_endpoint(sagemaker_endpoint, payload, region):
57
59
print (f"Error invoking SageMaker endpoint { sagemaker_endpoint } : { e } " )
58
60
59
61
60
- def semantic_search_neighbors (lang , search_text , features , os_client , sort_param , k_neighbors = 30 , from_param = 0 , idx_name = model_name , filters = None , size = 10 ):
62
+ def semantic_search_neighbors (lang , search_text , features , os_client , sort_param , k_neighbors = 50 , from_param = 0 , idx_name = model_name , filters = None , size = 10 ):
61
63
"""
62
64
Perform semantic search and get neighbots using the cosine similarity of the vectors
63
65
output: a list of json, each json contains _id, _score, title, and uuid
@@ -97,6 +99,8 @@ def semantic_search_neighbors(lang, search_text, features, os_client, sort_param
97
99
request_timeout = 55 ,
98
100
index = idx_name ,
99
101
body = query )
102
+
103
+ print (res )
100
104
101
105
# # Return a dataframe of the searched results, including title and uuid
102
106
# query_result = [
@@ -159,7 +163,7 @@ def add_to_top_of_dict(original_dict, key, value):
159
163
new_dict .update (original_dict )
160
164
161
165
return new_dict
162
-
166
+ """
163
167
def create_api_response(search_results):
164
168
response = {
165
169
"total_hits": len(search_results['hits']['hits']),
@@ -176,7 +180,7 @@ def create_api_response(search_results):
176
180
except Exception as e:
177
181
print(f"Error processing hit: {e}")
178
182
return response
179
-
183
+ """
180
184
def create_api_response_geojson (search_results , lang ):
181
185
182
186
total_hits = search_results ['hits' ]['total' ]['value' ] if 'total' in search_results ['hits' ] else 0
@@ -256,7 +260,12 @@ def lambda_handler(event, context):
256
260
/postText: Uses semantic search to find similar records based on vector similarity.
257
261
Other paths: Uses a direct keyword text match to find matched records .
258
262
"""
259
- awsauth = get_awsauth_from_secret (region , secret_id = os_secret_id )
263
+ #awsauth = get_awsauth_from_secret(region, secret_id=os_secret_id)
264
+ #print(awsauth)
265
+
266
+ credentials = boto3 .Session ().get_credentials ()
267
+ awsauth = AWS4Auth (credentials .access_key , credentials .secret_key , region , 'es' , session_token = credentials .token )
268
+
260
269
os_client = OpenSearch (
261
270
hosts = [{'host' : aos_host , 'port' : 443 }],
262
271
http_auth = awsauth ,
@@ -268,10 +277,12 @@ def lambda_handler(event, context):
268
277
#print(event)
269
278
270
279
k = 10
271
- payload = event [ 'q' ]
280
+ payload = event . get ( 'q' , '' ) or ''
272
281
273
282
# Debug event
274
283
#print("event", event)
284
+
285
+
275
286
276
287
filter_config = load_config ()
277
288
@@ -348,12 +359,55 @@ def lambda_handler(event, context):
348
359
filters .append (build_spatial_filter (spatial_field , spatial_filter , relation ))
349
360
350
361
# Sort param
351
- sort_param = build_sort_filter (sort_field = sort_param , sort_order = order_param )
362
+ sort_param_final = build_sort_filter (sort_field = sort_param , sort_order = order_param )
352
363
353
364
# If no filters are specified, set filters to None
354
365
filters = filters if filters else None
355
366
356
367
#print("filters : ", filters)
368
+
369
+ ####
370
+ #OpenSearch DashBoard code
371
+ ####
372
+ ip_address = event .get ('ip_address' , '' ) or ''
373
+ timestamp = event .get ('timestamp' , '' ) or ''
374
+ user_agent = event .get ('user_agent' , '' ) or ''
375
+ http_method = event .get ('http_method' , '' ) or ''
376
+
377
+ create_opensearch_index (os_client , search_index_name )
378
+ ip2geo_data = {}
379
+ ip2geo_data = ip2geo_handler (os_client , ip_address )
380
+ document = [
381
+ {
382
+ "timestamp" : timestamp ,
383
+ "lang" : lang_filter ,
384
+ "q" : payload ,
385
+ "ip_address" : ip_address ,
386
+ "user_agent" : user_agent ,
387
+ "http_method" : http_method ,
388
+ "sort_param" : sort_param ,
389
+ "order_param" : order_param ,
390
+ "organization_filter" : organization_filter ,
391
+ "metadata_source_filter" : metadata_source_filter ,
392
+ "theme_filter" : theme_filter ,
393
+ "type_filter" : type_filter ,
394
+ #"start_date_filter": start_date_filter,
395
+ #"end_date_filter": end_date_filter,
396
+ #"spatial_filter": spatial_filter,
397
+ "relation" : relation ,
398
+ "size" : size ,
399
+ "ip2geo" : ip2geo_data
400
+ }
401
+ ]
402
+
403
+ print (f"Document to be indexed: { document } " )
404
+
405
+ save_to_opensearch (os_client , search_index_name , document )
406
+
407
+ ### End of OpenSearch DashBoard code
408
+
409
+ if event ['method' ] == 'postText' :
410
+ payload = json .loads (event ['body' ])['text' ]
357
411
358
412
if event ['method' ] == 'SemanticSearch' :
359
413
#print(f'This is payload {payload}')
@@ -367,11 +421,11 @@ def lambda_handler(event, context):
367
421
search_text = payload ,
368
422
features = features ,
369
423
os_client = os_client ,
370
- k_neighbors = 10 ,
424
+ k_neighbors = k ,
371
425
from_param = from_param ,
372
426
idx_name = model_name ,
373
427
filters = filters ,
374
- sort_param = sort_param ,
428
+ sort_param = sort_param_final ,
375
429
size = size
376
430
)
377
431
0 commit comments