rough working recommend endpoint
This commit is contained in:
parent
1828ae5693
commit
17b66f2051
@ -1,3 +1,5 @@
|
|||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.8.0
|
||||||
blinker==1.9.0
|
blinker==1.9.0
|
||||||
certifi==2025.1.31
|
certifi==2025.1.31
|
||||||
charset-normalizer==3.4.1
|
charset-normalizer==3.4.1
|
||||||
@ -5,7 +7,15 @@ click==8.1.8
|
|||||||
filelock==3.17.0
|
filelock==3.17.0
|
||||||
Flask==3.1.0
|
Flask==3.1.0
|
||||||
fsspec==2025.3.0
|
fsspec==2025.3.0
|
||||||
|
grpcio==1.71.0
|
||||||
|
grpcio-tools==1.71.0
|
||||||
|
h11==0.14.0
|
||||||
|
h2==4.2.0
|
||||||
|
hpack==4.1.0
|
||||||
|
httpcore==1.0.7
|
||||||
|
httpx==0.28.1
|
||||||
huggingface-hub==0.29.2
|
huggingface-hub==0.29.2
|
||||||
|
hyperframe==6.1.0
|
||||||
idna==3.10
|
idna==3.10
|
||||||
itsdangerous==2.2.0
|
itsdangerous==2.2.0
|
||||||
Jinja2==3.1.6
|
Jinja2==3.1.6
|
||||||
@ -29,12 +39,18 @@ nvidia-nccl-cu12==2.21.5
|
|||||||
nvidia-nvjitlink-cu12==12.4.127
|
nvidia-nvjitlink-cu12==12.4.127
|
||||||
nvidia-nvtx-cu12==12.4.127
|
nvidia-nvtx-cu12==12.4.127
|
||||||
packaging==24.2
|
packaging==24.2
|
||||||
|
portalocker==2.10.1
|
||||||
|
protobuf==5.29.3
|
||||||
|
pydantic==2.10.6
|
||||||
|
pydantic_core==2.27.2
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
PyYAML==6.0.2
|
PyYAML==6.0.2
|
||||||
|
qdrant-client==1.13.3
|
||||||
regex==2024.11.6
|
regex==2024.11.6
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
safetensors==0.5.3
|
safetensors==0.5.3
|
||||||
setuptools==76.0.0
|
setuptools==76.0.0
|
||||||
|
sniffio==1.3.1
|
||||||
sympy==1.13.1
|
sympy==1.13.1
|
||||||
tokenizers==0.21.0
|
tokenizers==0.21.0
|
||||||
torch==2.6.0
|
torch==2.6.0
|
||||||
|
|||||||
51
src/app.py
51
src/app.py
@ -1,7 +1,11 @@
|
|||||||
from flask import Flask, jsonify, request
|
from flask import Flask, jsonify, request
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from services.embedding_service import Embedding
|
from services.embedding_service import Embedding
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from qdrant_client.http import models
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
@ -15,6 +19,51 @@ def hello():
|
|||||||
return jsonify({"message":"Hello, World!"})
|
return jsonify({"message":"Hello, World!"})
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/recommend")
|
||||||
|
def recommend():
|
||||||
|
try:
|
||||||
|
|
||||||
|
from services.recommend_service import Recommend
|
||||||
|
clause = request.args.get("clause", '')
|
||||||
|
# filter shall be a json array. each element is a dict with following keys:
|
||||||
|
# key: field name
|
||||||
|
# value: field value
|
||||||
|
# type : operator type, default (and the only supported one for now) is "match"
|
||||||
|
# example: [{"key": "meta.source", "value": "km", "type": "match"}]
|
||||||
|
# please note we limit the number of filters to 1 for now
|
||||||
|
|
||||||
|
filter_ = request.args.get("filter", None)
|
||||||
|
logger.info(f"recommend request received: [{clause}] with filter [{filter_}]")
|
||||||
|
|
||||||
|
updated_clause = clause.replace('"', '')
|
||||||
|
|
||||||
|
|
||||||
|
if filter_:
|
||||||
|
filter_ = json.loads(filter_)
|
||||||
|
filter_ = filter_[0]
|
||||||
|
if filter_['type'] != 'match':
|
||||||
|
raise ValueError('unsupported filter type')
|
||||||
|
f = models.Filter(
|
||||||
|
must=[
|
||||||
|
models.FieldCondition(
|
||||||
|
key=filter_['key'], match=models.MatchValue(value=filter_['value'])
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
filter_ = f
|
||||||
|
|
||||||
|
result = Recommend.call(updated_clause, filter=filter_)
|
||||||
|
if result is None:
|
||||||
|
result = {"data": []}
|
||||||
|
else:
|
||||||
|
result = {"data": [i.dict() for i in result]}
|
||||||
|
logger.info("recommender returns %d results" % len(result["data"]))
|
||||||
|
return jsonify(result), 200
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('recommend error: %s', type(e).__name__, exc_info=True)
|
||||||
|
return jsonify({"error":"Internal Server Error"}), 500
|
||||||
|
|
||||||
|
|
||||||
@app.post("/get_embedding")
|
@app.post("/get_embedding")
|
||||||
def get_embedding():
|
def get_embedding():
|
||||||
# Get the query text from request JSON
|
# Get the query text from request JSON
|
||||||
@ -29,7 +78,7 @@ def get_embedding():
|
|||||||
|
|
||||||
return jsonify({"query": query, "number of embedding": len(vector), "one embedding":len(vector[0])})
|
return jsonify({"query": query, "number of embedding": len(vector), "one embedding":len(vector[0])})
|
||||||
|
|
||||||
print(__name__)
|
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
app.run(debug=True, port=8000)
|
app.run(debug=True, port=8000)
|
||||||
@ -5,7 +5,7 @@ from functools import wraps
|
|||||||
|
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
from qdrant_services.settings import settings
|
from qdrant_services.settings import QDRANT_HOST, QDRANT_PORT, QDRANT_HTTPS, QDRANT_SSL_VERIFY,QDRANT_API_KEY, QDRANT_CLUSTER, QDRANT_COLLECTION_NAME
|
||||||
from services.service import Singleton
|
from services.service import Singleton
|
||||||
|
|
||||||
from qdrant_services.policy import QdrantClientSelectionPolicy
|
from qdrant_services.policy import QdrantClientSelectionPolicy
|
||||||
@ -25,16 +25,16 @@ class Qdrant(metaclass=Singleton):
|
|||||||
):
|
):
|
||||||
self.collection_name = collection_name or self.__class__.get_default_collection_name()
|
self.collection_name = collection_name or self.__class__.get_default_collection_name()
|
||||||
|
|
||||||
if settings.QDRANT_HTTPS is not None and kwargs.get('https') is None:
|
if QDRANT_HTTPS is not None and kwargs.get('https') is None:
|
||||||
kwargs['https'] = settings.QDRANT_HTTPS
|
kwargs['https'] = QDRANT_HTTPS
|
||||||
if settings.QDRANT_SSL_VERIFY is not None and kwargs.get('verify') is None:
|
if QDRANT_SSL_VERIFY is not None and kwargs.get('verify') is None:
|
||||||
kwargs['verify'] = settings.QDRANT_SSL_VERIFY
|
kwargs['verify'] = QDRANT_SSL_VERIFY
|
||||||
if settings.QDRANT_API_KEY is not None and kwargs.get('api_key') is None:
|
if QDRANT_API_KEY is not None and kwargs.get('api_key') is None:
|
||||||
kwargs['api_key'] = settings.QDRANT_API_KEY
|
kwargs['api_key'] = QDRANT_API_KEY
|
||||||
|
|
||||||
host = host or settings.QDRANT_HOST
|
host = host or QDRANT_HOST
|
||||||
port = port or settings.QDRANT_PORT
|
port = port or QDRANT_PORT
|
||||||
cluster = cluster or settings.QDRANT_CLUSTER
|
cluster = cluster or QDRANT_CLUSTER
|
||||||
if cluster:
|
if cluster:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Connecting to qdrant cluster: cluster: %s, collection_name: %s" % (cluster, self.collection_name))
|
"Connecting to qdrant cluster: cluster: %s, collection_name: %s" % (cluster, self.collection_name))
|
||||||
@ -82,13 +82,16 @@ class Qdrant(metaclass=Singleton):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_collection_name(cls) -> str:
|
def get_default_collection_name(cls) -> str:
|
||||||
return settings.QDRANT_COLLECTION_NAME
|
return QDRANT_COLLECTION_NAME
|
||||||
|
|
||||||
class QdrantClientProxy(QdrantClient):
|
class QdrantClientProxy(QdrantClient):
|
||||||
|
|
||||||
def __init__(self, outer_instance, *args, **kwargs):
|
def __init__(self, outer_instance, *args, **kwargs):
|
||||||
self.outer_instance = outer_instance
|
self.outer_instance = outer_instance
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(url="https://192.168.99.122:6333",
|
||||||
|
api_key="NS00TXlKUUIweUhWaGFuUUpUVTk6bWNVWXI1VXRSN2VWcFRtaEZ6NmdCUQ==",
|
||||||
|
https=True,
|
||||||
|
verify=False)
|
||||||
|
|
||||||
_excluded_methods = ['_dynamic_call_decorator']
|
_excluded_methods = ['_dynamic_call_decorator']
|
||||||
|
|
||||||
|
|||||||
@ -1,2 +1,8 @@
|
|||||||
QDRANT_HOST = "192.168.99.122"
|
QDRANT_HOST = "192.168.99.122"
|
||||||
QDRANT_PORT = 6333
|
QDRANT_PORT = 6333
|
||||||
|
QDRANT_HTTPS=''
|
||||||
|
QDRANT_SSL_VERIFY=''
|
||||||
|
QDRANT_API_KEY=''
|
||||||
|
QDRANT_CLUSTER=''
|
||||||
|
QDRANT_COLLECTION_NAME='titan-2502110203'
|
||||||
|
|
||||||
|
|||||||
@ -80,11 +80,7 @@ class Embedding(Component):
|
|||||||
clause_segments[-1] = clause_segments[-1] + ' ' + temp_str
|
clause_segments[-1] = clause_segments[-1] + ' ' + temp_str
|
||||||
else:
|
else:
|
||||||
clause_segments.append(temp_str)
|
clause_segments.append(temp_str)
|
||||||
# print("in embedding_service line 82, clause_segments: ", clause_segments)
|
|
||||||
# print("clause_segments len: ", len(clause_segments))
|
|
||||||
# print(len(clause_segments[0]))
|
|
||||||
# print(len(clause_segments[1]))
|
|
||||||
# print(len(clause_segments[2]))
|
|
||||||
clause_vectors = self.embedding_sentences(clause_segments)
|
clause_vectors = self.embedding_sentences(clause_segments)
|
||||||
for i in range(len(clause_segments)):
|
for i in range(len(clause_segments)):
|
||||||
if len(clause_segments[i]) < segment_threshold:
|
if len(clause_segments[i]) < segment_threshold:
|
||||||
|
|||||||
@ -0,0 +1,66 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from qdrant_client.http import models
|
||||||
|
|
||||||
|
from services.service import Component
|
||||||
|
from services.embedding_service import Embedding
|
||||||
|
from qdrant_services.qdrant import Qdrant
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Recommend(Component):
|
||||||
|
def process(self, *args, **kwargs):
|
||||||
|
return self.do_query_multiplier_new(*args, **kwargs)
|
||||||
|
|
||||||
|
def do_query_multiplier_new(self, query_clause, filter = None):
|
||||||
|
top = 10
|
||||||
|
|
||||||
|
input_query_clause = query_clause.replace('"', '')
|
||||||
|
input_query_clause = clean_recommended_clause(input_query_clause)
|
||||||
|
# can perform any times of queries
|
||||||
|
###########################################
|
||||||
|
# note: even set top_k = 10 but system may return < 10, because some clauses have more than 1 vectors
|
||||||
|
# top_k = 10 : imply return 10 vectors. if a clause has 2 vectors, 1 clause returned only
|
||||||
|
|
||||||
|
return recommend_clauses(input_query_clause, top_k=top * 7, filter=filter);
|
||||||
|
|
||||||
|
|
||||||
|
def clean_recommended_clause(clause):
|
||||||
|
m_clause = clause.replace('\t', ' ')
|
||||||
|
m_clause = m_clause.strip()
|
||||||
|
m_clause_list = m_clause.split(' ')
|
||||||
|
result = ""
|
||||||
|
for clause_part in m_clause_list:
|
||||||
|
if len(clause_part) == 0:
|
||||||
|
continue
|
||||||
|
result += clause_part + " "
|
||||||
|
result = result.strip()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def recommend_clauses(query_str, top_k=10, filter=None):
|
||||||
|
# query_vector = embedding_a_clause(query_str, is_query=True)
|
||||||
|
query_vector = Embedding.call(query_str, is_query=True)
|
||||||
|
query_vector = np.array(query_vector).astype(np.float32)
|
||||||
|
|
||||||
|
search_queries = [
|
||||||
|
models.SearchRequest(
|
||||||
|
vector=vector.tolist(),
|
||||||
|
filter = filter,
|
||||||
|
limit=top_k,
|
||||||
|
with_payload=True,
|
||||||
|
)
|
||||||
|
for vector in query_vector
|
||||||
|
]
|
||||||
|
logger.info(f"search_queries: {search_queries}")
|
||||||
|
results = Qdrant.get_client().search_batch(requests=search_queries)
|
||||||
|
results = [i for j in results for i in j] # flatten
|
||||||
|
results = sorted(results, key=lambda x: x.score) # sort by score, ascending
|
||||||
|
# as this is a batch search, same record may appear in different search results
|
||||||
|
# remove duplicates by id (keep the highest score for each record)
|
||||||
|
results = list({v.id: v for v in results}.values())
|
||||||
|
results.reverse() # sort by score, descending
|
||||||
|
results = results[:top_k]
|
||||||
|
return results
|
||||||
Loading…
Reference in New Issue
Block a user