Update src/display/utils.py
Browse files- src/display/utils.py +44 -19
src/display/utils.py
CHANGED
|
@@ -61,11 +61,33 @@ class ModelDetails:
|
|
| 61 |
symbol: str = "" # emoji
|
| 62 |
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
class ModelType(Enum):
|
| 65 |
PT = ModelDetails(name="pretrained", symbol="π’")
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
| 69 |
Unknown = ModelDetails(name="", symbol="?")
|
| 70 |
|
| 71 |
def to_str(self, separator=" "):
|
|
@@ -75,41 +97,44 @@ class ModelType(Enum):
|
|
| 75 |
def from_str(type):
|
| 76 |
if "fine-tuned" in type or "πΆ" in type:
|
| 77 |
return ModelType.FT
|
|
|
|
|
|
|
| 78 |
if "pretrained" in type or "π’" in type:
|
| 79 |
return ModelType.PT
|
| 80 |
-
if "RL-tuned"
|
| 81 |
-
return ModelType.
|
| 82 |
-
if "
|
| 83 |
-
return ModelType.
|
| 84 |
return ModelType.Unknown
|
| 85 |
|
|
|
|
| 86 |
class WeightType(Enum):
|
| 87 |
Adapter = ModelDetails("Adapter")
|
| 88 |
Original = ModelDetails("Original")
|
| 89 |
Delta = ModelDetails("Delta")
|
| 90 |
|
| 91 |
class Precision(Enum):
|
|
|
|
| 92 |
float16 = ModelDetails("float16")
|
| 93 |
bfloat16 = ModelDetails("bfloat16")
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
#qt_GPTQ = ModelDetails("GPTQ")
|
| 98 |
Unknown = ModelDetails("?")
|
| 99 |
|
| 100 |
def from_str(precision):
|
|
|
|
|
|
|
| 101 |
if precision in ["torch.float16", "float16"]:
|
| 102 |
return Precision.float16
|
| 103 |
if precision in ["torch.bfloat16", "bfloat16"]:
|
| 104 |
return Precision.bfloat16
|
| 105 |
-
if precision in ["
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
#if precision in ["GPTQ", "None"]:
|
| 112 |
-
# return Precision.qt_GPTQ
|
| 113 |
return Precision.Unknown
|
| 114 |
|
| 115 |
# Column selection
|
|
|
|
| 61 |
symbol: str = "" # emoji
|
| 62 |
|
| 63 |
|
| 64 |
+
# class ModelType(Enum):
|
| 65 |
+
# PT = ModelDetails(name="pretrained", symbol="π’")
|
| 66 |
+
# FT = ModelDetails(name="fine-tuned", symbol="πΆ")
|
| 67 |
+
# IFT = ModelDetails(name="instruction-tuned", symbol="β")
|
| 68 |
+
# RL = ModelDetails(name="RL-tuned", symbol="π¦")
|
| 69 |
+
# Unknown = ModelDetails(name="", symbol="?")
|
| 70 |
+
|
| 71 |
+
# def to_str(self, separator=" "):
|
| 72 |
+
# return f"{self.value.symbol}{separator}{self.value.name}"
|
| 73 |
+
|
| 74 |
+
# @staticmethod
|
| 75 |
+
# def from_str(type):
|
| 76 |
+
# if "fine-tuned" in type or "πΆ" in type:
|
| 77 |
+
# return ModelType.FT
|
| 78 |
+
# if "pretrained" in type or "π’" in type:
|
| 79 |
+
# return ModelType.PT
|
| 80 |
+
# if "RL-tuned" in type or "π¦" in type:
|
| 81 |
+
# return ModelType.RL
|
| 82 |
+
# if "instruction-tuned" in type or "β" in type:
|
| 83 |
+
# return ModelType.IFT
|
| 84 |
+
# return ModelType.Unknown
|
| 85 |
class ModelType(Enum):
|
| 86 |
PT = ModelDetails(name="pretrained", symbol="π’")
|
| 87 |
+
CPT = ModelDetails(name="continuously pretrained", symbol="π©")
|
| 88 |
+
FT = ModelDetails(name="fine-tuned on domain-specific datasets", symbol="πΆ")
|
| 89 |
+
chat = ModelDetails(name="chat models (RLHF, DPO, IFT, ...)", symbol="π¬")
|
| 90 |
+
merges = ModelDetails(name="base merges and moerges", symbol="π€")
|
| 91 |
Unknown = ModelDetails(name="", symbol="?")
|
| 92 |
|
| 93 |
def to_str(self, separator=" "):
|
|
|
|
| 97 |
def from_str(type):
|
| 98 |
if "fine-tuned" in type or "πΆ" in type:
|
| 99 |
return ModelType.FT
|
| 100 |
+
if "continously pretrained" in type or "π©" in type:
|
| 101 |
+
return ModelType.CPT
|
| 102 |
if "pretrained" in type or "π’" in type:
|
| 103 |
return ModelType.PT
|
| 104 |
+
if any([k in type for k in ["instruction-tuned", "RL-tuned", "chat", "π¦", "β", "π¬"]]):
|
| 105 |
+
return ModelType.chat
|
| 106 |
+
if "merge" in type or "π€" in type:
|
| 107 |
+
return ModelType.merges
|
| 108 |
return ModelType.Unknown
|
| 109 |
|
| 110 |
+
|
| 111 |
class WeightType(Enum):
|
| 112 |
Adapter = ModelDetails("Adapter")
|
| 113 |
Original = ModelDetails("Original")
|
| 114 |
Delta = ModelDetails("Delta")
|
| 115 |
|
| 116 |
class Precision(Enum):
|
| 117 |
+
float32 = ModelDetails("float32")
|
| 118 |
float16 = ModelDetails("float16")
|
| 119 |
bfloat16 = ModelDetails("bfloat16")
|
| 120 |
+
qt_8bit = ModelDetails("8bit")
|
| 121 |
+
qt_4bit = ModelDetails("4bit")
|
| 122 |
+
qt_GPTQ = ModelDetails("GPTQ")
|
|
|
|
| 123 |
Unknown = ModelDetails("?")
|
| 124 |
|
| 125 |
def from_str(precision):
|
| 126 |
+
if precision in ["float32"]:
|
| 127 |
+
return Precision.float32
|
| 128 |
if precision in ["torch.float16", "float16"]:
|
| 129 |
return Precision.float16
|
| 130 |
if precision in ["torch.bfloat16", "bfloat16"]:
|
| 131 |
return Precision.bfloat16
|
| 132 |
+
if precision in ["8bit"]:
|
| 133 |
+
return Precision.qt_8bit
|
| 134 |
+
if precision in ["4bit"]:
|
| 135 |
+
return Precision.qt_4bit
|
| 136 |
+
if precision in ["GPTQ", "None"]:
|
| 137 |
+
return Precision.qt_GPTQ
|
|
|
|
|
|
|
| 138 |
return Precision.Unknown
|
| 139 |
|
| 140 |
# Column selection
|