qdrant service

This commit is contained in:
charlene tau express 2025-03-13 18:03:04 +08:00
parent 61749d94ed
commit 1828ae5693
7 changed files with 263 additions and 9 deletions

View File

@ -1,19 +1,34 @@
from flask import Flask, jsonify
from flask import Flask, jsonify, request
from dotenv import load_dotenv
from services.embedding_service import Embedding
load_dotenv()
app = Flask(__name__)
@app.route("/")
def hello():
return jsonify({"message":"Hello, World!"})
# @app.route("/")
# def hello():
# return jsonify({"message":"Hello, World!"})
@app.get("/")
def hello():
return jsonify({"message":"Hello, World!"})
@app.post("/get_embedding")
def get_embedding():
# Get the query text from request JSON
data = request.get_json()
query = data.get("query", "")
if not query:
return jsonify({"error": "Query text is required"}), 400
# Call the embedding function
vector = Embedding.call(query, is_query=True)
return jsonify({"query": query, "number of embedding": len(vector), "one embedding":len(vector[0])})
print(__name__)
if __name__=="__main__":

View File

@ -0,0 +1,5 @@
from enum import Enum
class QdrantClientSelectionPolicy(Enum):
RANDOM = 1,
ROUND_ROBIN = 2,

View File

@ -0,0 +1,126 @@
import inspect
import logging
import random
from functools import wraps
from qdrant_client import QdrantClient
from qdrant_services.settings import settings
from services.service import Singleton
from qdrant_services.policy import QdrantClientSelectionPolicy
class Qdrant(metaclass=Singleton):
logger = logging.getLogger(__name__)
QDRANT_DEFAULT_PORT = 6333
def __init__(self,
host=None,
port=None,
collection_name=None,
cluster=None,
cluster_selection_policy=QdrantClientSelectionPolicy.ROUND_ROBIN,
**kwargs
):
self.collection_name = collection_name or self.__class__.get_default_collection_name()
if settings.QDRANT_HTTPS is not None and kwargs.get('https') is None:
kwargs['https'] = settings.QDRANT_HTTPS
if settings.QDRANT_SSL_VERIFY is not None and kwargs.get('verify') is None:
kwargs['verify'] = settings.QDRANT_SSL_VERIFY
if settings.QDRANT_API_KEY is not None and kwargs.get('api_key') is None:
kwargs['api_key'] = settings.QDRANT_API_KEY
host = host or settings.QDRANT_HOST
port = port or settings.QDRANT_PORT
cluster = cluster or settings.QDRANT_CLUSTER
if cluster:
self.logger.info(
"Connecting to qdrant cluster: cluster: %s, collection_name: %s" % (cluster, self.collection_name))
hosts = cluster.split(",")
self._clients = []
for host_port in hosts:
host, port = host_port.split(":")
port = port or self.QDRANT_DEFAULT_PORT
port = int(port)
self._clients.append(self.QdrantClientProxy(self, host=host, port=port, **kwargs))
if len(self._clients) == 0:
raise ValueError("No qdrant hosts provided")
if len(self._clients) == 1:
self.logger.warning("Only one host provided for cluster, using it as a single client")
self._client = self._clients[0]
self._clients = None
self.cluster_selection_policy = cluster_selection_policy
self._round_robin_index = -1
else:
self.logger.info(
"Connecting to qdrant: host: %s, port: %s, collection_name: %s" % (host, port, self.collection_name))
self._client = self.QdrantClientProxy(self, host=host, port=port, **kwargs)
def _get_client(self) -> QdrantClient:
if hasattr(self, "_client"):
return self._client
clients = self._clients
if self.cluster_selection_policy == QdrantClientSelectionPolicy.RANDOM:
return random.choice(clients)
else:
self._round_robin_index = (self._round_robin_index + 1) % len(clients)
return clients[self._round_robin_index]
@classmethod
def get_client(cls, host=None, port=None, collection_name=None, cluster=None) -> QdrantClient:
proxy = cls(host, port, collection_name, cluster)
return proxy._get_client()
@classmethod
def get_all_clients(cls, cluster=None, collection_name=None) -> [QdrantClient]:
proxy = cls(collection_name=collection_name, cluster=cluster)
if hasattr(proxy, "_client"):
return [proxy._client]
return proxy._clients
@classmethod
def get_default_collection_name(cls) -> str:
return settings.QDRANT_COLLECTION_NAME
class QdrantClientProxy(QdrantClient):
def __init__(self, outer_instance, *args, **kwargs):
self.outer_instance = outer_instance
super().__init__(*args, **kwargs)
_excluded_methods = ['_dynamic_call_decorator']
def __getattribute__(self, item):
attr = super().__getattribute__(item)
if callable(attr) and item not in self._excluded_methods:
return self._dynamic_call_decorator(attr)
return attr
def _dynamic_call_decorator(self, func):
METHODS_MODIFIED = []
METHODS_EXCLUDED = []
def check_need_collection_name():
sig = inspect.signature(func)
params = sig.parameters
kwargs = [param.name for param in params.values() if param.default == inspect.Parameter.empty]
return "collection_name" in kwargs
@wraps(func)
def wrapper(*args, **kwargs):
while True:
if "collection_name" in kwargs or \
func.__name__ in METHODS_EXCLUDED:
break
if check_need_collection_name():
METHODS_MODIFIED.append(func.__name__)
else:
METHODS_EXCLUDED.append(func.__name__)
break
kwargs["collection_name"] = self.outer_instance.collection_name
break
return func(*args, **kwargs)
return wrapper

View File

@ -0,0 +1,2 @@
QDRANT_HOST = "192.168.99.122"
QDRANT_PORT = 6333

View File

@ -7,7 +7,7 @@ import os
from transformers import AutoTokenizer, AutoModel
# from app.settings import MODEL_FOLDER, TOKENIZER_FOLDER
from core.service import Component
from services.service import Component
TOKENIZER_FOLDER = os.getenv("TOKENIZER_FOLDER")
MODEL_FOLDER = os.getenv("MODEL_FOLDER")
class Embedding(Component):
@ -23,10 +23,13 @@ class Embedding(Component):
# model = AutoModel.from_pretrained('sentence-transformers/msmarco-distilbert-base-v2')
# Load model from HuggingFace Hub
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_FOLDER)
self.model = AutoModel.from_pretrained(MODEL_FOLDER)
def process(self, input_clause, is_query=False):
return self.embedding_a_clause(input_clause, is_query)
def embedding_a_clause(self, input_clause, is_query=False):
@ -52,6 +55,7 @@ class Embedding(Component):
short_strs = [str]
"""
str_list = str.split('\n')
short_strs = []
for str_item in str_list:
str_item = re.sub(r'\([a-z]\)\.?', '', str_item)
@ -60,6 +64,7 @@ class Embedding(Component):
short_strs.extend(nltk.sent_tokenize(str_item))
else:
short_strs.append(str_item)
clause_segments = []
temp_str = ""
for short_str in short_strs:
@ -75,7 +80,11 @@ class Embedding(Component):
clause_segments[-1] = clause_segments[-1] + ' ' + temp_str
else:
clause_segments.append(temp_str)
# print("in embedding_service line 82, clause_segments: ", clause_segments)
# print("clause_segments len: ", len(clause_segments))
# print(len(clause_segments[0]))
# print(len(clause_segments[1]))
# print(len(clause_segments[2]))
clause_vectors = self.embedding_sentences(clause_segments)
for i in range(len(clause_segments)):
if len(clause_segments[i]) < segment_threshold:
@ -83,6 +92,7 @@ class Embedding(Component):
###
clause_vectors[i] /= np.linalg.norm(clause_vectors[i])
###
return clause_vectors
def embedding_sentences(self, input_sentences):
@ -95,11 +105,11 @@ class Embedding(Component):
"""
# Tokenize sentences
encoded_input = self.tokenizer(input_sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = self.model(**encoded_input)
# Perform pooling. In this case, max pooling.
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])

View File

96
src/services/service.py Normal file
View File

@ -0,0 +1,96 @@
import logging
from abc import ABCMeta, abstractmethod
logger = logging.getLogger(__name__)
# container contains components, components contains class
class Container:
def __init__(self):
self.instances = {}
def get_component(self, name):
return self.instances[name]
def has_component(self, name):
return name in self.instances
def set_component(self, name, component):
self.instances[name] = component
container = Container()
class Component(object):
__metaclass__ = ABCMeta
def __init__(self):
pass
@classmethod
def get_class(cls):
return cls
@classmethod
def get_name(cls):
return "%s.%s" % (cls.__module__, cls.__name__)
@classmethod
def _get_instance_from_container(cls):
# container = get_container_instance()
name = cls.get_name()
if not container.has_component(name):
instance = cls()
container.set_component(name, instance)
instance.load()
return container.get_component(name)
@classmethod
def prepare(cls):
"""
Prepare this component.
It calls `load` if this hasn't been done.
"""
return cls._get_instance_from_container()
@classmethod
def call(cls, *args, **kwargs):
"""
Call `process` and return the results.
It calls `load` if this hasn't been done before calling `process`
"""
return cls._get_instance_from_container().process(*args, **kwargs)
def load(self):
"""
Called when instance initialized, usually to load something can be shared between different calls.
It's recommend for each module/app itself to load if it takes very long time
It's guaranteed that the function MUST be called before any call to process.
"""
pass
@abstractmethod
def process(self, *args, **kwargs):
pass
def shutdown(self):
"""
called when application shutdown
"""
pass
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]