rough working recommend endpoint

This commit is contained in:
charlene tau express 2025-03-14 15:33:31 +08:00
parent 1828ae5693
commit 17b66f2051
6 changed files with 155 additions and 19 deletions

View File

@ -1,3 +1,5 @@
annotated-types==0.7.0
anyio==4.8.0
blinker==1.9.0
certifi==2025.1.31
charset-normalizer==3.4.1
@ -5,7 +7,15 @@ click==8.1.8
filelock==3.17.0
Flask==3.1.0
fsspec==2025.3.0
grpcio==1.71.0
grpcio-tools==1.71.0
h11==0.14.0
h2==4.2.0
hpack==4.1.0
httpcore==1.0.7
httpx==0.28.1
huggingface-hub==0.29.2
hyperframe==6.1.0
idna==3.10
itsdangerous==2.2.0
Jinja2==3.1.6
@ -29,12 +39,18 @@ nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
packaging==24.2
portalocker==2.10.1
protobuf==5.29.3
pydantic==2.10.6
pydantic_core==2.27.2
python-dotenv==1.0.1
PyYAML==6.0.2
qdrant-client==1.13.3
regex==2024.11.6
requests==2.32.3
safetensors==0.5.3
setuptools==76.0.0
sniffio==1.3.1
sympy==1.13.1
tokenizers==0.21.0
torch==2.6.0

View File

@ -1,7 +1,11 @@
from flask import Flask, jsonify, request
from dotenv import load_dotenv
from services.embedding_service import Embedding
import logging
import json
from qdrant_client.http import models
logger = logging.getLogger(__name__)
load_dotenv()
app = Flask(__name__)
@ -15,6 +19,51 @@ def hello():
return jsonify({"message":"Hello, World!"})
@app.get("/recommend")
def recommend():
try:
from services.recommend_service import Recommend
clause = request.args.get("clause", '')
# filter shall be a json array. each element is a dict with following keys:
# key: field name
# value: field value
# type : operator type, default (and the only supported one for now) is "match"
# example: [{"key": "meta.source", "value": "km", "type": "match"}]
# please note we limit the number of filters to 1 for now
filter_ = request.args.get("filter", None)
logger.info(f"recommend request received: [{clause}] with filter [{filter_}]")
updated_clause = clause.replace('"', '')
if filter_:
filter_ = json.loads(filter_)
filter_ = filter_[0]
if filter_['type'] != 'match':
raise ValueError('unsupported filter type')
f = models.Filter(
must=[
models.FieldCondition(
key=filter_['key'], match=models.MatchValue(value=filter_['value'])
),
],
)
filter_ = f
result = Recommend.call(updated_clause, filter=filter_)
if result is None:
result = {"data": []}
else:
result = {"data": [i.dict() for i in result]}
logger.info("recommender returns %d results" % len(result["data"]))
return jsonify(result), 200
except Exception as e:
logger.error('recommend error: %s', type(e).__name__, exc_info=True)
return jsonify({"error":"Internal Server Error"}), 500
@app.post("/get_embedding")
def get_embedding():
# Get the query text from request JSON
@ -29,7 +78,7 @@ def get_embedding():
return jsonify({"query": query, "number of embedding": len(vector), "one embedding":len(vector[0])})
print(__name__)
if __name__=="__main__":
app.run(debug=True, port=8000)

View File

@ -5,7 +5,7 @@ from functools import wraps
from qdrant_client import QdrantClient
from qdrant_services.settings import settings
from qdrant_services.settings import QDRANT_HOST, QDRANT_PORT, QDRANT_HTTPS, QDRANT_SSL_VERIFY,QDRANT_API_KEY, QDRANT_CLUSTER, QDRANT_COLLECTION_NAME
from services.service import Singleton
from qdrant_services.policy import QdrantClientSelectionPolicy
@ -25,16 +25,16 @@ class Qdrant(metaclass=Singleton):
):
self.collection_name = collection_name or self.__class__.get_default_collection_name()
if settings.QDRANT_HTTPS is not None and kwargs.get('https') is None:
kwargs['https'] = settings.QDRANT_HTTPS
if settings.QDRANT_SSL_VERIFY is not None and kwargs.get('verify') is None:
kwargs['verify'] = settings.QDRANT_SSL_VERIFY
if settings.QDRANT_API_KEY is not None and kwargs.get('api_key') is None:
kwargs['api_key'] = settings.QDRANT_API_KEY
if QDRANT_HTTPS is not None and kwargs.get('https') is None:
kwargs['https'] = QDRANT_HTTPS
if QDRANT_SSL_VERIFY is not None and kwargs.get('verify') is None:
kwargs['verify'] = QDRANT_SSL_VERIFY
if QDRANT_API_KEY is not None and kwargs.get('api_key') is None:
kwargs['api_key'] = QDRANT_API_KEY
host = host or settings.QDRANT_HOST
port = port or settings.QDRANT_PORT
cluster = cluster or settings.QDRANT_CLUSTER
host = host or QDRANT_HOST
port = port or QDRANT_PORT
cluster = cluster or QDRANT_CLUSTER
if cluster:
self.logger.info(
"Connecting to qdrant cluster: cluster: %s, collection_name: %s" % (cluster, self.collection_name))
@ -82,13 +82,16 @@ class Qdrant(metaclass=Singleton):
@classmethod
def get_default_collection_name(cls) -> str:
return settings.QDRANT_COLLECTION_NAME
return QDRANT_COLLECTION_NAME
class QdrantClientProxy(QdrantClient):
def __init__(self, outer_instance, *args, **kwargs):
self.outer_instance = outer_instance
super().__init__(*args, **kwargs)
super().__init__(url="https://192.168.99.122:6333",
api_key="NS00TXlKUUIweUhWaGFuUUpUVTk6bWNVWXI1VXRSN2VWcFRtaEZ6NmdCUQ==",
https=True,
verify=False)
_excluded_methods = ['_dynamic_call_decorator']

View File

@ -1,2 +1,8 @@
QDRANT_HOST = "192.168.99.122"
QDRANT_PORT = 6333
QDRANT_PORT = 6333
QDRANT_HTTPS=''
QDRANT_SSL_VERIFY=''
QDRANT_API_KEY=''
QDRANT_CLUSTER=''
QDRANT_COLLECTION_NAME='titan-2502110203'

View File

@ -80,11 +80,7 @@ class Embedding(Component):
clause_segments[-1] = clause_segments[-1] + ' ' + temp_str
else:
clause_segments.append(temp_str)
# print("in embedding_service line 82, clause_segments: ", clause_segments)
# print("clause_segments len: ", len(clause_segments))
# print(len(clause_segments[0]))
# print(len(clause_segments[1]))
# print(len(clause_segments[2]))
clause_vectors = self.embedding_sentences(clause_segments)
for i in range(len(clause_segments)):
if len(clause_segments[i]) < segment_threshold:

View File

@ -0,0 +1,66 @@
import logging
import numpy as np
from qdrant_client.http import models
from services.service import Component
from services.embedding_service import Embedding
from qdrant_services.qdrant import Qdrant
logger = logging.getLogger(__name__)
class Recommend(Component):
def process(self, *args, **kwargs):
return self.do_query_multiplier_new(*args, **kwargs)
def do_query_multiplier_new(self, query_clause, filter = None):
top = 10
input_query_clause = query_clause.replace('"', '')
input_query_clause = clean_recommended_clause(input_query_clause)
# can perform any times of queries
###########################################
# note: even set top_k = 10 but system may return < 10, because some clauses have more than 1 vectors
# top_k = 10 : imply return 10 vectors. if a clause has 2 vectors, 1 clause returned only
return recommend_clauses(input_query_clause, top_k=top * 7, filter=filter);
def clean_recommended_clause(clause):
m_clause = clause.replace('\t', ' ')
m_clause = m_clause.strip()
m_clause_list = m_clause.split(' ')
result = ""
for clause_part in m_clause_list:
if len(clause_part) == 0:
continue
result += clause_part + " "
result = result.strip()
return result
def recommend_clauses(query_str, top_k=10, filter=None):
# query_vector = embedding_a_clause(query_str, is_query=True)
query_vector = Embedding.call(query_str, is_query=True)
query_vector = np.array(query_vector).astype(np.float32)
search_queries = [
models.SearchRequest(
vector=vector.tolist(),
filter = filter,
limit=top_k,
with_payload=True,
)
for vector in query_vector
]
logger.info(f"search_queries: {search_queries}")
results = Qdrant.get_client().search_batch(requests=search_queries)
results = [i for j in results for i in j] # flatten
results = sorted(results, key=lambda x: x.score) # sort by score, ascending
# as this is a batch search, same record may appear in different search results
# remove duplicates by id (keep the highest score for each record)
results = list({v.id: v for v in results}.values())
results.reverse() # sort by score, descending
results = results[:top_k]
return results