Emil25 commited on
Commit
fd429a6
1 Parent(s): 4802bdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -28
app.py CHANGED
@@ -12,17 +12,19 @@ import xgboost as xgb
12
 
13
  # Setting up the page configuration for Streamlit App
14
  st.set_page_config(
15
- page_title="Taxi",
16
  # layout="wide",
17
  initial_sidebar_state="expanded"
18
  )
19
 
 
20
  # Load the XGBoost model
21
  #@st.cache_data()
22
  def get_model():
23
  model = pickle.load(open("models/model_xgb.pkl", "rb"))
24
  return model
25
 
 
26
  # Function to make prediction using the model and input data
27
  def make_prediction(data):
28
  model = get_model()
@@ -38,21 +40,17 @@ def make_prediction(data):
38
  return model.predict(data_matrix)
39
 
40
 
 
41
  def get_coordinates(address):
42
- # Создание экземпляра геокодера
43
  geolocator = Nominatim(user_agent="my_app")
44
-
45
- # Получение координат по адресу
46
  location = geolocator.geocode(address)
47
-
48
- # Вывод широты и долготы
49
  return (location.longitude, location.latitude)
50
 
51
 
52
  def show_map(lon_from, lat_from, lon_to, lat_to):
53
  # Creating a map
54
  fig = go.Figure(go.Scattermapbox(
55
- mode = "markers",
56
  marker = {'size': 15, 'color': 'red'}
57
  ))
58
 
@@ -63,11 +61,11 @@ def show_map(lon_from, lat_from, lon_to, lat_to):
63
  lat = [lat_from, lat_to],
64
  marker = go.scattermapbox.Marker(
65
  size=25,
66
- color='red'
67
  )
68
  ))
69
 
70
- # Добавление линии между точками
71
  fig.add_trace(go.Scattermapbox(
72
  mode = "lines",
73
  lon = [lon_from, lon_to],
@@ -75,19 +73,19 @@ def show_map(lon_from, lat_from, lon_to, lat_to):
75
  line = dict(width=2, color='green')
76
  ))
77
 
78
- # Настройка отображения карты
79
  fig.update_layout(
80
  mapbox = {
81
- 'style': "open-street-map", # Стиль карты
82
- 'center': {'lon': (lon_from + lon_to) / 2, 'lat': (lat_from + lat_to) / 2}, # Центр карты
83
- 'zoom': 9, # Уровень масштабирования карты
84
  },
85
  showlegend = False,
86
- height = 600, # Изменение высоты карты
87
- width = 1200 # Изменение ширины карты
88
  )
89
-
90
- # Отображение карты
91
  return fig
92
 
93
 
@@ -123,18 +121,18 @@ def get_haversine_distance(lat1, lng1, lat2, lng2):
123
 
124
 
125
  # User input features
126
- def user_input_features(lon_from, lat_from, lon_to, lat_to, passenger_count):
127
- current_time = datetime.now()
128
  pickup_hour= current_time.hour
129
  today = datetime.today()
130
  pickup_holiday = 1 if today in holidays.USA() else 0
131
  total_distance, total_travel_time, number_of_steps = get_total_distance(lon_from, lat_from, lon_to, lat_to)
132
- haversine_distance = get_haversine_distance(lat_from, lon_from, lat_to, lon_to)
133
  weekday_number = current_time.weekday()
134
 
135
  data = {'vendor_id': 1,
136
  'passenger_count': passenger_count,
137
- 'pickup_longitude': lon_from,
138
  'pickup_latitude': lat_from,
139
  'dropoff_longitude': lon_to,
140
  'dropoff_latitude': lat_to,
@@ -168,19 +166,19 @@ def min_max_scaler(data):
168
  data_scaled = scaler.transform(data)
169
  return data_scaled
170
 
171
- # Main function
172
- def main():
173
 
 
 
174
  if 'btn_predict' not in st.session_state:
175
- st.session_state['btn_predict'] = False
176
-
177
  # Sidebar
178
  st.sidebar.markdown(''' # New York City Taxi Trip Duration''')
179
  st.sidebar.image("img/taxi_img.png")
180
  address_from = st.sidebar.text_input("Откуда:", value="New York, 11 Wall Street")
181
- address_to = st.sidebar.text_input("Куда:", value="New York, 740 Park Avenue")
182
  passenger_count = st.sidebar.slider("Количество пассажиров", 1, 4, 1)
183
-
184
  st.session_state['btn_predict'] = st.sidebar.button('Start')
185
 
186
  if st.session_state['btn_predict']:
@@ -190,7 +188,7 @@ def main():
190
  user_data = user_input_features(lon_from, lat_from, lon_to, lat_to, passenger_count)
191
  # st.write(user_data)
192
  data_scaled = min_max_scaler(user_data)
193
- trip_duration = np.exp(make_prediction(data_scaled)) - 1
194
  trip_duration = round(float(trip_duration) / 60)
195
  st.markdown(f"""
196
  <div style='background-color: lightgreen; padding: 10px;'>
@@ -198,6 +196,7 @@ def main():
198
  </div>
199
  """, unsafe_allow_html=True)
200
 
 
201
  # Running the main function
202
  if __name__ == "__main__":
203
  main()
 
12
 
13
  # Setting up the page configuration for Streamlit App
14
  st.set_page_config(
15
+ page_title="Taxi",
16
  # layout="wide",
17
  initial_sidebar_state="expanded"
18
  )
19
 
20
+
21
  # Load the XGBoost model
22
  #@st.cache_data()
23
  def get_model():
24
  model = pickle.load(open("models/model_xgb.pkl", "rb"))
25
  return model
26
 
27
+
28
  # Function to make prediction using the model and input data
29
  def make_prediction(data):
30
  model = get_model()
 
40
  return model.predict(data_matrix)
41
 
42
 
43
+ # Get coordinates from address
44
  def get_coordinates(address):
 
45
  geolocator = Nominatim(user_agent="my_app")
 
 
46
  location = geolocator.geocode(address)
 
 
47
  return (location.longitude, location.latitude)
48
 
49
 
50
  def show_map(lon_from, lat_from, lon_to, lat_to):
51
  # Creating a map
52
  fig = go.Figure(go.Scattermapbox(
53
+ mode = "markers",
54
  marker = {'size': 15, 'color': 'red'}
55
  ))
56
 
 
61
  lat = [lat_from, lat_to],
62
  marker = go.scattermapbox.Marker(
63
  size=25,
64
+ color='red'
65
  )
66
  ))
67
 
68
+ # Adding a line
69
  fig.add_trace(go.Scattermapbox(
70
  mode = "lines",
71
  lon = [lon_from, lon_to],
 
73
  line = dict(width=2, color='green')
74
  ))
75
 
76
+ # Configuring the display of a map
77
  fig.update_layout(
78
  mapbox = {
79
+ 'style': "open-street-map",
80
+ 'center': {'lon': (lon_from + lon_to) / 2, 'lat': (lat_from + lat_to) / 2},
81
+ 'zoom': 9,
82
  },
83
  showlegend = False,
84
+ height = 600,
85
+ width = 1200
86
  )
87
+
88
+ # Display the map
89
  return fig
90
 
91
 
 
121
 
122
 
123
  # User input features
124
+ def user_input_features(lon_from, lat_from, lon_to, lat_to, passenger_count):
125
+ current_time = datetime.now()
126
  pickup_hour= current_time.hour
127
  today = datetime.today()
128
  pickup_holiday = 1 if today in holidays.USA() else 0
129
  total_distance, total_travel_time, number_of_steps = get_total_distance(lon_from, lat_from, lon_to, lat_to)
130
+ haversine_distance = get_haversine_distance(lat_from, lon_from, lat_to, lon_to)
131
  weekday_number = current_time.weekday()
132
 
133
  data = {'vendor_id': 1,
134
  'passenger_count': passenger_count,
135
+ 'pickup_longitude': lon_from,
136
  'pickup_latitude': lat_from,
137
  'dropoff_longitude': lon_to,
138
  'dropoff_latitude': lat_to,
 
166
  data_scaled = scaler.transform(data)
167
  return data_scaled
168
 
 
 
169
 
170
+ # Main function
171
+ def main():
172
  if 'btn_predict' not in st.session_state:
173
+ st.session_state['btn_predict'] = False
174
+
175
  # Sidebar
176
  st.sidebar.markdown(''' # New York City Taxi Trip Duration''')
177
  st.sidebar.image("img/taxi_img.png")
178
  address_from = st.sidebar.text_input("Откуда:", value="New York, 11 Wall Street")
179
+ address_to = st.sidebar.text_input("Куда:", value="New York, 740 Park Avenue")
180
  passenger_count = st.sidebar.slider("Количество пассажиров", 1, 4, 1)
181
+
182
  st.session_state['btn_predict'] = st.sidebar.button('Start')
183
 
184
  if st.session_state['btn_predict']:
 
188
  user_data = user_input_features(lon_from, lat_from, lon_to, lat_to, passenger_count)
189
  # st.write(user_data)
190
  data_scaled = min_max_scaler(user_data)
191
+ trip_duration = np.exp(make_prediction(data_scaled)) - 1
192
  trip_duration = round(float(trip_duration) / 60)
193
  st.markdown(f"""
194
  <div style='background-color: lightgreen; padding: 10px;'>
 
196
  </div>
197
  """, unsafe_allow_html=True)
198
 
199
+
200
  # Running the main function
201
  if __name__ == "__main__":
202
  main()