invoke_server.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import json
  2. import psycopg2
  3. import requests
  4. from datetime import datetime
  5. from flask import Flask, request, jsonify
  6. from loguru import logger
  7. from psycopg2.extras import DictCursor
  8. log_month = datetime.now().strftime("%Y-%m")
  9. logger.add(f"serve_log/{log_month}/log_{datetime.now().strftime('%Y-%m-%d')}.log", format="{time} {level} {message}",
  10. compression=None)
  11. app = Flask(__file__)
  12. DB_NAME = "dify"
  13. DB_USER = "postgres"
  14. DB_PASSWORD = "difyai123456"
  15. DB_HOST = "127.0.0.1"
  16. DB_PORT = "5432"
  17. SERVER_ADDR = '127.0.0.1:9998/console/api'
  18. con = psycopg2.connect(database=DB_NAME, user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT)
  19. cur = con.cursor(cursor_factory=DictCursor)
  20. dataset_cache = {}
  21. document_cache = {}
  22. def get_dataset_info(dataset_id):
  23. if dataset_id in dataset_cache:
  24. return dataset_cache[dataset_id]
  25. sql = f'''SELECT * FROM datasets WHERE "id" = '{dataset_id}';'''
  26. cur.execute(sql)
  27. dataset_info = cur.fetchone()
  28. dataset_cache[dataset_id] = dataset_info
  29. logger.info(f'dataset_info:{dataset_info}')
  30. return dataset_info
  31. def get_document_info(document_id):
  32. if document_id in document_cache:
  33. return document_cache[document_id]
  34. sql = f'''SELECT * FROM documents WHERE "id" = '{document_id}';'''
  35. cur.execute(sql)
  36. doc_info = cur.fetchone()
  37. document_cache[document_id] = doc_info
  38. return doc_info
  39. @app.route('/mult_dataset_invoke', methods=['post']) # 根路径
  40. def mult_dataset_invoke():
  41. params = request.get_json()
  42. logger.info(f'recv param:{params}')
  43. try:
  44. headers = {'Authorization': 'Bearer PlatformPassage'}
  45. total_list = {}
  46. for dataset_id in params['dataset_list']:
  47. dataset_info = get_dataset_info(dataset_id)
  48. if not dataset_info or not dataset_info[-5]:
  49. continue
  50. payload = {'query': params['query'],
  51. 'retrieval_model': {"top_k": params['top_k'],
  52. "weights": {"weight_type": "customized",
  53. "vector_setting": {"vector_weight": params['vector_weight'],
  54. "embedding_model_name": dataset_info[-5],
  55. "embedding_provider_name": dataset_info[-4]},
  56. "keyword_setting": {"keyword_weight": params['keyword_weight']}},
  57. "search_method": "hybrid_search",
  58. "reranking_mode": params['reranking_mode'],
  59. "reranking_model": {"reranking_model_name": params['reranking_model_name'],
  60. "reranking_provider_name": params[
  61. 'reranking_provider_name']},
  62. "score_threshold": params['score'],
  63. "reranking_enable": True,
  64. "score_threshold_enabled": False}}
  65. response = requests.request("POST",
  66. f"http://{SERVER_ADDR}/datasets/{dataset_id}/hit-testing",
  67. headers=headers, json=payload)
  68. logger.info(response.json())
  69. if not response.json().get('records'):
  70. continue
  71. for segment in response.json()['records']:
  72. doc_info = get_document_info(segment['segment']['document_id'])
  73. total_list[segment['score']] = {"metadata": {
  74. "_source": "knowledge",
  75. "dataset_id": dataset_id,
  76. "dataset_name": dataset_info[2],
  77. "document_id": segment['segment']['document_id'],
  78. "document_name": doc_info[8],
  79. "document_data_source_type": doc_info[4],
  80. "segment_id": segment['segment']['index_node_id'],
  81. "retriever_from": "workflow",
  82. "score": segment['score'],
  83. "segment_hit_count": segment['segment']['hit_count'],
  84. "segment_word_count": segment['segment']['word_count'],
  85. "segment_position": segment['segment']['position'],
  86. "segment_index_node_hash": segment['segment']['index_node_hash'],
  87. "doc_metadata": {
  88. "source": doc_info[4],
  89. "uploader": doc_info[10],
  90. "upload_date": int(doc_info[12].timestamp()),
  91. "document_name": doc_info[8],
  92. "last_update_date": int(doc_info[-5].timestamp())
  93. },
  94. "position": 1
  95. },
  96. "title": doc_info[8],
  97. "content": segment['segment']['content']
  98. }
  99. if params.get('top_k'):
  100. sorted_dict = dict(sorted(total_list.items(), reverse=True))
  101. result_list = list(sorted_dict.values())
  102. if len(result_list) < params.get('top_k'):
  103. return jsonify(result_list)
  104. else:
  105. return jsonify(result_list[:params.get('top_k')])
  106. except Exception as e:
  107. logger.exception(e)
  108. return {'result': e.args}
  109. if __name__ == '__main__':
  110. app.run(host='0.0.0.0', port=8803)