akshaynayaks9845 commited on
Commit
22680c0
·
verified ·
1 Parent(s): 7413962

Upload rml_ai/memory.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. rml_ai/memory.py +178 -0
rml_ai/memory.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Memory Store for RML System
3
+ Handles vector storage and semantic search
4
+ """
5
+
6
+ import numpy as np
7
+ from typing import List, Dict, Any, Optional, Callable
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+
10
+
11
+ class MemoryStore:
12
+ """Vector-based memory store for semantic search"""
13
+
14
+ def __init__(self):
15
+ self.entries = []
16
+ self.embeddings = None
17
+ self.encode_query_fn: Optional[Callable] = None
18
+
19
+ def add_entries(self, entries: List[Dict[str, Any]], embeddings: np.ndarray):
20
+ """Add entries with their embeddings"""
21
+ self.entries = entries
22
+ self.embeddings = embeddings
23
+
24
+ def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
25
+ """Search for relevant entries using semantic similarity"""
26
+ if not self.entries or self.embeddings is None:
27
+ return []
28
+
29
+ if not self.encode_query_fn:
30
+ # Fallback to keyword search
31
+ return self._keyword_search(query, top_k)
32
+
33
+ try:
34
+ # Encode query
35
+ query_embedding = self.encode_query_fn(query)
36
+
37
+ # Handle empty embeddings
38
+ if self.embeddings is None or len(self.embeddings) == 0:
39
+ return self._keyword_search(query, top_k)
40
+
41
+ # Ensure proper dimensions
42
+ if len(self.embeddings.shape) == 1:
43
+ # If embeddings is 1D, reshape to 2D
44
+ embeddings = self.embeddings.reshape(1, -1)
45
+ else:
46
+ embeddings = self.embeddings
47
+
48
+ if len(query_embedding.shape) == 1:
49
+ query_embedding = query_embedding.reshape(1, -1)
50
+
51
+ # Check dimension compatibility
52
+ if query_embedding.shape[1] != embeddings.shape[1]:
53
+ print(f"Embedding dimension mismatch: query {query_embedding.shape[1]} vs entries {embeddings.shape[1]}")
54
+ return self._keyword_search(query, top_k)
55
+
56
+ # Calculate similarities
57
+ similarities = cosine_similarity(query_embedding, embeddings)[0]
58
+
59
+ # Get top-k results
60
+ top_indices = np.argsort(similarities)[::-1][:top_k]
61
+
62
+ results = []
63
+ for idx in top_indices:
64
+ if similarities[idx] > 0.1: # Minimum similarity threshold
65
+ entry = self.entries[idx].copy()
66
+ entry['text'] = self._extract_text(entry)
67
+ entry['similarity'] = float(similarities[idx])
68
+ entry['source'] = entry.get('source', 'internal dataset')
69
+ results.append(entry)
70
+
71
+ return results
72
+
73
+ except Exception as e:
74
+ print(f"Error during search: {e}")
75
+ return self._keyword_search(query, top_k)
76
+
77
+ def _keyword_search(self, query: str, top_k: int) -> List[Dict[str, Any]]:
78
+ """Fallback keyword search with RML-aware scoring"""
79
+ query_lower = query.lower()
80
+ query_words = set(query_lower.split())
81
+ results = []
82
+
83
+ for entry in self.entries:
84
+ score = 0
85
+ text = self._extract_text(entry).lower()
86
+
87
+ # Check direct text matches
88
+ text_words = set(text.split())
89
+ common_words = query_words.intersection(text_words)
90
+ score += len(common_words) * 2 # Base score for word matches
91
+
92
+ # Boost score for matches in specific RML fields
93
+ if 'concepts' in entry and entry['concepts']:
94
+ concepts_text = " ".join(entry['concepts']).lower() if isinstance(entry['concepts'], list) else str(entry['concepts']).lower()
95
+ concept_matches = sum(1 for word in query_words if word in concepts_text)
96
+ score += concept_matches * 3 # Higher weight for concept matches
97
+
98
+ if 'tags' in entry and entry['tags']:
99
+ tags_text = " ".join(entry['tags']).lower() if isinstance(entry['tags'], list) else str(entry['tags']).lower()
100
+ tag_matches = sum(1 for word in query_words if word in tags_text)
101
+ score += tag_matches * 2 # Medium weight for tag matches
102
+
103
+ if 'summaries' in entry and entry['summaries']:
104
+ summary_text = entry['summaries'][0].lower() if isinstance(entry['summaries'], list) and entry['summaries'] else str(entry['summaries']).lower()
105
+ summary_matches = sum(1 for word in query_words if word in summary_text)
106
+ score += summary_matches * 4 # Highest weight for summary matches
107
+
108
+ # Only include results with some relevance
109
+ if score > 0:
110
+ entry_copy = entry.copy()
111
+ entry_copy['text'] = self._extract_text(entry)
112
+ entry_copy['similarity'] = min(0.9, score / 10) # Normalize score to similarity
113
+ entry_copy['source'] = entry.get('source', 'internal dataset')
114
+ results.append(entry_copy)
115
+
116
+ # Sort by similarity score and return top-k
117
+ results.sort(key=lambda x: x['similarity'], reverse=True)
118
+ return results[:top_k]
119
+
120
+ def _extract_text(self, entry: Dict[str, Any]) -> str:
121
+ """Extract text content from entry, handling RML-specific structure"""
122
+ # First try standard fields
123
+ for field in ['text', 'content', 'body', 'chunk', 'summary', 'title']:
124
+ if field in entry and entry[field]:
125
+ return str(entry[field])
126
+
127
+ # Handle RML-specific structure
128
+ text_parts = []
129
+
130
+ # Extract from summaries (first priority for RML data)
131
+ if 'summaries' in entry and entry['summaries']:
132
+ if isinstance(entry['summaries'], list) and entry['summaries']:
133
+ text_parts.append(entry['summaries'][0])
134
+ elif isinstance(entry['summaries'], str):
135
+ text_parts.append(entry['summaries'])
136
+
137
+ # Extract from concepts
138
+ if 'concepts' in entry and entry['concepts']:
139
+ if isinstance(entry['concepts'], list):
140
+ text_parts.append(" ".join(entry['concepts'][:10])) # First 10 concepts
141
+ elif isinstance(entry['concepts'], str):
142
+ text_parts.append(entry['concepts'])
143
+
144
+ # Extract from tags
145
+ if 'tags' in entry and entry['tags']:
146
+ if isinstance(entry['tags'], list):
147
+ text_parts.append(" ".join(entry['tags'][:10])) # First 10 tags
148
+ elif isinstance(entry['tags'], str):
149
+ text_parts.append(entry['tags'])
150
+
151
+ # Combine all parts
152
+ if text_parts:
153
+ return " ".join(text_parts)
154
+
155
+ # Fallback: convert entire entry to string (excluding large arrays)
156
+ filtered_entry = {}
157
+ for k, v in entry.items():
158
+ if k not in ['vectors', 'embeddings'] and v:
159
+ if isinstance(v, list) and len(v) > 20:
160
+ filtered_entry[k] = v[:5] # Only first 5 items of large lists
161
+ else:
162
+ filtered_entry[k] = v
163
+
164
+ return str(filtered_entry) if filtered_entry else "No content available"
165
+
166
+ def get_stats(self) -> Dict[str, Any]:
167
+ """Get memory store statistics"""
168
+ embedding_dim = 0
169
+ if self.embeddings is not None and len(self.embeddings.shape) > 1:
170
+ embedding_dim = self.embeddings.shape[1]
171
+ elif self.embeddings is not None and len(self.embeddings.shape) == 1:
172
+ embedding_dim = len(self.embeddings)
173
+
174
+ return {
175
+ 'total_entries': len(self.entries),
176
+ 'embedding_dim': embedding_dim,
177
+ 'has_embeddings': self.embeddings is not None
178
+ }