rough working recommend endpoint
This commit is contained in:
parent
1828ae5693
commit
17b66f2051
@ -1,3 +1,5 @@
|
||||
annotated-types==0.7.0
|
||||
anyio==4.8.0
|
||||
blinker==1.9.0
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
@ -5,7 +7,15 @@ click==8.1.8
|
||||
filelock==3.17.0
|
||||
Flask==3.1.0
|
||||
fsspec==2025.3.0
|
||||
grpcio==1.71.0
|
||||
grpcio-tools==1.71.0
|
||||
h11==0.14.0
|
||||
h2==4.2.0
|
||||
hpack==4.1.0
|
||||
httpcore==1.0.7
|
||||
httpx==0.28.1
|
||||
huggingface-hub==0.29.2
|
||||
hyperframe==6.1.0
|
||||
idna==3.10
|
||||
itsdangerous==2.2.0
|
||||
Jinja2==3.1.6
|
||||
@ -29,12 +39,18 @@ nvidia-nccl-cu12==2.21.5
|
||||
nvidia-nvjitlink-cu12==12.4.127
|
||||
nvidia-nvtx-cu12==12.4.127
|
||||
packaging==24.2
|
||||
portalocker==2.10.1
|
||||
protobuf==5.29.3
|
||||
pydantic==2.10.6
|
||||
pydantic_core==2.27.2
|
||||
python-dotenv==1.0.1
|
||||
PyYAML==6.0.2
|
||||
qdrant-client==1.13.3
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
safetensors==0.5.3
|
||||
setuptools==76.0.0
|
||||
sniffio==1.3.1
|
||||
sympy==1.13.1
|
||||
tokenizers==0.21.0
|
||||
torch==2.6.0
|
||||
|
||||
51
src/app.py
51
src/app.py
@ -1,7 +1,11 @@
|
||||
from flask import Flask, jsonify, request
|
||||
from dotenv import load_dotenv
|
||||
from services.embedding_service import Embedding
|
||||
import logging
|
||||
import json
|
||||
from qdrant_client.http import models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
load_dotenv()
|
||||
app = Flask(__name__)
|
||||
@ -15,6 +19,51 @@ def hello():
|
||||
return jsonify({"message":"Hello, World!"})
|
||||
|
||||
|
||||
@app.get("/recommend")
|
||||
def recommend():
|
||||
try:
|
||||
|
||||
from services.recommend_service import Recommend
|
||||
clause = request.args.get("clause", '')
|
||||
# filter shall be a json array. each element is a dict with following keys:
|
||||
# key: field name
|
||||
# value: field value
|
||||
# type : operator type, default (and the only supported one for now) is "match"
|
||||
# example: [{"key": "meta.source", "value": "km", "type": "match"}]
|
||||
# please note we limit the number of filters to 1 for now
|
||||
|
||||
filter_ = request.args.get("filter", None)
|
||||
logger.info(f"recommend request received: [{clause}] with filter [{filter_}]")
|
||||
|
||||
updated_clause = clause.replace('"', '')
|
||||
|
||||
|
||||
if filter_:
|
||||
filter_ = json.loads(filter_)
|
||||
filter_ = filter_[0]
|
||||
if filter_['type'] != 'match':
|
||||
raise ValueError('unsupported filter type')
|
||||
f = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=filter_['key'], match=models.MatchValue(value=filter_['value'])
|
||||
),
|
||||
],
|
||||
)
|
||||
filter_ = f
|
||||
|
||||
result = Recommend.call(updated_clause, filter=filter_)
|
||||
if result is None:
|
||||
result = {"data": []}
|
||||
else:
|
||||
result = {"data": [i.dict() for i in result]}
|
||||
logger.info("recommender returns %d results" % len(result["data"]))
|
||||
return jsonify(result), 200
|
||||
except Exception as e:
|
||||
logger.error('recommend error: %s', type(e).__name__, exc_info=True)
|
||||
return jsonify({"error":"Internal Server Error"}), 500
|
||||
|
||||
|
||||
@app.post("/get_embedding")
|
||||
def get_embedding():
|
||||
# Get the query text from request JSON
|
||||
@ -29,7 +78,7 @@ def get_embedding():
|
||||
|
||||
return jsonify({"query": query, "number of embedding": len(vector), "one embedding":len(vector[0])})
|
||||
|
||||
print(__name__)
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
app.run(debug=True, port=8000)
|
||||
@ -5,7 +5,7 @@ from functools import wraps
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
from qdrant_services.settings import settings
|
||||
from qdrant_services.settings import QDRANT_HOST, QDRANT_PORT, QDRANT_HTTPS, QDRANT_SSL_VERIFY,QDRANT_API_KEY, QDRANT_CLUSTER, QDRANT_COLLECTION_NAME
|
||||
from services.service import Singleton
|
||||
|
||||
from qdrant_services.policy import QdrantClientSelectionPolicy
|
||||
@ -25,16 +25,16 @@ class Qdrant(metaclass=Singleton):
|
||||
):
|
||||
self.collection_name = collection_name or self.__class__.get_default_collection_name()
|
||||
|
||||
if settings.QDRANT_HTTPS is not None and kwargs.get('https') is None:
|
||||
kwargs['https'] = settings.QDRANT_HTTPS
|
||||
if settings.QDRANT_SSL_VERIFY is not None and kwargs.get('verify') is None:
|
||||
kwargs['verify'] = settings.QDRANT_SSL_VERIFY
|
||||
if settings.QDRANT_API_KEY is not None and kwargs.get('api_key') is None:
|
||||
kwargs['api_key'] = settings.QDRANT_API_KEY
|
||||
if QDRANT_HTTPS is not None and kwargs.get('https') is None:
|
||||
kwargs['https'] = QDRANT_HTTPS
|
||||
if QDRANT_SSL_VERIFY is not None and kwargs.get('verify') is None:
|
||||
kwargs['verify'] = QDRANT_SSL_VERIFY
|
||||
if QDRANT_API_KEY is not None and kwargs.get('api_key') is None:
|
||||
kwargs['api_key'] = QDRANT_API_KEY
|
||||
|
||||
host = host or settings.QDRANT_HOST
|
||||
port = port or settings.QDRANT_PORT
|
||||
cluster = cluster or settings.QDRANT_CLUSTER
|
||||
host = host or QDRANT_HOST
|
||||
port = port or QDRANT_PORT
|
||||
cluster = cluster or QDRANT_CLUSTER
|
||||
if cluster:
|
||||
self.logger.info(
|
||||
"Connecting to qdrant cluster: cluster: %s, collection_name: %s" % (cluster, self.collection_name))
|
||||
@ -82,13 +82,16 @@ class Qdrant(metaclass=Singleton):
|
||||
|
||||
@classmethod
|
||||
def get_default_collection_name(cls) -> str:
|
||||
return settings.QDRANT_COLLECTION_NAME
|
||||
return QDRANT_COLLECTION_NAME
|
||||
|
||||
class QdrantClientProxy(QdrantClient):
|
||||
|
||||
def __init__(self, outer_instance, *args, **kwargs):
|
||||
self.outer_instance = outer_instance
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(url="https://192.168.99.122:6333",
|
||||
api_key="NS00TXlKUUIweUhWaGFuUUpUVTk6bWNVWXI1VXRSN2VWcFRtaEZ6NmdCUQ==",
|
||||
https=True,
|
||||
verify=False)
|
||||
|
||||
_excluded_methods = ['_dynamic_call_decorator']
|
||||
|
||||
|
||||
@ -1,2 +1,8 @@
|
||||
QDRANT_HOST = "192.168.99.122"
|
||||
QDRANT_PORT = 6333
|
||||
QDRANT_PORT = 6333
|
||||
QDRANT_HTTPS=''
|
||||
QDRANT_SSL_VERIFY=''
|
||||
QDRANT_API_KEY=''
|
||||
QDRANT_CLUSTER=''
|
||||
QDRANT_COLLECTION_NAME='titan-2502110203'
|
||||
|
||||
|
||||
@ -80,11 +80,7 @@ class Embedding(Component):
|
||||
clause_segments[-1] = clause_segments[-1] + ' ' + temp_str
|
||||
else:
|
||||
clause_segments.append(temp_str)
|
||||
# print("in embedding_service line 82, clause_segments: ", clause_segments)
|
||||
# print("clause_segments len: ", len(clause_segments))
|
||||
# print(len(clause_segments[0]))
|
||||
# print(len(clause_segments[1]))
|
||||
# print(len(clause_segments[2]))
|
||||
|
||||
clause_vectors = self.embedding_sentences(clause_segments)
|
||||
for i in range(len(clause_segments)):
|
||||
if len(clause_segments[i]) < segment_threshold:
|
||||
|
||||
@ -0,0 +1,66 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from qdrant_client.http import models
|
||||
|
||||
from services.service import Component
|
||||
from services.embedding_service import Embedding
|
||||
from qdrant_services.qdrant import Qdrant
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Recommend(Component):
|
||||
def process(self, *args, **kwargs):
|
||||
return self.do_query_multiplier_new(*args, **kwargs)
|
||||
|
||||
def do_query_multiplier_new(self, query_clause, filter = None):
|
||||
top = 10
|
||||
|
||||
input_query_clause = query_clause.replace('"', '')
|
||||
input_query_clause = clean_recommended_clause(input_query_clause)
|
||||
# can perform any times of queries
|
||||
###########################################
|
||||
# note: even set top_k = 10 but system may return < 10, because some clauses have more than 1 vectors
|
||||
# top_k = 10 : imply return 10 vectors. if a clause has 2 vectors, 1 clause returned only
|
||||
|
||||
return recommend_clauses(input_query_clause, top_k=top * 7, filter=filter);
|
||||
|
||||
|
||||
def clean_recommended_clause(clause):
|
||||
m_clause = clause.replace('\t', ' ')
|
||||
m_clause = m_clause.strip()
|
||||
m_clause_list = m_clause.split(' ')
|
||||
result = ""
|
||||
for clause_part in m_clause_list:
|
||||
if len(clause_part) == 0:
|
||||
continue
|
||||
result += clause_part + " "
|
||||
result = result.strip()
|
||||
return result
|
||||
|
||||
|
||||
def recommend_clauses(query_str, top_k=10, filter=None):
|
||||
# query_vector = embedding_a_clause(query_str, is_query=True)
|
||||
query_vector = Embedding.call(query_str, is_query=True)
|
||||
query_vector = np.array(query_vector).astype(np.float32)
|
||||
|
||||
search_queries = [
|
||||
models.SearchRequest(
|
||||
vector=vector.tolist(),
|
||||
filter = filter,
|
||||
limit=top_k,
|
||||
with_payload=True,
|
||||
)
|
||||
for vector in query_vector
|
||||
]
|
||||
logger.info(f"search_queries: {search_queries}")
|
||||
results = Qdrant.get_client().search_batch(requests=search_queries)
|
||||
results = [i for j in results for i in j] # flatten
|
||||
results = sorted(results, key=lambda x: x.score) # sort by score, ascending
|
||||
# as this is a batch search, same record may appear in different search results
|
||||
# remove duplicates by id (keep the highest score for each record)
|
||||
results = list({v.id: v for v in results}.values())
|
||||
results.reverse() # sort by score, descending
|
||||
results = results[:top_k]
|
||||
return results
|
||||
Loading…
Reference in New Issue
Block a user