Skip to content

[IR2Vec] Restructuring Vocabulary #145119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/svkeerthy/06-20-overloading_operator_for_embeddngs
Choose a base branch
from

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented Jun 20, 2025

This PR restructures the vocabulary.

  • String based look-ups are removed. Vocabulary is changed from a map to vector. ([IR2Vec] Storing and Managing Vocabulary #141832)
  • Grouped all the vocabulary related methods under a single class - ir2vec::Vocabulary. This replaces IR2VecVocabResult.
  • ir2vec::Vocabulary effectively abstracts out the layout and other internal details of the vector structure. Exposes necessary APIs for accessing the Vocabulary.

These changes ensure that all known opcodes and types are present in the vocabulary. We have retained the original operands. This can be extended going forward.

(Tracking issue - #141817)

@svkeerthy svkeerthy changed the title Vocab Changes [IR2Vec] Restructuring Vocabulary Jun 20, 2025
@svkeerthy svkeerthy marked this pull request as ready for review June 20, 2025 23:44
@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-mlgo

@llvm/pr-subscribers-llvm-analysis

Author: S. VenkataKeerthy (svkeerthy)

Changes

This PR restructures the vocabulary.

  • String based look-ups are removed. Vocabulary is changed from a map to vector. (#141832)
  • Grouped all the vocabulary related methods under a single class - ir2vec::Vocabulary. This replaces IR2VecVocabResult.
  • ir2vec::Vocabulary effectively abstracts out the layout and other internal details of the vector structure. Exposes necessary APIs for accessing the Vocabulary.

These changes ensure that all known opcodes and types are present in the vocabulary. We have retained the original operands. This can be extended going forward.

(Tracking issue - #141817)


Patch is 201.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145119.diff

20 Files Affected:

  • (modified) llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h (+3-5)
  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+85-40)
  • (modified) llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp (+10-10)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+257-103)
  • (modified) llvm/lib/Analysis/models/seedEmbeddingVocab75D.json (+64-64)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json (+83-3)
  • (added) llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json (+91)
  • (added) llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json (+91)
  • (added) llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_type_vocab.json (+91)
  • (removed) llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_vocab.json (-15)
  • (added) llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt (+92)
  • (added) llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt (+92)
  • (added) llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt (+92)
  • (modified) llvm/test/Analysis/IR2Vec/basic.ll (+67-45)
  • (modified) llvm/test/Analysis/IR2Vec/dbg-inst.ll (+3-3)
  • (modified) llvm/test/Analysis/IR2Vec/if-else.ll (+5-5)
  • (modified) llvm/test/Analysis/IR2Vec/unreachable.ll (+5-5)
  • (modified) llvm/test/Analysis/IR2Vec/vocab-test.ll (+8-15)
  • (modified) llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp (+5-26)
  • (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+147-79)
diff --git a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
index 06dbfc35a5294..f0b570dd792df 100644
--- a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
+++ b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
@@ -34,13 +34,13 @@ class FunctionPropertiesInfo {
   void reIncludeBB(const BasicBlock &BB);
 
   ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0);
-  std::optional<ir2vec::Vocab> IR2VecVocab;
+  const ir2vec::Vocabulary *IR2VecVocab = nullptr;
 
 public:
   LLVM_ABI static FunctionPropertiesInfo
   getFunctionPropertiesInfo(const Function &F, const DominatorTree &DT,
                             const LoopInfo &LI,
-                            const IR2VecVocabResult *VocabResult);
+                            const ir2vec::Vocabulary *Vocabulary);
 
   LLVM_ABI static FunctionPropertiesInfo
   getFunctionPropertiesInfo(Function &F, FunctionAnalysisManager &FAM);
@@ -145,9 +145,7 @@ class FunctionPropertiesInfo {
     return FunctionEmbedding;
   }
 
-  const std::optional<ir2vec::Vocab> &getIR2VecVocab() const {
-    return IR2VecVocab;
-  }
+  const ir2vec::Vocabulary *getIR2VecVocab() const { return IR2VecVocab; }
 
   // Helper intended to be useful for unittests
   void setFunctionEmbeddingForTest(const ir2vec::Embedding &Embedding) {
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index f6c40d36f8026..eb9817a39dfd4 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -31,6 +31,7 @@
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ErrorOr.h"
 #include "llvm/Support/JSON.h"
@@ -42,10 +43,10 @@ class Module;
 class BasicBlock;
 class Instruction;
 class Function;
-class Type;
 class Value;
 class raw_ostream;
 class LLVMContext;
+class IR2VecVocabAnalysis;
 
 /// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
 /// Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -124,9 +125,73 @@ struct Embedding {
 
 using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
 using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
-// FIXME: Current the keys are strings. This can be changed to
-// use integers for cheaper lookups.
-using Vocab = std::map<std::string, Embedding>;
+
+/// Class for storing and accessing the IR2Vec vocabulary.
+/// Encapsulates all vocabulary-related constants, logic, and access methods.
+class Vocabulary {
+  friend class llvm::IR2VecVocabAnalysis;
+  using VocabVector = std::vector<ir2vec::Embedding>;
+  VocabVector Vocab;
+  bool Valid = false;
+
+/// Operand kinds supported by IR2Vec Vocabulary
+#define OPERAND_KINDS                                                          \
+  OPERAND_KIND(FunctionID, "Function")                                         \
+  OPERAND_KIND(PointerID, "Pointer")                                           \
+  OPERAND_KIND(ConstantID, "Constant")                                         \
+  OPERAND_KIND(VariableID, "Variable")
+
+  enum class OperandKind : unsigned {
+#define OPERAND_KIND(Name, Str) Name,
+    OPERAND_KINDS
+#undef OPERAND_KIND
+        MaxOperandKind
+  };
+
+#undef OPERAND_KINDS
+
+  /// Vocabulary layout constants
+#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
+#include "llvm/IR/Instruction.def"
+#undef LAST_OTHER_INST
+
+  static constexpr unsigned MaxTypes = Type::TypeID::TargetExtTyID + 1;
+  static constexpr unsigned MaxOperandKinds =
+      static_cast<unsigned>(OperandKind::MaxOperandKind);
+
+  /// Helper function to get vocabulary key for a given OperandKind
+  static StringRef getVocabKeyForOperandKind(OperandKind Kind);
+
+  /// Helper function to classify an operand into OperandKind
+  static OperandKind getOperandKind(const Value *Op);
+
+  /// Helper function to get vocabulary key for a given TypeID
+  static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
+
+public:
+  Vocabulary() = default;
+  Vocabulary(VocabVector &&Vocab);
+
+  bool isValid() const;
+  unsigned getDimension() const;
+  unsigned size() const;
+
+  const ir2vec::Embedding &at(unsigned Position) const;
+  const ir2vec::Embedding &operator[](unsigned Opcode) const;
+  const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
+  const ir2vec::Embedding &operator[](const Value *Arg) const;
+
+  /// Returns the string key for a given index position in the vocabulary.
+  /// This is useful for debugging or printing the vocabulary. Do not use this
+  /// for embedding generation as string based lookups are inefficient.
+  static StringRef getStringKey(unsigned Pos);
+
+  /// Create a dummy vocabulary for testing purposes.
+  static VocabVector createDummyVocabForTest(unsigned Dim = 1);
+
+  bool invalidate(Module &M, const PreservedAnalyses &PA,
+                  ModuleAnalysisManager::Invalidator &Inv) const;
+};
 
 /// Embedder provides the interface to generate embeddings (vector
 /// representations) for instructions, basic blocks, and functions. The
@@ -137,7 +202,7 @@ using Vocab = std::map<std::string, Embedding>;
 class Embedder {
 protected:
   const Function &F;
-  const Vocab &Vocabulary;
+  const Vocabulary &Vocab;
 
   /// Dimension of the vector representation; captured from the input vocabulary
   const unsigned Dimension;
@@ -152,7 +217,7 @@ class Embedder {
   mutable BBEmbeddingsMap BBVecMap;
   mutable InstEmbeddingsMap InstVecMap;
 
-  Embedder(const Function &F, const Vocab &Vocabulary);
+  Embedder(const Function &F, const Vocabulary &Vocab);
 
   /// Helper function to compute embeddings. It generates embeddings for all
   /// the instructions and basic blocks in the function F. Logic of computing
@@ -163,16 +228,12 @@ class Embedder {
   /// Specific to the kind of embeddings being computed.
   virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
 
-  /// Lookup vocabulary for a given Key. If the key is not found, it returns a
-  /// zero vector.
-  Embedding lookupVocab(const std::string &Key) const;
-
 public:
   virtual ~Embedder() = default;
 
   /// Factory method to create an Embedder object.
   static std::unique_ptr<Embedder> create(IR2VecKind Mode, const Function &F,
-                                          const Vocab &Vocabulary);
+                                          const Vocabulary &Vocab);
 
   /// Returns a map containing instructions and the corresponding embeddings for
   /// the function F if it has been computed. If not, it computes the embeddings
@@ -198,56 +259,40 @@ class Embedder {
 /// representations obtained from the Vocabulary.
 class SymbolicEmbedder : public Embedder {
 private:
-  /// Utility function to compute the embedding for a given type.
-  Embedding getTypeEmbedding(const Type *Ty) const;
-
-  /// Utility function to compute the embedding for a given operand.
-  Embedding getOperandEmbedding(const Value *Op) const;
-
   void computeEmbeddings() const override;
   void computeEmbeddings(const BasicBlock &BB) const override;
 
 public:
-  SymbolicEmbedder(const Function &F, const Vocab &Vocabulary)
-      : Embedder(F, Vocabulary) {
+  SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
+      : Embedder(F, Vocab) {
     FuncVector = Embedding(Dimension, 0);
   }
 };
 
 } // namespace ir2vec
 
-/// Class for storing the result of the IR2VecVocabAnalysis.
-class IR2VecVocabResult {
-  ir2vec::Vocab Vocabulary;
-  bool Valid = false;
-
-public:
-  IR2VecVocabResult() = default;
-  IR2VecVocabResult(ir2vec::Vocab &&Vocabulary);
-
-  bool isValid() const { return Valid; }
-  const ir2vec::Vocab &getVocabulary() const;
-  unsigned getDimension() const;
-  bool invalidate(Module &M, const PreservedAnalyses &PA,
-                  ModuleAnalysisManager::Invalidator &Inv) const;
-};
-
 /// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
 /// mapping between an entity of the IR (like opcode, type, argument, etc.) and
 /// its corresponding embedding.
 class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
-  ir2vec::Vocab Vocabulary;
+  using VocabVector = std::vector<ir2vec::Embedding>;
+  using VocabMap = std::map<std::string, ir2vec::Embedding>;
+  VocabMap OpcVocab, TypeVocab, ArgVocab;
+  VocabVector Vocab;
+
+  unsigned Dim = 0;
   Error readVocabulary();
   Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
-                          ir2vec::Vocab &TargetVocab, unsigned &Dim);
+                          VocabMap &TargetVocab, unsigned &Dim);
+  void generateNumMappedVocab();
   void emitError(Error Err, LLVMContext &Ctx);
 
 public:
   static AnalysisKey Key;
   IR2VecVocabAnalysis() = default;
-  explicit IR2VecVocabAnalysis(const ir2vec::Vocab &Vocab);
-  explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
-  using Result = IR2VecVocabResult;
+  explicit IR2VecVocabAnalysis(const VocabVector &Vocab);
+  explicit IR2VecVocabAnalysis(VocabVector &&Vocab);
+  using Result = ir2vec::Vocabulary;
   Result run(Module &M, ModuleAnalysisManager &MAM);
 };
 
diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
index dd4eb7f0df053..c52a6de2bb71e 100644
--- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
+++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
@@ -242,20 +242,20 @@ FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
   // We use the cached result of the IR2VecVocabAnalysis run by
   // InlineAdvisorAnalysis. If the IR2VecVocabAnalysis is not run, we don't
   // use IR2Vec embeddings.
-  auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
-                         .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
+  auto Vocabulary = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+                        .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
   return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F),
-                                   FAM.getResult<LoopAnalysis>(F), VocabResult);
+                                   FAM.getResult<LoopAnalysis>(F), Vocabulary);
 }
 
 FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
     const Function &F, const DominatorTree &DT, const LoopInfo &LI,
-    const IR2VecVocabResult *VocabResult) {
+    const ir2vec::Vocabulary *Vocabulary) {
 
   FunctionPropertiesInfo FPI;
-  if (VocabResult && VocabResult->isValid()) {
-    FPI.IR2VecVocab = VocabResult->getVocabulary();
-    FPI.FunctionEmbedding = ir2vec::Embedding(VocabResult->getDimension(), 0.0);
+  if (Vocabulary && Vocabulary->isValid()) {
+    FPI.IR2VecVocab = Vocabulary;
+    FPI.FunctionEmbedding = ir2vec::Embedding(Vocabulary->getDimension(), 0.0);
   }
   for (const auto &BB : F)
     if (DT.isReachableFromEntry(&BB))
@@ -588,9 +588,9 @@ bool FunctionPropertiesUpdater::isUpdateValid(Function &F,
     return false;
   DominatorTree DT(F);
   LoopInfo LI(DT);
-  auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
-                         .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
+  auto Vocabulary = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+                        .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
   auto Fresh =
-      FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, VocabResult);
+      FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, Vocabulary);
   return FPI == Fresh;
 }
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index d5d27db8bd2bf..2423179960007 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -14,6 +14,7 @@
 #include "llvm/Analysis/IR2Vec.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Module.h"
@@ -56,6 +57,9 @@ cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
 
 AnalysisKey IR2VecVocabAnalysis::Key;
 
+// ==----------------------------------------------------------------------===//
+// Local helper functions
+//===----------------------------------------------------------------------===//
 namespace llvm::json {
 inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
                      llvm::json::Path P) {
@@ -126,35 +130,21 @@ void Embedding::print(raw_ostream &OS) const {
 // Embedder and its subclasses
 //===----------------------------------------------------------------------===//
 
-Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
-    : F(F), Vocabulary(Vocabulary),
-      Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
-      TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
+Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
+    : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
+      OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
+}
 
 std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
-                                           const Vocab &Vocabulary) {
+                                           const Vocabulary &Vocab) {
   switch (Mode) {
   case IR2VecKind::Symbolic:
-    return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
+    return std::make_unique<SymbolicEmbedder>(F, Vocab);
   }
   llvm_unreachable("Unknown IR2Vec kind");
   return nullptr;
 }
 
-// FIXME: Currently lookups are string based. Use numeric Keys
-// for efficiency
-Embedding Embedder::lookupVocab(const std::string &Key) const {
-  Embedding Vec(Dimension, 0);
-  // FIXME: Use zero vectors in vocab and assert failure for
-  // unknown entities rather than silently returning zeroes here.
-  auto It = Vocabulary.find(Key);
-  if (It != Vocabulary.end())
-    return It->second;
-  LLVM_DEBUG(errs() << "cannot find key in map : " << Key << "\n");
-  ++VocabMissCounter;
-  return Vec;
-}
-
 const InstEmbeddingsMap &Embedder::getInstVecMap() const {
   if (InstVecMap.empty())
     computeEmbeddings();
@@ -182,49 +172,16 @@ const Embedding &Embedder::getFunctionVector() const {
   return FuncVector;
 }
 
-#define RETURN_LOOKUP_IF(CONDITION, KEY_STR)                                   \
-  if (CONDITION)                                                               \
-    return lookupVocab(KEY_STR);
-
-Embedding SymbolicEmbedder::getTypeEmbedding(const Type *Ty) const {
-  RETURN_LOOKUP_IF(Ty->isVoidTy(), "voidTy");
-  RETURN_LOOKUP_IF(Ty->isFloatingPointTy(), "floatTy");
-  RETURN_LOOKUP_IF(Ty->isIntegerTy(), "integerTy");
-  RETURN_LOOKUP_IF(Ty->isFunctionTy(), "functionTy");
-  RETURN_LOOKUP_IF(Ty->isStructTy(), "structTy");
-  RETURN_LOOKUP_IF(Ty->isArrayTy(), "arrayTy");
-  RETURN_LOOKUP_IF(Ty->isPointerTy(), "pointerTy");
-  RETURN_LOOKUP_IF(Ty->isVectorTy(), "vectorTy");
-  RETURN_LOOKUP_IF(Ty->isEmptyTy(), "emptyTy");
-  RETURN_LOOKUP_IF(Ty->isLabelTy(), "labelTy");
-  RETURN_LOOKUP_IF(Ty->isTokenTy(), "tokenTy");
-  RETURN_LOOKUP_IF(Ty->isMetadataTy(), "metadataTy");
-  return lookupVocab("unknownTy");
-}
-
-Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
-  RETURN_LOOKUP_IF(isa<Function>(Op), "function");
-  RETURN_LOOKUP_IF(isa<PointerType>(Op->getType()), "pointer");
-  RETURN_LOOKUP_IF(isa<Constant>(Op), "constant");
-  return lookupVocab("variable");
-}
-
-#undef RETURN_LOOKUP_IF
-
 void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
   Embedding BBVector(Dimension, 0);
 
   // We consider only the non-debug and non-pseudo instructions
   for (const auto &I : BB.instructionsWithoutDebug()) {
-    Embedding InstVector(Dimension, 0);
-
-    // FIXME: Currently lookups are string based. Use numeric Keys
-    // for efficiency.
-    InstVector += lookupVocab(I.getOpcodeName());
-    InstVector += getTypeEmbedding(I.getType());
-    for (const auto &Op : I.operands()) {
-      InstVector += getOperandEmbedding(Op.get());
-    }
+    Embedding ArgEmb(Dimension, 0);
+    for (const auto &Op : I.operands())
+      ArgEmb += Vocab[Op];
+    auto InstVector =
+        Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
     InstVecMap[&I] = InstVector;
     BBVector += InstVector;
   }
@@ -243,33 +200,165 @@ void SymbolicEmbedder::computeEmbeddings() const {
 }
 
 // ==----------------------------------------------------------------------===//
-// IR2VecVocabResult and IR2VecVocabAnalysis
+// Vocabulary
 //===----------------------------------------------------------------------===//
 
-IR2VecVocabResult::IR2VecVocabResult(ir2vec::Vocab &&Vocabulary)
-    : Vocabulary(std::move(Vocabulary)), Valid(true) {}
+Vocabulary::Vocabulary(VocabVector &&Vocab)
+    : Vocab(std::move(Vocab)), Valid(true) {}
 
-const ir2vec::Vocab &IR2VecVocabResult::getVocabulary() const {
+bool Vocabulary::isValid() const {
+  return Vocab.size() == MaxOpcodes + MaxTypes + MaxOperandKinds && Valid;
+}
+
+unsigned Vocabulary::size() const {
   assert(Valid && "IR2Vec Vocabulary is invalid");
-  return Vocabulary;
+  return Vocab.size();
 }
 
-unsigned IR2VecVocabResult::getDimension() const {
+unsigned Vocabulary::getDimension() const {
   assert(Valid && "IR2Vec Vocabulary is invalid");
-  return Vocabulary.begin()->second.size();
+  return Vocab[0].size();
+}
+
+const Embedding &Vocabulary::at(unsigned Position) const {
+  assert(Position < Vocab.size() && "Position out of bounds in vocabulary");
+  return Vocab[Position];
+}
+
+const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+  assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+  return Vocab[Opcode - 1];
+}
+
+const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {
+  assert(static_cast<unsigned>(TypeId) < MaxTypes && "Invalid type ID");
+  return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];
+}
+
+const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
+  OperandKind ArgKind = getOperandKind(Arg);
+  return Vocab[MaxOpcodes + MaxTypes + static_cast<unsigned>(ArgKind)];
+}
+
+StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
+  switch (TypeID) {
+  case Type::VoidTyID:
+    return "VoidTy";
+  case Type::HalfTyID:
+  case Type::BFloatTyID:
+  case Type::FloatTyID:
+  case Type::DoubleTyID:
+  case Type::X86_FP80TyID:
+  case Type::FP128TyID:
+  case Type::PPC_FP128TyID:
+    return "FloatTy";
+  case Type::IntegerTyID:
+    return "IntegerTy";
+  case Type::FunctionTyID:
+    return "FunctionTy";
+  case Type::StructTyID:
+    return "StructTy";
+  case Type::ArrayTyID:
+    return "ArrayTy";
+  case Type::PointerTyID:
+  case Type::TypedPointerTyID:
+    return "PointerTy";
+  case Type::FixedVectorTyID:
+  case Type::ScalableVectorTyID:
+    return "VectorTy";
+  case Type::LabelTyID:
+    return "LabelTy";
+  case Type::TokenTyID:
+    return "TokenTy";
+  case Type::MetadataTyID:
+    return "MetadataTy";
+  case Type::X86_AMXTyID:
+  case Type::TargetExtTyID:
+  default:
+    return "UnknownTy";
+  }
+}
+
+// Operand kinds supported by IR2Vec - string mappings
+#define OPERAND_KINDS                                                          \
+  OPERAND_KIND(FunctionID, "Function")                                         \
+  OPERAND_KIND(PointerID, "Pointer")                                           \
+  OPERAND_KIND(ConstantID, "Constant")                                         \
+  OPERAND_KIND(VariableID, "Variable")
+
+StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
+  switch (Kind) {
+#define OPERAND_KIND(Name, Str)                                                \
+  case Vocabulary::OperandKind::Name:                                          \
+    return Str;
+    OPERAND_KINDS
+#undef OPERAND_KIND
+  case Vocabulary::OperandKind::MaxOperandKind:
+    llvm_unreachable("Invalid OperandKind");
+  }
+  llvm_unreachable("Unknown OperandKind");
+}
+
+#undef OPERAND_KINDS
+
+Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
+  VocabVector DummyVocab;
+  float DummyVal = 0.1f;
+  // Create a dummy vocabulary with entries for all opcodes, types, and
+  // operand
+  for (unsigned _ : seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypes +
+                                Vocabulary::MaxOperandKin...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy