import json import psycopg2 import requests from datetime import datetime from flask import Flask, request, jsonify from loguru import logger from psycopg2.extras import DictCursor log_month = datetime.now().strftime("%Y-%m") logger.add(f"serve_log/{log_month}/log_{datetime.now().strftime('%Y-%m-%d')}.log", format="{time} {level} {message}", compression=None) app = Flask(__file__) DB_NAME = "dify" DB_USER = "postgres" DB_PASSWORD = "difyai123456" DB_HOST = "127.0.0.1" DB_PORT = "5432" SERVER_ADDR = '127.0.0.1:9998/console/api' con = psycopg2.connect(database=DB_NAME, user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT) cur = con.cursor(cursor_factory=DictCursor) dataset_cache = {} document_cache = {} def get_dataset_info(dataset_id): if dataset_id in dataset_cache: return dataset_cache[dataset_id] sql = f'''SELECT * FROM datasets WHERE "id" = '{dataset_id}';''' cur.execute(sql) dataset_info = cur.fetchone() dataset_cache[dataset_id] = dataset_info logger.info(f'dataset_info:{dataset_info}') return dataset_info def get_document_info(document_id): if document_id in document_cache: return document_cache[document_id] sql = f'''SELECT * FROM documents WHERE "id" = '{document_id}';''' cur.execute(sql) doc_info = cur.fetchone() document_cache[document_id] = doc_info return doc_info @app.route('/mult_dataset_invoke', methods=['post']) # 根路径 def mult_dataset_invoke(): params = request.get_json() logger.info(f'recv param:{params}') try: headers = {'Authorization': 'Bearer PlatformPassage'} total_list = {} for dataset_id in params['dataset_list']: dataset_info = get_dataset_info(dataset_id) if not dataset_info or not dataset_info[-5]: continue payload = {'query': params['query'], 'retrieval_model': {"top_k": params['top_k'], "weights": {"weight_type": "customized", "vector_setting": {"vector_weight": params['vector_weight'], "embedding_model_name": dataset_info[-5], "embedding_provider_name": dataset_info[-4]}, "keyword_setting": {"keyword_weight": params['keyword_weight']}}, "search_method": "hybrid_search", "reranking_mode": params['reranking_mode'], "reranking_model": {"reranking_model_name": params['reranking_model_name'], "reranking_provider_name": params[ 'reranking_provider_name']}, "score_threshold": params['score'], "reranking_enable": True, "score_threshold_enabled": False}} response = requests.request("POST", f"http://{SERVER_ADDR}/datasets/{dataset_id}/hit-testing", headers=headers, json=payload) logger.info(response.json()) if not response.json().get('records'): continue for segment in response.json()['records']: doc_info = get_document_info(segment['segment']['document_id']) total_list[segment['score']] = {"metadata": { "_source": "knowledge", "dataset_id": dataset_id, "dataset_name": dataset_info[2], "document_id": segment['segment']['document_id'], "document_name": doc_info[8], "document_data_source_type": doc_info[4], "segment_id": segment['segment']['index_node_id'], "retriever_from": "workflow", "score": segment['score'], "segment_hit_count": segment['segment']['hit_count'], "segment_word_count": segment['segment']['word_count'], "segment_position": segment['segment']['position'], "segment_index_node_hash": segment['segment']['index_node_hash'], "doc_metadata": { "source": doc_info[4], "uploader": doc_info[10], "upload_date": int(doc_info[12].timestamp()), "document_name": doc_info[8], "last_update_date": int(doc_info[-5].timestamp()) }, "position": 1 }, "title": doc_info[8], "content": segment['segment']['content'] } if params.get('top_k'): sorted_dict = dict(sorted(total_list.items(), reverse=True)) result_list = list(sorted_dict.values()) if len(result_list) < params.get('top_k'): return jsonify(result_list) else: return jsonify(result_list[:params.get('top_k')]) except Exception as e: logger.exception(e) return {'result': e.args} if __name__ == '__main__': app.run(host='0.0.0.0', port=8803)