qdrant service
This commit is contained in:
parent
61749d94ed
commit
1828ae5693
25
src/app.py
25
src/app.py
@ -1,19 +1,34 @@
|
|||||||
from flask import Flask, jsonify
|
from flask import Flask, jsonify, request
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from services.embedding_service import Embedding
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
@app.route("/")
|
# @app.route("/")
|
||||||
def hello():
|
# def hello():
|
||||||
return jsonify({"message":"Hello, World!"})
|
# return jsonify({"message":"Hello, World!"})
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
def hello():
|
def hello():
|
||||||
return jsonify({"message":"Hello, World!"})
|
return jsonify({"message":"Hello, World!"})
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/get_embedding")
|
||||||
|
def get_embedding():
|
||||||
|
# Get the query text from request JSON
|
||||||
|
data = request.get_json()
|
||||||
|
query = data.get("query", "")
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
return jsonify({"error": "Query text is required"}), 400
|
||||||
|
|
||||||
|
# Call the embedding function
|
||||||
|
vector = Embedding.call(query, is_query=True)
|
||||||
|
|
||||||
|
return jsonify({"query": query, "number of embedding": len(vector), "one embedding":len(vector[0])})
|
||||||
|
|
||||||
print(__name__)
|
print(__name__)
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
|
|||||||
5
src/qdrant_services/policy.py
Normal file
5
src/qdrant_services/policy.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class QdrantClientSelectionPolicy(Enum):
|
||||||
|
RANDOM = 1,
|
||||||
|
ROUND_ROBIN = 2,
|
||||||
126
src/qdrant_services/qdrant.py
Normal file
126
src/qdrant_services/qdrant.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
from qdrant_services.settings import settings
|
||||||
|
from services.service import Singleton
|
||||||
|
|
||||||
|
from qdrant_services.policy import QdrantClientSelectionPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class Qdrant(metaclass=Singleton):
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
QDRANT_DEFAULT_PORT = 6333
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
host=None,
|
||||||
|
port=None,
|
||||||
|
collection_name=None,
|
||||||
|
cluster=None,
|
||||||
|
cluster_selection_policy=QdrantClientSelectionPolicy.ROUND_ROBIN,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.collection_name = collection_name or self.__class__.get_default_collection_name()
|
||||||
|
|
||||||
|
if settings.QDRANT_HTTPS is not None and kwargs.get('https') is None:
|
||||||
|
kwargs['https'] = settings.QDRANT_HTTPS
|
||||||
|
if settings.QDRANT_SSL_VERIFY is not None and kwargs.get('verify') is None:
|
||||||
|
kwargs['verify'] = settings.QDRANT_SSL_VERIFY
|
||||||
|
if settings.QDRANT_API_KEY is not None and kwargs.get('api_key') is None:
|
||||||
|
kwargs['api_key'] = settings.QDRANT_API_KEY
|
||||||
|
|
||||||
|
host = host or settings.QDRANT_HOST
|
||||||
|
port = port or settings.QDRANT_PORT
|
||||||
|
cluster = cluster or settings.QDRANT_CLUSTER
|
||||||
|
if cluster:
|
||||||
|
self.logger.info(
|
||||||
|
"Connecting to qdrant cluster: cluster: %s, collection_name: %s" % (cluster, self.collection_name))
|
||||||
|
hosts = cluster.split(",")
|
||||||
|
self._clients = []
|
||||||
|
for host_port in hosts:
|
||||||
|
host, port = host_port.split(":")
|
||||||
|
port = port or self.QDRANT_DEFAULT_PORT
|
||||||
|
port = int(port)
|
||||||
|
self._clients.append(self.QdrantClientProxy(self, host=host, port=port, **kwargs))
|
||||||
|
if len(self._clients) == 0:
|
||||||
|
raise ValueError("No qdrant hosts provided")
|
||||||
|
if len(self._clients) == 1:
|
||||||
|
self.logger.warning("Only one host provided for cluster, using it as a single client")
|
||||||
|
self._client = self._clients[0]
|
||||||
|
self._clients = None
|
||||||
|
self.cluster_selection_policy = cluster_selection_policy
|
||||||
|
self._round_robin_index = -1
|
||||||
|
else:
|
||||||
|
self.logger.info(
|
||||||
|
"Connecting to qdrant: host: %s, port: %s, collection_name: %s" % (host, port, self.collection_name))
|
||||||
|
self._client = self.QdrantClientProxy(self, host=host, port=port, **kwargs)
|
||||||
|
|
||||||
|
def _get_client(self) -> QdrantClient:
|
||||||
|
if hasattr(self, "_client"):
|
||||||
|
return self._client
|
||||||
|
clients = self._clients
|
||||||
|
if self.cluster_selection_policy == QdrantClientSelectionPolicy.RANDOM:
|
||||||
|
return random.choice(clients)
|
||||||
|
else:
|
||||||
|
self._round_robin_index = (self._round_robin_index + 1) % len(clients)
|
||||||
|
return clients[self._round_robin_index]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_client(cls, host=None, port=None, collection_name=None, cluster=None) -> QdrantClient:
|
||||||
|
proxy = cls(host, port, collection_name, cluster)
|
||||||
|
return proxy._get_client()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_all_clients(cls, cluster=None, collection_name=None) -> [QdrantClient]:
|
||||||
|
proxy = cls(collection_name=collection_name, cluster=cluster)
|
||||||
|
if hasattr(proxy, "_client"):
|
||||||
|
return [proxy._client]
|
||||||
|
return proxy._clients
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_collection_name(cls) -> str:
|
||||||
|
return settings.QDRANT_COLLECTION_NAME
|
||||||
|
|
||||||
|
class QdrantClientProxy(QdrantClient):
|
||||||
|
|
||||||
|
def __init__(self, outer_instance, *args, **kwargs):
|
||||||
|
self.outer_instance = outer_instance
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
_excluded_methods = ['_dynamic_call_decorator']
|
||||||
|
|
||||||
|
def __getattribute__(self, item):
|
||||||
|
attr = super().__getattribute__(item)
|
||||||
|
if callable(attr) and item not in self._excluded_methods:
|
||||||
|
return self._dynamic_call_decorator(attr)
|
||||||
|
return attr
|
||||||
|
|
||||||
|
def _dynamic_call_decorator(self, func):
|
||||||
|
METHODS_MODIFIED = []
|
||||||
|
METHODS_EXCLUDED = []
|
||||||
|
|
||||||
|
def check_need_collection_name():
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = sig.parameters
|
||||||
|
kwargs = [param.name for param in params.values() if param.default == inspect.Parameter.empty]
|
||||||
|
return "collection_name" in kwargs
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
while True:
|
||||||
|
if "collection_name" in kwargs or \
|
||||||
|
func.__name__ in METHODS_EXCLUDED:
|
||||||
|
break
|
||||||
|
if check_need_collection_name():
|
||||||
|
METHODS_MODIFIED.append(func.__name__)
|
||||||
|
else:
|
||||||
|
METHODS_EXCLUDED.append(func.__name__)
|
||||||
|
break
|
||||||
|
kwargs["collection_name"] = self.outer_instance.collection_name
|
||||||
|
break
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
2
src/qdrant_services/settings.py
Normal file
2
src/qdrant_services/settings.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
QDRANT_HOST = "192.168.99.122"
|
||||||
|
QDRANT_PORT = 6333
|
||||||
@ -7,7 +7,7 @@ import os
|
|||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
|
||||||
# from app.settings import MODEL_FOLDER, TOKENIZER_FOLDER
|
# from app.settings import MODEL_FOLDER, TOKENIZER_FOLDER
|
||||||
from core.service import Component
|
from services.service import Component
|
||||||
TOKENIZER_FOLDER = os.getenv("TOKENIZER_FOLDER")
|
TOKENIZER_FOLDER = os.getenv("TOKENIZER_FOLDER")
|
||||||
MODEL_FOLDER = os.getenv("MODEL_FOLDER")
|
MODEL_FOLDER = os.getenv("MODEL_FOLDER")
|
||||||
class Embedding(Component):
|
class Embedding(Component):
|
||||||
@ -23,10 +23,13 @@ class Embedding(Component):
|
|||||||
# model = AutoModel.from_pretrained('sentence-transformers/msmarco-distilbert-base-v2')
|
# model = AutoModel.from_pretrained('sentence-transformers/msmarco-distilbert-base-v2')
|
||||||
|
|
||||||
# Load model from HuggingFace Hub
|
# Load model from HuggingFace Hub
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_FOLDER)
|
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_FOLDER)
|
||||||
self.model = AutoModel.from_pretrained(MODEL_FOLDER)
|
self.model = AutoModel.from_pretrained(MODEL_FOLDER)
|
||||||
|
|
||||||
|
|
||||||
def process(self, input_clause, is_query=False):
|
def process(self, input_clause, is_query=False):
|
||||||
|
|
||||||
return self.embedding_a_clause(input_clause, is_query)
|
return self.embedding_a_clause(input_clause, is_query)
|
||||||
|
|
||||||
def embedding_a_clause(self, input_clause, is_query=False):
|
def embedding_a_clause(self, input_clause, is_query=False):
|
||||||
@ -52,6 +55,7 @@ class Embedding(Component):
|
|||||||
short_strs = [str]
|
short_strs = [str]
|
||||||
"""
|
"""
|
||||||
str_list = str.split('\n')
|
str_list = str.split('\n')
|
||||||
|
|
||||||
short_strs = []
|
short_strs = []
|
||||||
for str_item in str_list:
|
for str_item in str_list:
|
||||||
str_item = re.sub(r'\([a-z]\)\.?', '', str_item)
|
str_item = re.sub(r'\([a-z]\)\.?', '', str_item)
|
||||||
@ -60,6 +64,7 @@ class Embedding(Component):
|
|||||||
short_strs.extend(nltk.sent_tokenize(str_item))
|
short_strs.extend(nltk.sent_tokenize(str_item))
|
||||||
else:
|
else:
|
||||||
short_strs.append(str_item)
|
short_strs.append(str_item)
|
||||||
|
|
||||||
clause_segments = []
|
clause_segments = []
|
||||||
temp_str = ""
|
temp_str = ""
|
||||||
for short_str in short_strs:
|
for short_str in short_strs:
|
||||||
@ -75,7 +80,11 @@ class Embedding(Component):
|
|||||||
clause_segments[-1] = clause_segments[-1] + ' ' + temp_str
|
clause_segments[-1] = clause_segments[-1] + ' ' + temp_str
|
||||||
else:
|
else:
|
||||||
clause_segments.append(temp_str)
|
clause_segments.append(temp_str)
|
||||||
|
# print("in embedding_service line 82, clause_segments: ", clause_segments)
|
||||||
|
# print("clause_segments len: ", len(clause_segments))
|
||||||
|
# print(len(clause_segments[0]))
|
||||||
|
# print(len(clause_segments[1]))
|
||||||
|
# print(len(clause_segments[2]))
|
||||||
clause_vectors = self.embedding_sentences(clause_segments)
|
clause_vectors = self.embedding_sentences(clause_segments)
|
||||||
for i in range(len(clause_segments)):
|
for i in range(len(clause_segments)):
|
||||||
if len(clause_segments[i]) < segment_threshold:
|
if len(clause_segments[i]) < segment_threshold:
|
||||||
@ -83,6 +92,7 @@ class Embedding(Component):
|
|||||||
###
|
###
|
||||||
clause_vectors[i] /= np.linalg.norm(clause_vectors[i])
|
clause_vectors[i] /= np.linalg.norm(clause_vectors[i])
|
||||||
###
|
###
|
||||||
|
|
||||||
return clause_vectors
|
return clause_vectors
|
||||||
|
|
||||||
def embedding_sentences(self, input_sentences):
|
def embedding_sentences(self, input_sentences):
|
||||||
|
|||||||
0
src/services/recommend_service.py
Normal file
0
src/services/recommend_service.py
Normal file
96
src/services/service.py
Normal file
96
src/services/service.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
import logging
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# container contains components, components contains class
|
||||||
|
class Container:
|
||||||
|
def __init__(self):
|
||||||
|
self.instances = {}
|
||||||
|
|
||||||
|
def get_component(self, name):
|
||||||
|
return self.instances[name]
|
||||||
|
|
||||||
|
def has_component(self, name):
|
||||||
|
return name in self.instances
|
||||||
|
|
||||||
|
def set_component(self, name, component):
|
||||||
|
self.instances[name] = component
|
||||||
|
|
||||||
|
|
||||||
|
container = Container()
|
||||||
|
|
||||||
|
|
||||||
|
class Component(object):
|
||||||
|
__metaclass__ = ABCMeta
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_class(cls):
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls):
|
||||||
|
return "%s.%s" % (cls.__module__, cls.__name__)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_instance_from_container(cls):
|
||||||
|
# container = get_container_instance()
|
||||||
|
|
||||||
|
name = cls.get_name()
|
||||||
|
|
||||||
|
if not container.has_component(name):
|
||||||
|
|
||||||
|
instance = cls()
|
||||||
|
container.set_component(name, instance)
|
||||||
|
instance.load()
|
||||||
|
|
||||||
|
|
||||||
|
return container.get_component(name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def prepare(cls):
|
||||||
|
"""
|
||||||
|
Prepare this component.
|
||||||
|
It calls `load` if this hasn't been done.
|
||||||
|
"""
|
||||||
|
return cls._get_instance_from_container()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def call(cls, *args, **kwargs):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Call `process` and return the results.
|
||||||
|
It calls `load` if this hasn't been done before calling `process`
|
||||||
|
"""
|
||||||
|
return cls._get_instance_from_container().process(*args, **kwargs)
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
"""
|
||||||
|
Called when instance initialized, usually to load something can be shared between different calls.
|
||||||
|
It's recommend for each module/app itself to load if it takes very long time
|
||||||
|
It's guaranteed that the function MUST be called before any call to process.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""
|
||||||
|
called when application shutdown
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Singleton(type):
|
||||||
|
_instances = {}
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||||
|
return cls._instances[cls]
|
||||||
Loading…
Reference in New Issue
Block a user