From 17b66f20510b992576f82d8c70ee08affc13b8a0 Mon Sep 17 00:00:00 2001 From: charlene tau express Date: Fri, 14 Mar 2025 15:33:31 +0800 Subject: [PATCH] rough working recommend endpoint --- requirements.txt | 16 ++++++++ src/app.py | 51 +++++++++++++++++++++++- src/qdrant_services/qdrant.py | 27 +++++++------ src/qdrant_services/settings.py | 8 +++- src/services/embedding_service.py | 6 +-- src/services/recommend_service.py | 66 +++++++++++++++++++++++++++++++ 6 files changed, 155 insertions(+), 19 deletions(-) diff --git a/requirements.txt b/requirements.txt index 98fd90f..c7eae37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ +annotated-types==0.7.0 +anyio==4.8.0 blinker==1.9.0 certifi==2025.1.31 charset-normalizer==3.4.1 @@ -5,7 +7,15 @@ click==8.1.8 filelock==3.17.0 Flask==3.1.0 fsspec==2025.3.0 +grpcio==1.71.0 +grpcio-tools==1.71.0 +h11==0.14.0 +h2==4.2.0 +hpack==4.1.0 +httpcore==1.0.7 +httpx==0.28.1 huggingface-hub==0.29.2 +hyperframe==6.1.0 idna==3.10 itsdangerous==2.2.0 Jinja2==3.1.6 @@ -29,12 +39,18 @@ nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127 packaging==24.2 +portalocker==2.10.1 +protobuf==5.29.3 +pydantic==2.10.6 +pydantic_core==2.27.2 python-dotenv==1.0.1 PyYAML==6.0.2 +qdrant-client==1.13.3 regex==2024.11.6 requests==2.32.3 safetensors==0.5.3 setuptools==76.0.0 +sniffio==1.3.1 sympy==1.13.1 tokenizers==0.21.0 torch==2.6.0 diff --git a/src/app.py b/src/app.py index cd1ad44..1ffa5ff 100644 --- a/src/app.py +++ b/src/app.py @@ -1,7 +1,11 @@ from flask import Flask, jsonify, request from dotenv import load_dotenv from services.embedding_service import Embedding +import logging +import json +from qdrant_client.http import models +logger = logging.getLogger(__name__) load_dotenv() app = Flask(__name__) @@ -15,6 +19,51 @@ def hello(): return jsonify({"message":"Hello, World!"}) +@app.get("/recommend") +def recommend(): + try: + + from services.recommend_service import Recommend + clause = request.args.get("clause", '') + # filter shall be a json array. each element is a dict with following keys: + # key: field name + # value: field value + # type : operator type, default (and the only supported one for now) is "match" + # example: [{"key": "meta.source", "value": "km", "type": "match"}] + # please note we limit the number of filters to 1 for now + + filter_ = request.args.get("filter", None) + logger.info(f"recommend request received: [{clause}] with filter [{filter_}]") + + updated_clause = clause.replace('"', '') + + + if filter_: + filter_ = json.loads(filter_) + filter_ = filter_[0] + if filter_['type'] != 'match': + raise ValueError('unsupported filter type') + f = models.Filter( + must=[ + models.FieldCondition( + key=filter_['key'], match=models.MatchValue(value=filter_['value']) + ), + ], + ) + filter_ = f + + result = Recommend.call(updated_clause, filter=filter_) + if result is None: + result = {"data": []} + else: + result = {"data": [i.dict() for i in result]} + logger.info("recommender returns %d results" % len(result["data"])) + return jsonify(result), 200 + except Exception as e: + logger.error('recommend error: %s', type(e).__name__, exc_info=True) + return jsonify({"error":"Internal Server Error"}), 500 + + @app.post("/get_embedding") def get_embedding(): # Get the query text from request JSON @@ -29,7 +78,7 @@ def get_embedding(): return jsonify({"query": query, "number of embedding": len(vector), "one embedding":len(vector[0])}) -print(__name__) + if __name__=="__main__": app.run(debug=True, port=8000) \ No newline at end of file diff --git a/src/qdrant_services/qdrant.py b/src/qdrant_services/qdrant.py index d6c2dd6..5fc8d6c 100644 --- a/src/qdrant_services/qdrant.py +++ b/src/qdrant_services/qdrant.py @@ -5,7 +5,7 @@ from functools import wraps from qdrant_client import QdrantClient -from qdrant_services.settings import settings +from qdrant_services.settings import QDRANT_HOST, QDRANT_PORT, QDRANT_HTTPS, QDRANT_SSL_VERIFY,QDRANT_API_KEY, QDRANT_CLUSTER, QDRANT_COLLECTION_NAME from services.service import Singleton from qdrant_services.policy import QdrantClientSelectionPolicy @@ -25,16 +25,16 @@ class Qdrant(metaclass=Singleton): ): self.collection_name = collection_name or self.__class__.get_default_collection_name() - if settings.QDRANT_HTTPS is not None and kwargs.get('https') is None: - kwargs['https'] = settings.QDRANT_HTTPS - if settings.QDRANT_SSL_VERIFY is not None and kwargs.get('verify') is None: - kwargs['verify'] = settings.QDRANT_SSL_VERIFY - if settings.QDRANT_API_KEY is not None and kwargs.get('api_key') is None: - kwargs['api_key'] = settings.QDRANT_API_KEY + if QDRANT_HTTPS is not None and kwargs.get('https') is None: + kwargs['https'] = QDRANT_HTTPS + if QDRANT_SSL_VERIFY is not None and kwargs.get('verify') is None: + kwargs['verify'] = QDRANT_SSL_VERIFY + if QDRANT_API_KEY is not None and kwargs.get('api_key') is None: + kwargs['api_key'] = QDRANT_API_KEY - host = host or settings.QDRANT_HOST - port = port or settings.QDRANT_PORT - cluster = cluster or settings.QDRANT_CLUSTER + host = host or QDRANT_HOST + port = port or QDRANT_PORT + cluster = cluster or QDRANT_CLUSTER if cluster: self.logger.info( "Connecting to qdrant cluster: cluster: %s, collection_name: %s" % (cluster, self.collection_name)) @@ -82,13 +82,16 @@ class Qdrant(metaclass=Singleton): @classmethod def get_default_collection_name(cls) -> str: - return settings.QDRANT_COLLECTION_NAME + return QDRANT_COLLECTION_NAME class QdrantClientProxy(QdrantClient): def __init__(self, outer_instance, *args, **kwargs): self.outer_instance = outer_instance - super().__init__(*args, **kwargs) + super().__init__(url="https://192.168.99.122:6333", + api_key="NS00TXlKUUIweUhWaGFuUUpUVTk6bWNVWXI1VXRSN2VWcFRtaEZ6NmdCUQ==", + https=True, + verify=False) _excluded_methods = ['_dynamic_call_decorator'] diff --git a/src/qdrant_services/settings.py b/src/qdrant_services/settings.py index d033e22..48d6653 100644 --- a/src/qdrant_services/settings.py +++ b/src/qdrant_services/settings.py @@ -1,2 +1,8 @@ QDRANT_HOST = "192.168.99.122" -QDRANT_PORT = 6333 \ No newline at end of file +QDRANT_PORT = 6333 +QDRANT_HTTPS='' +QDRANT_SSL_VERIFY='' +QDRANT_API_KEY='' +QDRANT_CLUSTER='' +QDRANT_COLLECTION_NAME='titan-2502110203' + diff --git a/src/services/embedding_service.py b/src/services/embedding_service.py index acd348a..ff85014 100644 --- a/src/services/embedding_service.py +++ b/src/services/embedding_service.py @@ -80,11 +80,7 @@ class Embedding(Component): clause_segments[-1] = clause_segments[-1] + ' ' + temp_str else: clause_segments.append(temp_str) - # print("in embedding_service line 82, clause_segments: ", clause_segments) - # print("clause_segments len: ", len(clause_segments)) - # print(len(clause_segments[0])) - # print(len(clause_segments[1])) - # print(len(clause_segments[2])) + clause_vectors = self.embedding_sentences(clause_segments) for i in range(len(clause_segments)): if len(clause_segments[i]) < segment_threshold: diff --git a/src/services/recommend_service.py b/src/services/recommend_service.py index e69de29..bb9217b 100644 --- a/src/services/recommend_service.py +++ b/src/services/recommend_service.py @@ -0,0 +1,66 @@ +import logging + +import numpy as np +from qdrant_client.http import models + +from services.service import Component +from services.embedding_service import Embedding +from qdrant_services.qdrant import Qdrant + +logger = logging.getLogger(__name__) + + +class Recommend(Component): + def process(self, *args, **kwargs): + return self.do_query_multiplier_new(*args, **kwargs) + + def do_query_multiplier_new(self, query_clause, filter = None): + top = 10 + + input_query_clause = query_clause.replace('"', '') + input_query_clause = clean_recommended_clause(input_query_clause) + # can perform any times of queries + ########################################### + # note: even set top_k = 10 but system may return < 10, because some clauses have more than 1 vectors + # top_k = 10 : imply return 10 vectors. if a clause has 2 vectors, 1 clause returned only + + return recommend_clauses(input_query_clause, top_k=top * 7, filter=filter); + + +def clean_recommended_clause(clause): + m_clause = clause.replace('\t', ' ') + m_clause = m_clause.strip() + m_clause_list = m_clause.split(' ') + result = "" + for clause_part in m_clause_list: + if len(clause_part) == 0: + continue + result += clause_part + " " + result = result.strip() + return result + + +def recommend_clauses(query_str, top_k=10, filter=None): + # query_vector = embedding_a_clause(query_str, is_query=True) + query_vector = Embedding.call(query_str, is_query=True) + query_vector = np.array(query_vector).astype(np.float32) + + search_queries = [ + models.SearchRequest( + vector=vector.tolist(), + filter = filter, + limit=top_k, + with_payload=True, + ) + for vector in query_vector + ] + logger.info(f"search_queries: {search_queries}") + results = Qdrant.get_client().search_batch(requests=search_queries) + results = [i for j in results for i in j] # flatten + results = sorted(results, key=lambda x: x.score) # sort by score, ascending + # as this is a batch search, same record may appear in different search results + # remove duplicates by id (keep the highest score for each record) + results = list({v.id: v for v in results}.values()) + results.reverse() # sort by score, descending + results = results[:top_k] + return results