__init__.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING
  15. from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
  16. _import_structure = {
  17. "configuration_utils": [
  18. "BaseWatermarkingConfig",
  19. "CompileConfig",
  20. "ContinuousBatchingConfig",
  21. "GenerationConfig",
  22. "GenerationMode",
  23. "SynthIDTextWatermarkingConfig",
  24. "WatermarkingConfig",
  25. ],
  26. "streamers": ["AsyncTextIteratorStreamer", "BaseStreamer", "TextIteratorStreamer", "TextStreamer"],
  27. }
  28. try:
  29. if not is_torch_available():
  30. raise OptionalDependencyNotAvailable()
  31. except OptionalDependencyNotAvailable:
  32. pass
  33. else:
  34. _import_structure["candidate_generator"] = [
  35. "AssistedCandidateGenerator",
  36. "CandidateGenerator",
  37. "EarlyExitCandidateGenerator",
  38. "PromptLookupCandidateGenerator",
  39. ]
  40. _import_structure["logits_process"] = [
  41. "AlternatingCodebooksLogitsProcessor",
  42. "ClassifierFreeGuidanceLogitsProcessor",
  43. "EncoderNoRepeatNGramLogitsProcessor",
  44. "EncoderRepetitionPenaltyLogitsProcessor",
  45. "EpsilonLogitsWarper",
  46. "EtaLogitsWarper",
  47. "ExponentialDecayLengthPenalty",
  48. "ForcedBOSTokenLogitsProcessor",
  49. "ForcedEOSTokenLogitsProcessor",
  50. "InfNanRemoveLogitsProcessor",
  51. "LogitNormalization",
  52. "LogitsProcessor",
  53. "LogitsProcessorList",
  54. "MinLengthLogitsProcessor",
  55. "MinNewTokensLengthLogitsProcessor",
  56. "MinPLogitsWarper",
  57. "NoBadWordsLogitsProcessor",
  58. "NoRepeatNGramLogitsProcessor",
  59. "PrefixConstrainedLogitsProcessor",
  60. "RepetitionPenaltyLogitsProcessor",
  61. "SequenceBiasLogitsProcessor",
  62. "SuppressTokensLogitsProcessor",
  63. "SuppressTokensAtBeginLogitsProcessor",
  64. "SynthIDTextWatermarkLogitsProcessor",
  65. "TemperatureLogitsWarper",
  66. "TopHLogitsWarper",
  67. "TopKLogitsWarper",
  68. "TopPLogitsWarper",
  69. "TypicalLogitsWarper",
  70. "UnbatchedClassifierFreeGuidanceLogitsProcessor",
  71. "WhisperTimeStampLogitsProcessor",
  72. "WatermarkLogitsProcessor",
  73. ]
  74. _import_structure["stopping_criteria"] = [
  75. "MaxLengthCriteria",
  76. "MaxTimeCriteria",
  77. "ConfidenceCriteria",
  78. "EosTokenCriteria",
  79. "StoppingCriteria",
  80. "StoppingCriteriaList",
  81. "validate_stopping_criteria",
  82. "StopStringCriteria",
  83. ]
  84. _import_structure["continuous_batching"] = [
  85. "ContinuousBatchingManager",
  86. "ContinuousMixin",
  87. "FIFOScheduler",
  88. "PrefillFirstScheduler",
  89. "Scheduler",
  90. ]
  91. _import_structure["utils"] = [
  92. "GenerationMixin",
  93. "GenerateBeamDecoderOnlyOutput",
  94. "GenerateBeamEncoderDecoderOutput",
  95. "GenerateDecoderOnlyOutput",
  96. "GenerateEncoderDecoderOutput",
  97. ]
  98. _import_structure["watermarking"] = [
  99. "WatermarkDetector",
  100. "WatermarkDetectorOutput",
  101. "BayesianDetectorModel",
  102. "BayesianDetectorConfig",
  103. "SynthIDTextWatermarkDetector",
  104. ]
  105. if TYPE_CHECKING:
  106. from .configuration_utils import (
  107. BaseWatermarkingConfig,
  108. CompileConfig,
  109. ContinuousBatchingConfig,
  110. GenerationConfig,
  111. GenerationMode,
  112. SynthIDTextWatermarkingConfig,
  113. WatermarkingConfig,
  114. )
  115. from .streamers import AsyncTextIteratorStreamer, BaseStreamer, TextIteratorStreamer, TextStreamer
  116. try:
  117. if not is_torch_available():
  118. raise OptionalDependencyNotAvailable()
  119. except OptionalDependencyNotAvailable:
  120. pass
  121. else:
  122. from .candidate_generator import (
  123. AssistedCandidateGenerator,
  124. CandidateGenerator,
  125. EarlyExitCandidateGenerator,
  126. PromptLookupCandidateGenerator,
  127. )
  128. from .continuous_batching import (
  129. ContinuousBatchingManager,
  130. ContinuousMixin,
  131. FIFOScheduler,
  132. PrefillFirstScheduler,
  133. Scheduler,
  134. )
  135. from .logits_process import (
  136. AlternatingCodebooksLogitsProcessor,
  137. ClassifierFreeGuidanceLogitsProcessor,
  138. EncoderNoRepeatNGramLogitsProcessor,
  139. EncoderRepetitionPenaltyLogitsProcessor,
  140. EpsilonLogitsWarper,
  141. EtaLogitsWarper,
  142. ExponentialDecayLengthPenalty,
  143. ForcedBOSTokenLogitsProcessor,
  144. ForcedEOSTokenLogitsProcessor,
  145. InfNanRemoveLogitsProcessor,
  146. LogitNormalization,
  147. LogitsProcessor,
  148. LogitsProcessorList,
  149. MinLengthLogitsProcessor,
  150. MinNewTokensLengthLogitsProcessor,
  151. MinPLogitsWarper,
  152. NoBadWordsLogitsProcessor,
  153. NoRepeatNGramLogitsProcessor,
  154. PrefixConstrainedLogitsProcessor,
  155. RepetitionPenaltyLogitsProcessor,
  156. SequenceBiasLogitsProcessor,
  157. SuppressTokensAtBeginLogitsProcessor,
  158. SuppressTokensLogitsProcessor,
  159. SynthIDTextWatermarkLogitsProcessor,
  160. TemperatureLogitsWarper,
  161. TopHLogitsWarper,
  162. TopKLogitsWarper,
  163. TopPLogitsWarper,
  164. TypicalLogitsWarper,
  165. UnbatchedClassifierFreeGuidanceLogitsProcessor,
  166. WatermarkLogitsProcessor,
  167. WhisperTimeStampLogitsProcessor,
  168. )
  169. from .stopping_criteria import (
  170. ConfidenceCriteria,
  171. EosTokenCriteria,
  172. MaxLengthCriteria,
  173. MaxTimeCriteria,
  174. StoppingCriteria,
  175. StoppingCriteriaList,
  176. StopStringCriteria,
  177. validate_stopping_criteria,
  178. )
  179. from .utils import (
  180. GenerateBeamDecoderOnlyOutput,
  181. GenerateBeamEncoderDecoderOutput,
  182. GenerateDecoderOnlyOutput,
  183. GenerateEncoderDecoderOutput,
  184. GenerationMixin,
  185. )
  186. from .watermarking import (
  187. BayesianDetectorConfig,
  188. BayesianDetectorModel,
  189. SynthIDTextWatermarkDetector,
  190. WatermarkDetector,
  191. WatermarkDetectorOutput,
  192. )
  193. else:
  194. import sys
  195. sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)