From 1828ae569328baeb9ababa11e63b26bbfdba3cb0 Mon Sep 17 00:00:00 2001 From: charlene tau express Date: Thu, 13 Mar 2025 18:03:04 +0800 Subject: [PATCH] qdrant service --- src/app.py | 25 ++++-- src/qdrant_services/policy.py | 5 ++ src/qdrant_services/qdrant.py | 126 ++++++++++++++++++++++++++++++ src/qdrant_services/settings.py | 2 + src/services/embedding_service.py | 18 ++++- src/services/recommend_service.py | 0 src/services/service.py | 96 +++++++++++++++++++++++ 7 files changed, 263 insertions(+), 9 deletions(-) create mode 100644 src/qdrant_services/policy.py create mode 100644 src/qdrant_services/qdrant.py create mode 100644 src/qdrant_services/settings.py create mode 100644 src/services/recommend_service.py create mode 100644 src/services/service.py diff --git a/src/app.py b/src/app.py index 7853068..cd1ad44 100644 --- a/src/app.py +++ b/src/app.py @@ -1,19 +1,34 @@ -from flask import Flask, jsonify +from flask import Flask, jsonify, request from dotenv import load_dotenv - +from services.embedding_service import Embedding load_dotenv() app = Flask(__name__) -@app.route("/") -def hello(): - return jsonify({"message":"Hello, World!"}) +# @app.route("/") +# def hello(): +# return jsonify({"message":"Hello, World!"}) @app.get("/") def hello(): return jsonify({"message":"Hello, World!"}) + +@app.post("/get_embedding") +def get_embedding(): + # Get the query text from request JSON + data = request.get_json() + query = data.get("query", "") + + if not query: + return jsonify({"error": "Query text is required"}), 400 + + # Call the embedding function + vector = Embedding.call(query, is_query=True) + + return jsonify({"query": query, "number of embedding": len(vector), "one embedding":len(vector[0])}) + print(__name__) if __name__=="__main__": diff --git a/src/qdrant_services/policy.py b/src/qdrant_services/policy.py new file mode 100644 index 0000000..2376ad1 --- /dev/null +++ b/src/qdrant_services/policy.py @@ -0,0 +1,5 @@ +from enum import Enum + +class QdrantClientSelectionPolicy(Enum): + RANDOM = 1, + ROUND_ROBIN = 2, \ No newline at end of file diff --git a/src/qdrant_services/qdrant.py b/src/qdrant_services/qdrant.py new file mode 100644 index 0000000..d6c2dd6 --- /dev/null +++ b/src/qdrant_services/qdrant.py @@ -0,0 +1,126 @@ +import inspect +import logging +import random +from functools import wraps + +from qdrant_client import QdrantClient + +from qdrant_services.settings import settings +from services.service import Singleton + +from qdrant_services.policy import QdrantClientSelectionPolicy + + +class Qdrant(metaclass=Singleton): + logger = logging.getLogger(__name__) + QDRANT_DEFAULT_PORT = 6333 + + def __init__(self, + host=None, + port=None, + collection_name=None, + cluster=None, + cluster_selection_policy=QdrantClientSelectionPolicy.ROUND_ROBIN, + **kwargs + ): + self.collection_name = collection_name or self.__class__.get_default_collection_name() + + if settings.QDRANT_HTTPS is not None and kwargs.get('https') is None: + kwargs['https'] = settings.QDRANT_HTTPS + if settings.QDRANT_SSL_VERIFY is not None and kwargs.get('verify') is None: + kwargs['verify'] = settings.QDRANT_SSL_VERIFY + if settings.QDRANT_API_KEY is not None and kwargs.get('api_key') is None: + kwargs['api_key'] = settings.QDRANT_API_KEY + + host = host or settings.QDRANT_HOST + port = port or settings.QDRANT_PORT + cluster = cluster or settings.QDRANT_CLUSTER + if cluster: + self.logger.info( + "Connecting to qdrant cluster: cluster: %s, collection_name: %s" % (cluster, self.collection_name)) + hosts = cluster.split(",") + self._clients = [] + for host_port in hosts: + host, port = host_port.split(":") + port = port or self.QDRANT_DEFAULT_PORT + port = int(port) + self._clients.append(self.QdrantClientProxy(self, host=host, port=port, **kwargs)) + if len(self._clients) == 0: + raise ValueError("No qdrant hosts provided") + if len(self._clients) == 1: + self.logger.warning("Only one host provided for cluster, using it as a single client") + self._client = self._clients[0] + self._clients = None + self.cluster_selection_policy = cluster_selection_policy + self._round_robin_index = -1 + else: + self.logger.info( + "Connecting to qdrant: host: %s, port: %s, collection_name: %s" % (host, port, self.collection_name)) + self._client = self.QdrantClientProxy(self, host=host, port=port, **kwargs) + + def _get_client(self) -> QdrantClient: + if hasattr(self, "_client"): + return self._client + clients = self._clients + if self.cluster_selection_policy == QdrantClientSelectionPolicy.RANDOM: + return random.choice(clients) + else: + self._round_robin_index = (self._round_robin_index + 1) % len(clients) + return clients[self._round_robin_index] + + @classmethod + def get_client(cls, host=None, port=None, collection_name=None, cluster=None) -> QdrantClient: + proxy = cls(host, port, collection_name, cluster) + return proxy._get_client() + + @classmethod + def get_all_clients(cls, cluster=None, collection_name=None) -> [QdrantClient]: + proxy = cls(collection_name=collection_name, cluster=cluster) + if hasattr(proxy, "_client"): + return [proxy._client] + return proxy._clients + + @classmethod + def get_default_collection_name(cls) -> str: + return settings.QDRANT_COLLECTION_NAME + + class QdrantClientProxy(QdrantClient): + + def __init__(self, outer_instance, *args, **kwargs): + self.outer_instance = outer_instance + super().__init__(*args, **kwargs) + + _excluded_methods = ['_dynamic_call_decorator'] + + def __getattribute__(self, item): + attr = super().__getattribute__(item) + if callable(attr) and item not in self._excluded_methods: + return self._dynamic_call_decorator(attr) + return attr + + def _dynamic_call_decorator(self, func): + METHODS_MODIFIED = [] + METHODS_EXCLUDED = [] + + def check_need_collection_name(): + sig = inspect.signature(func) + params = sig.parameters + kwargs = [param.name for param in params.values() if param.default == inspect.Parameter.empty] + return "collection_name" in kwargs + + @wraps(func) + def wrapper(*args, **kwargs): + while True: + if "collection_name" in kwargs or \ + func.__name__ in METHODS_EXCLUDED: + break + if check_need_collection_name(): + METHODS_MODIFIED.append(func.__name__) + else: + METHODS_EXCLUDED.append(func.__name__) + break + kwargs["collection_name"] = self.outer_instance.collection_name + break + return func(*args, **kwargs) + + return wrapper diff --git a/src/qdrant_services/settings.py b/src/qdrant_services/settings.py new file mode 100644 index 0000000..d033e22 --- /dev/null +++ b/src/qdrant_services/settings.py @@ -0,0 +1,2 @@ +QDRANT_HOST = "192.168.99.122" +QDRANT_PORT = 6333 \ No newline at end of file diff --git a/src/services/embedding_service.py b/src/services/embedding_service.py index 6af4210..acd348a 100644 --- a/src/services/embedding_service.py +++ b/src/services/embedding_service.py @@ -7,7 +7,7 @@ import os from transformers import AutoTokenizer, AutoModel # from app.settings import MODEL_FOLDER, TOKENIZER_FOLDER -from core.service import Component +from services.service import Component TOKENIZER_FOLDER = os.getenv("TOKENIZER_FOLDER") MODEL_FOLDER = os.getenv("MODEL_FOLDER") class Embedding(Component): @@ -23,10 +23,13 @@ class Embedding(Component): # model = AutoModel.from_pretrained('sentence-transformers/msmarco-distilbert-base-v2') # Load model from HuggingFace Hub + self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_FOLDER) self.model = AutoModel.from_pretrained(MODEL_FOLDER) + def process(self, input_clause, is_query=False): + return self.embedding_a_clause(input_clause, is_query) def embedding_a_clause(self, input_clause, is_query=False): @@ -52,6 +55,7 @@ class Embedding(Component): short_strs = [str] """ str_list = str.split('\n') + short_strs = [] for str_item in str_list: str_item = re.sub(r'\([a-z]\)\.?', '', str_item) @@ -60,6 +64,7 @@ class Embedding(Component): short_strs.extend(nltk.sent_tokenize(str_item)) else: short_strs.append(str_item) + clause_segments = [] temp_str = "" for short_str in short_strs: @@ -75,7 +80,11 @@ class Embedding(Component): clause_segments[-1] = clause_segments[-1] + ' ' + temp_str else: clause_segments.append(temp_str) - + # print("in embedding_service line 82, clause_segments: ", clause_segments) + # print("clause_segments len: ", len(clause_segments)) + # print(len(clause_segments[0])) + # print(len(clause_segments[1])) + # print(len(clause_segments[2])) clause_vectors = self.embedding_sentences(clause_segments) for i in range(len(clause_segments)): if len(clause_segments[i]) < segment_threshold: @@ -83,6 +92,7 @@ class Embedding(Component): ### clause_vectors[i] /= np.linalg.norm(clause_vectors[i]) ### + return clause_vectors def embedding_sentences(self, input_sentences): @@ -95,11 +105,11 @@ class Embedding(Component): """ # Tokenize sentences encoded_input = self.tokenizer(input_sentences, padding=True, truncation=True, return_tensors='pt') - + # Compute token embeddings with torch.no_grad(): model_output = self.model(**encoded_input) - + # Perform pooling. In this case, max pooling. sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask']) diff --git a/src/services/recommend_service.py b/src/services/recommend_service.py new file mode 100644 index 0000000..e69de29 diff --git a/src/services/service.py b/src/services/service.py new file mode 100644 index 0000000..113dc3d --- /dev/null +++ b/src/services/service.py @@ -0,0 +1,96 @@ +import logging +from abc import ABCMeta, abstractmethod + +logger = logging.getLogger(__name__) + + +# container contains components, components contains class +class Container: + def __init__(self): + self.instances = {} + + def get_component(self, name): + return self.instances[name] + + def has_component(self, name): + return name in self.instances + + def set_component(self, name, component): + self.instances[name] = component + + +container = Container() + + +class Component(object): + __metaclass__ = ABCMeta + + def __init__(self): + pass + + @classmethod + def get_class(cls): + return cls + + @classmethod + def get_name(cls): + return "%s.%s" % (cls.__module__, cls.__name__) + + @classmethod + def _get_instance_from_container(cls): + # container = get_container_instance() + + name = cls.get_name() + + if not container.has_component(name): + + instance = cls() + container.set_component(name, instance) + instance.load() + + + return container.get_component(name) + + @classmethod + def prepare(cls): + """ + Prepare this component. + It calls `load` if this hasn't been done. + """ + return cls._get_instance_from_container() + + @classmethod + def call(cls, *args, **kwargs): + + """ + Call `process` and return the results. + It calls `load` if this hasn't been done before calling `process` + """ + return cls._get_instance_from_container().process(*args, **kwargs) + + def load(self): + """ + Called when instance initialized, usually to load something can be shared between different calls. + It's recommend for each module/app itself to load if it takes very long time + It's guaranteed that the function MUST be called before any call to process. + """ + pass + + @abstractmethod + def process(self, *args, **kwargs): + pass + + def shutdown(self): + """ + called when application shutdown + """ + pass + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls]