Upload rml_ai/memory.py with huggingface_hub
Browse files- rml_ai/memory.py +178 -0
rml_ai/memory.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Memory Store for RML System
|
3 |
+
Handles vector storage and semantic search
|
4 |
+
"""
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from typing import List, Dict, Any, Optional, Callable
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
+
|
10 |
+
|
11 |
+
class MemoryStore:
|
12 |
+
"""Vector-based memory store for semantic search"""
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
self.entries = []
|
16 |
+
self.embeddings = None
|
17 |
+
self.encode_query_fn: Optional[Callable] = None
|
18 |
+
|
19 |
+
def add_entries(self, entries: List[Dict[str, Any]], embeddings: np.ndarray):
|
20 |
+
"""Add entries with their embeddings"""
|
21 |
+
self.entries = entries
|
22 |
+
self.embeddings = embeddings
|
23 |
+
|
24 |
+
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
25 |
+
"""Search for relevant entries using semantic similarity"""
|
26 |
+
if not self.entries or self.embeddings is None:
|
27 |
+
return []
|
28 |
+
|
29 |
+
if not self.encode_query_fn:
|
30 |
+
# Fallback to keyword search
|
31 |
+
return self._keyword_search(query, top_k)
|
32 |
+
|
33 |
+
try:
|
34 |
+
# Encode query
|
35 |
+
query_embedding = self.encode_query_fn(query)
|
36 |
+
|
37 |
+
# Handle empty embeddings
|
38 |
+
if self.embeddings is None or len(self.embeddings) == 0:
|
39 |
+
return self._keyword_search(query, top_k)
|
40 |
+
|
41 |
+
# Ensure proper dimensions
|
42 |
+
if len(self.embeddings.shape) == 1:
|
43 |
+
# If embeddings is 1D, reshape to 2D
|
44 |
+
embeddings = self.embeddings.reshape(1, -1)
|
45 |
+
else:
|
46 |
+
embeddings = self.embeddings
|
47 |
+
|
48 |
+
if len(query_embedding.shape) == 1:
|
49 |
+
query_embedding = query_embedding.reshape(1, -1)
|
50 |
+
|
51 |
+
# Check dimension compatibility
|
52 |
+
if query_embedding.shape[1] != embeddings.shape[1]:
|
53 |
+
print(f"Embedding dimension mismatch: query {query_embedding.shape[1]} vs entries {embeddings.shape[1]}")
|
54 |
+
return self._keyword_search(query, top_k)
|
55 |
+
|
56 |
+
# Calculate similarities
|
57 |
+
similarities = cosine_similarity(query_embedding, embeddings)[0]
|
58 |
+
|
59 |
+
# Get top-k results
|
60 |
+
top_indices = np.argsort(similarities)[::-1][:top_k]
|
61 |
+
|
62 |
+
results = []
|
63 |
+
for idx in top_indices:
|
64 |
+
if similarities[idx] > 0.1: # Minimum similarity threshold
|
65 |
+
entry = self.entries[idx].copy()
|
66 |
+
entry['text'] = self._extract_text(entry)
|
67 |
+
entry['similarity'] = float(similarities[idx])
|
68 |
+
entry['source'] = entry.get('source', 'internal dataset')
|
69 |
+
results.append(entry)
|
70 |
+
|
71 |
+
return results
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
print(f"Error during search: {e}")
|
75 |
+
return self._keyword_search(query, top_k)
|
76 |
+
|
77 |
+
def _keyword_search(self, query: str, top_k: int) -> List[Dict[str, Any]]:
|
78 |
+
"""Fallback keyword search with RML-aware scoring"""
|
79 |
+
query_lower = query.lower()
|
80 |
+
query_words = set(query_lower.split())
|
81 |
+
results = []
|
82 |
+
|
83 |
+
for entry in self.entries:
|
84 |
+
score = 0
|
85 |
+
text = self._extract_text(entry).lower()
|
86 |
+
|
87 |
+
# Check direct text matches
|
88 |
+
text_words = set(text.split())
|
89 |
+
common_words = query_words.intersection(text_words)
|
90 |
+
score += len(common_words) * 2 # Base score for word matches
|
91 |
+
|
92 |
+
# Boost score for matches in specific RML fields
|
93 |
+
if 'concepts' in entry and entry['concepts']:
|
94 |
+
concepts_text = " ".join(entry['concepts']).lower() if isinstance(entry['concepts'], list) else str(entry['concepts']).lower()
|
95 |
+
concept_matches = sum(1 for word in query_words if word in concepts_text)
|
96 |
+
score += concept_matches * 3 # Higher weight for concept matches
|
97 |
+
|
98 |
+
if 'tags' in entry and entry['tags']:
|
99 |
+
tags_text = " ".join(entry['tags']).lower() if isinstance(entry['tags'], list) else str(entry['tags']).lower()
|
100 |
+
tag_matches = sum(1 for word in query_words if word in tags_text)
|
101 |
+
score += tag_matches * 2 # Medium weight for tag matches
|
102 |
+
|
103 |
+
if 'summaries' in entry and entry['summaries']:
|
104 |
+
summary_text = entry['summaries'][0].lower() if isinstance(entry['summaries'], list) and entry['summaries'] else str(entry['summaries']).lower()
|
105 |
+
summary_matches = sum(1 for word in query_words if word in summary_text)
|
106 |
+
score += summary_matches * 4 # Highest weight for summary matches
|
107 |
+
|
108 |
+
# Only include results with some relevance
|
109 |
+
if score > 0:
|
110 |
+
entry_copy = entry.copy()
|
111 |
+
entry_copy['text'] = self._extract_text(entry)
|
112 |
+
entry_copy['similarity'] = min(0.9, score / 10) # Normalize score to similarity
|
113 |
+
entry_copy['source'] = entry.get('source', 'internal dataset')
|
114 |
+
results.append(entry_copy)
|
115 |
+
|
116 |
+
# Sort by similarity score and return top-k
|
117 |
+
results.sort(key=lambda x: x['similarity'], reverse=True)
|
118 |
+
return results[:top_k]
|
119 |
+
|
120 |
+
def _extract_text(self, entry: Dict[str, Any]) -> str:
|
121 |
+
"""Extract text content from entry, handling RML-specific structure"""
|
122 |
+
# First try standard fields
|
123 |
+
for field in ['text', 'content', 'body', 'chunk', 'summary', 'title']:
|
124 |
+
if field in entry and entry[field]:
|
125 |
+
return str(entry[field])
|
126 |
+
|
127 |
+
# Handle RML-specific structure
|
128 |
+
text_parts = []
|
129 |
+
|
130 |
+
# Extract from summaries (first priority for RML data)
|
131 |
+
if 'summaries' in entry and entry['summaries']:
|
132 |
+
if isinstance(entry['summaries'], list) and entry['summaries']:
|
133 |
+
text_parts.append(entry['summaries'][0])
|
134 |
+
elif isinstance(entry['summaries'], str):
|
135 |
+
text_parts.append(entry['summaries'])
|
136 |
+
|
137 |
+
# Extract from concepts
|
138 |
+
if 'concepts' in entry and entry['concepts']:
|
139 |
+
if isinstance(entry['concepts'], list):
|
140 |
+
text_parts.append(" ".join(entry['concepts'][:10])) # First 10 concepts
|
141 |
+
elif isinstance(entry['concepts'], str):
|
142 |
+
text_parts.append(entry['concepts'])
|
143 |
+
|
144 |
+
# Extract from tags
|
145 |
+
if 'tags' in entry and entry['tags']:
|
146 |
+
if isinstance(entry['tags'], list):
|
147 |
+
text_parts.append(" ".join(entry['tags'][:10])) # First 10 tags
|
148 |
+
elif isinstance(entry['tags'], str):
|
149 |
+
text_parts.append(entry['tags'])
|
150 |
+
|
151 |
+
# Combine all parts
|
152 |
+
if text_parts:
|
153 |
+
return " ".join(text_parts)
|
154 |
+
|
155 |
+
# Fallback: convert entire entry to string (excluding large arrays)
|
156 |
+
filtered_entry = {}
|
157 |
+
for k, v in entry.items():
|
158 |
+
if k not in ['vectors', 'embeddings'] and v:
|
159 |
+
if isinstance(v, list) and len(v) > 20:
|
160 |
+
filtered_entry[k] = v[:5] # Only first 5 items of large lists
|
161 |
+
else:
|
162 |
+
filtered_entry[k] = v
|
163 |
+
|
164 |
+
return str(filtered_entry) if filtered_entry else "No content available"
|
165 |
+
|
166 |
+
def get_stats(self) -> Dict[str, Any]:
|
167 |
+
"""Get memory store statistics"""
|
168 |
+
embedding_dim = 0
|
169 |
+
if self.embeddings is not None and len(self.embeddings.shape) > 1:
|
170 |
+
embedding_dim = self.embeddings.shape[1]
|
171 |
+
elif self.embeddings is not None and len(self.embeddings.shape) == 1:
|
172 |
+
embedding_dim = len(self.embeddings)
|
173 |
+
|
174 |
+
return {
|
175 |
+
'total_entries': len(self.entries),
|
176 |
+
'embedding_dim': embedding_dim,
|
177 |
+
'has_embeddings': self.embeddings is not None
|
178 |
+
}
|