46
46
)
47
47
48
48
_GEMINI_PRO_ENDPOINT = "gemini-pro"
49
+ _GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514"
50
+ _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514"
51
+ _GEMINI_ENDPOINTS = (
52
+ _GEMINI_PRO_ENDPOINT ,
53
+ _GEMINI_1P5_PRO_PREVIEW_ENDPOINT ,
54
+ _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT ,
55
+ )
56
+
49
57
50
58
_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
51
59
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
@@ -547,13 +555,16 @@ def to_gbq(
547
555
class GeminiTextGenerator (base .BaseEstimator ):
548
556
"""Gemini text generator LLM model.
549
557
550
- .. note::
551
- This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
552
- Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
553
- and might have limited support. For more information, see the launch stage descriptions
554
- (https://cloud.google.com/products#product-launch-stages).
555
-
556
558
Args:
559
+ model_name (str, Default to "gemini-pro"):
560
+ The model for natural language tasks. Accepted values are "gemini-pro", "gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514". Default to "gemini-pro".
561
+
562
+ .. note::
563
+ "gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514" is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
564
+ Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
565
+ and might have limited support. For more information, see the launch stage descriptions
566
+ (https://cloud.google.com/products#product-launch-stages).
567
+
557
568
session (bigframes.Session or None):
558
569
BQ session to create the model. If None, use the global default session.
559
570
connection_name (str or None):
@@ -565,9 +576,13 @@ class GeminiTextGenerator(base.BaseEstimator):
565
576
def __init__ (
566
577
self ,
567
578
* ,
579
+ model_name : Literal [
580
+ "gemini-pro" , "gemini-1.5-pro-preview-0514" , "gemini-1.5-flash-preview-0514"
581
+ ] = "gemini-pro" ,
568
582
session : Optional [bigframes .Session ] = None ,
569
583
connection_name : Optional [str ] = None ,
570
584
):
585
+ self .model_name = model_name
571
586
self .session = session or bpd .get_global_session ()
572
587
self ._bq_connection_manager = self .session .bqconnectionmanager
573
588
@@ -601,7 +616,12 @@ def _create_bqml_model(self):
601
616
iam_role = "aiplatform.user" ,
602
617
)
603
618
604
- options = {"endpoint" : _GEMINI_PRO_ENDPOINT }
619
+ if self .model_name not in _GEMINI_ENDPOINTS :
620
+ raise ValueError (
621
+ f"Model name { self .model_name } is not supported. We only support { ', ' .join (_GEMINI_ENDPOINTS )} ."
622
+ )
623
+
624
+ options = {"endpoint" : self .model_name }
605
625
606
626
return self ._bqml_model_factory .create_remote_model (
607
627
session = self .session , connection_name = self .connection_name , options = options
@@ -613,12 +633,17 @@ def _from_bq(
613
633
) -> GeminiTextGenerator :
614
634
assert bq_model .model_type == "MODEL_TYPE_UNSPECIFIED"
615
635
assert "remoteModelInfo" in bq_model ._properties
636
+ assert "endpoint" in bq_model ._properties ["remoteModelInfo" ]
616
637
assert "connection" in bq_model ._properties ["remoteModelInfo" ]
617
638
618
639
# Parse the remote model endpoint
640
+ bqml_endpoint = bq_model ._properties ["remoteModelInfo" ]["endpoint" ]
619
641
model_connection = bq_model ._properties ["remoteModelInfo" ]["connection" ]
642
+ model_endpoint = bqml_endpoint .split ("/" )[- 1 ]
620
643
621
- model = cls (session = session , connection_name = model_connection )
644
+ model = cls (
645
+ model_name = model_endpoint , session = session , connection_name = model_connection
646
+ )
622
647
model ._bqml_model = core .BqmlModel (session , bq_model )
623
648
return model
624
649
0 commit comments