| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- import json
- import psycopg2
- import requests
- from datetime import datetime
- from flask import Flask, request, jsonify
- from loguru import logger
- from psycopg2.extras import DictCursor
- log_month = datetime.now().strftime("%Y-%m")
- logger.add(f"serve_log/{log_month}/log_{datetime.now().strftime('%Y-%m-%d')}.log", format="{time} {level} {message}",
- compression=None)
- app = Flask(__file__)
- DB_NAME = "dify"
- DB_USER = "postgres"
- DB_PASSWORD = "difyai123456"
- DB_HOST = "127.0.0.1"
- DB_PORT = "5432"
- SERVER_ADDR = '127.0.0.1:9998/console/api'
- con = psycopg2.connect(database=DB_NAME, user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT)
- cur = con.cursor(cursor_factory=DictCursor)
- dataset_cache = {}
- document_cache = {}
- def get_dataset_info(dataset_id):
- if dataset_id in dataset_cache:
- return dataset_cache[dataset_id]
- sql = f'''SELECT * FROM datasets WHERE "id" = '{dataset_id}';'''
- cur.execute(sql)
- dataset_info = cur.fetchone()
- dataset_cache[dataset_id] = dataset_info
- logger.info(f'dataset_info:{dataset_info}')
- return dataset_info
- def get_document_info(document_id):
- if document_id in document_cache:
- return document_cache[document_id]
- sql = f'''SELECT * FROM documents WHERE "id" = '{document_id}';'''
- cur.execute(sql)
- doc_info = cur.fetchone()
- document_cache[document_id] = doc_info
- return doc_info
- @app.route('/mult_dataset_invoke', methods=['post']) # 根路径
- def mult_dataset_invoke():
- params = request.get_json()
- logger.info(f'recv param:{params}')
-
- try:
- headers = {'Authorization': 'Bearer PlatformPassage'}
- total_list = {}
- for dataset_id in params['dataset_list']:
- dataset_info = get_dataset_info(dataset_id)
- if not dataset_info or not dataset_info[-5]:
- continue
- payload = {'query': params['query'],
- 'retrieval_model': {"top_k": params['top_k'],
- "weights": {"weight_type": "customized",
- "vector_setting": {"vector_weight": params['vector_weight'],
- "embedding_model_name": dataset_info[-5],
- "embedding_provider_name": dataset_info[-4]},
- "keyword_setting": {"keyword_weight": params['keyword_weight']}},
- "search_method": "hybrid_search",
- "reranking_mode": params['reranking_mode'],
- "reranking_model": {"reranking_model_name": params['reranking_model_name'],
- "reranking_provider_name": params[
- 'reranking_provider_name']},
- "score_threshold": params['score'],
- "reranking_enable": True,
- "score_threshold_enabled": False}}
- response = requests.request("POST",
- f"http://{SERVER_ADDR}/datasets/{dataset_id}/hit-testing",
- headers=headers, json=payload)
- logger.info(response.json())
- if not response.json().get('records'):
- continue
- for segment in response.json()['records']:
- doc_info = get_document_info(segment['segment']['document_id'])
- total_list[segment['score']] = {"metadata": {
- "_source": "knowledge",
- "dataset_id": dataset_id,
- "dataset_name": dataset_info[2],
- "document_id": segment['segment']['document_id'],
- "document_name": doc_info[8],
- "document_data_source_type": doc_info[4],
- "segment_id": segment['segment']['index_node_id'],
- "retriever_from": "workflow",
- "score": segment['score'],
- "segment_hit_count": segment['segment']['hit_count'],
- "segment_word_count": segment['segment']['word_count'],
- "segment_position": segment['segment']['position'],
- "segment_index_node_hash": segment['segment']['index_node_hash'],
- "doc_metadata": {
- "source": doc_info[4],
- "uploader": doc_info[10],
- "upload_date": int(doc_info[12].timestamp()),
- "document_name": doc_info[8],
- "last_update_date": int(doc_info[-5].timestamp())
- },
- "position": 1
- },
- "title": doc_info[8],
- "content": segment['segment']['content']
- }
- if params.get('top_k'):
- sorted_dict = dict(sorted(total_list.items(), reverse=True))
- result_list = list(sorted_dict.values())
- if len(result_list) < params.get('top_k'):
- return jsonify(result_list)
- else:
- return jsonify(result_list[:params.get('top_k')])
- except Exception as e:
- logger.exception(e)
- return {'result': e.args}
- if __name__ == '__main__':
- app.run(host='0.0.0.0', port=8803)
|